From 276648a7abc7214780f1340a8a647305aeec9eb0 Mon Sep 17 00:00:00 2001 From: ened Date: Tue, 26 Aug 2025 02:40:23 +0900 Subject: [PATCH] remove unused source files --- README.md | 48 ++++- claude_desktop_config.json | 7 +- main.py | 27 ++- run.bat | 56 ----- run_forced_utf8.bat | 28 --- run_mcp_safe.bat | 22 ++ run_mcp_safe.py | 66 ++++++ run_utf8.bat | 16 -- run_utf8_forced.py | 24 --- src/async_task/__init__.py | 9 - src/async_task/models.py | 48 ----- src/async_task/task_manager.py | 324 ---------------------------- src/async_task/worker_pool.py | 258 ---------------------- src/server/__init__.py | 26 --- src/server/enhanced_handlers.py | 344 ------------------------------ src/server/enhanced_mcp_server.py | 63 ------ src/server/enhanced_models.py | 71 ------ src/server/handlers.py | 308 -------------------------- src/server/mcp_server.py | 57 ----- src/server/minimal_handler.py | 39 ---- src/server/models.py | 132 ------------ src/server/safe_result_handler.py | 80 ------- src/utils/image_utils.py | 107 ---------- src/utils/token_utils.py | 200 ----------------- test_mcp_compatibility.py | 40 ++++ tests/test_connector.py | 178 ---------------- tests/test_server.py | 136 ------------ 27 files changed, 190 insertions(+), 2524 deletions(-) delete mode 100644 run.bat delete mode 100644 run_forced_utf8.bat create mode 100644 run_mcp_safe.bat create mode 100644 run_mcp_safe.py delete mode 100644 run_utf8.bat delete mode 100644 run_utf8_forced.py delete mode 100644 src/async_task/__init__.py delete mode 100644 src/async_task/models.py delete mode 100644 src/async_task/task_manager.py delete mode 100644 src/async_task/worker_pool.py delete mode 100644 src/server/__init__.py delete mode 100644 src/server/enhanced_handlers.py delete mode 100644 src/server/enhanced_mcp_server.py delete mode 100644 src/server/enhanced_models.py delete mode 100644 src/server/handlers.py delete mode 100644 src/server/mcp_server.py delete mode 100644 src/server/minimal_handler.py delete mode 100644 src/server/models.py delete mode 100644 src/server/safe_result_handler.py delete mode 100644 src/utils/image_utils.py delete mode 100644 src/utils/token_utils.py create mode 100644 test_mcp_compatibility.py delete mode 100644 tests/test_connector.py delete mode 100644 tests/test_server.py diff --git a/README.md b/README.md index dc145b9..5726505 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,30 @@ Google Imagen 4를 사용한 AI 이미지 생성 MCP 서버입니다. 하나의 `main.py` 파일로 모든 기능이 통합되어 있으며, **한글 프롬프트를 완벽 지원**합니다. -## 🚀 주요 기능 +## 🚨 주의사항 - MCP 프로토콜 호환성 + +> **중요**: 이 서버는 MCP (Model Context Protocol)를 사용합니다. MCP는 stdout을 통해 JSON 메시지를 주고받으므로, 서버 코드에서 **절대로 stdout에 직접 출력하면 안됩니다**. + +### ❌ 금지사항 +```python +print("message") # ❌ MCP JSON 프로토콜 방해 +print("message", file=sys.stdout) # ❌ MCP JSON 프로토콜 방해 +sys.stdout.write("message") # ❌ MCP JSON 프로토콜 방해 +``` + +### ✅ 권장사항 +```python +print("message", file=sys.stderr) # ✅ 로그용 출력 +logging.info("message") # ✅ 로깅 시스템 사용 +logger.error("message") # ✅ 로깅 시스템 사용 +``` + +### 최신 수정사항 (2025-08-26) +- **UTF-8 테스트 출력 비활성화**: `[UTF8-TEST]` 메시지가 MCP JSON 파싱을 방해하던 문제 수정 +- **MCP 안전 실행**: `run_mcp_safe.py` 및 `run_mcp_safe.bat` 추가 +- **호환성 테스트**: `test_mcp_compatibility.py` 추가 + +**개발 시 반드시 `CLAUDE.md` 파일의 개발 가이드라인을 참조하세요!** - **고품질 이미지 생성**: 2048x2048 PNG 이미지 생성 - **미리보기 이미지**: 512x512 JPEG 미리보기를 base64로 제공 @@ -58,16 +81,16 @@ OUTPUT_PATH=./generated_images ## 🎮 사용법 -### 서버 실행 +### 서버 실행 (업데이트됨!) -#### Windows (유니코드 지원) +#### MCP 안전 실행 (추천) ```bash -run_unicode.bat +run_mcp_safe.bat ``` -#### 기본 실행 +또는 Python으로 직접: ```bash -run.bat +python run_mcp_safe.py ``` #### 직접 실행 (유니코드 환경 설정) @@ -85,24 +108,28 @@ python main.py ### Claude Desktop 설정 -`claude_desktop_config.json` 내용을 Claude Desktop 설정에 추가: +Claude Desktop 설정에 다음 내용을 추가하거나, `claude_desktop_config_mcp_safe.json` 파일을 사용하세요: ```json { "mcpServers": { "imagen4": { "command": "python", - "args": ["main.py"], + "args": ["run_mcp_safe.py"], "cwd": "D:\\Project\\imagen4", "env": { + "PYTHONPATH": "D:\\Project\\imagen4", "PYTHONIOENCODING": "utf-8", - "PYTHONUTF8": "1" + "PYTHONUTF8": "1", + "LC_ALL": "C.UTF-8" } } } } ``` +> **중요**: MCP 호환성을 위해 `run_mcp_safe.py`를 사용하세요! + ## 🛠️ 사용 가능한 도구 ### 1. generate_image @@ -195,7 +222,8 @@ def ensure_unicode_string(value): 한글 프롬프트 지원을 테스트하려면: ```bash -python test_main.py +python test_mcp_compatibility.py # MCP 호환성 테스트 (신규) +python test_main.py # 기본 기능 테스트 ``` 이 스크립트는 다음을 확인합니다: diff --git a/claude_desktop_config.json b/claude_desktop_config.json index 4aba09e..63169e7 100644 --- a/claude_desktop_config.json +++ b/claude_desktop_config.json @@ -2,10 +2,13 @@ "mcpServers": { "imagen4": { "command": "python", - "args": ["main.py"], + "args": ["run_mcp_safe.py"], "cwd": "D:\\Project\\imagen4", "env": { - "PYTHONPATH": "D:\\Project\\imagen4" + "PYTHONPATH": "D:\\Project\\imagen4", + "PYTHONIOENCODING": "utf-8", + "PYTHONUTF8": "1", + "LC_ALL": "C.UTF-8" } } } diff --git a/main.py b/main.py index 0cc1c02..86ca799 100644 --- a/main.py +++ b/main.py @@ -74,12 +74,15 @@ if sys.platform.startswith('win'): except Exception: pass -# Verify UTF-8 setup +# Verify UTF-8 setup (disabled for MCP compatibility) +# Note: UTF-8 test output disabled to prevent MCP protocol interference try: test_unicode = "Test UTF-8: 한글 테스트 ✓" - print(f"[UTF8-TEST] {test_unicode}") + # print(f"[UTF8-TEST] {test_unicode}") # Disabled: interferes with MCP JSON protocol except UnicodeEncodeError as e: - print(f"[UTF8-ERROR] Unicode test failed: {e}") + # Only log errors to stderr, not stdout + import sys + print(f"[UTF8-ERROR] Unicode test failed: {e}", file=sys.stderr) # ==================== Regular Imports (after UTF-8 setup) ==================== @@ -97,7 +100,9 @@ try: from dotenv import load_dotenv load_dotenv() except ImportError: - print("Warning: python-dotenv not installed", file=sys.stderr) + # Use logger for dotenv warning instead of direct print + import logging + logging.getLogger("imagen4-mcp-server").warning("python-dotenv not installed") # Add current directory to PYTHONPATH current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -111,8 +116,11 @@ try: from mcp.server import Server from mcp.types import Tool, TextContent except ImportError as e: - print(f"Error importing MCP: {e}", file=sys.stderr) - print("Please install required packages: pip install mcp", file=sys.stderr) + # Use logger for MCP import errors instead of direct print + import logging + logger = logging.getLogger("imagen4-mcp-server") + logger.error(f"Error importing MCP: {e}") + logger.error("Please install required packages: pip install mcp") sys.exit(1) # Connector imports @@ -124,8 +132,11 @@ from src.connector.utils import save_generated_images try: from PIL import Image except ImportError as e: - print(f"Error importing Pillow: {e}", file=sys.stderr) - print("Please install Pillow: pip install Pillow", file=sys.stderr) + # Use logger for Pillow import errors instead of direct print + import logging + logger = logging.getLogger("imagen4-mcp-server") + logger.error(f"Error importing Pillow: {e}") + logger.error("Please install Pillow: pip install Pillow") sys.exit(1) # ==================== Unicode-Safe Logging Setup ==================== diff --git a/run.bat b/run.bat deleted file mode 100644 index c7eaa58..0000000 --- a/run.bat +++ /dev/null @@ -1,56 +0,0 @@ -@echo off -chcp 65001 > nul 2>&1 -echo Starting Imagen4 MCP Server with Preview Image Support (Unicode Enabled)... -echo Features: 512x512 JPEG preview images, base64 encoding, Unicode logging support -echo. - -REM Set UTF-8 environment for Python -set PYTHONIOENCODING=utf-8 -set PYTHONUTF8=1 - -REM Check if virtual environment exists -if not exist "venv\Scripts\activate.bat" ( - echo Creating virtual environment... - python -m venv venv - if errorlevel 1 ( - echo Failed to create virtual environment - pause - exit /b 1 - ) -) - -REM Activate virtual environment -echo Activating virtual environment... -call venv\Scripts\activate.bat - -REM Check if requirements are installed -python -c "import PIL" 2>nul -if errorlevel 1 ( - echo Installing requirements... - pip install -r requirements.txt - if errorlevel 1 ( - echo Failed to install requirements - pause - exit /b 1 - ) -) - -REM Check if .env file exists -if not exist ".env" ( - echo Warning: .env file not found - echo Please copy .env.example to .env and configure your API credentials - pause - exit /b 1 -) - -REM Run server with UTF-8 support -echo Starting MCP server with Unicode support... -echo 한글 프롬프트 지원이 활성화되었습니다. -python main.py - -REM Keep window open if there's an error -if errorlevel 1 ( - echo. - echo Server exited with error - pause -) diff --git a/run_forced_utf8.bat b/run_forced_utf8.bat deleted file mode 100644 index 7bed40f..0000000 --- a/run_forced_utf8.bat +++ /dev/null @@ -1,28 +0,0 @@ -@echo off -echo Starting Imagen4 MCP Server with forced UTF-8 mode... -echo. - -REM Set console code page to UTF-8 -chcp 65001 >nul 2>&1 - -REM Set all UTF-8 environment variables -set PYTHONIOENCODING=utf-8 -set PYTHONUTF8=1 -set LC_ALL=C.UTF-8 -set LANG=C.UTF-8 - -REM Set console properties -title Imagen4 MCP Server (UTF-8 Forced) - -echo Environment Variables Set: -echo PYTHONIOENCODING=%PYTHONIOENCODING% -echo PYTHONUTF8=%PYTHONUTF8% -echo LC_ALL=%LC_ALL% -echo. - -REM Run Python with explicit UTF-8 mode flag -python -X utf8 main.py - -echo. -echo Server stopped. Press any key to exit... -pause >nul diff --git a/run_mcp_safe.bat b/run_mcp_safe.bat new file mode 100644 index 0000000..4d1a629 --- /dev/null +++ b/run_mcp_safe.bat @@ -0,0 +1,22 @@ +@echo off +REM MCP-Safe UTF-8 runner for imagen4 server +REM This prevents stdout interference with MCP JSON protocol + +echo [STARTUP] Starting Imagen4 MCP Server with safe UTF-8 handling... + +REM Change to script directory +cd /d "%~dp0" + +REM Set UTF-8 code page +chcp 65001 >nul 2>&1 + +REM Set UTF-8 environment variables +set PYTHONIOENCODING=utf-8 +set PYTHONUTF8=1 +set LC_ALL=C.UTF-8 + +REM Run the safe MCP runner +python run_mcp_safe.py + +echo [SHUTDOWN] Imagen4 MCP Server stopped. +pause diff --git a/run_mcp_safe.py b/run_mcp_safe.py new file mode 100644 index 0000000..1c6985f --- /dev/null +++ b/run_mcp_safe.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +""" +MCP Protocol Compatible UTF-8 Runner for imagen4 + +This script ensures proper UTF-8 handling without interfering with MCP protocol +""" + +import os +import sys +import subprocess +import locale + +def setup_utf8_environment(): + """Setup UTF-8 environment variables without stdout interference""" + # Set UTF-8 environment variables + env = os.environ.copy() + env['PYTHONIOENCODING'] = 'utf-8' + env['PYTHONUTF8'] = '1' + env['LC_ALL'] = 'C.UTF-8' + + # Windows-specific setup + if sys.platform.startswith('win'): + # Set console code page to UTF-8 (suppress output) + try: + os.system('chcp 65001 >nul 2>&1') + except Exception: + pass + + # Set locale + try: + locale.setlocale(locale.LC_ALL, 'C.UTF-8') + except locale.Error: + try: + locale.setlocale(locale.LC_ALL, '') + except locale.Error: + pass + + return env + +def main(): + """Main function - run imagen4 MCP server with proper UTF-8 setup""" + # Setup UTF-8 environment + env = setup_utf8_environment() + + # Get the directory of this script + script_dir = os.path.dirname(os.path.abspath(__file__)) + main_py = os.path.join(script_dir, 'main.py') + + # Run main.py with proper environment + try: + # Use subprocess to run with clean environment + result = subprocess.run([ + sys.executable, main_py + ], env=env, cwd=script_dir) + + return result.returncode + + except KeyboardInterrupt: + print("Server stopped by user", file=sys.stderr) + return 0 + except Exception as e: + print(f"Error running server: {e}", file=sys.stderr) + return 1 + +if __name__ == "__main__": + sys.exit(main()) diff --git a/run_utf8.bat b/run_utf8.bat deleted file mode 100644 index 3acdb1f..0000000 --- a/run_utf8.bat +++ /dev/null @@ -1,16 +0,0 @@ -@echo off -REM Force UTF-8 encoding for Python and console -chcp 65001 >nul 2>&1 - -REM Set Python UTF-8 environment variables -set PYTHONIOENCODING=utf-8 -set PYTHONUTF8=1 -set LC_ALL=C.UTF-8 - -REM Set console properties -title Imagen4 MCP Server (UTF-8) - -REM Run the Python script -python main.py - -pause diff --git a/run_utf8_forced.py b/run_utf8_forced.py deleted file mode 100644 index 40a1aea..0000000 --- a/run_utf8_forced.py +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env python3 -# Emergency UTF-8 override script -import sys -import os - -# Force UTF-8 mode at Python startup -if not hasattr(sys, '_utf8_mode') or not sys._utf8_mode: - os.environ['PYTHONUTF8'] = '1' - os.environ['PYTHONIOENCODING'] = 'utf-8' - - # Restart Python with UTF-8 mode - import subprocess - result = subprocess.run([sys.executable, '-X', 'utf8', __file__] + sys.argv[1:]) - sys.exit(result.returncode) - -# Now run the actual main.py -if __name__ == "__main__": - import importlib.util - import sys - - spec = importlib.util.spec_from_file_location("main", "main.py") - main_module = importlib.util.module_from_spec(spec) - sys.modules["main"] = main_module - spec.loader.exec_module(main_module) diff --git a/src/async_task/__init__.py b/src/async_task/__init__.py deleted file mode 100644 index 2c7d4c0..0000000 --- a/src/async_task/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -""" -Async Task Management Module for MCP Server -""" - -from .models import TaskStatus, TaskResult, generate_task_id -from .task_manager import TaskManager -from .worker_pool import WorkerPool - -__all__ = ['TaskManager', 'TaskStatus', 'TaskResult', 'WorkerPool'] diff --git a/src/async_task/models.py b/src/async_task/models.py deleted file mode 100644 index f9d06fb..0000000 --- a/src/async_task/models.py +++ /dev/null @@ -1,48 +0,0 @@ -""" -Task Status and Result Models -""" - -import uuid -from dataclasses import dataclass, field -from datetime import datetime -from enum import Enum -from typing import Any, Optional, Dict - - -class TaskStatus(Enum): - """Task execution status""" - PENDING = "pending" - RUNNING = "running" - COMPLETED = "completed" - FAILED = "failed" - CANCELLED = "cancelled" - - -@dataclass -class TaskResult: - """Task execution result""" - task_id: str - status: TaskStatus - result: Optional[Any] = None - error: Optional[str] = None - created_at: datetime = field(default_factory=datetime.now) - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - metadata: Dict[str, Any] = field(default_factory=dict) - - @property - def duration(self) -> Optional[float]: - """Calculate task execution duration in seconds""" - if self.started_at and self.completed_at: - return (self.completed_at - self.started_at).total_seconds() - return None - - @property - def is_finished(self) -> bool: - """Check if task is finished (completed, failed, or cancelled)""" - return self.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] - - -def generate_task_id() -> str: - """Generate unique task ID""" - return str(uuid.uuid4()) diff --git a/src/async_task/task_manager.py b/src/async_task/task_manager.py deleted file mode 100644 index 3c4e590..0000000 --- a/src/async_task/task_manager.py +++ /dev/null @@ -1,324 +0,0 @@ -""" -Task Manager for MCP Server -Manages background task execution with status tracking and result retrieval -""" - -import asyncio -import logging -from typing import Any, Callable, Optional, Dict, List -from datetime import datetime, timedelta - -from .models import TaskResult, TaskStatus, generate_task_id -from .worker_pool import WorkerPool - -logger = logging.getLogger(__name__) - - -class TaskManager: - """Main task manager for MCP server""" - - def __init__( - self, - max_workers: Optional[int] = None, - use_process_pool: bool = False, - default_timeout: float = 600.0, # 10 minutes - result_retention_hours: int = 24, # Keep results for 24 hours - max_retained_results: int = 200 - ): - """ - Initialize task manager - - Args: - max_workers: Maximum number of worker threads/processes - use_process_pool: Use process pool instead of thread pool - default_timeout: Default timeout for tasks in seconds - result_retention_hours: How long to keep finished task results - max_retained_results: Maximum number of results to retain - """ - self.worker_pool = WorkerPool( - max_workers=max_workers, - use_process_pool=use_process_pool, - task_timeout=default_timeout - ) - - self.result_retention_hours = result_retention_hours - self.max_retained_results = max_retained_results - - # Start cleanup task - self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) - - logger.info(f"TaskManager initialized with {self.worker_pool.max_workers} workers") - - async def submit_task( - self, - func: Callable, - *args, - task_id: Optional[str] = None, - timeout: Optional[float] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs - ) -> str: - """ - Submit a task for background execution - - Args: - func: Function to execute - *args: Function arguments - task_id: Optional custom task ID - timeout: Task timeout in seconds - metadata: Additional metadata for the task - **kwargs: Function keyword arguments - - Returns: - str: Task ID for tracking progress - """ - task_id = await self.worker_pool.submit_task( - func, *args, task_id=task_id, timeout=timeout, **kwargs - ) - - # Add additional metadata if provided - if metadata: - result = self.worker_pool.get_task_result(task_id) - if result: - result.metadata.update(metadata) - - return task_id - - def get_task_status(self, task_id: str) -> Optional[TaskStatus]: - """Get task status by ID""" - result = self.worker_pool.get_task_result(task_id) - return result.status if result else None - - def get_task_result(self, task_id: str) -> Optional[TaskResult]: - """Get complete task result by ID with circular reference protection""" - try: - raw_result = self.worker_pool.get_task_result(task_id) - if not raw_result: - return None - - # Create a safe copy to avoid circular references - safe_result = TaskResult( - task_id=raw_result.task_id, - status=raw_result.status, - result=None, # We'll set this safely - error=raw_result.error, - created_at=raw_result.created_at, - started_at=raw_result.started_at, - completed_at=raw_result.completed_at, - metadata=raw_result.metadata.copy() if raw_result.metadata else {} - ) - - # Safely copy the result data - if raw_result.result is not None: - try: - # If it's a dict, create a clean copy - if isinstance(raw_result.result, dict): - safe_result.result = { - 'success': raw_result.result.get('success', False), - 'error': raw_result.result.get('error'), - 'images_b64': raw_result.result.get('images_b64', []), - 'saved_files': raw_result.result.get('saved_files', []), - 'request': raw_result.result.get('request', {}), - 'image_count': raw_result.result.get('image_count', 0) - } - else: - # For non-dict results, just copy directly - safe_result.result = raw_result.result - except Exception as copy_error: - logger.error(f"Error creating safe copy of result: {str(copy_error)}") - safe_result.result = { - 'success': False, - 'error': f'Result data corrupted: {str(copy_error)}' - } - - return safe_result - - except Exception as e: - logger.error(f"Error in get_task_result: {str(e)}", exc_info=True) - return None - - def get_task_progress(self, task_id: str) -> Dict[str, Any]: - """ - Get task progress information - - Returns: - Dict with task status, timing info, and metadata - """ - result = self.get_task_result(task_id) - if not result: - return {'error': 'Task not found'} - - progress = { - 'task_id': result.task_id, - 'status': result.status.value, - 'created_at': result.created_at.isoformat(), - 'metadata': result.metadata - } - - if result.started_at: - progress['started_at'] = result.started_at.isoformat() - - if result.status == TaskStatus.RUNNING: - elapsed = (datetime.now() - result.started_at).total_seconds() - progress['elapsed_seconds'] = elapsed - - if result.completed_at: - progress['completed_at'] = result.completed_at.isoformat() - progress['duration_seconds'] = result.duration - - if result.error: - progress['error'] = result.error - - return progress - - def list_tasks( - self, - status_filter: Optional[TaskStatus] = None, - limit: int = 50 - ) -> List[Dict[str, Any]]: - """ - List tasks with optional filtering - - Args: - status_filter: Filter by task status - limit: Maximum number of tasks to return - - Returns: - List of task progress dictionaries - """ - all_results = self.worker_pool.get_all_results() - - # Filter by status if requested - if status_filter: - filtered_results = { - k: v for k, v in all_results.items() - if v.status == status_filter - } - else: - filtered_results = all_results - - # Sort by creation time (most recent first) - sorted_items = sorted( - filtered_results.items(), - key=lambda x: x[1].created_at, - reverse=True - ) - - # Apply limit and return progress info - return [ - self.get_task_progress(task_id) - for task_id, _ in sorted_items[:limit] - ] - - def cancel_task(self, task_id: str) -> bool: - """Cancel a running task""" - return self.worker_pool.cancel_task(task_id) - - async def wait_for_task( - self, - task_id: str, - timeout: Optional[float] = None, - poll_interval: float = 1.0 - ) -> TaskResult: - """ - Wait for a task to complete - - Args: - task_id: Task ID to wait for - timeout: Maximum time to wait in seconds - poll_interval: How often to check task status - - Returns: - TaskResult: Final task result - - Raises: - asyncio.TimeoutError: If timeout is reached - ValueError: If task doesn't exist - """ - start_time = datetime.now() - - while True: - result = self.get_task_result(task_id) - if not result: - raise ValueError(f"Task {task_id} not found") - - if result.is_finished: - return result - - # Check timeout - if timeout: - elapsed = (datetime.now() - start_time).total_seconds() - if elapsed >= timeout: - raise asyncio.TimeoutError(f"Timeout waiting for task {task_id}") - - await asyncio.sleep(poll_interval) - - async def _periodic_cleanup(self) -> None: - """Periodic cleanup of old task results""" - while True: - try: - # Wait 1 hour between cleanups - await asyncio.sleep(3600) - - # Clean up old results - cutoff_time = datetime.now() - timedelta(hours=self.result_retention_hours) - all_results = self.worker_pool.get_all_results() - - to_remove = [] - for task_id, result in all_results.items(): - if (result.is_finished and - result.completed_at and - result.completed_at < cutoff_time): - to_remove.append(task_id) - - # Remove old results - for task_id in to_remove: - if task_id in self.worker_pool._results: - del self.worker_pool._results[task_id] - - if to_remove: - logger.info(f"Cleaned up {len(to_remove)} old task results") - - # Also clean up excess results - self.worker_pool.cleanup_finished_tasks(self.max_retained_results) - - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"Error in periodic cleanup: {e}") - - async def shutdown(self, wait: bool = True) -> None: - """Shutdown task manager""" - logger.info("Shutting down task manager...") - - # Cancel cleanup task - if not self._cleanup_task.done(): - self._cleanup_task.cancel() - try: - await self._cleanup_task - except asyncio.CancelledError: - pass - - # Shutdown worker pool - await self.worker_pool.shutdown(wait=wait) - - logger.info("Task manager shutdown complete") - - @property - def stats(self) -> Dict[str, Any]: - """Get task manager statistics""" - all_results = self.worker_pool.get_all_results() - - status_counts = {} - for status in TaskStatus: - status_counts[status.value] = sum( - 1 for r in all_results.values() if r.status == status - ) - - return { - 'total_tasks': len(all_results), - 'active_tasks': self.worker_pool.active_task_count, - 'max_workers': self.worker_pool.max_workers, - 'worker_type': 'process' if self.worker_pool.use_process_pool else 'thread', - 'status_counts': status_counts - } diff --git a/src/async_task/worker_pool.py b/src/async_task/worker_pool.py deleted file mode 100644 index c40cb61..0000000 --- a/src/async_task/worker_pool.py +++ /dev/null @@ -1,258 +0,0 @@ -""" -Worker Pool for Background Task Execution -""" - -import asyncio -import logging -import multiprocessing -from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor -from typing import Any, Callable, Optional, Union -from datetime import datetime - -from .models import TaskResult, TaskStatus, generate_task_id - -logger = logging.getLogger(__name__) - - -class WorkerPool: - """Worker pool for executing background tasks""" - - def __init__( - self, - max_workers: Optional[int] = None, - use_process_pool: bool = False, - task_timeout: float = 600.0 # 10 minutes default - ): - """ - Initialize worker pool - - Args: - max_workers: Maximum number of worker threads/processes - use_process_pool: Use ProcessPoolExecutor instead of ThreadPoolExecutor - task_timeout: Default timeout for tasks in seconds - """ - self.max_workers = max_workers or min(32, (multiprocessing.cpu_count() or 1) + 4) - self.use_process_pool = use_process_pool - self.task_timeout = task_timeout - - # Initialize executor - if use_process_pool: - self.executor = ProcessPoolExecutor(max_workers=self.max_workers) - logger.info(f"Initialized ProcessPoolExecutor with {self.max_workers} workers") - else: - self.executor = ThreadPoolExecutor(max_workers=self.max_workers) - logger.info(f"Initialized ThreadPoolExecutor with {self.max_workers} workers") - - # Track running tasks - self._running_tasks: dict[str, asyncio.Task] = {} - self._results: dict[str, TaskResult] = {} - - async def submit_task( - self, - func: Callable, - *args, - task_id: Optional[str] = None, - timeout: Optional[float] = None, - **kwargs - ) -> str: - """ - Submit a task for background execution - - Args: - func: Function to execute - *args: Function arguments - task_id: Optional custom task ID - timeout: Task timeout in seconds - **kwargs: Function keyword arguments - - Returns: - str: Task ID for tracking - """ - if task_id is None: - task_id = generate_task_id() - - timeout = timeout or self.task_timeout - - # Create task result - task_result = TaskResult( - task_id=task_id, - status=TaskStatus.PENDING, - metadata={ - 'function_name': func.__name__, - 'timeout': timeout, - 'worker_type': 'process' if self.use_process_pool else 'thread' - } - ) - - self._results[task_id] = task_result - - # Create and start background task - background_task = asyncio.create_task( - self._execute_task(func, args, kwargs, task_result, timeout) - ) - - self._running_tasks[task_id] = background_task - - logger.info(f"Task {task_id} submitted for execution (timeout: {timeout}s)") - return task_id - - async def _execute_task( - self, - func: Callable, - args: tuple, - kwargs: dict, - task_result: TaskResult, - timeout: float - ) -> None: - """Execute task in background""" - try: - task_result.status = TaskStatus.RUNNING - task_result.started_at = datetime.now() - - logger.info(f"Starting task {task_result.task_id}: {func.__name__}") - - # Execute function with timeout - if self.use_process_pool: - # For process pool, function must be pickleable - future = self.executor.submit(func, *args, **kwargs) - result = await asyncio.wait_for( - asyncio.wrap_future(future), - timeout=timeout - ) - else: - # For thread pool, can use lambda/closure - result = await asyncio.wait_for( - asyncio.to_thread(func, *args, **kwargs), - timeout=timeout - ) - - task_result.result = result - task_result.status = TaskStatus.COMPLETED - task_result.completed_at = datetime.now() - - logger.info(f"Task {task_result.task_id} completed successfully in {task_result.duration:.2f}s") - - except asyncio.TimeoutError: - task_result.status = TaskStatus.FAILED - task_result.error = f"Task timed out after {timeout} seconds" - task_result.completed_at = datetime.now() - logger.error(f"Task {task_result.task_id} timed out after {timeout}s") - - except asyncio.CancelledError: - task_result.status = TaskStatus.CANCELLED - task_result.error = "Task was cancelled" - task_result.completed_at = datetime.now() - logger.info(f"Task {task_result.task_id} was cancelled") - - except Exception as e: - task_result.status = TaskStatus.FAILED - task_result.error = str(e) - task_result.completed_at = datetime.now() - logger.error(f"Task {task_result.task_id} failed: {str(e)}", exc_info=True) - - finally: - # Clean up running task reference - if task_result.task_id in self._running_tasks: - del self._running_tasks[task_result.task_id] - - def get_task_result(self, task_id: str) -> Optional[TaskResult]: - """Get task result by ID""" - return self._results.get(task_id) - - def get_all_results(self) -> dict[str, TaskResult]: - """Get all task results""" - return self._results.copy() - - def cancel_task(self, task_id: str) -> bool: - """ - Cancel a running task - - Args: - task_id: Task ID to cancel - - Returns: - bool: True if task was cancelled, False if not found or already finished - """ - if task_id in self._running_tasks: - task = self._running_tasks[task_id] - if not task.done(): - task.cancel() - logger.info(f"Task {task_id} cancellation requested") - return True - return False - - def cleanup_finished_tasks(self, max_results: int = 100) -> int: - """ - Clean up finished task results to prevent memory buildup - - Args: - max_results: Maximum number of results to keep - - Returns: - int: Number of results cleaned up - """ - if len(self._results) <= max_results: - return 0 - - # Sort by completion time, keep most recent - finished_results = { - k: v for k, v in self._results.items() - if v.is_finished and v.completed_at - } - - if len(finished_results) <= max_results: - return 0 - - sorted_results = sorted( - finished_results.items(), - key=lambda x: x[1].completed_at, - reverse=True - ) - - # Keep most recent max_results - to_keep = set(k for k, _ in sorted_results[:max_results]) - - # Remove old results - to_remove = [k for k in finished_results.keys() if k not in to_keep] - for task_id in to_remove: - del self._results[task_id] - - cleaned_count = len(to_remove) - if cleaned_count > 0: - logger.info(f"Cleaned up {cleaned_count} old task results") - - return cleaned_count - - async def shutdown(self, wait: bool = True) -> None: - """ - Shutdown worker pool - - Args: - wait: Whether to wait for running tasks to complete - """ - logger.info("Shutting down worker pool...") - - if wait: - # Cancel all running tasks - for task_id, task in self._running_tasks.items(): - if not task.done(): - task.cancel() - logger.info(f"Cancelled task {task_id}") - - # Wait for tasks to complete cancellation - if self._running_tasks: - await asyncio.gather(*self._running_tasks.values(), return_exceptions=True) - - # Shutdown executor - self.executor.shutdown(wait=wait) - logger.info("Worker pool shutdown complete") - - @property - def active_task_count(self) -> int: - """Get number of currently running tasks""" - return len([t for t in self._running_tasks.values() if not t.done()]) - - @property - def total_task_count(self) -> int: - """Get total number of tracked tasks""" - return len(self._results) diff --git a/src/server/__init__.py b/src/server/__init__.py deleted file mode 100644 index 0aa1108..0000000 --- a/src/server/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -Imagen4 Server Package - -MCP 서버 구현을 담당하는 모듈 -""" - -# Import basic server components -from src.server.mcp_server import Imagen4MCPServer -from src.server.handlers import ToolHandlers -from src.server.models import MCPToolDefinitions - -# Import enhanced components -try: - from src.server.enhanced_mcp_server import EnhancedImagen4MCPServer - from src.server.enhanced_handlers import EnhancedToolHandlers - from src.server.enhanced_models import ImageGenerationResult, PreviewImageResponse - - __all__ = [ - 'Imagen4MCPServer', 'ToolHandlers', 'MCPToolDefinitions', - 'EnhancedImagen4MCPServer', 'EnhancedToolHandlers', - 'ImageGenerationResult', 'PreviewImageResponse' - ] -except ImportError as e: - # Fallback if enhanced modules are not available - print(f"Warning: Enhanced features not available: {e}") - __all__ = ['Imagen4MCPServer', 'ToolHandlers', 'MCPToolDefinitions'] diff --git a/src/server/enhanced_handlers.py b/src/server/enhanced_handlers.py deleted file mode 100644 index 2da7c21..0000000 --- a/src/server/enhanced_handlers.py +++ /dev/null @@ -1,344 +0,0 @@ -""" -Enhanced Tool Handlers for MCP Server with Preview Image Support and Unicode Logging -""" - -import asyncio -import base64 -import json -import random -import logging -import sys -import os -from typing import List, Dict, Any - -from mcp.types import TextContent, ImageContent - -from src.connector import Imagen4Client, Config -from src.connector.imagen4_client import ImageGenerationRequest -from src.connector.utils import save_generated_images -from src.utils.image_utils import create_preview_image_b64, get_image_info -from src.server.enhanced_models import ImageGenerationResult, PreviewImageResponse - -# Configure Unicode logging -logger = logging.getLogger(__name__) - -# Ensure proper Unicode handling for all string operations -def ensure_unicode_string(value): - """Ensure a value is a proper Unicode string""" - if isinstance(value, bytes): - return value.decode('utf-8', errors='replace') - elif isinstance(value, str): - return value - else: - return str(value) - - -def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]: - """Remove or truncate sensitive data from arguments for safe logging with Unicode support""" - safe_args = {} - for key, value in arguments.items(): - if isinstance(value, str): - # Ensure proper Unicode handling - value = ensure_unicode_string(value) - - # Check if it's likely base64 image data - if (key in ['data', 'image_data', 'base64', 'preview_image_b64'] or - (len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))): - # Truncate long image data - safe_args[key] = f"" - elif len(value) > 1000: - # Truncate any very long strings but preserve Unicode characters - truncated = value[:100] - safe_args[key] = f"{truncated}..." - else: - # Keep the original string (including Korean/Unicode characters) - safe_args[key] = value - else: - safe_args[key] = value - return safe_args - - -class EnhancedToolHandlers: - """Enhanced MCP tool handler class with preview image support and Unicode logging""" - - def __init__(self, config: Config): - """Initialize handler""" - self.config = config - self.client = Imagen4Client(config) - - async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]: - """Random seed generation handler""" - try: - random_seed = random.randint(0, 2**32 - 1) - return [TextContent( - type="text", - text=f"Generated random seed: {random_seed}" - )] - except Exception as e: - logger.error(f"Random seed generation error: {str(e)}") - return [TextContent( - type="text", - text=f"Error occurred during random seed generation: {str(e)}" - )] - - async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent]: - """Image regeneration from JSON file handler with preview support""" - try: - json_file_path = arguments.get("json_file_path") - save_to_file = arguments.get("save_to_file", True) - - if not json_file_path: - return [TextContent( - type="text", - text="Error: JSON file path is required." - )] - - # Load parameters from JSON file with proper UTF-8 encoding - try: - with open(json_file_path, 'r', encoding='utf-8') as f: - params = json.load(f) - except FileNotFoundError: - return [TextContent( - type="text", - text=f"Error: Cannot find JSON file: {json_file_path}" - )] - except json.JSONDecodeError as e: - return [TextContent( - type="text", - text=f"Error: JSON parsing error: {str(e)}" - )] - - # Check required parameters - required_params = ['prompt', 'seed'] - missing_params = [p for p in required_params if p not in params] - - if missing_params: - return [TextContent( - type="text", - text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}" - )] - - # Create image generation request object - request = ImageGenerationRequest( - prompt=params.get('prompt'), - negative_prompt=params.get('negative_prompt', ''), - number_of_images=params.get('number_of_images', 1), - seed=params.get('seed'), - aspect_ratio=params.get('aspect_ratio', '1:1'), - model=params.get('model', self.config.default_model) - ) - - logger.info(f"Loaded parameters from JSON: {json_file_path}") - - # Generate image - response = await self.client.generate_image(request) - - if not response.success: - return [TextContent( - type="text", - text=f"Error occurred during image regeneration: {response.error_message}" - )] - - # Create preview images - preview_images_b64 = [] - for image_data in response.images_data: - preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85) - if preview_b64: - preview_images_b64.append(preview_b64) - logger.info(f"Created preview image: {len(preview_b64)} chars base64 JPEG") - - # Save files (optional) - saved_files = [] - if save_to_file: - regeneration_params = { - "prompt": request.prompt, - "negative_prompt": request.negative_prompt, - "number_of_images": request.number_of_images, - "seed": request.seed, - "aspect_ratio": request.aspect_ratio, - "model": request.model, - "regenerated_from": json_file_path, - "original_generated_at": params.get('generated_at', 'unknown') - } - - saved_files = save_generated_images( - images_data=response.images_data, - save_directory=self.config.output_path, - seed=request.seed, - generation_params=regeneration_params, - filename_prefix="imagen4_regen" - ) - - # Create result with preview images - result = ImageGenerationResult( - success=True, - message=f"✅ Images have been successfully regenerated from {json_file_path}", - original_images_count=len(response.images_data), - preview_images_b64=preview_images_b64, - saved_files=saved_files, - generation_params={ - "prompt": request.prompt, - "seed": request.seed, - "aspect_ratio": request.aspect_ratio, - "number_of_images": request.number_of_images, - "model": request.model - } - ) - - return [TextContent( - type="text", - text=result.to_text_content() - )] - - except Exception as e: - logger.error(f"Error occurred during image regeneration: {str(e)}") - return [TextContent( - type="text", - text=f"Error occurred during image regeneration: {str(e)}" - )] - - async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent]: - """Enhanced image generation handler with preview image support and proper Unicode logging""" - try: - # Log arguments safely without exposing image data but preserve Unicode - safe_args = sanitize_args_for_logging(arguments) - logger.info(f"handle_generate_image called with arguments: {safe_args}") - - # Extract and validate arguments - prompt = arguments.get("prompt") - if not prompt: - logger.error("No prompt provided") - return [TextContent( - type="text", - text="Error: Prompt is required." - )] - - # Ensure prompt is properly handled as Unicode - prompt = ensure_unicode_string(prompt) - - seed = arguments.get("seed") - if seed is None: - logger.error("No seed provided") - return [TextContent( - type="text", - text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed." - )] - - # Create image generation request object - request = ImageGenerationRequest( - prompt=prompt, - negative_prompt=ensure_unicode_string(arguments.get("negative_prompt", "")), - number_of_images=arguments.get("number_of_images", 1), - seed=seed, - aspect_ratio=arguments.get("aspect_ratio", "1:1"), - model=arguments.get("model", self.config.default_model) - ) - - save_to_file = arguments.get("save_to_file", True) - - # Log with proper Unicode handling for Korean/international text - prompt_preview = prompt[:50] + "..." if len(prompt) > 50 else prompt - logger.info(f"Starting image generation: '{prompt_preview}', Seed: {seed}") - - # Generate image with timeout - try: - logger.info("Calling client.generate_image()...") - response = await asyncio.wait_for( - self.client.generate_image(request), - timeout=360.0 # 6 minute timeout - ) - logger.info(f"Image generation completed. Success: {response.success}") - except asyncio.TimeoutError: - logger.error("Image generation timed out after 6 minutes") - return [TextContent( - type="text", - text="Error: Image generation timed out after 6 minutes. Please try again." - )] - - if not response.success: - logger.error(f"Image generation failed: {response.error_message}") - return [TextContent( - type="text", - text=f"Error occurred during image generation: {response.error_message}" - )] - - logger.info(f"Generated {len(response.images_data)} images successfully") - - # Create preview images (512x512 JPEG base64) - preview_images_b64 = [] - for i, image_data in enumerate(response.images_data): - # Log original image info - image_info = get_image_info(image_data) - if image_info: - logger.info(f"Original image {i+1}: {image_info}") - - # Create preview - preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85) - if preview_b64: - preview_images_b64.append(preview_b64) - logger.info(f"Created preview image {i+1}: {len(preview_b64)} chars base64 JPEG") - else: - logger.warning(f"Failed to create preview for image {i+1}") - - # Save files if requested - saved_files = [] - if save_to_file: - logger.info("Saving files to disk...") - generation_params = { - "prompt": request.prompt, - "negative_prompt": request.negative_prompt, - "number_of_images": request.number_of_images, - "seed": request.seed, - "aspect_ratio": request.aspect_ratio, - "model": request.model, - "guidance_scale": 7.5, - "safety_filter_level": "block_only_high", - "person_generation": "allow_all", - "add_watermark": False - } - - saved_files = save_generated_images( - images_data=response.images_data, - save_directory=self.config.output_path, - seed=request.seed, - generation_params=generation_params - ) - logger.info(f"Files saved: {saved_files}") - - # Verify files were created - for file_path in saved_files: - if os.path.exists(file_path): - size = os.path.getsize(file_path) - logger.info(f" ✓ Verified: {file_path} ({size} bytes)") - else: - logger.error(f" ❌ Missing: {file_path}") - - # Create enhanced result with preview images - result = ImageGenerationResult( - success=True, - message=f"✅ Images have been successfully generated!", - original_images_count=len(response.images_data), - preview_images_b64=preview_images_b64, - saved_files=saved_files, - generation_params={ - "prompt": request.prompt, - "seed": request.seed, - "aspect_ratio": request.aspect_ratio, - "number_of_images": request.number_of_images, - "negative_prompt": request.negative_prompt, - "model": request.model - } - ) - - logger.info(f"Returning response with {len(preview_images_b64)} preview images") - return [TextContent( - type="text", - text=result.to_text_content() - )] - - except Exception as e: - logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True) - return [TextContent( - type="text", - text=f"Error occurred during image generation: {str(e)}" - )] diff --git a/src/server/enhanced_mcp_server.py b/src/server/enhanced_mcp_server.py deleted file mode 100644 index 3e25689..0000000 --- a/src/server/enhanced_mcp_server.py +++ /dev/null @@ -1,63 +0,0 @@ -""" -Enhanced MCP Server implementation for Imagen4 with Preview Image Support -""" - -import logging -from typing import Dict, Any, List - -from mcp.server import Server -from mcp.types import Tool, TextContent - -from src.connector import Config -from src.server.models import MCPToolDefinitions -from src.server.enhanced_handlers import EnhancedToolHandlers, sanitize_args_for_logging - -logger = logging.getLogger(__name__) - - -class EnhancedImagen4MCPServer: - """Enhanced Imagen4 MCP server class with preview image support""" - - def __init__(self, config: Config): - """Initialize server""" - self.config = config - self.server = Server("imagen4-enhanced-mcp-server") - self.handlers = EnhancedToolHandlers(config) - - # Register handlers - self._register_handlers() - - def _register_handlers(self) -> None: - """Register MCP handlers""" - - @self.server.list_tools() - async def handle_list_tools() -> List[Tool]: - """Return list of available tools""" - logger.info("Listing available tools (enhanced version)") - return MCPToolDefinitions.get_all_tools() - - @self.server.call_tool() - async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: - """Handle tool calls with enhanced preview image support""" - # Log tool call safely without exposing sensitive data - safe_args = sanitize_args_for_logging(arguments) - logger.info(f"Enhanced tool called: {name} with arguments: {safe_args}") - - if name == "generate_random_seed": - return await self.handlers.handle_generate_random_seed(arguments) - elif name == "regenerate_from_json": - return await self.handlers.handle_regenerate_from_json(arguments) - elif name == "generate_image": - return await self.handlers.handle_generate_image(arguments) - else: - raise ValueError(f"Unknown tool: {name}") - - def get_server(self) -> Server: - """Return MCP server instance""" - return self.server - - -# Create a factory function for easier import -def create_enhanced_server(config: Config) -> EnhancedImagen4MCPServer: - """Factory function to create enhanced server""" - return EnhancedImagen4MCPServer(config) diff --git a/src/server/enhanced_models.py b/src/server/enhanced_models.py deleted file mode 100644 index 670a64f..0000000 --- a/src/server/enhanced_models.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Enhanced MCP response models with preview image support -""" - -from typing import Optional, Dict, Any -from dataclasses import dataclass - - -@dataclass -class ImageGenerationResult: - """Enhanced image generation result with preview support""" - success: bool - message: str - original_images_count: int - preview_images_b64: Optional[list[str]] = None # List of base64 JPEG previews (512x512) - saved_files: Optional[list[str]] = None - generation_params: Optional[Dict[str, Any]] = None - error_message: Optional[str] = None - - def to_text_content(self) -> str: - """Convert result to text format for MCP response""" - lines = [self.message] - - if self.success and self.preview_images_b64: - lines.append(f"\n🖼️ Preview Images Generated: {len(self.preview_images_b64)} images (512x512 JPEG)") - for i, preview_b64 in enumerate(self.preview_images_b64): - lines.append(f"Preview {i+1} (base64 JPEG): {preview_b64[:50]}...({len(preview_b64)} chars)") - - if self.saved_files: - lines.append(f"\n📁 Files saved:") - for filepath in self.saved_files: - lines.append(f" - {filepath}") - - if self.generation_params: - lines.append(f"\n⚙️ Generation Parameters:") - for key, value in self.generation_params.items(): - if key == 'prompt' and len(str(value)) > 100: - lines.append(f" - {key}: {str(value)[:100]}...") - else: - lines.append(f" - {key}: {value}") - - return "\n".join(lines) - - -@dataclass -class PreviewImageResponse: - """Response containing preview images in base64 format""" - preview_images_b64: list[str] # Base64 encoded JPEG images (512x512) - original_count: int - message: str - - @classmethod - def from_image_data(cls, images_data: list[bytes], message: str = "") -> 'PreviewImageResponse': - """Create response from original PNG image data""" - from src.utils.image_utils import create_preview_image_b64 - - preview_images = [] - for i, image_data in enumerate(images_data): - preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85) - if preview_b64: - preview_images.append(preview_b64) - else: - # Fallback - use original if preview creation fails - import base64 - preview_images.append(base64.b64encode(image_data).decode('utf-8')) - - return cls( - preview_images_b64=preview_images, - original_count=len(images_data), - message=message or f"Generated {len(preview_images)} preview images" - ) diff --git a/src/server/handlers.py b/src/server/handlers.py deleted file mode 100644 index ead0186..0000000 --- a/src/server/handlers.py +++ /dev/null @@ -1,308 +0,0 @@ -""" -Tool Handlers for MCP Server -""" - -import asyncio -import base64 -import json -import random -import logging -from typing import List, Dict, Any - -from mcp.types import TextContent, ImageContent - -from src.connector import Imagen4Client, Config -from src.connector.imagen4_client import ImageGenerationRequest -from src.connector.utils import save_generated_images - -logger = logging.getLogger(__name__) - - -def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]: - """Remove or truncate sensitive data from arguments for safe logging""" - safe_args = {} - for key, value in arguments.items(): - if isinstance(value, str): - # Check if it's likely base64 image data - if (key in ['data', 'image_data', 'base64'] or - (len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))): - # Truncate long image data - safe_args[key] = f"" - elif len(value) > 1000: - # Truncate any very long strings - safe_args[key] = f"{value[:100]}..." - else: - safe_args[key] = value - else: - safe_args[key] = value - return safe_args - - -class ToolHandlers: - """MCP tool handler class""" - - def __init__(self, config: Config): - """Initialize handler""" - self.config = config - self.client = Imagen4Client(config) - - async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]: - """Random seed generation handler""" - try: - random_seed = random.randint(0, 2**32 - 1) - return [TextContent( - type="text", - text=f"Generated random seed: {random_seed}" - )] - except Exception as e: - logger.error(f"Random seed generation error: {str(e)}") - return [TextContent( - type="text", - text=f"Error occurred during random seed generation: {str(e)}" - )] - - async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]: - """Image regeneration from JSON file handler""" - try: - json_file_path = arguments.get("json_file_path") - save_to_file = arguments.get("save_to_file", True) - - if not json_file_path: - return [TextContent( - type="text", - text="Error: JSON file path is required." - )] - - # Load parameters from JSON file - try: - with open(json_file_path, 'r', encoding='utf-8') as f: - params = json.load(f) - except FileNotFoundError: - return [TextContent( - type="text", - text=f"Error: Cannot find JSON file: {json_file_path}" - )] - except json.JSONDecodeError as e: - return [TextContent( - type="text", - text=f"Error: JSON parsing error: {str(e)}" - )] - - # Check required parameters - required_params = ['prompt', 'seed'] - missing_params = [p for p in required_params if p not in params] - - if missing_params: - return [TextContent( - type="text", - text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}" - )] - - # Create image generation request object - request = ImageGenerationRequest( - prompt=params.get('prompt'), - negative_prompt=params.get('negative_prompt', ''), - number_of_images=params.get('number_of_images', 1), - seed=params.get('seed'), - aspect_ratio=params.get('aspect_ratio', '1:1') - ) - - logger.info(f"Loaded parameters from JSON: {json_file_path}") - - # Generate image - response = await self.client.generate_image(request) - - if not response.success: - return [TextContent( - type="text", - text=f"Error occurred during image regeneration: {response.error_message}" - )] - - # Save files (optional) - saved_files = [] - if save_to_file: - regeneration_params = { - "prompt": request.prompt, - "negative_prompt": request.negative_prompt, - "number_of_images": request.number_of_images, - "seed": request.seed, - "aspect_ratio": request.aspect_ratio, - "regenerated_from": json_file_path, - "original_generated_at": params.get('generated_at', 'unknown') - } - - saved_files = save_generated_images( - images_data=response.images_data, - save_directory=self.config.output_path, - seed=request.seed, - generation_params=regeneration_params, - filename_prefix="imagen4_regen" - ) - - # Generate result message - message_parts = [ - f"Images have been successfully regenerated.", - f"JSON file: {json_file_path}", - f"Size: 2048x2048 PNG", - f"Count: {len(response.images_data)} images", - f"Seed: {request.seed}", - f"Prompt: {request.prompt}" - ] - - if saved_files: - message_parts.append(f"\nRegenerated files saved successfully:") - for filepath in saved_files: - message_parts.append(f"- {filepath}") - - # Generate result - result = [ - TextContent( - type="text", - text="\n".join(message_parts) - ) - ] - - # Add Base64 encoded images - for i, image_data in enumerate(response.images_data): - image_base64 = base64.b64encode(image_data).decode('utf-8') - result.append( - ImageContent( - type="image", - data=image_base64, - mimeType="image/png" - ) - ) - logger.info(f"Added image {i+1} to response: {len(image_base64)} chars base64 data") - - return result - - except Exception as e: - logger.error(f"Error occurred during image regeneration: {str(e)}") - return [TextContent( - type="text", - text=f"Error occurred during image regeneration: {str(e)}" - )] - - async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]: - """Image generation handler with synchronous processing""" - try: - # Log arguments safely without exposing image data - safe_args = sanitize_args_for_logging(arguments) - logger.info(f"handle_generate_image called with arguments: {safe_args}") - - # Extract and validate arguments - prompt = arguments.get("prompt") - if not prompt: - logger.error("No prompt provided") - return [TextContent( - type="text", - text="Error: Prompt is required." - )] - - seed = arguments.get("seed") - if seed is None: - logger.error("No seed provided") - return [TextContent( - type="text", - text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed." - )] - - # Create image generation request object - request = ImageGenerationRequest( - prompt=prompt, - negative_prompt=arguments.get("negative_prompt", ""), - number_of_images=arguments.get("number_of_images", 1), - seed=seed, - aspect_ratio=arguments.get("aspect_ratio", "1:1") - ) - - save_to_file = arguments.get("save_to_file", True) - - logger.info(f"Starting SYNCHRONOUS image generation: '{prompt[:50]}...', Seed: {seed}") - - # Generate image synchronously with longer timeout - try: - logger.info("Calling client.generate_image() synchronously...") - response = await asyncio.wait_for( - self.client.generate_image(request), - timeout=360.0 # 6 minute timeout - ) - logger.info(f"Image generation completed. Success: {response.success}") - except asyncio.TimeoutError: - logger.error("Image generation timed out after 6 minutes") - return [TextContent( - type="text", - text="Error: Image generation timed out after 6 minutes. Please try again." - )] - - if not response.success: - logger.error(f"Image generation failed: {response.error_message}") - return [TextContent( - type="text", - text=f"Error occurred during image generation: {response.error_message}" - )] - - logger.info(f"Generated {len(response.images_data)} images successfully") - - # Save files if requested - saved_files = [] - if save_to_file: - logger.info("Saving files to disk...") - generation_params = { - "prompt": request.prompt, - "negative_prompt": request.negative_prompt, - "number_of_images": request.number_of_images, - "seed": request.seed, - "aspect_ratio": request.aspect_ratio, - "guidance_scale": 7.5, - "safety_filter_level": "block_only_high", - "person_generation": "allow_all", - "add_watermark": False - } - - saved_files = save_generated_images( - images_data=response.images_data, - save_directory=self.config.output_path, - seed=request.seed, - generation_params=generation_params - ) - logger.info(f"Files saved: {saved_files}") - - # Verify files were created - import os - for file_path in saved_files: - if os.path.exists(file_path): - size = os.path.getsize(file_path) - logger.info(f" ✓ Verified: {file_path} ({size} bytes)") - else: - logger.error(f" ❌ Missing: {file_path}") - - # Generate result message - message_parts = [ - f"✅ Images have been successfully generated!", - f"Prompt: {request.prompt}", - f"Seed: {request.seed}", - f"Size: 2048x2048 PNG", - f"Count: {len(response.images_data)} images" - ] - - if saved_files: - message_parts.append(f"\n📁 Files saved successfully:") - for filepath in saved_files: - message_parts.append(f" - {filepath}") - else: - message_parts.append(f"\nℹ️ Images generated but not saved to file. Use save_to_file=true to save.") - - logger.info("Returning synchronous response") - return [TextContent( - type="text", - text="\n".join(message_parts) - )] - - except Exception as e: - logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True) - return [TextContent( - type="text", - text=f"Error occurred during image generation: {str(e)}" - )] - diff --git a/src/server/mcp_server.py b/src/server/mcp_server.py deleted file mode 100644 index f2f884b..0000000 --- a/src/server/mcp_server.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -MCP Server implementation for Imagen4 -""" - -import logging -from typing import Dict, Any, List - -from mcp.server import Server -from mcp.types import Tool, TextContent, ImageContent - -from src.connector import Config -from src.server.models import MCPToolDefinitions -from src.server.handlers import ToolHandlers, sanitize_args_for_logging - -logger = logging.getLogger(__name__) - - -class Imagen4MCPServer: - """Imagen4 MCP server class""" - - def __init__(self, config: Config): - """Initialize server""" - self.config = config - self.server = Server("imagen4-mcp-server") - self.handlers = ToolHandlers(config) - - # Register handlers - self._register_handlers() - - def _register_handlers(self) -> None: - """Register MCP handlers""" - - @self.server.list_tools() - async def handle_list_tools() -> List[Tool]: - """Return list of available tools""" - logger.info("Listing available tools") - return MCPToolDefinitions.get_all_tools() - - @self.server.call_tool() - async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]: - """Handle tool calls""" - # Log tool call safely without exposing sensitive data - safe_args = sanitize_args_for_logging(arguments) - logger.info(f"Tool called: {name} with arguments: {safe_args}") - - if name == "generate_random_seed": - return await self.handlers.handle_generate_random_seed(arguments) - elif name == "regenerate_from_json": - return await self.handlers.handle_regenerate_from_json(arguments) - elif name == "generate_image": - return await self.handlers.handle_generate_image(arguments) - else: - raise ValueError(f"Unknown tool: {name}") - - def get_server(self) -> Server: - """Return MCP server instance""" - return self.server diff --git a/src/server/minimal_handler.py b/src/server/minimal_handler.py deleted file mode 100644 index d3a3bd8..0000000 --- a/src/server/minimal_handler.py +++ /dev/null @@ -1,39 +0,0 @@ -""" -Minimal Task Result Handler for Emergency Use -""" - -import logging -from typing import List, Dict, Any - -from mcp.types import TextContent - -logger = logging.getLogger(__name__) - - -async def minimal_get_task_result(task_manager, task_id: str) -> List[TextContent]: - """Absolutely minimal task result handler""" - try: - logger.info(f"Minimal handler for task: {task_id}") - - # Just return the task status for now - status = task_manager.get_task_status(task_id) - - if status is None: - return [TextContent( - type="text", - text=f"❌ Task '{task_id}' not found." - )] - - return [TextContent( - type="text", - text=f"📋 Task '{task_id}' Status: {status.value}\n" - f"Note: This is a minimal emergency handler.\n" - f"If you see this message, the original handler has a serious issue." - )] - - except Exception as e: - logger.error(f"Even minimal handler failed: {str(e)}", exc_info=True) - return [TextContent( - type="text", - text=f"Complete failure: {str(e)}" - )] diff --git a/src/server/models.py b/src/server/models.py deleted file mode 100644 index ddd7662..0000000 --- a/src/server/models.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -MCP Tool Models and Definitions -""" - -from mcp.types import Tool - - -class MCPToolDefinitions: - """MCP tool definition class""" - - @staticmethod - def get_generate_image_tool() -> Tool: - """Image generation tool definition""" - return Tool( - name="generate_image", - description="Generate 2048x2048 PNG images from text prompts using Google Imagen 4 AI.", - inputSchema={ - "type": "object", - "properties": { - "prompt": { - "type": "string", - "description": "Text prompt for image generation (Korean or English). Maximum 480 tokens allowed.", - "maxLength": 2000 # Approximate character limit for safety - }, - "negative_prompt": { - "type": "string", - "description": "Negative prompt specifying elements not to generate (optional). Maximum 240 tokens allowed.", - "default": "", - "maxLength": 1000 # Approximate character limit for safety - }, - "number_of_images": { - "type": "integer", - "description": "Number of images to generate (1 or 2 only)", - "enum": [1, 2], - "default": 1 - }, - "seed": { - "type": "integer", - "description": "Seed value for reproducible results (required, 0 ~ 4294967295 range)", - "minimum": 0, - "maximum": 4294967295 - }, - "aspect_ratio": { - "type": "string", - "description": "Image aspect ratio", - "enum": ["1:1", "9:16", "16:9", "3:4", "4:3"], - "default": "1:1" - }, - "save_to_file": { - "type": "boolean", - "description": "Whether to save generated images to files (default: true)", - "default": True - }, - "model": { - "type": "string", - "description": "Imagen 4 model to use for generation", - "enum": ["imagen-4.0-generate-001", "imagen-4.0-ultra-generate-001"], - "default": "imagen-4.0-generate-001" - } - }, - "required": ["prompt", "seed"] - } - ) - - @staticmethod - def get_regenerate_from_json_tool() -> Tool: - """Regenerate from JSON tool definition""" - return Tool( - name="regenerate_from_json", - description="Read parameters from JSON file and regenerate images with the same settings.", - inputSchema={ - "type": "object", - "properties": { - "json_file_path": { - "type": "string", - "description": "Path to JSON file containing saved parameters" - }, - "save_to_file": { - "type": "boolean", - "description": "Whether to save regenerated images to files (default: true)", - "default": True - } - }, - "required": ["json_file_path"] - } - ) - - @staticmethod - def get_generate_random_seed_tool() -> Tool: - """Random seed generation tool definition""" - return Tool( - name="generate_random_seed", - description="Generate random seed value for image generation.", - inputSchema={ - "type": "object", - "properties": {}, - "required": [] - } - ) - - @staticmethod - def get_validate_prompt_tool() -> Tool: - """Prompt validation tool definition""" - return Tool( - name="validate_prompt", - description="Validate prompt length and estimate token count before image generation.", - inputSchema={ - "type": "object", - "properties": { - "prompt": { - "type": "string", - "description": "Text prompt to validate" - }, - "negative_prompt": { - "type": "string", - "description": "Negative prompt to validate (optional)", - "default": "" - } - }, - "required": ["prompt"] - } - ) - - @classmethod - def get_all_tools(cls) -> list[Tool]: - """Return all tool definitions""" - return [ - cls.get_generate_image_tool(), - cls.get_regenerate_from_json_tool(), - cls.get_generate_random_seed_tool(), - cls.get_validate_prompt_tool() - ] diff --git a/src/server/safe_result_handler.py b/src/server/safe_result_handler.py deleted file mode 100644 index 2c6e727..0000000 --- a/src/server/safe_result_handler.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -Safe Get Task Result Handler - Debug Version -""" - -import logging -from typing import List, Dict, Any - -from mcp.types import TextContent - -logger = logging.getLogger(__name__) - - -async def safe_get_task_result(task_manager, task_id: str) -> List[TextContent]: - """Minimal safe version of get_task_result without image processing""" - try: - logger.info(f"Safe get_task_result called for: {task_id}") - - # Step 1: Try to get raw result - try: - raw_result = task_manager.worker_pool.get_task_result(task_id) - logger.info(f"Raw result type: {type(raw_result)}") - except Exception as e: - logger.error(f"Failed to get raw result: {str(e)}") - return [TextContent( - type="text", - text=f"Error accessing task data: {str(e)}" - )] - - if not raw_result: - return [TextContent( - type="text", - text=f"❌ Task '{task_id}' not found." - )] - - # Step 2: Check status safely - try: - status = raw_result.status - logger.info(f"Task status: {status}") - except Exception as e: - logger.error(f"Failed to get status: {str(e)}") - return [TextContent( - type="text", - text=f"Error reading task status: {str(e)}" - )] - - # Step 3: Return basic info without touching result.result - try: - duration = raw_result.duration if hasattr(raw_result, 'duration') else None - created = raw_result.created_at.isoformat() if hasattr(raw_result, 'created_at') and raw_result.created_at else "unknown" - - message = f"✅ Task '{task_id}' Status Report:\n" - message += f"Status: {status.value}\n" - message += f"Created: {created}\n" - - if duration: - message += f"Duration: {duration:.2f}s\n" - - if hasattr(raw_result, 'error') and raw_result.error: - message += f"Error: {raw_result.error}\n" - - message += "\nNote: This is a safe diagnostic version." - - return [TextContent( - type="text", - text=message - )] - - except Exception as e: - logger.error(f"Failed to format response: {str(e)}") - return [TextContent( - type="text", - text=f"Task exists but data formatting failed: {str(e)}" - )] - - except Exception as e: - logger.error(f"Critical error in safe_get_task_result: {str(e)}", exc_info=True) - return [TextContent( - type="text", - text=f"Critical error: {str(e)}" - )] diff --git a/src/utils/image_utils.py b/src/utils/image_utils.py deleted file mode 100644 index 1364cb3..0000000 --- a/src/utils/image_utils.py +++ /dev/null @@ -1,107 +0,0 @@ -""" -Image processing utilities for preview generation -""" - -import base64 -import io -import logging -from typing import Optional -from PIL import Image - -logger = logging.getLogger(__name__) - - -def create_preview_image_b64(image_data: bytes, target_size: int = 512, quality: int = 85) -> Optional[str]: - """ - Convert PNG image data to JPEG preview with specified size and return as base64 - - Args: - image_data: Original PNG image data in bytes - target_size: Target size for the preview (default: 512x512) - quality: JPEG quality (1-100, default: 85) - - Returns: - Base64 encoded JPEG image string, or None if conversion fails - """ - try: - # Open image from bytes - with Image.open(io.BytesIO(image_data)) as img: - # Convert to RGB if necessary (PNG might have alpha channel) - if img.mode in ('RGBA', 'LA', 'P'): - # Create white background for transparent images - background = Image.new('RGB', img.size, (255, 255, 255)) - if img.mode == 'P': - img = img.convert('RGBA') - background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None) - img = background - elif img.mode != 'RGB': - img = img.convert('RGB') - - # Resize to target size maintaining aspect ratio - img.thumbnail((target_size, target_size), Image.Resampling.LANCZOS) - - # If image is smaller than target size, pad it to exact size - if img.size != (target_size, target_size): - # Create new image with target size and white background - new_img = Image.new('RGB', (target_size, target_size), (255, 255, 255)) - # Center the resized image - x = (target_size - img.size[0]) // 2 - y = (target_size - img.size[1]) // 2 - new_img.paste(img, (x, y)) - img = new_img - - # Convert to JPEG and encode as base64 - output_buffer = io.BytesIO() - img.save(output_buffer, format='JPEG', quality=quality, optimize=True) - jpeg_data = output_buffer.getvalue() - - # Encode to base64 - base64_data = base64.b64encode(jpeg_data).decode('utf-8') - - logger.info(f"Preview image created: {target_size}x{target_size} JPEG, {len(jpeg_data)} bytes, base64 length: {len(base64_data)}") - return base64_data - - except Exception as e: - logger.error(f"Failed to create preview image: {str(e)}") - return None - - -def validate_image_data(image_data: bytes) -> bool: - """ - Validate if image data is a valid image - - Args: - image_data: Image data in bytes - - Returns: - True if valid image, False otherwise - """ - try: - with Image.open(io.BytesIO(image_data)) as img: - img.verify() # Verify image integrity - return True - except Exception: - return False - - -def get_image_info(image_data: bytes) -> Optional[dict]: - """ - Get image information - - Args: - image_data: Image data in bytes - - Returns: - Dictionary with image info (format, size, mode) or None if invalid - """ - try: - with Image.open(io.BytesIO(image_data)) as img: - return { - 'format': img.format, - 'size': img.size, - 'mode': img.mode, - 'bytes': len(image_data) - } - except Exception as e: - logger.error(f"Failed to get image info: {str(e)}") - return None diff --git a/src/utils/token_utils.py b/src/utils/token_utils.py deleted file mode 100644 index 1f63890..0000000 --- a/src/utils/token_utils.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Token counting utilities for Imagen4 prompts -""" - -import re -import logging -from typing import Optional - -logger = logging.getLogger(__name__) - -# 기본 토큰 추정 상수 -AVERAGE_CHARS_PER_TOKEN = 4 # 영어 기준 평균값 -KOREAN_CHARS_PER_TOKEN = 2 # 한글 기준 평균값 -MAX_PROMPT_TOKENS = 480 # 최대 프롬프트 토큰 수 - - -def estimate_token_count(text: str) -> int: - """ - 텍스트의 토큰 수를 추정합니다. - - 정확한 토큰 계산을 위해서는 실제 토크나이저가 필요하지만, - API 호출 전 빠른 검증을 위해 추정값을 사용합니다. - - Args: - text: 토큰 수를 계산할 텍스트 - - Returns: - int: 추정된 토큰 수 - """ - if not text: - return 0 - - # 텍스트 정리 - text = text.strip() - if not text: - return 0 - - # 한글과 영어 문자 분리 - korean_chars = len(re.findall(r'[가-힣]', text)) - english_chars = len(re.findall(r'[a-zA-Z]', text)) - other_chars = len(text) - korean_chars - english_chars - - # 토큰 수 추정 - korean_tokens = korean_chars / KOREAN_CHARS_PER_TOKEN - english_tokens = english_chars / AVERAGE_CHARS_PER_TOKEN - other_tokens = other_chars / AVERAGE_CHARS_PER_TOKEN - - estimated_tokens = int(korean_tokens + english_tokens + other_tokens) - - # 최소 1토큰은 보장 - return max(1, estimated_tokens) - - -def validate_prompt_length(prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> tuple[bool, int, Optional[str]]: - """ - 프롬프트 길이를 검증합니다. - - Args: - prompt: 검증할 프롬프트 - max_tokens: 최대 허용 토큰 수 - - Returns: - tuple: (유효성, 토큰 수, 오류 메시지) - """ - if not prompt: - return False, 0, "프롬프트가 비어있습니다." - - token_count = estimate_token_count(prompt) - - if token_count > max_tokens: - error_msg = ( - f"프롬프트가 너무 깁니다. " - f"현재: {token_count}토큰, 최대: {max_tokens}토큰. " - f"프롬프트를 {token_count - max_tokens}토큰 줄여주세요." - ) - return False, token_count, error_msg - - return True, token_count, None - - -def truncate_prompt(prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> str: - """ - 프롬프트를 지정된 토큰 수로 자릅니다. - - Args: - prompt: 자를 프롬프트 - max_tokens: 최대 토큰 수 - - Returns: - str: 잘린 프롬프트 - """ - if not prompt: - return "" - - current_tokens = estimate_token_count(prompt) - if current_tokens <= max_tokens: - return prompt - - # 대략적인 비율로 텍스트 자르기 - ratio = max_tokens / current_tokens - target_length = int(len(prompt) * ratio * 0.9) # 여유분 10% - - truncated = prompt[:target_length] - - # 단어/문장 경계에서 자르기 - if len(truncated) < len(prompt): - # 마지막 완전한 단어까지만 유지 - last_space = truncated.rfind(' ') - last_korean = truncated.rfind('다') # 한글 어미 - last_punct = max(truncated.rfind('.'), truncated.rfind(','), truncated.rfind('!')) - - cut_point = max(last_space, last_korean, last_punct) - if cut_point > target_length * 0.8: # 너무 많이 잘리지 않도록 - truncated = truncated[:cut_point] - - # 최종 검증 - final_tokens = estimate_token_count(truncated) - if final_tokens > max_tokens: - # 강제로 문자 단위로 자르기 - chars_per_token = len(truncated) / final_tokens - target_chars = int(max_tokens * chars_per_token * 0.95) - truncated = truncated[:target_chars] - - return truncated.strip() - - -def get_prompt_stats(prompt: str) -> dict: - """ - 프롬프트 통계 정보를 반환합니다. - - Args: - prompt: 분석할 프롬프트 - - Returns: - dict: 프롬프트 통계 - """ - if not prompt: - return { - "character_count": 0, - "estimated_tokens": 0, - "korean_chars": 0, - "english_chars": 0, - "other_chars": 0, - "is_valid": False, - "remaining_tokens": MAX_PROMPT_TOKENS - } - - char_count = len(prompt) - korean_chars = len(re.findall(r'[가-힣]', prompt)) - english_chars = len(re.findall(r'[a-zA-Z]', prompt)) - other_chars = char_count - korean_chars - english_chars - estimated_tokens = estimate_token_count(prompt) - is_valid = estimated_tokens <= MAX_PROMPT_TOKENS - remaining_tokens = MAX_PROMPT_TOKENS - estimated_tokens - - return { - "character_count": char_count, - "estimated_tokens": estimated_tokens, - "korean_chars": korean_chars, - "english_chars": english_chars, - "other_chars": other_chars, - "is_valid": is_valid, - "remaining_tokens": remaining_tokens, - "max_tokens": MAX_PROMPT_TOKENS - } - - -# 실제 토크나이저 사용 시 대체할 수 있는 인터페이스 -class TokenCounter: - """토큰 카운터 인터페이스""" - - def __init__(self, tokenizer_name: Optional[str] = None): - """ - 토큰 카운터 초기화 - - Args: - tokenizer_name: 사용할 토크나이저 이름 (향후 확장용) - """ - self.tokenizer_name = tokenizer_name or "estimate" - logger.info(f"토큰 카운터 초기화: {self.tokenizer_name}") - - def count_tokens(self, text: str) -> int: - """텍스트의 토큰 수 계산""" - if self.tokenizer_name == "estimate": - return estimate_token_count(text) - else: - # 향후 실제 토크나이저 구현 - raise NotImplementedError(f"토크나이저 '{self.tokenizer_name}'는 아직 구현되지 않았습니다.") - - def validate_prompt(self, prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> tuple[bool, int, Optional[str]]: - """프롬프트 검증""" - return validate_prompt_length(prompt, max_tokens) - - def truncate_prompt(self, prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> str: - """프롬프트 자르기""" - return truncate_prompt(prompt, max_tokens) - - -# 전역 토큰 카운터 인스턴스 -default_token_counter = TokenCounter() diff --git a/test_mcp_compatibility.py b/test_mcp_compatibility.py new file mode 100644 index 0000000..b0d1c22 --- /dev/null +++ b/test_mcp_compatibility.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +""" +Test script to verify MCP protocol compatibility +""" + +import sys +import json +import os + +# Add current directory to path +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +def test_json_output(): + """Test that we can output valid JSON without interference""" + # This should work without any encoding issues + test_data = { + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {} + } + + # Output JSON to stdout (this is what MCP protocol expects) + print(json.dumps(test_data)) + +def test_unicode_handling(): + """Test Unicode handling in stderr only""" + # This should go to stderr only, not interfere with stdout JSON + test_unicode = "Test UTF-8: 한글 테스트 ✓" + print(f"[UTF8-TEST] {test_unicode}", file=sys.stderr) + +if __name__ == "__main__": + print("[MCP-TEST] Testing JSON output compatibility...", file=sys.stderr) + test_json_output() + print("[MCP-TEST] JSON output test completed", file=sys.stderr) + + print("[MCP-TEST] Testing Unicode handling...", file=sys.stderr) + test_unicode_handling() + print("[MCP-TEST] Unicode test completed", file=sys.stderr) + + print("[MCP-TEST] All tests passed!", file=sys.stderr) diff --git a/tests/test_connector.py b/tests/test_connector.py deleted file mode 100644 index f95dcba..0000000 --- a/tests/test_connector.py +++ /dev/null @@ -1,178 +0,0 @@ -""" -Tests for Imagen4 Connector -""" - -import pytest -import asyncio -from unittest.mock import Mock, patch - -from src.connector import Config, Imagen4Client -from src.connector.imagen4_client import ImageGenerationRequest - - -class TestConfig: - """Config 클래스 테스트""" - - def test_config_creation(self): - """설정 생성 테스트""" - config = Config( - project_id="test-project", - location="us-central1", - output_path="./test_images" - ) - - assert config.project_id == "test-project" - assert config.location == "us-central1" - assert config.output_path == "./test_images" - - def test_config_validation(self): - """설정 유효성 검사 테스트""" - # 유효한 설정 - valid_config = Config(project_id="test-project") - assert valid_config.validate() is True - - # 무효한 설정 - invalid_config = Config(project_id="") - assert invalid_config.validate() is False - - @patch.dict('os.environ', { - 'GOOGLE_CLOUD_PROJECT_ID': 'test-project', - 'GOOGLE_CLOUD_LOCATION': 'us-west1', - 'GENERATED_IMAGES_PATH': './custom_path' - }) - def test_config_from_env(self): - """환경 변수에서 설정 로드 테스트""" - config = Config.from_env() - - assert config.project_id == "test-project" - assert config.location == "us-west1" - assert config.output_path == "./custom_path" - - @patch.dict('os.environ', {}, clear=True) - def test_config_from_env_missing_project_id(self): - """필수 환경 변수 누락 테스트""" - with pytest.raises(ValueError, match="GOOGLE_CLOUD_PROJECT_ID environment variable is required"): - Config.from_env() - - -class TestImageGenerationRequest: - """ImageGenerationRequest 테스트""" - - def test_valid_request(self): - """유효한 요청 테스트""" - request = ImageGenerationRequest( - prompt="A beautiful landscape", - seed=12345 - ) - - # 유효성 검사가 예외 없이 완료되어야 함 - request.validate() - - def test_invalid_prompt(self): - """무효한 프롬프트 테스트""" - request = ImageGenerationRequest( - prompt="", - seed=12345 - ) - - with pytest.raises(ValueError, match="Prompt is required"): - request.validate() - - def test_invalid_seed(self): - """무효한 시드값 테스트""" - request = ImageGenerationRequest( - prompt="Test prompt", - seed=-1 - ) - - with pytest.raises(ValueError, match="Seed value must be an integer between 0 and 4294967295"): - request.validate() - - def test_invalid_number_of_images(self): - """무효한 이미지 개수 테스트""" - request = ImageGenerationRequest( - prompt="Test prompt", - seed=12345, - number_of_images=3 - ) - - with pytest.raises(ValueError, match="Number of images must be 1 or 2 only"): - request.validate() - - -class TestImagen4Client: - """Imagen4Client 테스트""" - - def test_client_initialization(self): - """클라이언트 초기화 테스트""" - config = Config(project_id="test-project") - - with patch('src.connector.imagen4_client.genai.Client'): - client = Imagen4Client(config) - assert client.config == config - - def test_client_invalid_config(self): - """무효한 설정으로 클라이언트 초기화 테스트""" - invalid_config = Config(project_id="") - - with pytest.raises(ValueError, match="Invalid configuration"): - Imagen4Client(invalid_config) - - @pytest.mark.asyncio - async def test_generate_image_success(self): - """이미지 생성 성공 테스트""" - config = Config(project_id="test-project") - - # Create mock response - mock_image_data = b"fake_image_data" - mock_response = Mock() - mock_response.generated_images = [Mock()] - mock_response.generated_images[0].image = Mock() - mock_response.generated_images[0].image.image_bytes = mock_image_data - - with patch('src.connector.imagen4_client.genai.Client') as mock_client_class: - mock_client = Mock() - mock_client_class.return_value = mock_client - mock_client.models.generate_images = Mock(return_value=mock_response) - - client = Imagen4Client(config) - - request = ImageGenerationRequest( - prompt="Test prompt", - seed=12345 - ) - - with patch('asyncio.to_thread', return_value=mock_response): - response = await client.generate_image(request) - - assert response.success is True - assert len(response.images_data) == 1 - assert response.images_data[0] == mock_image_data - - @pytest.mark.asyncio - async def test_generate_image_failure(self): - """Image generation failure test""" - config = Config(project_id="test-project") - - with patch('src.connector.imagen4_client.genai.Client') as mock_client_class: - mock_client = Mock() - mock_client_class.return_value = mock_client - - client = Imagen4Client(config) - - request = ImageGenerationRequest( - prompt="Test prompt", - seed=12345 - ) - - # Simulate API call failure - with patch('asyncio.to_thread', side_effect=Exception("API Error")): - response = await client.generate_image(request) - - assert response.success is False - assert "API Error" in response.error_message - assert len(response.images_data) == 0 - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/test_server.py b/tests/test_server.py deleted file mode 100644 index 2c488ad..0000000 --- a/tests/test_server.py +++ /dev/null @@ -1,136 +0,0 @@ -""" -Tests for Imagen4 Server -""" - -import pytest -from unittest.mock import Mock, AsyncMock, patch - -from src.connector import Config -from src.server import Imagen4MCPServer, ToolHandlers, MCPToolDefinitions -from src.server.models import MCPToolDefinitions - - -class TestMCPToolDefinitions: - """MCPToolDefinitions tests""" - - def test_get_generate_image_tool(self): - """Image generation tool definition test""" - tool = MCPToolDefinitions.get_generate_image_tool() - - assert tool.name == "generate_image" - assert "Google Imagen 4" in tool.description - assert "prompt" in tool.inputSchema["properties"] - assert "seed" in tool.inputSchema["required"] - - def test_get_regenerate_from_json_tool(self): - """JSON regeneration tool definition test""" - tool = MCPToolDefinitions.get_regenerate_from_json_tool() - - assert tool.name == "regenerate_from_json" - assert "JSON file" in tool.description - assert "json_file_path" in tool.inputSchema["properties"] - - def test_get_generate_random_seed_tool(self): - """Random seed generation tool definition test""" - tool = MCPToolDefinitions.get_generate_random_seed_tool() - - assert tool.name == "generate_random_seed" - assert "random seed" in tool.description - - def test_get_all_tools(self): - """All tools return test""" - tools = MCPToolDefinitions.get_all_tools() - - assert len(tools) == 3 - tool_names = [tool.name for tool in tools] - assert "generate_image" in tool_names - assert "regenerate_from_json" in tool_names - assert "generate_random_seed" in tool_names - - -class TestToolHandlers: - """ToolHandlers tests""" - - @pytest.fixture - def config(self): - """Test configuration""" - return Config(project_id="test-project") - - @pytest.fixture - def handlers(self, config): - """Test handlers""" - with patch('src.server.handlers.Imagen4Client'): - return ToolHandlers(config) - - @pytest.mark.asyncio - async def test_handle_generate_random_seed(self, handlers): - """Random seed generation handler test""" - result = await handlers.handle_generate_random_seed({}) - - assert len(result) == 1 - assert result[0].type == "text" - assert "Generated random seed:" in result[0].text - - @pytest.mark.asyncio - async def test_handle_generate_image_missing_prompt(self, handlers): - """Missing prompt test""" - arguments = {"seed": 12345} - - result = await handlers.handle_generate_image(arguments) - - assert len(result) == 1 - assert result[0].type == "text" - assert "Prompt is required" in result[0].text - - @pytest.mark.asyncio - async def test_handle_generate_image_missing_seed(self, handlers): - """Missing seed value test""" - arguments = {"prompt": "Test prompt"} - - result = await handlers.handle_generate_image(arguments) - - assert len(result) == 1 - assert result[0].type == "text" - assert "Seed value is required" in result[0].text - - @pytest.mark.asyncio - async def test_handle_regenerate_from_json_missing_path(self, handlers): - """Missing JSON file path test""" - arguments = {} - - result = await handlers.handle_regenerate_from_json(arguments) - - assert len(result) == 1 - assert result[0].type == "text" - assert "JSON file path is required" in result[0].text - - -class TestImagen4MCPServer: - """Imagen4MCPServer tests""" - - @pytest.fixture - def config(self): - """Test configuration""" - return Config(project_id="test-project") - - def test_server_initialization(self, config): - """Server initialization test""" - with patch('src.server.mcp_server.ToolHandlers'): - server = Imagen4MCPServer(config) - - assert server.config == config - assert server.server is not None - assert server.handlers is not None - - def test_get_server(self, config): - """Server instance return test""" - with patch('src.server.mcp_server.ToolHandlers'): - mcp_server = Imagen4MCPServer(config) - server = mcp_server.get_server() - - assert server is not None - assert server.name == "imagen4-mcp-server" - - -if __name__ == "__main__": - pytest.main([__file__])