remove unused source files
This commit is contained in:
48
README.md
48
README.md
@@ -2,7 +2,30 @@
|
||||
|
||||
Google Imagen 4를 사용한 AI 이미지 생성 MCP 서버입니다. 하나의 `main.py` 파일로 모든 기능이 통합되어 있으며, **한글 프롬프트를 완벽 지원**합니다.
|
||||
|
||||
## 🚀 주요 기능
|
||||
## 🚨 주의사항 - MCP 프로토콜 호환성
|
||||
|
||||
> **중요**: 이 서버는 MCP (Model Context Protocol)를 사용합니다. MCP는 stdout을 통해 JSON 메시지를 주고받으므로, 서버 코드에서 **절대로 stdout에 직접 출력하면 안됩니다**.
|
||||
|
||||
### ❌ 금지사항
|
||||
```python
|
||||
print("message") # ❌ MCP JSON 프로토콜 방해
|
||||
print("message", file=sys.stdout) # ❌ MCP JSON 프로토콜 방해
|
||||
sys.stdout.write("message") # ❌ MCP JSON 프로토콜 방해
|
||||
```
|
||||
|
||||
### ✅ 권장사항
|
||||
```python
|
||||
print("message", file=sys.stderr) # ✅ 로그용 출력
|
||||
logging.info("message") # ✅ 로깅 시스템 사용
|
||||
logger.error("message") # ✅ 로깅 시스템 사용
|
||||
```
|
||||
|
||||
### 최신 수정사항 (2025-08-26)
|
||||
- **UTF-8 테스트 출력 비활성화**: `[UTF8-TEST]` 메시지가 MCP JSON 파싱을 방해하던 문제 수정
|
||||
- **MCP 안전 실행**: `run_mcp_safe.py` 및 `run_mcp_safe.bat` 추가
|
||||
- **호환성 테스트**: `test_mcp_compatibility.py` 추가
|
||||
|
||||
**개발 시 반드시 `CLAUDE.md` 파일의 개발 가이드라인을 참조하세요!**
|
||||
|
||||
- **고품질 이미지 생성**: 2048x2048 PNG 이미지 생성
|
||||
- **미리보기 이미지**: 512x512 JPEG 미리보기를 base64로 제공
|
||||
@@ -58,16 +81,16 @@ OUTPUT_PATH=./generated_images
|
||||
|
||||
## 🎮 사용법
|
||||
|
||||
### 서버 실행
|
||||
### 서버 실행 (업데이트됨!)
|
||||
|
||||
#### Windows (유니코드 지원)
|
||||
#### MCP 안전 실행 (추천)
|
||||
```bash
|
||||
run_unicode.bat
|
||||
run_mcp_safe.bat
|
||||
```
|
||||
|
||||
#### 기본 실행
|
||||
또는 Python으로 직접:
|
||||
```bash
|
||||
run.bat
|
||||
python run_mcp_safe.py
|
||||
```
|
||||
|
||||
#### 직접 실행 (유니코드 환경 설정)
|
||||
@@ -85,24 +108,28 @@ python main.py
|
||||
|
||||
### Claude Desktop 설정
|
||||
|
||||
`claude_desktop_config.json` 내용을 Claude Desktop 설정에 추가:
|
||||
Claude Desktop 설정에 다음 내용을 추가하거나, `claude_desktop_config_mcp_safe.json` 파일을 사용하세요:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"imagen4": {
|
||||
"command": "python",
|
||||
"args": ["main.py"],
|
||||
"args": ["run_mcp_safe.py"],
|
||||
"cwd": "D:\\Project\\imagen4",
|
||||
"env": {
|
||||
"PYTHONPATH": "D:\\Project\\imagen4",
|
||||
"PYTHONIOENCODING": "utf-8",
|
||||
"PYTHONUTF8": "1"
|
||||
"PYTHONUTF8": "1",
|
||||
"LC_ALL": "C.UTF-8"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> **중요**: MCP 호환성을 위해 `run_mcp_safe.py`를 사용하세요!
|
||||
|
||||
## 🛠️ 사용 가능한 도구
|
||||
|
||||
### 1. generate_image
|
||||
@@ -195,7 +222,8 @@ def ensure_unicode_string(value):
|
||||
한글 프롬프트 지원을 테스트하려면:
|
||||
|
||||
```bash
|
||||
python test_main.py
|
||||
python test_mcp_compatibility.py # MCP 호환성 테스트 (신규)
|
||||
python test_main.py # 기본 기능 테스트
|
||||
```
|
||||
|
||||
이 스크립트는 다음을 확인합니다:
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
"mcpServers": {
|
||||
"imagen4": {
|
||||
"command": "python",
|
||||
"args": ["main.py"],
|
||||
"args": ["run_mcp_safe.py"],
|
||||
"cwd": "D:\\Project\\imagen4",
|
||||
"env": {
|
||||
"PYTHONPATH": "D:\\Project\\imagen4"
|
||||
"PYTHONPATH": "D:\\Project\\imagen4",
|
||||
"PYTHONIOENCODING": "utf-8",
|
||||
"PYTHONUTF8": "1",
|
||||
"LC_ALL": "C.UTF-8"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
27
main.py
27
main.py
@@ -74,12 +74,15 @@ if sys.platform.startswith('win'):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Verify UTF-8 setup
|
||||
# Verify UTF-8 setup (disabled for MCP compatibility)
|
||||
# Note: UTF-8 test output disabled to prevent MCP protocol interference
|
||||
try:
|
||||
test_unicode = "Test UTF-8: 한글 테스트 ✓"
|
||||
print(f"[UTF8-TEST] {test_unicode}")
|
||||
# print(f"[UTF8-TEST] {test_unicode}") # Disabled: interferes with MCP JSON protocol
|
||||
except UnicodeEncodeError as e:
|
||||
print(f"[UTF8-ERROR] Unicode test failed: {e}")
|
||||
# Only log errors to stderr, not stdout
|
||||
import sys
|
||||
print(f"[UTF8-ERROR] Unicode test failed: {e}", file=sys.stderr)
|
||||
|
||||
# ==================== Regular Imports (after UTF-8 setup) ====================
|
||||
|
||||
@@ -97,7 +100,9 @@ try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except ImportError:
|
||||
print("Warning: python-dotenv not installed", file=sys.stderr)
|
||||
# Use logger for dotenv warning instead of direct print
|
||||
import logging
|
||||
logging.getLogger("imagen4-mcp-server").warning("python-dotenv not installed")
|
||||
|
||||
# Add current directory to PYTHONPATH
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -111,8 +116,11 @@ try:
|
||||
from mcp.server import Server
|
||||
from mcp.types import Tool, TextContent
|
||||
except ImportError as e:
|
||||
print(f"Error importing MCP: {e}", file=sys.stderr)
|
||||
print("Please install required packages: pip install mcp", file=sys.stderr)
|
||||
# Use logger for MCP import errors instead of direct print
|
||||
import logging
|
||||
logger = logging.getLogger("imagen4-mcp-server")
|
||||
logger.error(f"Error importing MCP: {e}")
|
||||
logger.error("Please install required packages: pip install mcp")
|
||||
sys.exit(1)
|
||||
|
||||
# Connector imports
|
||||
@@ -124,8 +132,11 @@ from src.connector.utils import save_generated_images
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError as e:
|
||||
print(f"Error importing Pillow: {e}", file=sys.stderr)
|
||||
print("Please install Pillow: pip install Pillow", file=sys.stderr)
|
||||
# Use logger for Pillow import errors instead of direct print
|
||||
import logging
|
||||
logger = logging.getLogger("imagen4-mcp-server")
|
||||
logger.error(f"Error importing Pillow: {e}")
|
||||
logger.error("Please install Pillow: pip install Pillow")
|
||||
sys.exit(1)
|
||||
|
||||
# ==================== Unicode-Safe Logging Setup ====================
|
||||
|
||||
56
run.bat
56
run.bat
@@ -1,56 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 > nul 2>&1
|
||||
echo Starting Imagen4 MCP Server with Preview Image Support (Unicode Enabled)...
|
||||
echo Features: 512x512 JPEG preview images, base64 encoding, Unicode logging support
|
||||
echo.
|
||||
|
||||
REM Set UTF-8 environment for Python
|
||||
set PYTHONIOENCODING=utf-8
|
||||
set PYTHONUTF8=1
|
||||
|
||||
REM Check if virtual environment exists
|
||||
if not exist "venv\Scripts\activate.bat" (
|
||||
echo Creating virtual environment...
|
||||
python -m venv venv
|
||||
if errorlevel 1 (
|
||||
echo Failed to create virtual environment
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
|
||||
REM Activate virtual environment
|
||||
echo Activating virtual environment...
|
||||
call venv\Scripts\activate.bat
|
||||
|
||||
REM Check if requirements are installed
|
||||
python -c "import PIL" 2>nul
|
||||
if errorlevel 1 (
|
||||
echo Installing requirements...
|
||||
pip install -r requirements.txt
|
||||
if errorlevel 1 (
|
||||
echo Failed to install requirements
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
|
||||
REM Check if .env file exists
|
||||
if not exist ".env" (
|
||||
echo Warning: .env file not found
|
||||
echo Please copy .env.example to .env and configure your API credentials
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
REM Run server with UTF-8 support
|
||||
echo Starting MCP server with Unicode support...
|
||||
echo 한글 프롬프트 지원이 활성화되었습니다.
|
||||
python main.py
|
||||
|
||||
REM Keep window open if there's an error
|
||||
if errorlevel 1 (
|
||||
echo.
|
||||
echo Server exited with error
|
||||
pause
|
||||
)
|
||||
@@ -1,28 +0,0 @@
|
||||
@echo off
|
||||
echo Starting Imagen4 MCP Server with forced UTF-8 mode...
|
||||
echo.
|
||||
|
||||
REM Set console code page to UTF-8
|
||||
chcp 65001 >nul 2>&1
|
||||
|
||||
REM Set all UTF-8 environment variables
|
||||
set PYTHONIOENCODING=utf-8
|
||||
set PYTHONUTF8=1
|
||||
set LC_ALL=C.UTF-8
|
||||
set LANG=C.UTF-8
|
||||
|
||||
REM Set console properties
|
||||
title Imagen4 MCP Server (UTF-8 Forced)
|
||||
|
||||
echo Environment Variables Set:
|
||||
echo PYTHONIOENCODING=%PYTHONIOENCODING%
|
||||
echo PYTHONUTF8=%PYTHONUTF8%
|
||||
echo LC_ALL=%LC_ALL%
|
||||
echo.
|
||||
|
||||
REM Run Python with explicit UTF-8 mode flag
|
||||
python -X utf8 main.py
|
||||
|
||||
echo.
|
||||
echo Server stopped. Press any key to exit...
|
||||
pause >nul
|
||||
22
run_mcp_safe.bat
Normal file
22
run_mcp_safe.bat
Normal file
@@ -0,0 +1,22 @@
|
||||
@echo off
|
||||
REM MCP-Safe UTF-8 runner for imagen4 server
|
||||
REM This prevents stdout interference with MCP JSON protocol
|
||||
|
||||
echo [STARTUP] Starting Imagen4 MCP Server with safe UTF-8 handling...
|
||||
|
||||
REM Change to script directory
|
||||
cd /d "%~dp0"
|
||||
|
||||
REM Set UTF-8 code page
|
||||
chcp 65001 >nul 2>&1
|
||||
|
||||
REM Set UTF-8 environment variables
|
||||
set PYTHONIOENCODING=utf-8
|
||||
set PYTHONUTF8=1
|
||||
set LC_ALL=C.UTF-8
|
||||
|
||||
REM Run the safe MCP runner
|
||||
python run_mcp_safe.py
|
||||
|
||||
echo [SHUTDOWN] Imagen4 MCP Server stopped.
|
||||
pause
|
||||
66
run_mcp_safe.py
Normal file
66
run_mcp_safe.py
Normal file
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MCP Protocol Compatible UTF-8 Runner for imagen4
|
||||
|
||||
This script ensures proper UTF-8 handling without interfering with MCP protocol
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import locale
|
||||
|
||||
def setup_utf8_environment():
|
||||
"""Setup UTF-8 environment variables without stdout interference"""
|
||||
# Set UTF-8 environment variables
|
||||
env = os.environ.copy()
|
||||
env['PYTHONIOENCODING'] = 'utf-8'
|
||||
env['PYTHONUTF8'] = '1'
|
||||
env['LC_ALL'] = 'C.UTF-8'
|
||||
|
||||
# Windows-specific setup
|
||||
if sys.platform.startswith('win'):
|
||||
# Set console code page to UTF-8 (suppress output)
|
||||
try:
|
||||
os.system('chcp 65001 >nul 2>&1')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Set locale
|
||||
try:
|
||||
locale.setlocale(locale.LC_ALL, 'C.UTF-8')
|
||||
except locale.Error:
|
||||
try:
|
||||
locale.setlocale(locale.LC_ALL, '')
|
||||
except locale.Error:
|
||||
pass
|
||||
|
||||
return env
|
||||
|
||||
def main():
|
||||
"""Main function - run imagen4 MCP server with proper UTF-8 setup"""
|
||||
# Setup UTF-8 environment
|
||||
env = setup_utf8_environment()
|
||||
|
||||
# Get the directory of this script
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
main_py = os.path.join(script_dir, 'main.py')
|
||||
|
||||
# Run main.py with proper environment
|
||||
try:
|
||||
# Use subprocess to run with clean environment
|
||||
result = subprocess.run([
|
||||
sys.executable, main_py
|
||||
], env=env, cwd=script_dir)
|
||||
|
||||
return result.returncode
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Server stopped by user", file=sys.stderr)
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(f"Error running server: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
16
run_utf8.bat
16
run_utf8.bat
@@ -1,16 +0,0 @@
|
||||
@echo off
|
||||
REM Force UTF-8 encoding for Python and console
|
||||
chcp 65001 >nul 2>&1
|
||||
|
||||
REM Set Python UTF-8 environment variables
|
||||
set PYTHONIOENCODING=utf-8
|
||||
set PYTHONUTF8=1
|
||||
set LC_ALL=C.UTF-8
|
||||
|
||||
REM Set console properties
|
||||
title Imagen4 MCP Server (UTF-8)
|
||||
|
||||
REM Run the Python script
|
||||
python main.py
|
||||
|
||||
pause
|
||||
@@ -1,24 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Emergency UTF-8 override script
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Force UTF-8 mode at Python startup
|
||||
if not hasattr(sys, '_utf8_mode') or not sys._utf8_mode:
|
||||
os.environ['PYTHONUTF8'] = '1'
|
||||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||||
|
||||
# Restart Python with UTF-8 mode
|
||||
import subprocess
|
||||
result = subprocess.run([sys.executable, '-X', 'utf8', __file__] + sys.argv[1:])
|
||||
sys.exit(result.returncode)
|
||||
|
||||
# Now run the actual main.py
|
||||
if __name__ == "__main__":
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
spec = importlib.util.spec_from_file_location("main", "main.py")
|
||||
main_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["main"] = main_module
|
||||
spec.loader.exec_module(main_module)
|
||||
@@ -1,9 +0,0 @@
|
||||
"""
|
||||
Async Task Management Module for MCP Server
|
||||
"""
|
||||
|
||||
from .models import TaskStatus, TaskResult, generate_task_id
|
||||
from .task_manager import TaskManager
|
||||
from .worker_pool import WorkerPool
|
||||
|
||||
__all__ = ['TaskManager', 'TaskStatus', 'TaskResult', 'WorkerPool']
|
||||
@@ -1,48 +0,0 @@
|
||||
"""
|
||||
Task Status and Result Models
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Dict
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Task execution status"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
"""Task execution result"""
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
result: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def duration(self) -> Optional[float]:
|
||||
"""Calculate task execution duration in seconds"""
|
||||
if self.started_at and self.completed_at:
|
||||
return (self.completed_at - self.started_at).total_seconds()
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_finished(self) -> bool:
|
||||
"""Check if task is finished (completed, failed, or cancelled)"""
|
||||
return self.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]
|
||||
|
||||
|
||||
def generate_task_id() -> str:
|
||||
"""Generate unique task ID"""
|
||||
return str(uuid.uuid4())
|
||||
@@ -1,324 +0,0 @@
|
||||
"""
|
||||
Task Manager for MCP Server
|
||||
Manages background task execution with status tracking and result retrieval
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, Dict, List
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from .models import TaskResult, TaskStatus, generate_task_id
|
||||
from .worker_pool import WorkerPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""Main task manager for MCP server"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_workers: Optional[int] = None,
|
||||
use_process_pool: bool = False,
|
||||
default_timeout: float = 600.0, # 10 minutes
|
||||
result_retention_hours: int = 24, # Keep results for 24 hours
|
||||
max_retained_results: int = 200
|
||||
):
|
||||
"""
|
||||
Initialize task manager
|
||||
|
||||
Args:
|
||||
max_workers: Maximum number of worker threads/processes
|
||||
use_process_pool: Use process pool instead of thread pool
|
||||
default_timeout: Default timeout for tasks in seconds
|
||||
result_retention_hours: How long to keep finished task results
|
||||
max_retained_results: Maximum number of results to retain
|
||||
"""
|
||||
self.worker_pool = WorkerPool(
|
||||
max_workers=max_workers,
|
||||
use_process_pool=use_process_pool,
|
||||
task_timeout=default_timeout
|
||||
)
|
||||
|
||||
self.result_retention_hours = result_retention_hours
|
||||
self.max_retained_results = max_retained_results
|
||||
|
||||
# Start cleanup task
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
|
||||
logger.info(f"TaskManager initialized with {self.worker_pool.max_workers} workers")
|
||||
|
||||
async def submit_task(
|
||||
self,
|
||||
func: Callable,
|
||||
*args,
|
||||
task_id: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Submit a task for background execution
|
||||
|
||||
Args:
|
||||
func: Function to execute
|
||||
*args: Function arguments
|
||||
task_id: Optional custom task ID
|
||||
timeout: Task timeout in seconds
|
||||
metadata: Additional metadata for the task
|
||||
**kwargs: Function keyword arguments
|
||||
|
||||
Returns:
|
||||
str: Task ID for tracking progress
|
||||
"""
|
||||
task_id = await self.worker_pool.submit_task(
|
||||
func, *args, task_id=task_id, timeout=timeout, **kwargs
|
||||
)
|
||||
|
||||
# Add additional metadata if provided
|
||||
if metadata:
|
||||
result = self.worker_pool.get_task_result(task_id)
|
||||
if result:
|
||||
result.metadata.update(metadata)
|
||||
|
||||
return task_id
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[TaskStatus]:
|
||||
"""Get task status by ID"""
|
||||
result = self.worker_pool.get_task_result(task_id)
|
||||
return result.status if result else None
|
||||
|
||||
def get_task_result(self, task_id: str) -> Optional[TaskResult]:
|
||||
"""Get complete task result by ID with circular reference protection"""
|
||||
try:
|
||||
raw_result = self.worker_pool.get_task_result(task_id)
|
||||
if not raw_result:
|
||||
return None
|
||||
|
||||
# Create a safe copy to avoid circular references
|
||||
safe_result = TaskResult(
|
||||
task_id=raw_result.task_id,
|
||||
status=raw_result.status,
|
||||
result=None, # We'll set this safely
|
||||
error=raw_result.error,
|
||||
created_at=raw_result.created_at,
|
||||
started_at=raw_result.started_at,
|
||||
completed_at=raw_result.completed_at,
|
||||
metadata=raw_result.metadata.copy() if raw_result.metadata else {}
|
||||
)
|
||||
|
||||
# Safely copy the result data
|
||||
if raw_result.result is not None:
|
||||
try:
|
||||
# If it's a dict, create a clean copy
|
||||
if isinstance(raw_result.result, dict):
|
||||
safe_result.result = {
|
||||
'success': raw_result.result.get('success', False),
|
||||
'error': raw_result.result.get('error'),
|
||||
'images_b64': raw_result.result.get('images_b64', []),
|
||||
'saved_files': raw_result.result.get('saved_files', []),
|
||||
'request': raw_result.result.get('request', {}),
|
||||
'image_count': raw_result.result.get('image_count', 0)
|
||||
}
|
||||
else:
|
||||
# For non-dict results, just copy directly
|
||||
safe_result.result = raw_result.result
|
||||
except Exception as copy_error:
|
||||
logger.error(f"Error creating safe copy of result: {str(copy_error)}")
|
||||
safe_result.result = {
|
||||
'success': False,
|
||||
'error': f'Result data corrupted: {str(copy_error)}'
|
||||
}
|
||||
|
||||
return safe_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_task_result: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def get_task_progress(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get task progress information
|
||||
|
||||
Returns:
|
||||
Dict with task status, timing info, and metadata
|
||||
"""
|
||||
result = self.get_task_result(task_id)
|
||||
if not result:
|
||||
return {'error': 'Task not found'}
|
||||
|
||||
progress = {
|
||||
'task_id': result.task_id,
|
||||
'status': result.status.value,
|
||||
'created_at': result.created_at.isoformat(),
|
||||
'metadata': result.metadata
|
||||
}
|
||||
|
||||
if result.started_at:
|
||||
progress['started_at'] = result.started_at.isoformat()
|
||||
|
||||
if result.status == TaskStatus.RUNNING:
|
||||
elapsed = (datetime.now() - result.started_at).total_seconds()
|
||||
progress['elapsed_seconds'] = elapsed
|
||||
|
||||
if result.completed_at:
|
||||
progress['completed_at'] = result.completed_at.isoformat()
|
||||
progress['duration_seconds'] = result.duration
|
||||
|
||||
if result.error:
|
||||
progress['error'] = result.error
|
||||
|
||||
return progress
|
||||
|
||||
def list_tasks(
|
||||
self,
|
||||
status_filter: Optional[TaskStatus] = None,
|
||||
limit: int = 50
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List tasks with optional filtering
|
||||
|
||||
Args:
|
||||
status_filter: Filter by task status
|
||||
limit: Maximum number of tasks to return
|
||||
|
||||
Returns:
|
||||
List of task progress dictionaries
|
||||
"""
|
||||
all_results = self.worker_pool.get_all_results()
|
||||
|
||||
# Filter by status if requested
|
||||
if status_filter:
|
||||
filtered_results = {
|
||||
k: v for k, v in all_results.items()
|
||||
if v.status == status_filter
|
||||
}
|
||||
else:
|
||||
filtered_results = all_results
|
||||
|
||||
# Sort by creation time (most recent first)
|
||||
sorted_items = sorted(
|
||||
filtered_results.items(),
|
||||
key=lambda x: x[1].created_at,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Apply limit and return progress info
|
||||
return [
|
||||
self.get_task_progress(task_id)
|
||||
for task_id, _ in sorted_items[:limit]
|
||||
]
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""Cancel a running task"""
|
||||
return self.worker_pool.cancel_task(task_id)
|
||||
|
||||
async def wait_for_task(
|
||||
self,
|
||||
task_id: str,
|
||||
timeout: Optional[float] = None,
|
||||
poll_interval: float = 1.0
|
||||
) -> TaskResult:
|
||||
"""
|
||||
Wait for a task to complete
|
||||
|
||||
Args:
|
||||
task_id: Task ID to wait for
|
||||
timeout: Maximum time to wait in seconds
|
||||
poll_interval: How often to check task status
|
||||
|
||||
Returns:
|
||||
TaskResult: Final task result
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If timeout is reached
|
||||
ValueError: If task doesn't exist
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
while True:
|
||||
result = self.get_task_result(task_id)
|
||||
if not result:
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
|
||||
if result.is_finished:
|
||||
return result
|
||||
|
||||
# Check timeout
|
||||
if timeout:
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
if elapsed >= timeout:
|
||||
raise asyncio.TimeoutError(f"Timeout waiting for task {task_id}")
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
async def _periodic_cleanup(self) -> None:
|
||||
"""Periodic cleanup of old task results"""
|
||||
while True:
|
||||
try:
|
||||
# Wait 1 hour between cleanups
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
# Clean up old results
|
||||
cutoff_time = datetime.now() - timedelta(hours=self.result_retention_hours)
|
||||
all_results = self.worker_pool.get_all_results()
|
||||
|
||||
to_remove = []
|
||||
for task_id, result in all_results.items():
|
||||
if (result.is_finished and
|
||||
result.completed_at and
|
||||
result.completed_at < cutoff_time):
|
||||
to_remove.append(task_id)
|
||||
|
||||
# Remove old results
|
||||
for task_id in to_remove:
|
||||
if task_id in self.worker_pool._results:
|
||||
del self.worker_pool._results[task_id]
|
||||
|
||||
if to_remove:
|
||||
logger.info(f"Cleaned up {len(to_remove)} old task results")
|
||||
|
||||
# Also clean up excess results
|
||||
self.worker_pool.cleanup_finished_tasks(self.max_retained_results)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in periodic cleanup: {e}")
|
||||
|
||||
async def shutdown(self, wait: bool = True) -> None:
|
||||
"""Shutdown task manager"""
|
||||
logger.info("Shutting down task manager...")
|
||||
|
||||
# Cancel cleanup task
|
||||
if not self._cleanup_task.done():
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Shutdown worker pool
|
||||
await self.worker_pool.shutdown(wait=wait)
|
||||
|
||||
logger.info("Task manager shutdown complete")
|
||||
|
||||
@property
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
"""Get task manager statistics"""
|
||||
all_results = self.worker_pool.get_all_results()
|
||||
|
||||
status_counts = {}
|
||||
for status in TaskStatus:
|
||||
status_counts[status.value] = sum(
|
||||
1 for r in all_results.values() if r.status == status
|
||||
)
|
||||
|
||||
return {
|
||||
'total_tasks': len(all_results),
|
||||
'active_tasks': self.worker_pool.active_task_count,
|
||||
'max_workers': self.worker_pool.max_workers,
|
||||
'worker_type': 'process' if self.worker_pool.use_process_pool else 'thread',
|
||||
'status_counts': status_counts
|
||||
}
|
||||
@@ -1,258 +0,0 @@
|
||||
"""
|
||||
Worker Pool for Background Task Execution
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import multiprocessing
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
from .models import TaskResult, TaskStatus, generate_task_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerPool:
|
||||
"""Worker pool for executing background tasks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_workers: Optional[int] = None,
|
||||
use_process_pool: bool = False,
|
||||
task_timeout: float = 600.0 # 10 minutes default
|
||||
):
|
||||
"""
|
||||
Initialize worker pool
|
||||
|
||||
Args:
|
||||
max_workers: Maximum number of worker threads/processes
|
||||
use_process_pool: Use ProcessPoolExecutor instead of ThreadPoolExecutor
|
||||
task_timeout: Default timeout for tasks in seconds
|
||||
"""
|
||||
self.max_workers = max_workers or min(32, (multiprocessing.cpu_count() or 1) + 4)
|
||||
self.use_process_pool = use_process_pool
|
||||
self.task_timeout = task_timeout
|
||||
|
||||
# Initialize executor
|
||||
if use_process_pool:
|
||||
self.executor = ProcessPoolExecutor(max_workers=self.max_workers)
|
||||
logger.info(f"Initialized ProcessPoolExecutor with {self.max_workers} workers")
|
||||
else:
|
||||
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
||||
logger.info(f"Initialized ThreadPoolExecutor with {self.max_workers} workers")
|
||||
|
||||
# Track running tasks
|
||||
self._running_tasks: dict[str, asyncio.Task] = {}
|
||||
self._results: dict[str, TaskResult] = {}
|
||||
|
||||
async def submit_task(
|
||||
self,
|
||||
func: Callable,
|
||||
*args,
|
||||
task_id: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Submit a task for background execution
|
||||
|
||||
Args:
|
||||
func: Function to execute
|
||||
*args: Function arguments
|
||||
task_id: Optional custom task ID
|
||||
timeout: Task timeout in seconds
|
||||
**kwargs: Function keyword arguments
|
||||
|
||||
Returns:
|
||||
str: Task ID for tracking
|
||||
"""
|
||||
if task_id is None:
|
||||
task_id = generate_task_id()
|
||||
|
||||
timeout = timeout or self.task_timeout
|
||||
|
||||
# Create task result
|
||||
task_result = TaskResult(
|
||||
task_id=task_id,
|
||||
status=TaskStatus.PENDING,
|
||||
metadata={
|
||||
'function_name': func.__name__,
|
||||
'timeout': timeout,
|
||||
'worker_type': 'process' if self.use_process_pool else 'thread'
|
||||
}
|
||||
)
|
||||
|
||||
self._results[task_id] = task_result
|
||||
|
||||
# Create and start background task
|
||||
background_task = asyncio.create_task(
|
||||
self._execute_task(func, args, kwargs, task_result, timeout)
|
||||
)
|
||||
|
||||
self._running_tasks[task_id] = background_task
|
||||
|
||||
logger.info(f"Task {task_id} submitted for execution (timeout: {timeout}s)")
|
||||
return task_id
|
||||
|
||||
async def _execute_task(
|
||||
self,
|
||||
func: Callable,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
task_result: TaskResult,
|
||||
timeout: float
|
||||
) -> None:
|
||||
"""Execute task in background"""
|
||||
try:
|
||||
task_result.status = TaskStatus.RUNNING
|
||||
task_result.started_at = datetime.now()
|
||||
|
||||
logger.info(f"Starting task {task_result.task_id}: {func.__name__}")
|
||||
|
||||
# Execute function with timeout
|
||||
if self.use_process_pool:
|
||||
# For process pool, function must be pickleable
|
||||
future = self.executor.submit(func, *args, **kwargs)
|
||||
result = await asyncio.wait_for(
|
||||
asyncio.wrap_future(future),
|
||||
timeout=timeout
|
||||
)
|
||||
else:
|
||||
# For thread pool, can use lambda/closure
|
||||
result = await asyncio.wait_for(
|
||||
asyncio.to_thread(func, *args, **kwargs),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
task_result.result = result
|
||||
task_result.status = TaskStatus.COMPLETED
|
||||
task_result.completed_at = datetime.now()
|
||||
|
||||
logger.info(f"Task {task_result.task_id} completed successfully in {task_result.duration:.2f}s")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
task_result.status = TaskStatus.FAILED
|
||||
task_result.error = f"Task timed out after {timeout} seconds"
|
||||
task_result.completed_at = datetime.now()
|
||||
logger.error(f"Task {task_result.task_id} timed out after {timeout}s")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
task_result.status = TaskStatus.CANCELLED
|
||||
task_result.error = "Task was cancelled"
|
||||
task_result.completed_at = datetime.now()
|
||||
logger.info(f"Task {task_result.task_id} was cancelled")
|
||||
|
||||
except Exception as e:
|
||||
task_result.status = TaskStatus.FAILED
|
||||
task_result.error = str(e)
|
||||
task_result.completed_at = datetime.now()
|
||||
logger.error(f"Task {task_result.task_id} failed: {str(e)}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# Clean up running task reference
|
||||
if task_result.task_id in self._running_tasks:
|
||||
del self._running_tasks[task_result.task_id]
|
||||
|
||||
def get_task_result(self, task_id: str) -> Optional[TaskResult]:
|
||||
"""Get task result by ID"""
|
||||
return self._results.get(task_id)
|
||||
|
||||
def get_all_results(self) -> dict[str, TaskResult]:
|
||||
"""Get all task results"""
|
||||
return self._results.copy()
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""
|
||||
Cancel a running task
|
||||
|
||||
Args:
|
||||
task_id: Task ID to cancel
|
||||
|
||||
Returns:
|
||||
bool: True if task was cancelled, False if not found or already finished
|
||||
"""
|
||||
if task_id in self._running_tasks:
|
||||
task = self._running_tasks[task_id]
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
logger.info(f"Task {task_id} cancellation requested")
|
||||
return True
|
||||
return False
|
||||
|
||||
def cleanup_finished_tasks(self, max_results: int = 100) -> int:
|
||||
"""
|
||||
Clean up finished task results to prevent memory buildup
|
||||
|
||||
Args:
|
||||
max_results: Maximum number of results to keep
|
||||
|
||||
Returns:
|
||||
int: Number of results cleaned up
|
||||
"""
|
||||
if len(self._results) <= max_results:
|
||||
return 0
|
||||
|
||||
# Sort by completion time, keep most recent
|
||||
finished_results = {
|
||||
k: v for k, v in self._results.items()
|
||||
if v.is_finished and v.completed_at
|
||||
}
|
||||
|
||||
if len(finished_results) <= max_results:
|
||||
return 0
|
||||
|
||||
sorted_results = sorted(
|
||||
finished_results.items(),
|
||||
key=lambda x: x[1].completed_at,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Keep most recent max_results
|
||||
to_keep = set(k for k, _ in sorted_results[:max_results])
|
||||
|
||||
# Remove old results
|
||||
to_remove = [k for k in finished_results.keys() if k not in to_keep]
|
||||
for task_id in to_remove:
|
||||
del self._results[task_id]
|
||||
|
||||
cleaned_count = len(to_remove)
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"Cleaned up {cleaned_count} old task results")
|
||||
|
||||
return cleaned_count
|
||||
|
||||
async def shutdown(self, wait: bool = True) -> None:
|
||||
"""
|
||||
Shutdown worker pool
|
||||
|
||||
Args:
|
||||
wait: Whether to wait for running tasks to complete
|
||||
"""
|
||||
logger.info("Shutting down worker pool...")
|
||||
|
||||
if wait:
|
||||
# Cancel all running tasks
|
||||
for task_id, task in self._running_tasks.items():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
logger.info(f"Cancelled task {task_id}")
|
||||
|
||||
# Wait for tasks to complete cancellation
|
||||
if self._running_tasks:
|
||||
await asyncio.gather(*self._running_tasks.values(), return_exceptions=True)
|
||||
|
||||
# Shutdown executor
|
||||
self.executor.shutdown(wait=wait)
|
||||
logger.info("Worker pool shutdown complete")
|
||||
|
||||
@property
|
||||
def active_task_count(self) -> int:
|
||||
"""Get number of currently running tasks"""
|
||||
return len([t for t in self._running_tasks.values() if not t.done()])
|
||||
|
||||
@property
|
||||
def total_task_count(self) -> int:
|
||||
"""Get total number of tracked tasks"""
|
||||
return len(self._results)
|
||||
@@ -1,26 +0,0 @@
|
||||
"""
|
||||
Imagen4 Server Package
|
||||
|
||||
MCP 서버 구현을 담당하는 모듈
|
||||
"""
|
||||
|
||||
# Import basic server components
|
||||
from src.server.mcp_server import Imagen4MCPServer
|
||||
from src.server.handlers import ToolHandlers
|
||||
from src.server.models import MCPToolDefinitions
|
||||
|
||||
# Import enhanced components
|
||||
try:
|
||||
from src.server.enhanced_mcp_server import EnhancedImagen4MCPServer
|
||||
from src.server.enhanced_handlers import EnhancedToolHandlers
|
||||
from src.server.enhanced_models import ImageGenerationResult, PreviewImageResponse
|
||||
|
||||
__all__ = [
|
||||
'Imagen4MCPServer', 'ToolHandlers', 'MCPToolDefinitions',
|
||||
'EnhancedImagen4MCPServer', 'EnhancedToolHandlers',
|
||||
'ImageGenerationResult', 'PreviewImageResponse'
|
||||
]
|
||||
except ImportError as e:
|
||||
# Fallback if enhanced modules are not available
|
||||
print(f"Warning: Enhanced features not available: {e}")
|
||||
__all__ = ['Imagen4MCPServer', 'ToolHandlers', 'MCPToolDefinitions']
|
||||
@@ -1,344 +0,0 @@
|
||||
"""
|
||||
Enhanced Tool Handlers for MCP Server with Preview Image Support and Unicode Logging
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from mcp.types import TextContent, ImageContent
|
||||
|
||||
from src.connector import Imagen4Client, Config
|
||||
from src.connector.imagen4_client import ImageGenerationRequest
|
||||
from src.connector.utils import save_generated_images
|
||||
from src.utils.image_utils import create_preview_image_b64, get_image_info
|
||||
from src.server.enhanced_models import ImageGenerationResult, PreviewImageResponse
|
||||
|
||||
# Configure Unicode logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Ensure proper Unicode handling for all string operations
|
||||
def ensure_unicode_string(value):
|
||||
"""Ensure a value is a proper Unicode string"""
|
||||
if isinstance(value, bytes):
|
||||
return value.decode('utf-8', errors='replace')
|
||||
elif isinstance(value, str):
|
||||
return value
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
|
||||
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Remove or truncate sensitive data from arguments for safe logging with Unicode support"""
|
||||
safe_args = {}
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, str):
|
||||
# Ensure proper Unicode handling
|
||||
value = ensure_unicode_string(value)
|
||||
|
||||
# Check if it's likely base64 image data
|
||||
if (key in ['data', 'image_data', 'base64', 'preview_image_b64'] or
|
||||
(len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))):
|
||||
# Truncate long image data
|
||||
safe_args[key] = f"<image_data:{len(value)} chars>"
|
||||
elif len(value) > 1000:
|
||||
# Truncate any very long strings but preserve Unicode characters
|
||||
truncated = value[:100]
|
||||
safe_args[key] = f"{truncated}...<truncated:{len(value)} total chars>"
|
||||
else:
|
||||
# Keep the original string (including Korean/Unicode characters)
|
||||
safe_args[key] = value
|
||||
else:
|
||||
safe_args[key] = value
|
||||
return safe_args
|
||||
|
||||
|
||||
class EnhancedToolHandlers:
|
||||
"""Enhanced MCP tool handler class with preview image support and Unicode logging"""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
"""Initialize handler"""
|
||||
self.config = config
|
||||
self.client = Imagen4Client(config)
|
||||
|
||||
async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||
"""Random seed generation handler"""
|
||||
try:
|
||||
random_seed = random.randint(0, 2**32 - 1)
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Generated random seed: {random_seed}"
|
||||
)]
|
||||
except Exception as e:
|
||||
logger.error(f"Random seed generation error: {str(e)}")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error occurred during random seed generation: {str(e)}"
|
||||
)]
|
||||
|
||||
async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||
"""Image regeneration from JSON file handler with preview support"""
|
||||
try:
|
||||
json_file_path = arguments.get("json_file_path")
|
||||
save_to_file = arguments.get("save_to_file", True)
|
||||
|
||||
if not json_file_path:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text="Error: JSON file path is required."
|
||||
)]
|
||||
|
||||
# Load parameters from JSON file with proper UTF-8 encoding
|
||||
try:
|
||||
with open(json_file_path, 'r', encoding='utf-8') as f:
|
||||
params = json.load(f)
|
||||
except FileNotFoundError:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error: Cannot find JSON file: {json_file_path}"
|
||||
)]
|
||||
except json.JSONDecodeError as e:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error: JSON parsing error: {str(e)}"
|
||||
)]
|
||||
|
||||
# Check required parameters
|
||||
required_params = ['prompt', 'seed']
|
||||
missing_params = [p for p in required_params if p not in params]
|
||||
|
||||
if missing_params:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}"
|
||||
)]
|
||||
|
||||
# Create image generation request object
|
||||
request = ImageGenerationRequest(
|
||||
prompt=params.get('prompt'),
|
||||
negative_prompt=params.get('negative_prompt', ''),
|
||||
number_of_images=params.get('number_of_images', 1),
|
||||
seed=params.get('seed'),
|
||||
aspect_ratio=params.get('aspect_ratio', '1:1'),
|
||||
model=params.get('model', self.config.default_model)
|
||||
)
|
||||
|
||||
logger.info(f"Loaded parameters from JSON: {json_file_path}")
|
||||
|
||||
# Generate image
|
||||
response = await self.client.generate_image(request)
|
||||
|
||||
if not response.success:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error occurred during image regeneration: {response.error_message}"
|
||||
)]
|
||||
|
||||
# Create preview images
|
||||
preview_images_b64 = []
|
||||
for image_data in response.images_data:
|
||||
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
|
||||
if preview_b64:
|
||||
preview_images_b64.append(preview_b64)
|
||||
logger.info(f"Created preview image: {len(preview_b64)} chars base64 JPEG")
|
||||
|
||||
# Save files (optional)
|
||||
saved_files = []
|
||||
if save_to_file:
|
||||
regeneration_params = {
|
||||
"prompt": request.prompt,
|
||||
"negative_prompt": request.negative_prompt,
|
||||
"number_of_images": request.number_of_images,
|
||||
"seed": request.seed,
|
||||
"aspect_ratio": request.aspect_ratio,
|
||||
"model": request.model,
|
||||
"regenerated_from": json_file_path,
|
||||
"original_generated_at": params.get('generated_at', 'unknown')
|
||||
}
|
||||
|
||||
saved_files = save_generated_images(
|
||||
images_data=response.images_data,
|
||||
save_directory=self.config.output_path,
|
||||
seed=request.seed,
|
||||
generation_params=regeneration_params,
|
||||
filename_prefix="imagen4_regen"
|
||||
)
|
||||
|
||||
# Create result with preview images
|
||||
result = ImageGenerationResult(
|
||||
success=True,
|
||||
message=f"✅ Images have been successfully regenerated from {json_file_path}",
|
||||
original_images_count=len(response.images_data),
|
||||
preview_images_b64=preview_images_b64,
|
||||
saved_files=saved_files,
|
||||
generation_params={
|
||||
"prompt": request.prompt,
|
||||
"seed": request.seed,
|
||||
"aspect_ratio": request.aspect_ratio,
|
||||
"number_of_images": request.number_of_images,
|
||||
"model": request.model
|
||||
}
|
||||
)
|
||||
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=result.to_text_content()
|
||||
)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during image regeneration: {str(e)}")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error occurred during image regeneration: {str(e)}"
|
||||
)]
|
||||
|
||||
async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||
"""Enhanced image generation handler with preview image support and proper Unicode logging"""
|
||||
try:
|
||||
# Log arguments safely without exposing image data but preserve Unicode
|
||||
safe_args = sanitize_args_for_logging(arguments)
|
||||
logger.info(f"handle_generate_image called with arguments: {safe_args}")
|
||||
|
||||
# Extract and validate arguments
|
||||
prompt = arguments.get("prompt")
|
||||
if not prompt:
|
||||
logger.error("No prompt provided")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text="Error: Prompt is required."
|
||||
)]
|
||||
|
||||
# Ensure prompt is properly handled as Unicode
|
||||
prompt = ensure_unicode_string(prompt)
|
||||
|
||||
seed = arguments.get("seed")
|
||||
if seed is None:
|
||||
logger.error("No seed provided")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed."
|
||||
)]
|
||||
|
||||
# Create image generation request object
|
||||
request = ImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=ensure_unicode_string(arguments.get("negative_prompt", "")),
|
||||
number_of_images=arguments.get("number_of_images", 1),
|
||||
seed=seed,
|
||||
aspect_ratio=arguments.get("aspect_ratio", "1:1"),
|
||||
model=arguments.get("model", self.config.default_model)
|
||||
)
|
||||
|
||||
save_to_file = arguments.get("save_to_file", True)
|
||||
|
||||
# Log with proper Unicode handling for Korean/international text
|
||||
prompt_preview = prompt[:50] + "..." if len(prompt) > 50 else prompt
|
||||
logger.info(f"Starting image generation: '{prompt_preview}', Seed: {seed}")
|
||||
|
||||
# Generate image with timeout
|
||||
try:
|
||||
logger.info("Calling client.generate_image()...")
|
||||
response = await asyncio.wait_for(
|
||||
self.client.generate_image(request),
|
||||
timeout=360.0 # 6 minute timeout
|
||||
)
|
||||
logger.info(f"Image generation completed. Success: {response.success}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Image generation timed out after 6 minutes")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text="Error: Image generation timed out after 6 minutes. Please try again."
|
||||
)]
|
||||
|
||||
if not response.success:
|
||||
logger.error(f"Image generation failed: {response.error_message}")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error occurred during image generation: {response.error_message}"
|
||||
)]
|
||||
|
||||
logger.info(f"Generated {len(response.images_data)} images successfully")
|
||||
|
||||
# Create preview images (512x512 JPEG base64)
|
||||
preview_images_b64 = []
|
||||
for i, image_data in enumerate(response.images_data):
|
||||
# Log original image info
|
||||
image_info = get_image_info(image_data)
|
||||
if image_info:
|
||||
logger.info(f"Original image {i+1}: {image_info}")
|
||||
|
||||
# Create preview
|
||||
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
|
||||
if preview_b64:
|
||||
preview_images_b64.append(preview_b64)
|
||||
logger.info(f"Created preview image {i+1}: {len(preview_b64)} chars base64 JPEG")
|
||||
else:
|
||||
logger.warning(f"Failed to create preview for image {i+1}")
|
||||
|
||||
# Save files if requested
|
||||
saved_files = []
|
||||
if save_to_file:
|
||||
logger.info("Saving files to disk...")
|
||||
generation_params = {
|
||||
"prompt": request.prompt,
|
||||
"negative_prompt": request.negative_prompt,
|
||||
"number_of_images": request.number_of_images,
|
||||
"seed": request.seed,
|
||||
"aspect_ratio": request.aspect_ratio,
|
||||
"model": request.model,
|
||||
"guidance_scale": 7.5,
|
||||
"safety_filter_level": "block_only_high",
|
||||
"person_generation": "allow_all",
|
||||
"add_watermark": False
|
||||
}
|
||||
|
||||
saved_files = save_generated_images(
|
||||
images_data=response.images_data,
|
||||
save_directory=self.config.output_path,
|
||||
seed=request.seed,
|
||||
generation_params=generation_params
|
||||
)
|
||||
logger.info(f"Files saved: {saved_files}")
|
||||
|
||||
# Verify files were created
|
||||
for file_path in saved_files:
|
||||
if os.path.exists(file_path):
|
||||
size = os.path.getsize(file_path)
|
||||
logger.info(f" ✓ Verified: {file_path} ({size} bytes)")
|
||||
else:
|
||||
logger.error(f" ❌ Missing: {file_path}")
|
||||
|
||||
# Create enhanced result with preview images
|
||||
result = ImageGenerationResult(
|
||||
success=True,
|
||||
message=f"✅ Images have been successfully generated!",
|
||||
original_images_count=len(response.images_data),
|
||||
preview_images_b64=preview_images_b64,
|
||||
saved_files=saved_files,
|
||||
generation_params={
|
||||
"prompt": request.prompt,
|
||||
"seed": request.seed,
|
||||
"aspect_ratio": request.aspect_ratio,
|
||||
"number_of_images": request.number_of_images,
|
||||
"negative_prompt": request.negative_prompt,
|
||||
"model": request.model
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Returning response with {len(preview_images_b64)} preview images")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=result.to_text_content()
|
||||
)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True)
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error occurred during image generation: {str(e)}"
|
||||
)]
|
||||
@@ -1,63 +0,0 @@
|
||||
"""
|
||||
Enhanced MCP Server implementation for Imagen4 with Preview Image Support
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.types import Tool, TextContent
|
||||
|
||||
from src.connector import Config
|
||||
from src.server.models import MCPToolDefinitions
|
||||
from src.server.enhanced_handlers import EnhancedToolHandlers, sanitize_args_for_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnhancedImagen4MCPServer:
|
||||
"""Enhanced Imagen4 MCP server class with preview image support"""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
"""Initialize server"""
|
||||
self.config = config
|
||||
self.server = Server("imagen4-enhanced-mcp-server")
|
||||
self.handlers = EnhancedToolHandlers(config)
|
||||
|
||||
# Register handlers
|
||||
self._register_handlers()
|
||||
|
||||
def _register_handlers(self) -> None:
|
||||
"""Register MCP handlers"""
|
||||
|
||||
@self.server.list_tools()
|
||||
async def handle_list_tools() -> List[Tool]:
|
||||
"""Return list of available tools"""
|
||||
logger.info("Listing available tools (enhanced version)")
|
||||
return MCPToolDefinitions.get_all_tools()
|
||||
|
||||
@self.server.call_tool()
|
||||
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||
"""Handle tool calls with enhanced preview image support"""
|
||||
# Log tool call safely without exposing sensitive data
|
||||
safe_args = sanitize_args_for_logging(arguments)
|
||||
logger.info(f"Enhanced tool called: {name} with arguments: {safe_args}")
|
||||
|
||||
if name == "generate_random_seed":
|
||||
return await self.handlers.handle_generate_random_seed(arguments)
|
||||
elif name == "regenerate_from_json":
|
||||
return await self.handlers.handle_regenerate_from_json(arguments)
|
||||
elif name == "generate_image":
|
||||
return await self.handlers.handle_generate_image(arguments)
|
||||
else:
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
def get_server(self) -> Server:
|
||||
"""Return MCP server instance"""
|
||||
return self.server
|
||||
|
||||
|
||||
# Create a factory function for easier import
|
||||
def create_enhanced_server(config: Config) -> EnhancedImagen4MCPServer:
|
||||
"""Factory function to create enhanced server"""
|
||||
return EnhancedImagen4MCPServer(config)
|
||||
@@ -1,71 +0,0 @@
|
||||
"""
|
||||
Enhanced MCP response models with preview image support
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageGenerationResult:
|
||||
"""Enhanced image generation result with preview support"""
|
||||
success: bool
|
||||
message: str
|
||||
original_images_count: int
|
||||
preview_images_b64: Optional[list[str]] = None # List of base64 JPEG previews (512x512)
|
||||
saved_files: Optional[list[str]] = None
|
||||
generation_params: Optional[Dict[str, Any]] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
def to_text_content(self) -> str:
|
||||
"""Convert result to text format for MCP response"""
|
||||
lines = [self.message]
|
||||
|
||||
if self.success and self.preview_images_b64:
|
||||
lines.append(f"\n🖼️ Preview Images Generated: {len(self.preview_images_b64)} images (512x512 JPEG)")
|
||||
for i, preview_b64 in enumerate(self.preview_images_b64):
|
||||
lines.append(f"Preview {i+1} (base64 JPEG): {preview_b64[:50]}...({len(preview_b64)} chars)")
|
||||
|
||||
if self.saved_files:
|
||||
lines.append(f"\n📁 Files saved:")
|
||||
for filepath in self.saved_files:
|
||||
lines.append(f" - {filepath}")
|
||||
|
||||
if self.generation_params:
|
||||
lines.append(f"\n⚙️ Generation Parameters:")
|
||||
for key, value in self.generation_params.items():
|
||||
if key == 'prompt' and len(str(value)) > 100:
|
||||
lines.append(f" - {key}: {str(value)[:100]}...")
|
||||
else:
|
||||
lines.append(f" - {key}: {value}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreviewImageResponse:
|
||||
"""Response containing preview images in base64 format"""
|
||||
preview_images_b64: list[str] # Base64 encoded JPEG images (512x512)
|
||||
original_count: int
|
||||
message: str
|
||||
|
||||
@classmethod
|
||||
def from_image_data(cls, images_data: list[bytes], message: str = "") -> 'PreviewImageResponse':
|
||||
"""Create response from original PNG image data"""
|
||||
from src.utils.image_utils import create_preview_image_b64
|
||||
|
||||
preview_images = []
|
||||
for i, image_data in enumerate(images_data):
|
||||
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
|
||||
if preview_b64:
|
||||
preview_images.append(preview_b64)
|
||||
else:
|
||||
# Fallback - use original if preview creation fails
|
||||
import base64
|
||||
preview_images.append(base64.b64encode(image_data).decode('utf-8'))
|
||||
|
||||
return cls(
|
||||
preview_images_b64=preview_images,
|
||||
original_count=len(images_data),
|
||||
message=message or f"Generated {len(preview_images)} preview images"
|
||||
)
|
||||
@@ -1,308 +0,0 @@
|
||||
"""
|
||||
Tool Handlers for MCP Server
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
import logging
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from mcp.types import TextContent, ImageContent
|
||||
|
||||
from src.connector import Imagen4Client, Config
|
||||
from src.connector.imagen4_client import ImageGenerationRequest
|
||||
from src.connector.utils import save_generated_images
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Remove or truncate sensitive data from arguments for safe logging"""
|
||||
safe_args = {}
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, str):
|
||||
# Check if it's likely base64 image data
|
||||
if (key in ['data', 'image_data', 'base64'] or
|
||||
(len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))):
|
||||
# Truncate long image data
|
||||
safe_args[key] = f"<image_data:{len(value)} chars>"
|
||||
elif len(value) > 1000:
|
||||
# Truncate any very long strings
|
||||
safe_args[key] = f"{value[:100]}...<truncated:{len(value)} total chars>"
|
||||
else:
|
||||
safe_args[key] = value
|
||||
else:
|
||||
safe_args[key] = value
|
||||
return safe_args
|
||||
|
||||
|
||||
class ToolHandlers:
|
||||
"""MCP tool handler class"""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
"""Initialize handler"""
|
||||
self.config = config
|
||||
self.client = Imagen4Client(config)
|
||||
|
||||
async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||
"""Random seed generation handler"""
|
||||
try:
|
||||
random_seed = random.randint(0, 2**32 - 1)
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Generated random seed: {random_seed}"
|
||||
)]
|
||||
except Exception as e:
|
||||
logger.error(f"Random seed generation error: {str(e)}")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error occurred during random seed generation: {str(e)}"
|
||||
)]
|
||||
|
||||
async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
|
||||
"""Image regeneration from JSON file handler"""
|
||||
try:
|
||||
json_file_path = arguments.get("json_file_path")
|
||||
save_to_file = arguments.get("save_to_file", True)
|
||||
|
||||
if not json_file_path:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text="Error: JSON file path is required."
|
||||
)]
|
||||
|
||||
# Load parameters from JSON file
|
||||
try:
|
||||
with open(json_file_path, 'r', encoding='utf-8') as f:
|
||||
params = json.load(f)
|
||||
except FileNotFoundError:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error: Cannot find JSON file: {json_file_path}"
|
||||
)]
|
||||
except json.JSONDecodeError as e:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error: JSON parsing error: {str(e)}"
|
||||
)]
|
||||
|
||||
# Check required parameters
|
||||
required_params = ['prompt', 'seed']
|
||||
missing_params = [p for p in required_params if p not in params]
|
||||
|
||||
if missing_params:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}"
|
||||
)]
|
||||
|
||||
# Create image generation request object
|
||||
request = ImageGenerationRequest(
|
||||
prompt=params.get('prompt'),
|
||||
negative_prompt=params.get('negative_prompt', ''),
|
||||
number_of_images=params.get('number_of_images', 1),
|
||||
seed=params.get('seed'),
|
||||
aspect_ratio=params.get('aspect_ratio', '1:1')
|
||||
)
|
||||
|
||||
logger.info(f"Loaded parameters from JSON: {json_file_path}")
|
||||
|
||||
# Generate image
|
||||
response = await self.client.generate_image(request)
|
||||
|
||||
if not response.success:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error occurred during image regeneration: {response.error_message}"
|
||||
)]
|
||||
|
||||
# Save files (optional)
|
||||
saved_files = []
|
||||
if save_to_file:
|
||||
regeneration_params = {
|
||||
"prompt": request.prompt,
|
||||
"negative_prompt": request.negative_prompt,
|
||||
"number_of_images": request.number_of_images,
|
||||
"seed": request.seed,
|
||||
"aspect_ratio": request.aspect_ratio,
|
||||
"regenerated_from": json_file_path,
|
||||
"original_generated_at": params.get('generated_at', 'unknown')
|
||||
}
|
||||
|
||||
saved_files = save_generated_images(
|
||||
images_data=response.images_data,
|
||||
save_directory=self.config.output_path,
|
||||
seed=request.seed,
|
||||
generation_params=regeneration_params,
|
||||
filename_prefix="imagen4_regen"
|
||||
)
|
||||
|
||||
# Generate result message
|
||||
message_parts = [
|
||||
f"Images have been successfully regenerated.",
|
||||
f"JSON file: {json_file_path}",
|
||||
f"Size: 2048x2048 PNG",
|
||||
f"Count: {len(response.images_data)} images",
|
||||
f"Seed: {request.seed}",
|
||||
f"Prompt: {request.prompt}"
|
||||
]
|
||||
|
||||
if saved_files:
|
||||
message_parts.append(f"\nRegenerated files saved successfully:")
|
||||
for filepath in saved_files:
|
||||
message_parts.append(f"- {filepath}")
|
||||
|
||||
# Generate result
|
||||
result = [
|
||||
TextContent(
|
||||
type="text",
|
||||
text="\n".join(message_parts)
|
||||
)
|
||||
]
|
||||
|
||||
# Add Base64 encoded images
|
||||
for i, image_data in enumerate(response.images_data):
|
||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||
result.append(
|
||||
ImageContent(
|
||||
type="image",
|
||||
data=image_base64,
|
||||
mimeType="image/png"
|
||||
)
|
||||
)
|
||||
logger.info(f"Added image {i+1} to response: {len(image_base64)} chars base64 data")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during image regeneration: {str(e)}")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error occurred during image regeneration: {str(e)}"
|
||||
)]
|
||||
|
||||
async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
|
||||
"""Image generation handler with synchronous processing"""
|
||||
try:
|
||||
# Log arguments safely without exposing image data
|
||||
safe_args = sanitize_args_for_logging(arguments)
|
||||
logger.info(f"handle_generate_image called with arguments: {safe_args}")
|
||||
|
||||
# Extract and validate arguments
|
||||
prompt = arguments.get("prompt")
|
||||
if not prompt:
|
||||
logger.error("No prompt provided")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text="Error: Prompt is required."
|
||||
)]
|
||||
|
||||
seed = arguments.get("seed")
|
||||
if seed is None:
|
||||
logger.error("No seed provided")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed."
|
||||
)]
|
||||
|
||||
# Create image generation request object
|
||||
request = ImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=arguments.get("negative_prompt", ""),
|
||||
number_of_images=arguments.get("number_of_images", 1),
|
||||
seed=seed,
|
||||
aspect_ratio=arguments.get("aspect_ratio", "1:1")
|
||||
)
|
||||
|
||||
save_to_file = arguments.get("save_to_file", True)
|
||||
|
||||
logger.info(f"Starting SYNCHRONOUS image generation: '{prompt[:50]}...', Seed: {seed}")
|
||||
|
||||
# Generate image synchronously with longer timeout
|
||||
try:
|
||||
logger.info("Calling client.generate_image() synchronously...")
|
||||
response = await asyncio.wait_for(
|
||||
self.client.generate_image(request),
|
||||
timeout=360.0 # 6 minute timeout
|
||||
)
|
||||
logger.info(f"Image generation completed. Success: {response.success}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Image generation timed out after 6 minutes")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text="Error: Image generation timed out after 6 minutes. Please try again."
|
||||
)]
|
||||
|
||||
if not response.success:
|
||||
logger.error(f"Image generation failed: {response.error_message}")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error occurred during image generation: {response.error_message}"
|
||||
)]
|
||||
|
||||
logger.info(f"Generated {len(response.images_data)} images successfully")
|
||||
|
||||
# Save files if requested
|
||||
saved_files = []
|
||||
if save_to_file:
|
||||
logger.info("Saving files to disk...")
|
||||
generation_params = {
|
||||
"prompt": request.prompt,
|
||||
"negative_prompt": request.negative_prompt,
|
||||
"number_of_images": request.number_of_images,
|
||||
"seed": request.seed,
|
||||
"aspect_ratio": request.aspect_ratio,
|
||||
"guidance_scale": 7.5,
|
||||
"safety_filter_level": "block_only_high",
|
||||
"person_generation": "allow_all",
|
||||
"add_watermark": False
|
||||
}
|
||||
|
||||
saved_files = save_generated_images(
|
||||
images_data=response.images_data,
|
||||
save_directory=self.config.output_path,
|
||||
seed=request.seed,
|
||||
generation_params=generation_params
|
||||
)
|
||||
logger.info(f"Files saved: {saved_files}")
|
||||
|
||||
# Verify files were created
|
||||
import os
|
||||
for file_path in saved_files:
|
||||
if os.path.exists(file_path):
|
||||
size = os.path.getsize(file_path)
|
||||
logger.info(f" ✓ Verified: {file_path} ({size} bytes)")
|
||||
else:
|
||||
logger.error(f" ❌ Missing: {file_path}")
|
||||
|
||||
# Generate result message
|
||||
message_parts = [
|
||||
f"✅ Images have been successfully generated!",
|
||||
f"Prompt: {request.prompt}",
|
||||
f"Seed: {request.seed}",
|
||||
f"Size: 2048x2048 PNG",
|
||||
f"Count: {len(response.images_data)} images"
|
||||
]
|
||||
|
||||
if saved_files:
|
||||
message_parts.append(f"\n📁 Files saved successfully:")
|
||||
for filepath in saved_files:
|
||||
message_parts.append(f" - {filepath}")
|
||||
else:
|
||||
message_parts.append(f"\nℹ️ Images generated but not saved to file. Use save_to_file=true to save.")
|
||||
|
||||
logger.info("Returning synchronous response")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text="\n".join(message_parts)
|
||||
)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True)
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error occurred during image generation: {str(e)}"
|
||||
)]
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""
|
||||
MCP Server implementation for Imagen4
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.types import Tool, TextContent, ImageContent
|
||||
|
||||
from src.connector import Config
|
||||
from src.server.models import MCPToolDefinitions
|
||||
from src.server.handlers import ToolHandlers, sanitize_args_for_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Imagen4MCPServer:
|
||||
"""Imagen4 MCP server class"""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
"""Initialize server"""
|
||||
self.config = config
|
||||
self.server = Server("imagen4-mcp-server")
|
||||
self.handlers = ToolHandlers(config)
|
||||
|
||||
# Register handlers
|
||||
self._register_handlers()
|
||||
|
||||
def _register_handlers(self) -> None:
|
||||
"""Register MCP handlers"""
|
||||
|
||||
@self.server.list_tools()
|
||||
async def handle_list_tools() -> List[Tool]:
|
||||
"""Return list of available tools"""
|
||||
logger.info("Listing available tools")
|
||||
return MCPToolDefinitions.get_all_tools()
|
||||
|
||||
@self.server.call_tool()
|
||||
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
|
||||
"""Handle tool calls"""
|
||||
# Log tool call safely without exposing sensitive data
|
||||
safe_args = sanitize_args_for_logging(arguments)
|
||||
logger.info(f"Tool called: {name} with arguments: {safe_args}")
|
||||
|
||||
if name == "generate_random_seed":
|
||||
return await self.handlers.handle_generate_random_seed(arguments)
|
||||
elif name == "regenerate_from_json":
|
||||
return await self.handlers.handle_regenerate_from_json(arguments)
|
||||
elif name == "generate_image":
|
||||
return await self.handlers.handle_generate_image(arguments)
|
||||
else:
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
def get_server(self) -> Server:
|
||||
"""Return MCP server instance"""
|
||||
return self.server
|
||||
@@ -1,39 +0,0 @@
|
||||
"""
|
||||
Minimal Task Result Handler for Emergency Use
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from mcp.types import TextContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def minimal_get_task_result(task_manager, task_id: str) -> List[TextContent]:
|
||||
"""Absolutely minimal task result handler"""
|
||||
try:
|
||||
logger.info(f"Minimal handler for task: {task_id}")
|
||||
|
||||
# Just return the task status for now
|
||||
status = task_manager.get_task_status(task_id)
|
||||
|
||||
if status is None:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ Task '{task_id}' not found."
|
||||
)]
|
||||
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"📋 Task '{task_id}' Status: {status.value}\n"
|
||||
f"Note: This is a minimal emergency handler.\n"
|
||||
f"If you see this message, the original handler has a serious issue."
|
||||
)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Even minimal handler failed: {str(e)}", exc_info=True)
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Complete failure: {str(e)}"
|
||||
)]
|
||||
@@ -1,132 +0,0 @@
|
||||
"""
|
||||
MCP Tool Models and Definitions
|
||||
"""
|
||||
|
||||
from mcp.types import Tool
|
||||
|
||||
|
||||
class MCPToolDefinitions:
|
||||
"""MCP tool definition class"""
|
||||
|
||||
@staticmethod
|
||||
def get_generate_image_tool() -> Tool:
|
||||
"""Image generation tool definition"""
|
||||
return Tool(
|
||||
name="generate_image",
|
||||
description="Generate 2048x2048 PNG images from text prompts using Google Imagen 4 AI.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Text prompt for image generation (Korean or English). Maximum 480 tokens allowed.",
|
||||
"maxLength": 2000 # Approximate character limit for safety
|
||||
},
|
||||
"negative_prompt": {
|
||||
"type": "string",
|
||||
"description": "Negative prompt specifying elements not to generate (optional). Maximum 240 tokens allowed.",
|
||||
"default": "",
|
||||
"maxLength": 1000 # Approximate character limit for safety
|
||||
},
|
||||
"number_of_images": {
|
||||
"type": "integer",
|
||||
"description": "Number of images to generate (1 or 2 only)",
|
||||
"enum": [1, 2],
|
||||
"default": 1
|
||||
},
|
||||
"seed": {
|
||||
"type": "integer",
|
||||
"description": "Seed value for reproducible results (required, 0 ~ 4294967295 range)",
|
||||
"minimum": 0,
|
||||
"maximum": 4294967295
|
||||
},
|
||||
"aspect_ratio": {
|
||||
"type": "string",
|
||||
"description": "Image aspect ratio",
|
||||
"enum": ["1:1", "9:16", "16:9", "3:4", "4:3"],
|
||||
"default": "1:1"
|
||||
},
|
||||
"save_to_file": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to save generated images to files (default: true)",
|
||||
"default": True
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "Imagen 4 model to use for generation",
|
||||
"enum": ["imagen-4.0-generate-001", "imagen-4.0-ultra-generate-001"],
|
||||
"default": "imagen-4.0-generate-001"
|
||||
}
|
||||
},
|
||||
"required": ["prompt", "seed"]
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_regenerate_from_json_tool() -> Tool:
|
||||
"""Regenerate from JSON tool definition"""
|
||||
return Tool(
|
||||
name="regenerate_from_json",
|
||||
description="Read parameters from JSON file and regenerate images with the same settings.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"json_file_path": {
|
||||
"type": "string",
|
||||
"description": "Path to JSON file containing saved parameters"
|
||||
},
|
||||
"save_to_file": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to save regenerated images to files (default: true)",
|
||||
"default": True
|
||||
}
|
||||
},
|
||||
"required": ["json_file_path"]
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_generate_random_seed_tool() -> Tool:
|
||||
"""Random seed generation tool definition"""
|
||||
return Tool(
|
||||
name="generate_random_seed",
|
||||
description="Generate random seed value for image generation.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_validate_prompt_tool() -> Tool:
|
||||
"""Prompt validation tool definition"""
|
||||
return Tool(
|
||||
name="validate_prompt",
|
||||
description="Validate prompt length and estimate token count before image generation.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Text prompt to validate"
|
||||
},
|
||||
"negative_prompt": {
|
||||
"type": "string",
|
||||
"description": "Negative prompt to validate (optional)",
|
||||
"default": ""
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_all_tools(cls) -> list[Tool]:
|
||||
"""Return all tool definitions"""
|
||||
return [
|
||||
cls.get_generate_image_tool(),
|
||||
cls.get_regenerate_from_json_tool(),
|
||||
cls.get_generate_random_seed_tool(),
|
||||
cls.get_validate_prompt_tool()
|
||||
]
|
||||
@@ -1,80 +0,0 @@
|
||||
"""
|
||||
Safe Get Task Result Handler - Debug Version
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from mcp.types import TextContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def safe_get_task_result(task_manager, task_id: str) -> List[TextContent]:
|
||||
"""Minimal safe version of get_task_result without image processing"""
|
||||
try:
|
||||
logger.info(f"Safe get_task_result called for: {task_id}")
|
||||
|
||||
# Step 1: Try to get raw result
|
||||
try:
|
||||
raw_result = task_manager.worker_pool.get_task_result(task_id)
|
||||
logger.info(f"Raw result type: {type(raw_result)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get raw result: {str(e)}")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error accessing task data: {str(e)}"
|
||||
)]
|
||||
|
||||
if not raw_result:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ Task '{task_id}' not found."
|
||||
)]
|
||||
|
||||
# Step 2: Check status safely
|
||||
try:
|
||||
status = raw_result.status
|
||||
logger.info(f"Task status: {status}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get status: {str(e)}")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Error reading task status: {str(e)}"
|
||||
)]
|
||||
|
||||
# Step 3: Return basic info without touching result.result
|
||||
try:
|
||||
duration = raw_result.duration if hasattr(raw_result, 'duration') else None
|
||||
created = raw_result.created_at.isoformat() if hasattr(raw_result, 'created_at') and raw_result.created_at else "unknown"
|
||||
|
||||
message = f"✅ Task '{task_id}' Status Report:\n"
|
||||
message += f"Status: {status.value}\n"
|
||||
message += f"Created: {created}\n"
|
||||
|
||||
if duration:
|
||||
message += f"Duration: {duration:.2f}s\n"
|
||||
|
||||
if hasattr(raw_result, 'error') and raw_result.error:
|
||||
message += f"Error: {raw_result.error}\n"
|
||||
|
||||
message += "\nNote: This is a safe diagnostic version."
|
||||
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=message
|
||||
)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to format response: {str(e)}")
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Task exists but data formatting failed: {str(e)}"
|
||||
)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error in safe_get_task_result: {str(e)}", exc_info=True)
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Critical error: {str(e)}"
|
||||
)]
|
||||
@@ -1,107 +0,0 @@
|
||||
"""
|
||||
Image processing utilities for preview generation
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_preview_image_b64(image_data: bytes, target_size: int = 512, quality: int = 85) -> Optional[str]:
|
||||
"""
|
||||
Convert PNG image data to JPEG preview with specified size and return as base64
|
||||
|
||||
Args:
|
||||
image_data: Original PNG image data in bytes
|
||||
target_size: Target size for the preview (default: 512x512)
|
||||
quality: JPEG quality (1-100, default: 85)
|
||||
|
||||
Returns:
|
||||
Base64 encoded JPEG image string, or None if conversion fails
|
||||
"""
|
||||
try:
|
||||
# Open image from bytes
|
||||
with Image.open(io.BytesIO(image_data)) as img:
|
||||
# Convert to RGB if necessary (PNG might have alpha channel)
|
||||
if img.mode in ('RGBA', 'LA', 'P'):
|
||||
# Create white background for transparent images
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
if img.mode == 'P':
|
||||
img = img.convert('RGBA')
|
||||
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
|
||||
img = background
|
||||
elif img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Resize to target size maintaining aspect ratio
|
||||
img.thumbnail((target_size, target_size), Image.Resampling.LANCZOS)
|
||||
|
||||
# If image is smaller than target size, pad it to exact size
|
||||
if img.size != (target_size, target_size):
|
||||
# Create new image with target size and white background
|
||||
new_img = Image.new('RGB', (target_size, target_size), (255, 255, 255))
|
||||
# Center the resized image
|
||||
x = (target_size - img.size[0]) // 2
|
||||
y = (target_size - img.size[1]) // 2
|
||||
new_img.paste(img, (x, y))
|
||||
img = new_img
|
||||
|
||||
# Convert to JPEG and encode as base64
|
||||
output_buffer = io.BytesIO()
|
||||
img.save(output_buffer, format='JPEG', quality=quality, optimize=True)
|
||||
jpeg_data = output_buffer.getvalue()
|
||||
|
||||
# Encode to base64
|
||||
base64_data = base64.b64encode(jpeg_data).decode('utf-8')
|
||||
|
||||
logger.info(f"Preview image created: {target_size}x{target_size} JPEG, {len(jpeg_data)} bytes, base64 length: {len(base64_data)}")
|
||||
return base64_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create preview image: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def validate_image_data(image_data: bytes) -> bool:
|
||||
"""
|
||||
Validate if image data is a valid image
|
||||
|
||||
Args:
|
||||
image_data: Image data in bytes
|
||||
|
||||
Returns:
|
||||
True if valid image, False otherwise
|
||||
"""
|
||||
try:
|
||||
with Image.open(io.BytesIO(image_data)) as img:
|
||||
img.verify() # Verify image integrity
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_image_info(image_data: bytes) -> Optional[dict]:
|
||||
"""
|
||||
Get image information
|
||||
|
||||
Args:
|
||||
image_data: Image data in bytes
|
||||
|
||||
Returns:
|
||||
Dictionary with image info (format, size, mode) or None if invalid
|
||||
"""
|
||||
try:
|
||||
with Image.open(io.BytesIO(image_data)) as img:
|
||||
return {
|
||||
'format': img.format,
|
||||
'size': img.size,
|
||||
'mode': img.mode,
|
||||
'bytes': len(image_data)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get image info: {str(e)}")
|
||||
return None
|
||||
@@ -1,200 +0,0 @@
|
||||
"""
|
||||
Token counting utilities for Imagen4 prompts
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 기본 토큰 추정 상수
|
||||
AVERAGE_CHARS_PER_TOKEN = 4 # 영어 기준 평균값
|
||||
KOREAN_CHARS_PER_TOKEN = 2 # 한글 기준 평균값
|
||||
MAX_PROMPT_TOKENS = 480 # 최대 프롬프트 토큰 수
|
||||
|
||||
|
||||
def estimate_token_count(text: str) -> int:
|
||||
"""
|
||||
텍스트의 토큰 수를 추정합니다.
|
||||
|
||||
정확한 토큰 계산을 위해서는 실제 토크나이저가 필요하지만,
|
||||
API 호출 전 빠른 검증을 위해 추정값을 사용합니다.
|
||||
|
||||
Args:
|
||||
text: 토큰 수를 계산할 텍스트
|
||||
|
||||
Returns:
|
||||
int: 추정된 토큰 수
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# 텍스트 정리
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# 한글과 영어 문자 분리
|
||||
korean_chars = len(re.findall(r'[가-힣]', text))
|
||||
english_chars = len(re.findall(r'[a-zA-Z]', text))
|
||||
other_chars = len(text) - korean_chars - english_chars
|
||||
|
||||
# 토큰 수 추정
|
||||
korean_tokens = korean_chars / KOREAN_CHARS_PER_TOKEN
|
||||
english_tokens = english_chars / AVERAGE_CHARS_PER_TOKEN
|
||||
other_tokens = other_chars / AVERAGE_CHARS_PER_TOKEN
|
||||
|
||||
estimated_tokens = int(korean_tokens + english_tokens + other_tokens)
|
||||
|
||||
# 최소 1토큰은 보장
|
||||
return max(1, estimated_tokens)
|
||||
|
||||
|
||||
def validate_prompt_length(prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> tuple[bool, int, Optional[str]]:
|
||||
"""
|
||||
프롬프트 길이를 검증합니다.
|
||||
|
||||
Args:
|
||||
prompt: 검증할 프롬프트
|
||||
max_tokens: 최대 허용 토큰 수
|
||||
|
||||
Returns:
|
||||
tuple: (유효성, 토큰 수, 오류 메시지)
|
||||
"""
|
||||
if not prompt:
|
||||
return False, 0, "프롬프트가 비어있습니다."
|
||||
|
||||
token_count = estimate_token_count(prompt)
|
||||
|
||||
if token_count > max_tokens:
|
||||
error_msg = (
|
||||
f"프롬프트가 너무 깁니다. "
|
||||
f"현재: {token_count}토큰, 최대: {max_tokens}토큰. "
|
||||
f"프롬프트를 {token_count - max_tokens}토큰 줄여주세요."
|
||||
)
|
||||
return False, token_count, error_msg
|
||||
|
||||
return True, token_count, None
|
||||
|
||||
|
||||
def truncate_prompt(prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> str:
|
||||
"""
|
||||
프롬프트를 지정된 토큰 수로 자릅니다.
|
||||
|
||||
Args:
|
||||
prompt: 자를 프롬프트
|
||||
max_tokens: 최대 토큰 수
|
||||
|
||||
Returns:
|
||||
str: 잘린 프롬프트
|
||||
"""
|
||||
if not prompt:
|
||||
return ""
|
||||
|
||||
current_tokens = estimate_token_count(prompt)
|
||||
if current_tokens <= max_tokens:
|
||||
return prompt
|
||||
|
||||
# 대략적인 비율로 텍스트 자르기
|
||||
ratio = max_tokens / current_tokens
|
||||
target_length = int(len(prompt) * ratio * 0.9) # 여유분 10%
|
||||
|
||||
truncated = prompt[:target_length]
|
||||
|
||||
# 단어/문장 경계에서 자르기
|
||||
if len(truncated) < len(prompt):
|
||||
# 마지막 완전한 단어까지만 유지
|
||||
last_space = truncated.rfind(' ')
|
||||
last_korean = truncated.rfind('다') # 한글 어미
|
||||
last_punct = max(truncated.rfind('.'), truncated.rfind(','), truncated.rfind('!'))
|
||||
|
||||
cut_point = max(last_space, last_korean, last_punct)
|
||||
if cut_point > target_length * 0.8: # 너무 많이 잘리지 않도록
|
||||
truncated = truncated[:cut_point]
|
||||
|
||||
# 최종 검증
|
||||
final_tokens = estimate_token_count(truncated)
|
||||
if final_tokens > max_tokens:
|
||||
# 강제로 문자 단위로 자르기
|
||||
chars_per_token = len(truncated) / final_tokens
|
||||
target_chars = int(max_tokens * chars_per_token * 0.95)
|
||||
truncated = truncated[:target_chars]
|
||||
|
||||
return truncated.strip()
|
||||
|
||||
|
||||
def get_prompt_stats(prompt: str) -> dict:
|
||||
"""
|
||||
프롬프트 통계 정보를 반환합니다.
|
||||
|
||||
Args:
|
||||
prompt: 분석할 프롬프트
|
||||
|
||||
Returns:
|
||||
dict: 프롬프트 통계
|
||||
"""
|
||||
if not prompt:
|
||||
return {
|
||||
"character_count": 0,
|
||||
"estimated_tokens": 0,
|
||||
"korean_chars": 0,
|
||||
"english_chars": 0,
|
||||
"other_chars": 0,
|
||||
"is_valid": False,
|
||||
"remaining_tokens": MAX_PROMPT_TOKENS
|
||||
}
|
||||
|
||||
char_count = len(prompt)
|
||||
korean_chars = len(re.findall(r'[가-힣]', prompt))
|
||||
english_chars = len(re.findall(r'[a-zA-Z]', prompt))
|
||||
other_chars = char_count - korean_chars - english_chars
|
||||
estimated_tokens = estimate_token_count(prompt)
|
||||
is_valid = estimated_tokens <= MAX_PROMPT_TOKENS
|
||||
remaining_tokens = MAX_PROMPT_TOKENS - estimated_tokens
|
||||
|
||||
return {
|
||||
"character_count": char_count,
|
||||
"estimated_tokens": estimated_tokens,
|
||||
"korean_chars": korean_chars,
|
||||
"english_chars": english_chars,
|
||||
"other_chars": other_chars,
|
||||
"is_valid": is_valid,
|
||||
"remaining_tokens": remaining_tokens,
|
||||
"max_tokens": MAX_PROMPT_TOKENS
|
||||
}
|
||||
|
||||
|
||||
# 실제 토크나이저 사용 시 대체할 수 있는 인터페이스
|
||||
class TokenCounter:
|
||||
"""토큰 카운터 인터페이스"""
|
||||
|
||||
def __init__(self, tokenizer_name: Optional[str] = None):
|
||||
"""
|
||||
토큰 카운터 초기화
|
||||
|
||||
Args:
|
||||
tokenizer_name: 사용할 토크나이저 이름 (향후 확장용)
|
||||
"""
|
||||
self.tokenizer_name = tokenizer_name or "estimate"
|
||||
logger.info(f"토큰 카운터 초기화: {self.tokenizer_name}")
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""텍스트의 토큰 수 계산"""
|
||||
if self.tokenizer_name == "estimate":
|
||||
return estimate_token_count(text)
|
||||
else:
|
||||
# 향후 실제 토크나이저 구현
|
||||
raise NotImplementedError(f"토크나이저 '{self.tokenizer_name}'는 아직 구현되지 않았습니다.")
|
||||
|
||||
def validate_prompt(self, prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> tuple[bool, int, Optional[str]]:
|
||||
"""프롬프트 검증"""
|
||||
return validate_prompt_length(prompt, max_tokens)
|
||||
|
||||
def truncate_prompt(self, prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> str:
|
||||
"""프롬프트 자르기"""
|
||||
return truncate_prompt(prompt, max_tokens)
|
||||
|
||||
|
||||
# 전역 토큰 카운터 인스턴스
|
||||
default_token_counter = TokenCounter()
|
||||
40
test_mcp_compatibility.py
Normal file
40
test_mcp_compatibility.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify MCP protocol compatibility
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
|
||||
# Add current directory to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
def test_json_output():
|
||||
"""Test that we can output valid JSON without interference"""
|
||||
# This should work without any encoding issues
|
||||
test_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized",
|
||||
"params": {}
|
||||
}
|
||||
|
||||
# Output JSON to stdout (this is what MCP protocol expects)
|
||||
print(json.dumps(test_data))
|
||||
|
||||
def test_unicode_handling():
|
||||
"""Test Unicode handling in stderr only"""
|
||||
# This should go to stderr only, not interfere with stdout JSON
|
||||
test_unicode = "Test UTF-8: 한글 테스트 ✓"
|
||||
print(f"[UTF8-TEST] {test_unicode}", file=sys.stderr)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("[MCP-TEST] Testing JSON output compatibility...", file=sys.stderr)
|
||||
test_json_output()
|
||||
print("[MCP-TEST] JSON output test completed", file=sys.stderr)
|
||||
|
||||
print("[MCP-TEST] Testing Unicode handling...", file=sys.stderr)
|
||||
test_unicode_handling()
|
||||
print("[MCP-TEST] Unicode test completed", file=sys.stderr)
|
||||
|
||||
print("[MCP-TEST] All tests passed!", file=sys.stderr)
|
||||
@@ -1,178 +0,0 @@
|
||||
"""
|
||||
Tests for Imagen4 Connector
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from src.connector import Config, Imagen4Client
|
||||
from src.connector.imagen4_client import ImageGenerationRequest
|
||||
|
||||
|
||||
class TestConfig:
|
||||
"""Config 클래스 테스트"""
|
||||
|
||||
def test_config_creation(self):
|
||||
"""설정 생성 테스트"""
|
||||
config = Config(
|
||||
project_id="test-project",
|
||||
location="us-central1",
|
||||
output_path="./test_images"
|
||||
)
|
||||
|
||||
assert config.project_id == "test-project"
|
||||
assert config.location == "us-central1"
|
||||
assert config.output_path == "./test_images"
|
||||
|
||||
def test_config_validation(self):
|
||||
"""설정 유효성 검사 테스트"""
|
||||
# 유효한 설정
|
||||
valid_config = Config(project_id="test-project")
|
||||
assert valid_config.validate() is True
|
||||
|
||||
# 무효한 설정
|
||||
invalid_config = Config(project_id="")
|
||||
assert invalid_config.validate() is False
|
||||
|
||||
@patch.dict('os.environ', {
|
||||
'GOOGLE_CLOUD_PROJECT_ID': 'test-project',
|
||||
'GOOGLE_CLOUD_LOCATION': 'us-west1',
|
||||
'GENERATED_IMAGES_PATH': './custom_path'
|
||||
})
|
||||
def test_config_from_env(self):
|
||||
"""환경 변수에서 설정 로드 테스트"""
|
||||
config = Config.from_env()
|
||||
|
||||
assert config.project_id == "test-project"
|
||||
assert config.location == "us-west1"
|
||||
assert config.output_path == "./custom_path"
|
||||
|
||||
@patch.dict('os.environ', {}, clear=True)
|
||||
def test_config_from_env_missing_project_id(self):
|
||||
"""필수 환경 변수 누락 테스트"""
|
||||
with pytest.raises(ValueError, match="GOOGLE_CLOUD_PROJECT_ID environment variable is required"):
|
||||
Config.from_env()
|
||||
|
||||
|
||||
class TestImageGenerationRequest:
|
||||
"""ImageGenerationRequest 테스트"""
|
||||
|
||||
def test_valid_request(self):
|
||||
"""유효한 요청 테스트"""
|
||||
request = ImageGenerationRequest(
|
||||
prompt="A beautiful landscape",
|
||||
seed=12345
|
||||
)
|
||||
|
||||
# 유효성 검사가 예외 없이 완료되어야 함
|
||||
request.validate()
|
||||
|
||||
def test_invalid_prompt(self):
|
||||
"""무효한 프롬프트 테스트"""
|
||||
request = ImageGenerationRequest(
|
||||
prompt="",
|
||||
seed=12345
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Prompt is required"):
|
||||
request.validate()
|
||||
|
||||
def test_invalid_seed(self):
|
||||
"""무효한 시드값 테스트"""
|
||||
request = ImageGenerationRequest(
|
||||
prompt="Test prompt",
|
||||
seed=-1
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Seed value must be an integer between 0 and 4294967295"):
|
||||
request.validate()
|
||||
|
||||
def test_invalid_number_of_images(self):
|
||||
"""무효한 이미지 개수 테스트"""
|
||||
request = ImageGenerationRequest(
|
||||
prompt="Test prompt",
|
||||
seed=12345,
|
||||
number_of_images=3
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Number of images must be 1 or 2 only"):
|
||||
request.validate()
|
||||
|
||||
|
||||
class TestImagen4Client:
|
||||
"""Imagen4Client 테스트"""
|
||||
|
||||
def test_client_initialization(self):
|
||||
"""클라이언트 초기화 테스트"""
|
||||
config = Config(project_id="test-project")
|
||||
|
||||
with patch('src.connector.imagen4_client.genai.Client'):
|
||||
client = Imagen4Client(config)
|
||||
assert client.config == config
|
||||
|
||||
def test_client_invalid_config(self):
|
||||
"""무효한 설정으로 클라이언트 초기화 테스트"""
|
||||
invalid_config = Config(project_id="")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid configuration"):
|
||||
Imagen4Client(invalid_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_image_success(self):
|
||||
"""이미지 생성 성공 테스트"""
|
||||
config = Config(project_id="test-project")
|
||||
|
||||
# Create mock response
|
||||
mock_image_data = b"fake_image_data"
|
||||
mock_response = Mock()
|
||||
mock_response.generated_images = [Mock()]
|
||||
mock_response.generated_images[0].image = Mock()
|
||||
mock_response.generated_images[0].image.image_bytes = mock_image_data
|
||||
|
||||
with patch('src.connector.imagen4_client.genai.Client') as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.models.generate_images = Mock(return_value=mock_response)
|
||||
|
||||
client = Imagen4Client(config)
|
||||
|
||||
request = ImageGenerationRequest(
|
||||
prompt="Test prompt",
|
||||
seed=12345
|
||||
)
|
||||
|
||||
with patch('asyncio.to_thread', return_value=mock_response):
|
||||
response = await client.generate_image(request)
|
||||
|
||||
assert response.success is True
|
||||
assert len(response.images_data) == 1
|
||||
assert response.images_data[0] == mock_image_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_image_failure(self):
|
||||
"""Image generation failure test"""
|
||||
config = Config(project_id="test-project")
|
||||
|
||||
with patch('src.connector.imagen4_client.genai.Client') as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
client = Imagen4Client(config)
|
||||
|
||||
request = ImageGenerationRequest(
|
||||
prompt="Test prompt",
|
||||
seed=12345
|
||||
)
|
||||
|
||||
# Simulate API call failure
|
||||
with patch('asyncio.to_thread', side_effect=Exception("API Error")):
|
||||
response = await client.generate_image(request)
|
||||
|
||||
assert response.success is False
|
||||
assert "API Error" in response.error_message
|
||||
assert len(response.images_data) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -1,136 +0,0 @@
|
||||
"""
|
||||
Tests for Imagen4 Server
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from src.connector import Config
|
||||
from src.server import Imagen4MCPServer, ToolHandlers, MCPToolDefinitions
|
||||
from src.server.models import MCPToolDefinitions
|
||||
|
||||
|
||||
class TestMCPToolDefinitions:
|
||||
"""MCPToolDefinitions tests"""
|
||||
|
||||
def test_get_generate_image_tool(self):
|
||||
"""Image generation tool definition test"""
|
||||
tool = MCPToolDefinitions.get_generate_image_tool()
|
||||
|
||||
assert tool.name == "generate_image"
|
||||
assert "Google Imagen 4" in tool.description
|
||||
assert "prompt" in tool.inputSchema["properties"]
|
||||
assert "seed" in tool.inputSchema["required"]
|
||||
|
||||
def test_get_regenerate_from_json_tool(self):
|
||||
"""JSON regeneration tool definition test"""
|
||||
tool = MCPToolDefinitions.get_regenerate_from_json_tool()
|
||||
|
||||
assert tool.name == "regenerate_from_json"
|
||||
assert "JSON file" in tool.description
|
||||
assert "json_file_path" in tool.inputSchema["properties"]
|
||||
|
||||
def test_get_generate_random_seed_tool(self):
|
||||
"""Random seed generation tool definition test"""
|
||||
tool = MCPToolDefinitions.get_generate_random_seed_tool()
|
||||
|
||||
assert tool.name == "generate_random_seed"
|
||||
assert "random seed" in tool.description
|
||||
|
||||
def test_get_all_tools(self):
|
||||
"""All tools return test"""
|
||||
tools = MCPToolDefinitions.get_all_tools()
|
||||
|
||||
assert len(tools) == 3
|
||||
tool_names = [tool.name for tool in tools]
|
||||
assert "generate_image" in tool_names
|
||||
assert "regenerate_from_json" in tool_names
|
||||
assert "generate_random_seed" in tool_names
|
||||
|
||||
|
||||
class TestToolHandlers:
|
||||
"""ToolHandlers tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
"""Test configuration"""
|
||||
return Config(project_id="test-project")
|
||||
|
||||
@pytest.fixture
|
||||
def handlers(self, config):
|
||||
"""Test handlers"""
|
||||
with patch('src.server.handlers.Imagen4Client'):
|
||||
return ToolHandlers(config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_generate_random_seed(self, handlers):
|
||||
"""Random seed generation handler test"""
|
||||
result = await handlers.handle_generate_random_seed({})
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].type == "text"
|
||||
assert "Generated random seed:" in result[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_generate_image_missing_prompt(self, handlers):
|
||||
"""Missing prompt test"""
|
||||
arguments = {"seed": 12345}
|
||||
|
||||
result = await handlers.handle_generate_image(arguments)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].type == "text"
|
||||
assert "Prompt is required" in result[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_generate_image_missing_seed(self, handlers):
|
||||
"""Missing seed value test"""
|
||||
arguments = {"prompt": "Test prompt"}
|
||||
|
||||
result = await handlers.handle_generate_image(arguments)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].type == "text"
|
||||
assert "Seed value is required" in result[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_regenerate_from_json_missing_path(self, handlers):
|
||||
"""Missing JSON file path test"""
|
||||
arguments = {}
|
||||
|
||||
result = await handlers.handle_regenerate_from_json(arguments)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].type == "text"
|
||||
assert "JSON file path is required" in result[0].text
|
||||
|
||||
|
||||
class TestImagen4MCPServer:
|
||||
"""Imagen4MCPServer tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
"""Test configuration"""
|
||||
return Config(project_id="test-project")
|
||||
|
||||
def test_server_initialization(self, config):
|
||||
"""Server initialization test"""
|
||||
with patch('src.server.mcp_server.ToolHandlers'):
|
||||
server = Imagen4MCPServer(config)
|
||||
|
||||
assert server.config == config
|
||||
assert server.server is not None
|
||||
assert server.handlers is not None
|
||||
|
||||
def test_get_server(self, config):
|
||||
"""Server instance return test"""
|
||||
with patch('src.server.mcp_server.ToolHandlers'):
|
||||
mcp_server = Imagen4MCPServer(config)
|
||||
server = mcp_server.get_server()
|
||||
|
||||
assert server is not None
|
||||
assert server.name == "imagen4-mcp-server"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user