remove unused source files
This commit is contained in:
48
README.md
48
README.md
@@ -2,7 +2,30 @@
|
|||||||
|
|
||||||
Google Imagen 4를 사용한 AI 이미지 생성 MCP 서버입니다. 하나의 `main.py` 파일로 모든 기능이 통합되어 있으며, **한글 프롬프트를 완벽 지원**합니다.
|
Google Imagen 4를 사용한 AI 이미지 생성 MCP 서버입니다. 하나의 `main.py` 파일로 모든 기능이 통합되어 있으며, **한글 프롬프트를 완벽 지원**합니다.
|
||||||
|
|
||||||
## 🚀 주요 기능
|
## 🚨 주의사항 - MCP 프로토콜 호환성
|
||||||
|
|
||||||
|
> **중요**: 이 서버는 MCP (Model Context Protocol)를 사용합니다. MCP는 stdout을 통해 JSON 메시지를 주고받으므로, 서버 코드에서 **절대로 stdout에 직접 출력하면 안됩니다**.
|
||||||
|
|
||||||
|
### ❌ 금지사항
|
||||||
|
```python
|
||||||
|
print("message") # ❌ MCP JSON 프로토콜 방해
|
||||||
|
print("message", file=sys.stdout) # ❌ MCP JSON 프로토콜 방해
|
||||||
|
sys.stdout.write("message") # ❌ MCP JSON 프로토콜 방해
|
||||||
|
```
|
||||||
|
|
||||||
|
### ✅ 권장사항
|
||||||
|
```python
|
||||||
|
print("message", file=sys.stderr) # ✅ 로그용 출력
|
||||||
|
logging.info("message") # ✅ 로깅 시스템 사용
|
||||||
|
logger.error("message") # ✅ 로깅 시스템 사용
|
||||||
|
```
|
||||||
|
|
||||||
|
### 최신 수정사항 (2025-08-26)
|
||||||
|
- **UTF-8 테스트 출력 비활성화**: `[UTF8-TEST]` 메시지가 MCP JSON 파싱을 방해하던 문제 수정
|
||||||
|
- **MCP 안전 실행**: `run_mcp_safe.py` 및 `run_mcp_safe.bat` 추가
|
||||||
|
- **호환성 테스트**: `test_mcp_compatibility.py` 추가
|
||||||
|
|
||||||
|
**개발 시 반드시 `CLAUDE.md` 파일의 개발 가이드라인을 참조하세요!**
|
||||||
|
|
||||||
- **고품질 이미지 생성**: 2048x2048 PNG 이미지 생성
|
- **고품질 이미지 생성**: 2048x2048 PNG 이미지 생성
|
||||||
- **미리보기 이미지**: 512x512 JPEG 미리보기를 base64로 제공
|
- **미리보기 이미지**: 512x512 JPEG 미리보기를 base64로 제공
|
||||||
@@ -58,16 +81,16 @@ OUTPUT_PATH=./generated_images
|
|||||||
|
|
||||||
## 🎮 사용법
|
## 🎮 사용법
|
||||||
|
|
||||||
### 서버 실행
|
### 서버 실행 (업데이트됨!)
|
||||||
|
|
||||||
#### Windows (유니코드 지원)
|
#### MCP 안전 실행 (추천)
|
||||||
```bash
|
```bash
|
||||||
run_unicode.bat
|
run_mcp_safe.bat
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 기본 실행
|
또는 Python으로 직접:
|
||||||
```bash
|
```bash
|
||||||
run.bat
|
python run_mcp_safe.py
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 직접 실행 (유니코드 환경 설정)
|
#### 직접 실행 (유니코드 환경 설정)
|
||||||
@@ -85,24 +108,28 @@ python main.py
|
|||||||
|
|
||||||
### Claude Desktop 설정
|
### Claude Desktop 설정
|
||||||
|
|
||||||
`claude_desktop_config.json` 내용을 Claude Desktop 설정에 추가:
|
Claude Desktop 설정에 다음 내용을 추가하거나, `claude_desktop_config_mcp_safe.json` 파일을 사용하세요:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"mcpServers": {
|
"mcpServers": {
|
||||||
"imagen4": {
|
"imagen4": {
|
||||||
"command": "python",
|
"command": "python",
|
||||||
"args": ["main.py"],
|
"args": ["run_mcp_safe.py"],
|
||||||
"cwd": "D:\\Project\\imagen4",
|
"cwd": "D:\\Project\\imagen4",
|
||||||
"env": {
|
"env": {
|
||||||
|
"PYTHONPATH": "D:\\Project\\imagen4",
|
||||||
"PYTHONIOENCODING": "utf-8",
|
"PYTHONIOENCODING": "utf-8",
|
||||||
"PYTHONUTF8": "1"
|
"PYTHONUTF8": "1",
|
||||||
|
"LC_ALL": "C.UTF-8"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **중요**: MCP 호환성을 위해 `run_mcp_safe.py`를 사용하세요!
|
||||||
|
|
||||||
## 🛠️ 사용 가능한 도구
|
## 🛠️ 사용 가능한 도구
|
||||||
|
|
||||||
### 1. generate_image
|
### 1. generate_image
|
||||||
@@ -195,7 +222,8 @@ def ensure_unicode_string(value):
|
|||||||
한글 프롬프트 지원을 테스트하려면:
|
한글 프롬프트 지원을 테스트하려면:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python test_main.py
|
python test_mcp_compatibility.py # MCP 호환성 테스트 (신규)
|
||||||
|
python test_main.py # 기본 기능 테스트
|
||||||
```
|
```
|
||||||
|
|
||||||
이 스크립트는 다음을 확인합니다:
|
이 스크립트는 다음을 확인합니다:
|
||||||
|
|||||||
@@ -2,10 +2,13 @@
|
|||||||
"mcpServers": {
|
"mcpServers": {
|
||||||
"imagen4": {
|
"imagen4": {
|
||||||
"command": "python",
|
"command": "python",
|
||||||
"args": ["main.py"],
|
"args": ["run_mcp_safe.py"],
|
||||||
"cwd": "D:\\Project\\imagen4",
|
"cwd": "D:\\Project\\imagen4",
|
||||||
"env": {
|
"env": {
|
||||||
"PYTHONPATH": "D:\\Project\\imagen4"
|
"PYTHONPATH": "D:\\Project\\imagen4",
|
||||||
|
"PYTHONIOENCODING": "utf-8",
|
||||||
|
"PYTHONUTF8": "1",
|
||||||
|
"LC_ALL": "C.UTF-8"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
27
main.py
27
main.py
@@ -74,12 +74,15 @@ if sys.platform.startswith('win'):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Verify UTF-8 setup
|
# Verify UTF-8 setup (disabled for MCP compatibility)
|
||||||
|
# Note: UTF-8 test output disabled to prevent MCP protocol interference
|
||||||
try:
|
try:
|
||||||
test_unicode = "Test UTF-8: 한글 테스트 ✓"
|
test_unicode = "Test UTF-8: 한글 테스트 ✓"
|
||||||
print(f"[UTF8-TEST] {test_unicode}")
|
# print(f"[UTF8-TEST] {test_unicode}") # Disabled: interferes with MCP JSON protocol
|
||||||
except UnicodeEncodeError as e:
|
except UnicodeEncodeError as e:
|
||||||
print(f"[UTF8-ERROR] Unicode test failed: {e}")
|
# Only log errors to stderr, not stdout
|
||||||
|
import sys
|
||||||
|
print(f"[UTF8-ERROR] Unicode test failed: {e}", file=sys.stderr)
|
||||||
|
|
||||||
# ==================== Regular Imports (after UTF-8 setup) ====================
|
# ==================== Regular Imports (after UTF-8 setup) ====================
|
||||||
|
|
||||||
@@ -97,7 +100,9 @@ try:
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Warning: python-dotenv not installed", file=sys.stderr)
|
# Use logger for dotenv warning instead of direct print
|
||||||
|
import logging
|
||||||
|
logging.getLogger("imagen4-mcp-server").warning("python-dotenv not installed")
|
||||||
|
|
||||||
# Add current directory to PYTHONPATH
|
# Add current directory to PYTHONPATH
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
@@ -111,8 +116,11 @@ try:
|
|||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.types import Tool, TextContent
|
from mcp.types import Tool, TextContent
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"Error importing MCP: {e}", file=sys.stderr)
|
# Use logger for MCP import errors instead of direct print
|
||||||
print("Please install required packages: pip install mcp", file=sys.stderr)
|
import logging
|
||||||
|
logger = logging.getLogger("imagen4-mcp-server")
|
||||||
|
logger.error(f"Error importing MCP: {e}")
|
||||||
|
logger.error("Please install required packages: pip install mcp")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Connector imports
|
# Connector imports
|
||||||
@@ -124,8 +132,11 @@ from src.connector.utils import save_generated_images
|
|||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"Error importing Pillow: {e}", file=sys.stderr)
|
# Use logger for Pillow import errors instead of direct print
|
||||||
print("Please install Pillow: pip install Pillow", file=sys.stderr)
|
import logging
|
||||||
|
logger = logging.getLogger("imagen4-mcp-server")
|
||||||
|
logger.error(f"Error importing Pillow: {e}")
|
||||||
|
logger.error("Please install Pillow: pip install Pillow")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# ==================== Unicode-Safe Logging Setup ====================
|
# ==================== Unicode-Safe Logging Setup ====================
|
||||||
|
|||||||
56
run.bat
56
run.bat
@@ -1,56 +0,0 @@
|
|||||||
@echo off
|
|
||||||
chcp 65001 > nul 2>&1
|
|
||||||
echo Starting Imagen4 MCP Server with Preview Image Support (Unicode Enabled)...
|
|
||||||
echo Features: 512x512 JPEG preview images, base64 encoding, Unicode logging support
|
|
||||||
echo.
|
|
||||||
|
|
||||||
REM Set UTF-8 environment for Python
|
|
||||||
set PYTHONIOENCODING=utf-8
|
|
||||||
set PYTHONUTF8=1
|
|
||||||
|
|
||||||
REM Check if virtual environment exists
|
|
||||||
if not exist "venv\Scripts\activate.bat" (
|
|
||||||
echo Creating virtual environment...
|
|
||||||
python -m venv venv
|
|
||||||
if errorlevel 1 (
|
|
||||||
echo Failed to create virtual environment
|
|
||||||
pause
|
|
||||||
exit /b 1
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
REM Activate virtual environment
|
|
||||||
echo Activating virtual environment...
|
|
||||||
call venv\Scripts\activate.bat
|
|
||||||
|
|
||||||
REM Check if requirements are installed
|
|
||||||
python -c "import PIL" 2>nul
|
|
||||||
if errorlevel 1 (
|
|
||||||
echo Installing requirements...
|
|
||||||
pip install -r requirements.txt
|
|
||||||
if errorlevel 1 (
|
|
||||||
echo Failed to install requirements
|
|
||||||
pause
|
|
||||||
exit /b 1
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
REM Check if .env file exists
|
|
||||||
if not exist ".env" (
|
|
||||||
echo Warning: .env file not found
|
|
||||||
echo Please copy .env.example to .env and configure your API credentials
|
|
||||||
pause
|
|
||||||
exit /b 1
|
|
||||||
)
|
|
||||||
|
|
||||||
REM Run server with UTF-8 support
|
|
||||||
echo Starting MCP server with Unicode support...
|
|
||||||
echo 한글 프롬프트 지원이 활성화되었습니다.
|
|
||||||
python main.py
|
|
||||||
|
|
||||||
REM Keep window open if there's an error
|
|
||||||
if errorlevel 1 (
|
|
||||||
echo.
|
|
||||||
echo Server exited with error
|
|
||||||
pause
|
|
||||||
)
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
@echo off
|
|
||||||
echo Starting Imagen4 MCP Server with forced UTF-8 mode...
|
|
||||||
echo.
|
|
||||||
|
|
||||||
REM Set console code page to UTF-8
|
|
||||||
chcp 65001 >nul 2>&1
|
|
||||||
|
|
||||||
REM Set all UTF-8 environment variables
|
|
||||||
set PYTHONIOENCODING=utf-8
|
|
||||||
set PYTHONUTF8=1
|
|
||||||
set LC_ALL=C.UTF-8
|
|
||||||
set LANG=C.UTF-8
|
|
||||||
|
|
||||||
REM Set console properties
|
|
||||||
title Imagen4 MCP Server (UTF-8 Forced)
|
|
||||||
|
|
||||||
echo Environment Variables Set:
|
|
||||||
echo PYTHONIOENCODING=%PYTHONIOENCODING%
|
|
||||||
echo PYTHONUTF8=%PYTHONUTF8%
|
|
||||||
echo LC_ALL=%LC_ALL%
|
|
||||||
echo.
|
|
||||||
|
|
||||||
REM Run Python with explicit UTF-8 mode flag
|
|
||||||
python -X utf8 main.py
|
|
||||||
|
|
||||||
echo.
|
|
||||||
echo Server stopped. Press any key to exit...
|
|
||||||
pause >nul
|
|
||||||
22
run_mcp_safe.bat
Normal file
22
run_mcp_safe.bat
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
@echo off
|
||||||
|
REM MCP-Safe UTF-8 runner for imagen4 server
|
||||||
|
REM This prevents stdout interference with MCP JSON protocol
|
||||||
|
|
||||||
|
echo [STARTUP] Starting Imagen4 MCP Server with safe UTF-8 handling...
|
||||||
|
|
||||||
|
REM Change to script directory
|
||||||
|
cd /d "%~dp0"
|
||||||
|
|
||||||
|
REM Set UTF-8 code page
|
||||||
|
chcp 65001 >nul 2>&1
|
||||||
|
|
||||||
|
REM Set UTF-8 environment variables
|
||||||
|
set PYTHONIOENCODING=utf-8
|
||||||
|
set PYTHONUTF8=1
|
||||||
|
set LC_ALL=C.UTF-8
|
||||||
|
|
||||||
|
REM Run the safe MCP runner
|
||||||
|
python run_mcp_safe.py
|
||||||
|
|
||||||
|
echo [SHUTDOWN] Imagen4 MCP Server stopped.
|
||||||
|
pause
|
||||||
66
run_mcp_safe.py
Normal file
66
run_mcp_safe.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
MCP Protocol Compatible UTF-8 Runner for imagen4
|
||||||
|
|
||||||
|
This script ensures proper UTF-8 handling without interfering with MCP protocol
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
import locale
|
||||||
|
|
||||||
|
def setup_utf8_environment():
|
||||||
|
"""Setup UTF-8 environment variables without stdout interference"""
|
||||||
|
# Set UTF-8 environment variables
|
||||||
|
env = os.environ.copy()
|
||||||
|
env['PYTHONIOENCODING'] = 'utf-8'
|
||||||
|
env['PYTHONUTF8'] = '1'
|
||||||
|
env['LC_ALL'] = 'C.UTF-8'
|
||||||
|
|
||||||
|
# Windows-specific setup
|
||||||
|
if sys.platform.startswith('win'):
|
||||||
|
# Set console code page to UTF-8 (suppress output)
|
||||||
|
try:
|
||||||
|
os.system('chcp 65001 >nul 2>&1')
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Set locale
|
||||||
|
try:
|
||||||
|
locale.setlocale(locale.LC_ALL, 'C.UTF-8')
|
||||||
|
except locale.Error:
|
||||||
|
try:
|
||||||
|
locale.setlocale(locale.LC_ALL, '')
|
||||||
|
except locale.Error:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function - run imagen4 MCP server with proper UTF-8 setup"""
|
||||||
|
# Setup UTF-8 environment
|
||||||
|
env = setup_utf8_environment()
|
||||||
|
|
||||||
|
# Get the directory of this script
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
main_py = os.path.join(script_dir, 'main.py')
|
||||||
|
|
||||||
|
# Run main.py with proper environment
|
||||||
|
try:
|
||||||
|
# Use subprocess to run with clean environment
|
||||||
|
result = subprocess.run([
|
||||||
|
sys.executable, main_py
|
||||||
|
], env=env, cwd=script_dir)
|
||||||
|
|
||||||
|
return result.returncode
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Server stopped by user", file=sys.stderr)
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error running server: {e}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
16
run_utf8.bat
16
run_utf8.bat
@@ -1,16 +0,0 @@
|
|||||||
@echo off
|
|
||||||
REM Force UTF-8 encoding for Python and console
|
|
||||||
chcp 65001 >nul 2>&1
|
|
||||||
|
|
||||||
REM Set Python UTF-8 environment variables
|
|
||||||
set PYTHONIOENCODING=utf-8
|
|
||||||
set PYTHONUTF8=1
|
|
||||||
set LC_ALL=C.UTF-8
|
|
||||||
|
|
||||||
REM Set console properties
|
|
||||||
title Imagen4 MCP Server (UTF-8)
|
|
||||||
|
|
||||||
REM Run the Python script
|
|
||||||
python main.py
|
|
||||||
|
|
||||||
pause
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Emergency UTF-8 override script
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Force UTF-8 mode at Python startup
|
|
||||||
if not hasattr(sys, '_utf8_mode') or not sys._utf8_mode:
|
|
||||||
os.environ['PYTHONUTF8'] = '1'
|
|
||||||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|
||||||
|
|
||||||
# Restart Python with UTF-8 mode
|
|
||||||
import subprocess
|
|
||||||
result = subprocess.run([sys.executable, '-X', 'utf8', __file__] + sys.argv[1:])
|
|
||||||
sys.exit(result.returncode)
|
|
||||||
|
|
||||||
# Now run the actual main.py
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import importlib.util
|
|
||||||
import sys
|
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location("main", "main.py")
|
|
||||||
main_module = importlib.util.module_from_spec(spec)
|
|
||||||
sys.modules["main"] = main_module
|
|
||||||
spec.loader.exec_module(main_module)
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
"""
|
|
||||||
Async Task Management Module for MCP Server
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .models import TaskStatus, TaskResult, generate_task_id
|
|
||||||
from .task_manager import TaskManager
|
|
||||||
from .worker_pool import WorkerPool
|
|
||||||
|
|
||||||
__all__ = ['TaskManager', 'TaskStatus', 'TaskResult', 'WorkerPool']
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
"""
|
|
||||||
Task Status and Result Models
|
|
||||||
"""
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Optional, Dict
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(Enum):
|
|
||||||
"""Task execution status"""
|
|
||||||
PENDING = "pending"
|
|
||||||
RUNNING = "running"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
CANCELLED = "cancelled"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TaskResult:
|
|
||||||
"""Task execution result"""
|
|
||||||
task_id: str
|
|
||||||
status: TaskStatus
|
|
||||||
result: Optional[Any] = None
|
|
||||||
error: Optional[str] = None
|
|
||||||
created_at: datetime = field(default_factory=datetime.now)
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def duration(self) -> Optional[float]:
|
|
||||||
"""Calculate task execution duration in seconds"""
|
|
||||||
if self.started_at and self.completed_at:
|
|
||||||
return (self.completed_at - self.started_at).total_seconds()
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_finished(self) -> bool:
|
|
||||||
"""Check if task is finished (completed, failed, or cancelled)"""
|
|
||||||
return self.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_task_id() -> str:
|
|
||||||
"""Generate unique task ID"""
|
|
||||||
return str(uuid.uuid4())
|
|
||||||
@@ -1,324 +0,0 @@
|
|||||||
"""
|
|
||||||
Task Manager for MCP Server
|
|
||||||
Manages background task execution with status tracking and result retrieval
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from typing import Any, Callable, Optional, Dict, List
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
from .models import TaskResult, TaskStatus, generate_task_id
|
|
||||||
from .worker_pool import WorkerPool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskManager:
|
|
||||||
"""Main task manager for MCP server"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_workers: Optional[int] = None,
|
|
||||||
use_process_pool: bool = False,
|
|
||||||
default_timeout: float = 600.0, # 10 minutes
|
|
||||||
result_retention_hours: int = 24, # Keep results for 24 hours
|
|
||||||
max_retained_results: int = 200
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize task manager
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_workers: Maximum number of worker threads/processes
|
|
||||||
use_process_pool: Use process pool instead of thread pool
|
|
||||||
default_timeout: Default timeout for tasks in seconds
|
|
||||||
result_retention_hours: How long to keep finished task results
|
|
||||||
max_retained_results: Maximum number of results to retain
|
|
||||||
"""
|
|
||||||
self.worker_pool = WorkerPool(
|
|
||||||
max_workers=max_workers,
|
|
||||||
use_process_pool=use_process_pool,
|
|
||||||
task_timeout=default_timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
self.result_retention_hours = result_retention_hours
|
|
||||||
self.max_retained_results = max_retained_results
|
|
||||||
|
|
||||||
# Start cleanup task
|
|
||||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
|
||||||
|
|
||||||
logger.info(f"TaskManager initialized with {self.worker_pool.max_workers} workers")
|
|
||||||
|
|
||||||
async def submit_task(
|
|
||||||
self,
|
|
||||||
func: Callable,
|
|
||||||
*args,
|
|
||||||
task_id: Optional[str] = None,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Submit a task for background execution
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func: Function to execute
|
|
||||||
*args: Function arguments
|
|
||||||
task_id: Optional custom task ID
|
|
||||||
timeout: Task timeout in seconds
|
|
||||||
metadata: Additional metadata for the task
|
|
||||||
**kwargs: Function keyword arguments
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Task ID for tracking progress
|
|
||||||
"""
|
|
||||||
task_id = await self.worker_pool.submit_task(
|
|
||||||
func, *args, task_id=task_id, timeout=timeout, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add additional metadata if provided
|
|
||||||
if metadata:
|
|
||||||
result = self.worker_pool.get_task_result(task_id)
|
|
||||||
if result:
|
|
||||||
result.metadata.update(metadata)
|
|
||||||
|
|
||||||
return task_id
|
|
||||||
|
|
||||||
def get_task_status(self, task_id: str) -> Optional[TaskStatus]:
|
|
||||||
"""Get task status by ID"""
|
|
||||||
result = self.worker_pool.get_task_result(task_id)
|
|
||||||
return result.status if result else None
|
|
||||||
|
|
||||||
def get_task_result(self, task_id: str) -> Optional[TaskResult]:
|
|
||||||
"""Get complete task result by ID with circular reference protection"""
|
|
||||||
try:
|
|
||||||
raw_result = self.worker_pool.get_task_result(task_id)
|
|
||||||
if not raw_result:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Create a safe copy to avoid circular references
|
|
||||||
safe_result = TaskResult(
|
|
||||||
task_id=raw_result.task_id,
|
|
||||||
status=raw_result.status,
|
|
||||||
result=None, # We'll set this safely
|
|
||||||
error=raw_result.error,
|
|
||||||
created_at=raw_result.created_at,
|
|
||||||
started_at=raw_result.started_at,
|
|
||||||
completed_at=raw_result.completed_at,
|
|
||||||
metadata=raw_result.metadata.copy() if raw_result.metadata else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Safely copy the result data
|
|
||||||
if raw_result.result is not None:
|
|
||||||
try:
|
|
||||||
# If it's a dict, create a clean copy
|
|
||||||
if isinstance(raw_result.result, dict):
|
|
||||||
safe_result.result = {
|
|
||||||
'success': raw_result.result.get('success', False),
|
|
||||||
'error': raw_result.result.get('error'),
|
|
||||||
'images_b64': raw_result.result.get('images_b64', []),
|
|
||||||
'saved_files': raw_result.result.get('saved_files', []),
|
|
||||||
'request': raw_result.result.get('request', {}),
|
|
||||||
'image_count': raw_result.result.get('image_count', 0)
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# For non-dict results, just copy directly
|
|
||||||
safe_result.result = raw_result.result
|
|
||||||
except Exception as copy_error:
|
|
||||||
logger.error(f"Error creating safe copy of result: {str(copy_error)}")
|
|
||||||
safe_result.result = {
|
|
||||||
'success': False,
|
|
||||||
'error': f'Result data corrupted: {str(copy_error)}'
|
|
||||||
}
|
|
||||||
|
|
||||||
return safe_result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in get_task_result: {str(e)}", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_task_progress(self, task_id: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get task progress information
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with task status, timing info, and metadata
|
|
||||||
"""
|
|
||||||
result = self.get_task_result(task_id)
|
|
||||||
if not result:
|
|
||||||
return {'error': 'Task not found'}
|
|
||||||
|
|
||||||
progress = {
|
|
||||||
'task_id': result.task_id,
|
|
||||||
'status': result.status.value,
|
|
||||||
'created_at': result.created_at.isoformat(),
|
|
||||||
'metadata': result.metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.started_at:
|
|
||||||
progress['started_at'] = result.started_at.isoformat()
|
|
||||||
|
|
||||||
if result.status == TaskStatus.RUNNING:
|
|
||||||
elapsed = (datetime.now() - result.started_at).total_seconds()
|
|
||||||
progress['elapsed_seconds'] = elapsed
|
|
||||||
|
|
||||||
if result.completed_at:
|
|
||||||
progress['completed_at'] = result.completed_at.isoformat()
|
|
||||||
progress['duration_seconds'] = result.duration
|
|
||||||
|
|
||||||
if result.error:
|
|
||||||
progress['error'] = result.error
|
|
||||||
|
|
||||||
return progress
|
|
||||||
|
|
||||||
def list_tasks(
|
|
||||||
self,
|
|
||||||
status_filter: Optional[TaskStatus] = None,
|
|
||||||
limit: int = 50
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
List tasks with optional filtering
|
|
||||||
|
|
||||||
Args:
|
|
||||||
status_filter: Filter by task status
|
|
||||||
limit: Maximum number of tasks to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of task progress dictionaries
|
|
||||||
"""
|
|
||||||
all_results = self.worker_pool.get_all_results()
|
|
||||||
|
|
||||||
# Filter by status if requested
|
|
||||||
if status_filter:
|
|
||||||
filtered_results = {
|
|
||||||
k: v for k, v in all_results.items()
|
|
||||||
if v.status == status_filter
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
filtered_results = all_results
|
|
||||||
|
|
||||||
# Sort by creation time (most recent first)
|
|
||||||
sorted_items = sorted(
|
|
||||||
filtered_results.items(),
|
|
||||||
key=lambda x: x[1].created_at,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply limit and return progress info
|
|
||||||
return [
|
|
||||||
self.get_task_progress(task_id)
|
|
||||||
for task_id, _ in sorted_items[:limit]
|
|
||||||
]
|
|
||||||
|
|
||||||
def cancel_task(self, task_id: str) -> bool:
|
|
||||||
"""Cancel a running task"""
|
|
||||||
return self.worker_pool.cancel_task(task_id)
|
|
||||||
|
|
||||||
async def wait_for_task(
|
|
||||||
self,
|
|
||||||
task_id: str,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
poll_interval: float = 1.0
|
|
||||||
) -> TaskResult:
|
|
||||||
"""
|
|
||||||
Wait for a task to complete
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to wait for
|
|
||||||
timeout: Maximum time to wait in seconds
|
|
||||||
poll_interval: How often to check task status
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TaskResult: Final task result
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
asyncio.TimeoutError: If timeout is reached
|
|
||||||
ValueError: If task doesn't exist
|
|
||||||
"""
|
|
||||||
start_time = datetime.now()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
result = self.get_task_result(task_id)
|
|
||||||
if not result:
|
|
||||||
raise ValueError(f"Task {task_id} not found")
|
|
||||||
|
|
||||||
if result.is_finished:
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Check timeout
|
|
||||||
if timeout:
|
|
||||||
elapsed = (datetime.now() - start_time).total_seconds()
|
|
||||||
if elapsed >= timeout:
|
|
||||||
raise asyncio.TimeoutError(f"Timeout waiting for task {task_id}")
|
|
||||||
|
|
||||||
await asyncio.sleep(poll_interval)
|
|
||||||
|
|
||||||
async def _periodic_cleanup(self) -> None:
|
|
||||||
"""Periodic cleanup of old task results"""
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# Wait 1 hour between cleanups
|
|
||||||
await asyncio.sleep(3600)
|
|
||||||
|
|
||||||
# Clean up old results
|
|
||||||
cutoff_time = datetime.now() - timedelta(hours=self.result_retention_hours)
|
|
||||||
all_results = self.worker_pool.get_all_results()
|
|
||||||
|
|
||||||
to_remove = []
|
|
||||||
for task_id, result in all_results.items():
|
|
||||||
if (result.is_finished and
|
|
||||||
result.completed_at and
|
|
||||||
result.completed_at < cutoff_time):
|
|
||||||
to_remove.append(task_id)
|
|
||||||
|
|
||||||
# Remove old results
|
|
||||||
for task_id in to_remove:
|
|
||||||
if task_id in self.worker_pool._results:
|
|
||||||
del self.worker_pool._results[task_id]
|
|
||||||
|
|
||||||
if to_remove:
|
|
||||||
logger.info(f"Cleaned up {len(to_remove)} old task results")
|
|
||||||
|
|
||||||
# Also clean up excess results
|
|
||||||
self.worker_pool.cleanup_finished_tasks(self.max_retained_results)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in periodic cleanup: {e}")
|
|
||||||
|
|
||||||
async def shutdown(self, wait: bool = True) -> None:
|
|
||||||
"""Shutdown task manager"""
|
|
||||||
logger.info("Shutting down task manager...")
|
|
||||||
|
|
||||||
# Cancel cleanup task
|
|
||||||
if not self._cleanup_task.done():
|
|
||||||
self._cleanup_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._cleanup_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Shutdown worker pool
|
|
||||||
await self.worker_pool.shutdown(wait=wait)
|
|
||||||
|
|
||||||
logger.info("Task manager shutdown complete")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stats(self) -> Dict[str, Any]:
|
|
||||||
"""Get task manager statistics"""
|
|
||||||
all_results = self.worker_pool.get_all_results()
|
|
||||||
|
|
||||||
status_counts = {}
|
|
||||||
for status in TaskStatus:
|
|
||||||
status_counts[status.value] = sum(
|
|
||||||
1 for r in all_results.values() if r.status == status
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'total_tasks': len(all_results),
|
|
||||||
'active_tasks': self.worker_pool.active_task_count,
|
|
||||||
'max_workers': self.worker_pool.max_workers,
|
|
||||||
'worker_type': 'process' if self.worker_pool.use_process_pool else 'thread',
|
|
||||||
'status_counts': status_counts
|
|
||||||
}
|
|
||||||
@@ -1,258 +0,0 @@
|
|||||||
"""
|
|
||||||
Worker Pool for Background Task Execution
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import multiprocessing
|
|
||||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
|
||||||
from typing import Any, Callable, Optional, Union
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from .models import TaskResult, TaskStatus, generate_task_id
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class WorkerPool:
|
|
||||||
"""Worker pool for executing background tasks"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_workers: Optional[int] = None,
|
|
||||||
use_process_pool: bool = False,
|
|
||||||
task_timeout: float = 600.0 # 10 minutes default
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize worker pool
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_workers: Maximum number of worker threads/processes
|
|
||||||
use_process_pool: Use ProcessPoolExecutor instead of ThreadPoolExecutor
|
|
||||||
task_timeout: Default timeout for tasks in seconds
|
|
||||||
"""
|
|
||||||
self.max_workers = max_workers or min(32, (multiprocessing.cpu_count() or 1) + 4)
|
|
||||||
self.use_process_pool = use_process_pool
|
|
||||||
self.task_timeout = task_timeout
|
|
||||||
|
|
||||||
# Initialize executor
|
|
||||||
if use_process_pool:
|
|
||||||
self.executor = ProcessPoolExecutor(max_workers=self.max_workers)
|
|
||||||
logger.info(f"Initialized ProcessPoolExecutor with {self.max_workers} workers")
|
|
||||||
else:
|
|
||||||
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
|
||||||
logger.info(f"Initialized ThreadPoolExecutor with {self.max_workers} workers")
|
|
||||||
|
|
||||||
# Track running tasks
|
|
||||||
self._running_tasks: dict[str, asyncio.Task] = {}
|
|
||||||
self._results: dict[str, TaskResult] = {}
|
|
||||||
|
|
||||||
async def submit_task(
|
|
||||||
self,
|
|
||||||
func: Callable,
|
|
||||||
*args,
|
|
||||||
task_id: Optional[str] = None,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Submit a task for background execution
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func: Function to execute
|
|
||||||
*args: Function arguments
|
|
||||||
task_id: Optional custom task ID
|
|
||||||
timeout: Task timeout in seconds
|
|
||||||
**kwargs: Function keyword arguments
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Task ID for tracking
|
|
||||||
"""
|
|
||||||
if task_id is None:
|
|
||||||
task_id = generate_task_id()
|
|
||||||
|
|
||||||
timeout = timeout or self.task_timeout
|
|
||||||
|
|
||||||
# Create task result
|
|
||||||
task_result = TaskResult(
|
|
||||||
task_id=task_id,
|
|
||||||
status=TaskStatus.PENDING,
|
|
||||||
metadata={
|
|
||||||
'function_name': func.__name__,
|
|
||||||
'timeout': timeout,
|
|
||||||
'worker_type': 'process' if self.use_process_pool else 'thread'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
self._results[task_id] = task_result
|
|
||||||
|
|
||||||
# Create and start background task
|
|
||||||
background_task = asyncio.create_task(
|
|
||||||
self._execute_task(func, args, kwargs, task_result, timeout)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._running_tasks[task_id] = background_task
|
|
||||||
|
|
||||||
logger.info(f"Task {task_id} submitted for execution (timeout: {timeout}s)")
|
|
||||||
return task_id
|
|
||||||
|
|
||||||
async def _execute_task(
|
|
||||||
self,
|
|
||||||
func: Callable,
|
|
||||||
args: tuple,
|
|
||||||
kwargs: dict,
|
|
||||||
task_result: TaskResult,
|
|
||||||
timeout: float
|
|
||||||
) -> None:
|
|
||||||
"""Execute task in background"""
|
|
||||||
try:
|
|
||||||
task_result.status = TaskStatus.RUNNING
|
|
||||||
task_result.started_at = datetime.now()
|
|
||||||
|
|
||||||
logger.info(f"Starting task {task_result.task_id}: {func.__name__}")
|
|
||||||
|
|
||||||
# Execute function with timeout
|
|
||||||
if self.use_process_pool:
|
|
||||||
# For process pool, function must be pickleable
|
|
||||||
future = self.executor.submit(func, *args, **kwargs)
|
|
||||||
result = await asyncio.wait_for(
|
|
||||||
asyncio.wrap_future(future),
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# For thread pool, can use lambda/closure
|
|
||||||
result = await asyncio.wait_for(
|
|
||||||
asyncio.to_thread(func, *args, **kwargs),
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
task_result.result = result
|
|
||||||
task_result.status = TaskStatus.COMPLETED
|
|
||||||
task_result.completed_at = datetime.now()
|
|
||||||
|
|
||||||
logger.info(f"Task {task_result.task_id} completed successfully in {task_result.duration:.2f}s")
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
task_result.status = TaskStatus.FAILED
|
|
||||||
task_result.error = f"Task timed out after {timeout} seconds"
|
|
||||||
task_result.completed_at = datetime.now()
|
|
||||||
logger.error(f"Task {task_result.task_id} timed out after {timeout}s")
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
task_result.status = TaskStatus.CANCELLED
|
|
||||||
task_result.error = "Task was cancelled"
|
|
||||||
task_result.completed_at = datetime.now()
|
|
||||||
logger.info(f"Task {task_result.task_id} was cancelled")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
task_result.status = TaskStatus.FAILED
|
|
||||||
task_result.error = str(e)
|
|
||||||
task_result.completed_at = datetime.now()
|
|
||||||
logger.error(f"Task {task_result.task_id} failed: {str(e)}", exc_info=True)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clean up running task reference
|
|
||||||
if task_result.task_id in self._running_tasks:
|
|
||||||
del self._running_tasks[task_result.task_id]
|
|
||||||
|
|
||||||
def get_task_result(self, task_id: str) -> Optional[TaskResult]:
|
|
||||||
"""Get task result by ID"""
|
|
||||||
return self._results.get(task_id)
|
|
||||||
|
|
||||||
def get_all_results(self) -> dict[str, TaskResult]:
|
|
||||||
"""Get all task results"""
|
|
||||||
return self._results.copy()
|
|
||||||
|
|
||||||
def cancel_task(self, task_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Cancel a running task
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to cancel
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if task was cancelled, False if not found or already finished
|
|
||||||
"""
|
|
||||||
if task_id in self._running_tasks:
|
|
||||||
task = self._running_tasks[task_id]
|
|
||||||
if not task.done():
|
|
||||||
task.cancel()
|
|
||||||
logger.info(f"Task {task_id} cancellation requested")
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def cleanup_finished_tasks(self, max_results: int = 100) -> int:
|
|
||||||
"""
|
|
||||||
Clean up finished task results to prevent memory buildup
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_results: Maximum number of results to keep
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: Number of results cleaned up
|
|
||||||
"""
|
|
||||||
if len(self._results) <= max_results:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Sort by completion time, keep most recent
|
|
||||||
finished_results = {
|
|
||||||
k: v for k, v in self._results.items()
|
|
||||||
if v.is_finished and v.completed_at
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(finished_results) <= max_results:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
sorted_results = sorted(
|
|
||||||
finished_results.items(),
|
|
||||||
key=lambda x: x[1].completed_at,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Keep most recent max_results
|
|
||||||
to_keep = set(k for k, _ in sorted_results[:max_results])
|
|
||||||
|
|
||||||
# Remove old results
|
|
||||||
to_remove = [k for k in finished_results.keys() if k not in to_keep]
|
|
||||||
for task_id in to_remove:
|
|
||||||
del self._results[task_id]
|
|
||||||
|
|
||||||
cleaned_count = len(to_remove)
|
|
||||||
if cleaned_count > 0:
|
|
||||||
logger.info(f"Cleaned up {cleaned_count} old task results")
|
|
||||||
|
|
||||||
return cleaned_count
|
|
||||||
|
|
||||||
async def shutdown(self, wait: bool = True) -> None:
|
|
||||||
"""
|
|
||||||
Shutdown worker pool
|
|
||||||
|
|
||||||
Args:
|
|
||||||
wait: Whether to wait for running tasks to complete
|
|
||||||
"""
|
|
||||||
logger.info("Shutting down worker pool...")
|
|
||||||
|
|
||||||
if wait:
|
|
||||||
# Cancel all running tasks
|
|
||||||
for task_id, task in self._running_tasks.items():
|
|
||||||
if not task.done():
|
|
||||||
task.cancel()
|
|
||||||
logger.info(f"Cancelled task {task_id}")
|
|
||||||
|
|
||||||
# Wait for tasks to complete cancellation
|
|
||||||
if self._running_tasks:
|
|
||||||
await asyncio.gather(*self._running_tasks.values(), return_exceptions=True)
|
|
||||||
|
|
||||||
# Shutdown executor
|
|
||||||
self.executor.shutdown(wait=wait)
|
|
||||||
logger.info("Worker pool shutdown complete")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def active_task_count(self) -> int:
|
|
||||||
"""Get number of currently running tasks"""
|
|
||||||
return len([t for t in self._running_tasks.values() if not t.done()])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def total_task_count(self) -> int:
|
|
||||||
"""Get total number of tracked tasks"""
|
|
||||||
return len(self._results)
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
"""
|
|
||||||
Imagen4 Server Package
|
|
||||||
|
|
||||||
MCP 서버 구현을 담당하는 모듈
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Import basic server components
|
|
||||||
from src.server.mcp_server import Imagen4MCPServer
|
|
||||||
from src.server.handlers import ToolHandlers
|
|
||||||
from src.server.models import MCPToolDefinitions
|
|
||||||
|
|
||||||
# Import enhanced components
|
|
||||||
try:
|
|
||||||
from src.server.enhanced_mcp_server import EnhancedImagen4MCPServer
|
|
||||||
from src.server.enhanced_handlers import EnhancedToolHandlers
|
|
||||||
from src.server.enhanced_models import ImageGenerationResult, PreviewImageResponse
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'Imagen4MCPServer', 'ToolHandlers', 'MCPToolDefinitions',
|
|
||||||
'EnhancedImagen4MCPServer', 'EnhancedToolHandlers',
|
|
||||||
'ImageGenerationResult', 'PreviewImageResponse'
|
|
||||||
]
|
|
||||||
except ImportError as e:
|
|
||||||
# Fallback if enhanced modules are not available
|
|
||||||
print(f"Warning: Enhanced features not available: {e}")
|
|
||||||
__all__ = ['Imagen4MCPServer', 'ToolHandlers', 'MCPToolDefinitions']
|
|
||||||
@@ -1,344 +0,0 @@
|
|||||||
"""
|
|
||||||
Enhanced Tool Handlers for MCP Server with Preview Image Support and Unicode Logging
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
|
|
||||||
from mcp.types import TextContent, ImageContent
|
|
||||||
|
|
||||||
from src.connector import Imagen4Client, Config
|
|
||||||
from src.connector.imagen4_client import ImageGenerationRequest
|
|
||||||
from src.connector.utils import save_generated_images
|
|
||||||
from src.utils.image_utils import create_preview_image_b64, get_image_info
|
|
||||||
from src.server.enhanced_models import ImageGenerationResult, PreviewImageResponse
|
|
||||||
|
|
||||||
# Configure Unicode logging
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Ensure proper Unicode handling for all string operations
|
|
||||||
def ensure_unicode_string(value):
|
|
||||||
"""Ensure a value is a proper Unicode string"""
|
|
||||||
if isinstance(value, bytes):
|
|
||||||
return value.decode('utf-8', errors='replace')
|
|
||||||
elif isinstance(value, str):
|
|
||||||
return value
|
|
||||||
else:
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Remove or truncate sensitive data from arguments for safe logging with Unicode support"""
|
|
||||||
safe_args = {}
|
|
||||||
for key, value in arguments.items():
|
|
||||||
if isinstance(value, str):
|
|
||||||
# Ensure proper Unicode handling
|
|
||||||
value = ensure_unicode_string(value)
|
|
||||||
|
|
||||||
# Check if it's likely base64 image data
|
|
||||||
if (key in ['data', 'image_data', 'base64', 'preview_image_b64'] or
|
|
||||||
(len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))):
|
|
||||||
# Truncate long image data
|
|
||||||
safe_args[key] = f"<image_data:{len(value)} chars>"
|
|
||||||
elif len(value) > 1000:
|
|
||||||
# Truncate any very long strings but preserve Unicode characters
|
|
||||||
truncated = value[:100]
|
|
||||||
safe_args[key] = f"{truncated}...<truncated:{len(value)} total chars>"
|
|
||||||
else:
|
|
||||||
# Keep the original string (including Korean/Unicode characters)
|
|
||||||
safe_args[key] = value
|
|
||||||
else:
|
|
||||||
safe_args[key] = value
|
|
||||||
return safe_args
|
|
||||||
|
|
||||||
|
|
||||||
class EnhancedToolHandlers:
|
|
||||||
"""Enhanced MCP tool handler class with preview image support and Unicode logging"""
|
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
|
||||||
"""Initialize handler"""
|
|
||||||
self.config = config
|
|
||||||
self.client = Imagen4Client(config)
|
|
||||||
|
|
||||||
async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
|
||||||
"""Random seed generation handler"""
|
|
||||||
try:
|
|
||||||
random_seed = random.randint(0, 2**32 - 1)
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Generated random seed: {random_seed}"
|
|
||||||
)]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Random seed generation error: {str(e)}")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error occurred during random seed generation: {str(e)}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
|
||||||
"""Image regeneration from JSON file handler with preview support"""
|
|
||||||
try:
|
|
||||||
json_file_path = arguments.get("json_file_path")
|
|
||||||
save_to_file = arguments.get("save_to_file", True)
|
|
||||||
|
|
||||||
if not json_file_path:
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text="Error: JSON file path is required."
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Load parameters from JSON file with proper UTF-8 encoding
|
|
||||||
try:
|
|
||||||
with open(json_file_path, 'r', encoding='utf-8') as f:
|
|
||||||
params = json.load(f)
|
|
||||||
except FileNotFoundError:
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error: Cannot find JSON file: {json_file_path}"
|
|
||||||
)]
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error: JSON parsing error: {str(e)}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Check required parameters
|
|
||||||
required_params = ['prompt', 'seed']
|
|
||||||
missing_params = [p for p in required_params if p not in params]
|
|
||||||
|
|
||||||
if missing_params:
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Create image generation request object
|
|
||||||
request = ImageGenerationRequest(
|
|
||||||
prompt=params.get('prompt'),
|
|
||||||
negative_prompt=params.get('negative_prompt', ''),
|
|
||||||
number_of_images=params.get('number_of_images', 1),
|
|
||||||
seed=params.get('seed'),
|
|
||||||
aspect_ratio=params.get('aspect_ratio', '1:1'),
|
|
||||||
model=params.get('model', self.config.default_model)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Loaded parameters from JSON: {json_file_path}")
|
|
||||||
|
|
||||||
# Generate image
|
|
||||||
response = await self.client.generate_image(request)
|
|
||||||
|
|
||||||
if not response.success:
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error occurred during image regeneration: {response.error_message}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Create preview images
|
|
||||||
preview_images_b64 = []
|
|
||||||
for image_data in response.images_data:
|
|
||||||
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
|
|
||||||
if preview_b64:
|
|
||||||
preview_images_b64.append(preview_b64)
|
|
||||||
logger.info(f"Created preview image: {len(preview_b64)} chars base64 JPEG")
|
|
||||||
|
|
||||||
# Save files (optional)
|
|
||||||
saved_files = []
|
|
||||||
if save_to_file:
|
|
||||||
regeneration_params = {
|
|
||||||
"prompt": request.prompt,
|
|
||||||
"negative_prompt": request.negative_prompt,
|
|
||||||
"number_of_images": request.number_of_images,
|
|
||||||
"seed": request.seed,
|
|
||||||
"aspect_ratio": request.aspect_ratio,
|
|
||||||
"model": request.model,
|
|
||||||
"regenerated_from": json_file_path,
|
|
||||||
"original_generated_at": params.get('generated_at', 'unknown')
|
|
||||||
}
|
|
||||||
|
|
||||||
saved_files = save_generated_images(
|
|
||||||
images_data=response.images_data,
|
|
||||||
save_directory=self.config.output_path,
|
|
||||||
seed=request.seed,
|
|
||||||
generation_params=regeneration_params,
|
|
||||||
filename_prefix="imagen4_regen"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create result with preview images
|
|
||||||
result = ImageGenerationResult(
|
|
||||||
success=True,
|
|
||||||
message=f"✅ Images have been successfully regenerated from {json_file_path}",
|
|
||||||
original_images_count=len(response.images_data),
|
|
||||||
preview_images_b64=preview_images_b64,
|
|
||||||
saved_files=saved_files,
|
|
||||||
generation_params={
|
|
||||||
"prompt": request.prompt,
|
|
||||||
"seed": request.seed,
|
|
||||||
"aspect_ratio": request.aspect_ratio,
|
|
||||||
"number_of_images": request.number_of_images,
|
|
||||||
"model": request.model
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=result.to_text_content()
|
|
||||||
)]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error occurred during image regeneration: {str(e)}")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error occurred during image regeneration: {str(e)}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
|
||||||
"""Enhanced image generation handler with preview image support and proper Unicode logging"""
|
|
||||||
try:
|
|
||||||
# Log arguments safely without exposing image data but preserve Unicode
|
|
||||||
safe_args = sanitize_args_for_logging(arguments)
|
|
||||||
logger.info(f"handle_generate_image called with arguments: {safe_args}")
|
|
||||||
|
|
||||||
# Extract and validate arguments
|
|
||||||
prompt = arguments.get("prompt")
|
|
||||||
if not prompt:
|
|
||||||
logger.error("No prompt provided")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text="Error: Prompt is required."
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Ensure prompt is properly handled as Unicode
|
|
||||||
prompt = ensure_unicode_string(prompt)
|
|
||||||
|
|
||||||
seed = arguments.get("seed")
|
|
||||||
if seed is None:
|
|
||||||
logger.error("No seed provided")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed."
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Create image generation request object
|
|
||||||
request = ImageGenerationRequest(
|
|
||||||
prompt=prompt,
|
|
||||||
negative_prompt=ensure_unicode_string(arguments.get("negative_prompt", "")),
|
|
||||||
number_of_images=arguments.get("number_of_images", 1),
|
|
||||||
seed=seed,
|
|
||||||
aspect_ratio=arguments.get("aspect_ratio", "1:1"),
|
|
||||||
model=arguments.get("model", self.config.default_model)
|
|
||||||
)
|
|
||||||
|
|
||||||
save_to_file = arguments.get("save_to_file", True)
|
|
||||||
|
|
||||||
# Log with proper Unicode handling for Korean/international text
|
|
||||||
prompt_preview = prompt[:50] + "..." if len(prompt) > 50 else prompt
|
|
||||||
logger.info(f"Starting image generation: '{prompt_preview}', Seed: {seed}")
|
|
||||||
|
|
||||||
# Generate image with timeout
|
|
||||||
try:
|
|
||||||
logger.info("Calling client.generate_image()...")
|
|
||||||
response = await asyncio.wait_for(
|
|
||||||
self.client.generate_image(request),
|
|
||||||
timeout=360.0 # 6 minute timeout
|
|
||||||
)
|
|
||||||
logger.info(f"Image generation completed. Success: {response.success}")
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error("Image generation timed out after 6 minutes")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text="Error: Image generation timed out after 6 minutes. Please try again."
|
|
||||||
)]
|
|
||||||
|
|
||||||
if not response.success:
|
|
||||||
logger.error(f"Image generation failed: {response.error_message}")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error occurred during image generation: {response.error_message}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
logger.info(f"Generated {len(response.images_data)} images successfully")
|
|
||||||
|
|
||||||
# Create preview images (512x512 JPEG base64)
|
|
||||||
preview_images_b64 = []
|
|
||||||
for i, image_data in enumerate(response.images_data):
|
|
||||||
# Log original image info
|
|
||||||
image_info = get_image_info(image_data)
|
|
||||||
if image_info:
|
|
||||||
logger.info(f"Original image {i+1}: {image_info}")
|
|
||||||
|
|
||||||
# Create preview
|
|
||||||
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
|
|
||||||
if preview_b64:
|
|
||||||
preview_images_b64.append(preview_b64)
|
|
||||||
logger.info(f"Created preview image {i+1}: {len(preview_b64)} chars base64 JPEG")
|
|
||||||
else:
|
|
||||||
logger.warning(f"Failed to create preview for image {i+1}")
|
|
||||||
|
|
||||||
# Save files if requested
|
|
||||||
saved_files = []
|
|
||||||
if save_to_file:
|
|
||||||
logger.info("Saving files to disk...")
|
|
||||||
generation_params = {
|
|
||||||
"prompt": request.prompt,
|
|
||||||
"negative_prompt": request.negative_prompt,
|
|
||||||
"number_of_images": request.number_of_images,
|
|
||||||
"seed": request.seed,
|
|
||||||
"aspect_ratio": request.aspect_ratio,
|
|
||||||
"model": request.model,
|
|
||||||
"guidance_scale": 7.5,
|
|
||||||
"safety_filter_level": "block_only_high",
|
|
||||||
"person_generation": "allow_all",
|
|
||||||
"add_watermark": False
|
|
||||||
}
|
|
||||||
|
|
||||||
saved_files = save_generated_images(
|
|
||||||
images_data=response.images_data,
|
|
||||||
save_directory=self.config.output_path,
|
|
||||||
seed=request.seed,
|
|
||||||
generation_params=generation_params
|
|
||||||
)
|
|
||||||
logger.info(f"Files saved: {saved_files}")
|
|
||||||
|
|
||||||
# Verify files were created
|
|
||||||
for file_path in saved_files:
|
|
||||||
if os.path.exists(file_path):
|
|
||||||
size = os.path.getsize(file_path)
|
|
||||||
logger.info(f" ✓ Verified: {file_path} ({size} bytes)")
|
|
||||||
else:
|
|
||||||
logger.error(f" ❌ Missing: {file_path}")
|
|
||||||
|
|
||||||
# Create enhanced result with preview images
|
|
||||||
result = ImageGenerationResult(
|
|
||||||
success=True,
|
|
||||||
message=f"✅ Images have been successfully generated!",
|
|
||||||
original_images_count=len(response.images_data),
|
|
||||||
preview_images_b64=preview_images_b64,
|
|
||||||
saved_files=saved_files,
|
|
||||||
generation_params={
|
|
||||||
"prompt": request.prompt,
|
|
||||||
"seed": request.seed,
|
|
||||||
"aspect_ratio": request.aspect_ratio,
|
|
||||||
"number_of_images": request.number_of_images,
|
|
||||||
"negative_prompt": request.negative_prompt,
|
|
||||||
"model": request.model
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Returning response with {len(preview_images_b64)} preview images")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=result.to_text_content()
|
|
||||||
)]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True)
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error occurred during image generation: {str(e)}"
|
|
||||||
)]
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
"""
|
|
||||||
Enhanced MCP Server implementation for Imagen4 with Preview Image Support
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Any, List
|
|
||||||
|
|
||||||
from mcp.server import Server
|
|
||||||
from mcp.types import Tool, TextContent
|
|
||||||
|
|
||||||
from src.connector import Config
|
|
||||||
from src.server.models import MCPToolDefinitions
|
|
||||||
from src.server.enhanced_handlers import EnhancedToolHandlers, sanitize_args_for_logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class EnhancedImagen4MCPServer:
|
|
||||||
"""Enhanced Imagen4 MCP server class with preview image support"""
|
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
|
||||||
"""Initialize server"""
|
|
||||||
self.config = config
|
|
||||||
self.server = Server("imagen4-enhanced-mcp-server")
|
|
||||||
self.handlers = EnhancedToolHandlers(config)
|
|
||||||
|
|
||||||
# Register handlers
|
|
||||||
self._register_handlers()
|
|
||||||
|
|
||||||
def _register_handlers(self) -> None:
|
|
||||||
"""Register MCP handlers"""
|
|
||||||
|
|
||||||
@self.server.list_tools()
|
|
||||||
async def handle_list_tools() -> List[Tool]:
|
|
||||||
"""Return list of available tools"""
|
|
||||||
logger.info("Listing available tools (enhanced version)")
|
|
||||||
return MCPToolDefinitions.get_all_tools()
|
|
||||||
|
|
||||||
@self.server.call_tool()
|
|
||||||
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
|
|
||||||
"""Handle tool calls with enhanced preview image support"""
|
|
||||||
# Log tool call safely without exposing sensitive data
|
|
||||||
safe_args = sanitize_args_for_logging(arguments)
|
|
||||||
logger.info(f"Enhanced tool called: {name} with arguments: {safe_args}")
|
|
||||||
|
|
||||||
if name == "generate_random_seed":
|
|
||||||
return await self.handlers.handle_generate_random_seed(arguments)
|
|
||||||
elif name == "regenerate_from_json":
|
|
||||||
return await self.handlers.handle_regenerate_from_json(arguments)
|
|
||||||
elif name == "generate_image":
|
|
||||||
return await self.handlers.handle_generate_image(arguments)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown tool: {name}")
|
|
||||||
|
|
||||||
def get_server(self) -> Server:
|
|
||||||
"""Return MCP server instance"""
|
|
||||||
return self.server
|
|
||||||
|
|
||||||
|
|
||||||
# Create a factory function for easier import
|
|
||||||
def create_enhanced_server(config: Config) -> EnhancedImagen4MCPServer:
|
|
||||||
"""Factory function to create enhanced server"""
|
|
||||||
return EnhancedImagen4MCPServer(config)
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
"""
|
|
||||||
Enhanced MCP response models with preview image support
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ImageGenerationResult:
|
|
||||||
"""Enhanced image generation result with preview support"""
|
|
||||||
success: bool
|
|
||||||
message: str
|
|
||||||
original_images_count: int
|
|
||||||
preview_images_b64: Optional[list[str]] = None # List of base64 JPEG previews (512x512)
|
|
||||||
saved_files: Optional[list[str]] = None
|
|
||||||
generation_params: Optional[Dict[str, Any]] = None
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
|
|
||||||
def to_text_content(self) -> str:
|
|
||||||
"""Convert result to text format for MCP response"""
|
|
||||||
lines = [self.message]
|
|
||||||
|
|
||||||
if self.success and self.preview_images_b64:
|
|
||||||
lines.append(f"\n🖼️ Preview Images Generated: {len(self.preview_images_b64)} images (512x512 JPEG)")
|
|
||||||
for i, preview_b64 in enumerate(self.preview_images_b64):
|
|
||||||
lines.append(f"Preview {i+1} (base64 JPEG): {preview_b64[:50]}...({len(preview_b64)} chars)")
|
|
||||||
|
|
||||||
if self.saved_files:
|
|
||||||
lines.append(f"\n📁 Files saved:")
|
|
||||||
for filepath in self.saved_files:
|
|
||||||
lines.append(f" - {filepath}")
|
|
||||||
|
|
||||||
if self.generation_params:
|
|
||||||
lines.append(f"\n⚙️ Generation Parameters:")
|
|
||||||
for key, value in self.generation_params.items():
|
|
||||||
if key == 'prompt' and len(str(value)) > 100:
|
|
||||||
lines.append(f" - {key}: {str(value)[:100]}...")
|
|
||||||
else:
|
|
||||||
lines.append(f" - {key}: {value}")
|
|
||||||
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PreviewImageResponse:
|
|
||||||
"""Response containing preview images in base64 format"""
|
|
||||||
preview_images_b64: list[str] # Base64 encoded JPEG images (512x512)
|
|
||||||
original_count: int
|
|
||||||
message: str
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_image_data(cls, images_data: list[bytes], message: str = "") -> 'PreviewImageResponse':
|
|
||||||
"""Create response from original PNG image data"""
|
|
||||||
from src.utils.image_utils import create_preview_image_b64
|
|
||||||
|
|
||||||
preview_images = []
|
|
||||||
for i, image_data in enumerate(images_data):
|
|
||||||
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
|
|
||||||
if preview_b64:
|
|
||||||
preview_images.append(preview_b64)
|
|
||||||
else:
|
|
||||||
# Fallback - use original if preview creation fails
|
|
||||||
import base64
|
|
||||||
preview_images.append(base64.b64encode(image_data).decode('utf-8'))
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
preview_images_b64=preview_images,
|
|
||||||
original_count=len(images_data),
|
|
||||||
message=message or f"Generated {len(preview_images)} preview images"
|
|
||||||
)
|
|
||||||
@@ -1,308 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool Handlers for MCP Server
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
import logging
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
|
|
||||||
from mcp.types import TextContent, ImageContent
|
|
||||||
|
|
||||||
from src.connector import Imagen4Client, Config
|
|
||||||
from src.connector.imagen4_client import ImageGenerationRequest
|
|
||||||
from src.connector.utils import save_generated_images
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Remove or truncate sensitive data from arguments for safe logging"""
|
|
||||||
safe_args = {}
|
|
||||||
for key, value in arguments.items():
|
|
||||||
if isinstance(value, str):
|
|
||||||
# Check if it's likely base64 image data
|
|
||||||
if (key in ['data', 'image_data', 'base64'] or
|
|
||||||
(len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))):
|
|
||||||
# Truncate long image data
|
|
||||||
safe_args[key] = f"<image_data:{len(value)} chars>"
|
|
||||||
elif len(value) > 1000:
|
|
||||||
# Truncate any very long strings
|
|
||||||
safe_args[key] = f"{value[:100]}...<truncated:{len(value)} total chars>"
|
|
||||||
else:
|
|
||||||
safe_args[key] = value
|
|
||||||
else:
|
|
||||||
safe_args[key] = value
|
|
||||||
return safe_args
|
|
||||||
|
|
||||||
|
|
||||||
class ToolHandlers:
|
|
||||||
"""MCP tool handler class"""
|
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
|
||||||
"""Initialize handler"""
|
|
||||||
self.config = config
|
|
||||||
self.client = Imagen4Client(config)
|
|
||||||
|
|
||||||
async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
|
||||||
"""Random seed generation handler"""
|
|
||||||
try:
|
|
||||||
random_seed = random.randint(0, 2**32 - 1)
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Generated random seed: {random_seed}"
|
|
||||||
)]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Random seed generation error: {str(e)}")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error occurred during random seed generation: {str(e)}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
|
|
||||||
"""Image regeneration from JSON file handler"""
|
|
||||||
try:
|
|
||||||
json_file_path = arguments.get("json_file_path")
|
|
||||||
save_to_file = arguments.get("save_to_file", True)
|
|
||||||
|
|
||||||
if not json_file_path:
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text="Error: JSON file path is required."
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Load parameters from JSON file
|
|
||||||
try:
|
|
||||||
with open(json_file_path, 'r', encoding='utf-8') as f:
|
|
||||||
params = json.load(f)
|
|
||||||
except FileNotFoundError:
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error: Cannot find JSON file: {json_file_path}"
|
|
||||||
)]
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error: JSON parsing error: {str(e)}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Check required parameters
|
|
||||||
required_params = ['prompt', 'seed']
|
|
||||||
missing_params = [p for p in required_params if p not in params]
|
|
||||||
|
|
||||||
if missing_params:
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Create image generation request object
|
|
||||||
request = ImageGenerationRequest(
|
|
||||||
prompt=params.get('prompt'),
|
|
||||||
negative_prompt=params.get('negative_prompt', ''),
|
|
||||||
number_of_images=params.get('number_of_images', 1),
|
|
||||||
seed=params.get('seed'),
|
|
||||||
aspect_ratio=params.get('aspect_ratio', '1:1')
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Loaded parameters from JSON: {json_file_path}")
|
|
||||||
|
|
||||||
# Generate image
|
|
||||||
response = await self.client.generate_image(request)
|
|
||||||
|
|
||||||
if not response.success:
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error occurred during image regeneration: {response.error_message}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Save files (optional)
|
|
||||||
saved_files = []
|
|
||||||
if save_to_file:
|
|
||||||
regeneration_params = {
|
|
||||||
"prompt": request.prompt,
|
|
||||||
"negative_prompt": request.negative_prompt,
|
|
||||||
"number_of_images": request.number_of_images,
|
|
||||||
"seed": request.seed,
|
|
||||||
"aspect_ratio": request.aspect_ratio,
|
|
||||||
"regenerated_from": json_file_path,
|
|
||||||
"original_generated_at": params.get('generated_at', 'unknown')
|
|
||||||
}
|
|
||||||
|
|
||||||
saved_files = save_generated_images(
|
|
||||||
images_data=response.images_data,
|
|
||||||
save_directory=self.config.output_path,
|
|
||||||
seed=request.seed,
|
|
||||||
generation_params=regeneration_params,
|
|
||||||
filename_prefix="imagen4_regen"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate result message
|
|
||||||
message_parts = [
|
|
||||||
f"Images have been successfully regenerated.",
|
|
||||||
f"JSON file: {json_file_path}",
|
|
||||||
f"Size: 2048x2048 PNG",
|
|
||||||
f"Count: {len(response.images_data)} images",
|
|
||||||
f"Seed: {request.seed}",
|
|
||||||
f"Prompt: {request.prompt}"
|
|
||||||
]
|
|
||||||
|
|
||||||
if saved_files:
|
|
||||||
message_parts.append(f"\nRegenerated files saved successfully:")
|
|
||||||
for filepath in saved_files:
|
|
||||||
message_parts.append(f"- {filepath}")
|
|
||||||
|
|
||||||
# Generate result
|
|
||||||
result = [
|
|
||||||
TextContent(
|
|
||||||
type="text",
|
|
||||||
text="\n".join(message_parts)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add Base64 encoded images
|
|
||||||
for i, image_data in enumerate(response.images_data):
|
|
||||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
|
||||||
result.append(
|
|
||||||
ImageContent(
|
|
||||||
type="image",
|
|
||||||
data=image_base64,
|
|
||||||
mimeType="image/png"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
logger.info(f"Added image {i+1} to response: {len(image_base64)} chars base64 data")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error occurred during image regeneration: {str(e)}")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error occurred during image regeneration: {str(e)}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
|
|
||||||
"""Image generation handler with synchronous processing"""
|
|
||||||
try:
|
|
||||||
# Log arguments safely without exposing image data
|
|
||||||
safe_args = sanitize_args_for_logging(arguments)
|
|
||||||
logger.info(f"handle_generate_image called with arguments: {safe_args}")
|
|
||||||
|
|
||||||
# Extract and validate arguments
|
|
||||||
prompt = arguments.get("prompt")
|
|
||||||
if not prompt:
|
|
||||||
logger.error("No prompt provided")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text="Error: Prompt is required."
|
|
||||||
)]
|
|
||||||
|
|
||||||
seed = arguments.get("seed")
|
|
||||||
if seed is None:
|
|
||||||
logger.error("No seed provided")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed."
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Create image generation request object
|
|
||||||
request = ImageGenerationRequest(
|
|
||||||
prompt=prompt,
|
|
||||||
negative_prompt=arguments.get("negative_prompt", ""),
|
|
||||||
number_of_images=arguments.get("number_of_images", 1),
|
|
||||||
seed=seed,
|
|
||||||
aspect_ratio=arguments.get("aspect_ratio", "1:1")
|
|
||||||
)
|
|
||||||
|
|
||||||
save_to_file = arguments.get("save_to_file", True)
|
|
||||||
|
|
||||||
logger.info(f"Starting SYNCHRONOUS image generation: '{prompt[:50]}...', Seed: {seed}")
|
|
||||||
|
|
||||||
# Generate image synchronously with longer timeout
|
|
||||||
try:
|
|
||||||
logger.info("Calling client.generate_image() synchronously...")
|
|
||||||
response = await asyncio.wait_for(
|
|
||||||
self.client.generate_image(request),
|
|
||||||
timeout=360.0 # 6 minute timeout
|
|
||||||
)
|
|
||||||
logger.info(f"Image generation completed. Success: {response.success}")
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error("Image generation timed out after 6 minutes")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text="Error: Image generation timed out after 6 minutes. Please try again."
|
|
||||||
)]
|
|
||||||
|
|
||||||
if not response.success:
|
|
||||||
logger.error(f"Image generation failed: {response.error_message}")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error occurred during image generation: {response.error_message}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
logger.info(f"Generated {len(response.images_data)} images successfully")
|
|
||||||
|
|
||||||
# Save files if requested
|
|
||||||
saved_files = []
|
|
||||||
if save_to_file:
|
|
||||||
logger.info("Saving files to disk...")
|
|
||||||
generation_params = {
|
|
||||||
"prompt": request.prompt,
|
|
||||||
"negative_prompt": request.negative_prompt,
|
|
||||||
"number_of_images": request.number_of_images,
|
|
||||||
"seed": request.seed,
|
|
||||||
"aspect_ratio": request.aspect_ratio,
|
|
||||||
"guidance_scale": 7.5,
|
|
||||||
"safety_filter_level": "block_only_high",
|
|
||||||
"person_generation": "allow_all",
|
|
||||||
"add_watermark": False
|
|
||||||
}
|
|
||||||
|
|
||||||
saved_files = save_generated_images(
|
|
||||||
images_data=response.images_data,
|
|
||||||
save_directory=self.config.output_path,
|
|
||||||
seed=request.seed,
|
|
||||||
generation_params=generation_params
|
|
||||||
)
|
|
||||||
logger.info(f"Files saved: {saved_files}")
|
|
||||||
|
|
||||||
# Verify files were created
|
|
||||||
import os
|
|
||||||
for file_path in saved_files:
|
|
||||||
if os.path.exists(file_path):
|
|
||||||
size = os.path.getsize(file_path)
|
|
||||||
logger.info(f" ✓ Verified: {file_path} ({size} bytes)")
|
|
||||||
else:
|
|
||||||
logger.error(f" ❌ Missing: {file_path}")
|
|
||||||
|
|
||||||
# Generate result message
|
|
||||||
message_parts = [
|
|
||||||
f"✅ Images have been successfully generated!",
|
|
||||||
f"Prompt: {request.prompt}",
|
|
||||||
f"Seed: {request.seed}",
|
|
||||||
f"Size: 2048x2048 PNG",
|
|
||||||
f"Count: {len(response.images_data)} images"
|
|
||||||
]
|
|
||||||
|
|
||||||
if saved_files:
|
|
||||||
message_parts.append(f"\n📁 Files saved successfully:")
|
|
||||||
for filepath in saved_files:
|
|
||||||
message_parts.append(f" - {filepath}")
|
|
||||||
else:
|
|
||||||
message_parts.append(f"\nℹ️ Images generated but not saved to file. Use save_to_file=true to save.")
|
|
||||||
|
|
||||||
logger.info("Returning synchronous response")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text="\n".join(message_parts)
|
|
||||||
)]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True)
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error occurred during image generation: {str(e)}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP Server implementation for Imagen4
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Any, List
|
|
||||||
|
|
||||||
from mcp.server import Server
|
|
||||||
from mcp.types import Tool, TextContent, ImageContent
|
|
||||||
|
|
||||||
from src.connector import Config
|
|
||||||
from src.server.models import MCPToolDefinitions
|
|
||||||
from src.server.handlers import ToolHandlers, sanitize_args_for_logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Imagen4MCPServer:
|
|
||||||
"""Imagen4 MCP server class"""
|
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
|
||||||
"""Initialize server"""
|
|
||||||
self.config = config
|
|
||||||
self.server = Server("imagen4-mcp-server")
|
|
||||||
self.handlers = ToolHandlers(config)
|
|
||||||
|
|
||||||
# Register handlers
|
|
||||||
self._register_handlers()
|
|
||||||
|
|
||||||
def _register_handlers(self) -> None:
|
|
||||||
"""Register MCP handlers"""
|
|
||||||
|
|
||||||
@self.server.list_tools()
|
|
||||||
async def handle_list_tools() -> List[Tool]:
|
|
||||||
"""Return list of available tools"""
|
|
||||||
logger.info("Listing available tools")
|
|
||||||
return MCPToolDefinitions.get_all_tools()
|
|
||||||
|
|
||||||
@self.server.call_tool()
|
|
||||||
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
|
|
||||||
"""Handle tool calls"""
|
|
||||||
# Log tool call safely without exposing sensitive data
|
|
||||||
safe_args = sanitize_args_for_logging(arguments)
|
|
||||||
logger.info(f"Tool called: {name} with arguments: {safe_args}")
|
|
||||||
|
|
||||||
if name == "generate_random_seed":
|
|
||||||
return await self.handlers.handle_generate_random_seed(arguments)
|
|
||||||
elif name == "regenerate_from_json":
|
|
||||||
return await self.handlers.handle_regenerate_from_json(arguments)
|
|
||||||
elif name == "generate_image":
|
|
||||||
return await self.handlers.handle_generate_image(arguments)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown tool: {name}")
|
|
||||||
|
|
||||||
def get_server(self) -> Server:
|
|
||||||
"""Return MCP server instance"""
|
|
||||||
return self.server
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
"""
|
|
||||||
Minimal Task Result Handler for Emergency Use
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
|
|
||||||
from mcp.types import TextContent
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def minimal_get_task_result(task_manager, task_id: str) -> List[TextContent]:
|
|
||||||
"""Absolutely minimal task result handler"""
|
|
||||||
try:
|
|
||||||
logger.info(f"Minimal handler for task: {task_id}")
|
|
||||||
|
|
||||||
# Just return the task status for now
|
|
||||||
status = task_manager.get_task_status(task_id)
|
|
||||||
|
|
||||||
if status is None:
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"❌ Task '{task_id}' not found."
|
|
||||||
)]
|
|
||||||
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"📋 Task '{task_id}' Status: {status.value}\n"
|
|
||||||
f"Note: This is a minimal emergency handler.\n"
|
|
||||||
f"If you see this message, the original handler has a serious issue."
|
|
||||||
)]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Even minimal handler failed: {str(e)}", exc_info=True)
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Complete failure: {str(e)}"
|
|
||||||
)]
|
|
||||||
@@ -1,132 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP Tool Models and Definitions
|
|
||||||
"""
|
|
||||||
|
|
||||||
from mcp.types import Tool
|
|
||||||
|
|
||||||
|
|
||||||
class MCPToolDefinitions:
|
|
||||||
"""MCP tool definition class"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_generate_image_tool() -> Tool:
|
|
||||||
"""Image generation tool definition"""
|
|
||||||
return Tool(
|
|
||||||
name="generate_image",
|
|
||||||
description="Generate 2048x2048 PNG images from text prompts using Google Imagen 4 AI.",
|
|
||||||
inputSchema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"prompt": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Text prompt for image generation (Korean or English). Maximum 480 tokens allowed.",
|
|
||||||
"maxLength": 2000 # Approximate character limit for safety
|
|
||||||
},
|
|
||||||
"negative_prompt": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Negative prompt specifying elements not to generate (optional). Maximum 240 tokens allowed.",
|
|
||||||
"default": "",
|
|
||||||
"maxLength": 1000 # Approximate character limit for safety
|
|
||||||
},
|
|
||||||
"number_of_images": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Number of images to generate (1 or 2 only)",
|
|
||||||
"enum": [1, 2],
|
|
||||||
"default": 1
|
|
||||||
},
|
|
||||||
"seed": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Seed value for reproducible results (required, 0 ~ 4294967295 range)",
|
|
||||||
"minimum": 0,
|
|
||||||
"maximum": 4294967295
|
|
||||||
},
|
|
||||||
"aspect_ratio": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Image aspect ratio",
|
|
||||||
"enum": ["1:1", "9:16", "16:9", "3:4", "4:3"],
|
|
||||||
"default": "1:1"
|
|
||||||
},
|
|
||||||
"save_to_file": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": "Whether to save generated images to files (default: true)",
|
|
||||||
"default": True
|
|
||||||
},
|
|
||||||
"model": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Imagen 4 model to use for generation",
|
|
||||||
"enum": ["imagen-4.0-generate-001", "imagen-4.0-ultra-generate-001"],
|
|
||||||
"default": "imagen-4.0-generate-001"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["prompt", "seed"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_regenerate_from_json_tool() -> Tool:
|
|
||||||
"""Regenerate from JSON tool definition"""
|
|
||||||
return Tool(
|
|
||||||
name="regenerate_from_json",
|
|
||||||
description="Read parameters from JSON file and regenerate images with the same settings.",
|
|
||||||
inputSchema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"json_file_path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Path to JSON file containing saved parameters"
|
|
||||||
},
|
|
||||||
"save_to_file": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": "Whether to save regenerated images to files (default: true)",
|
|
||||||
"default": True
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["json_file_path"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_generate_random_seed_tool() -> Tool:
|
|
||||||
"""Random seed generation tool definition"""
|
|
||||||
return Tool(
|
|
||||||
name="generate_random_seed",
|
|
||||||
description="Generate random seed value for image generation.",
|
|
||||||
inputSchema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {},
|
|
||||||
"required": []
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_validate_prompt_tool() -> Tool:
|
|
||||||
"""Prompt validation tool definition"""
|
|
||||||
return Tool(
|
|
||||||
name="validate_prompt",
|
|
||||||
description="Validate prompt length and estimate token count before image generation.",
|
|
||||||
inputSchema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"prompt": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Text prompt to validate"
|
|
||||||
},
|
|
||||||
"negative_prompt": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Negative prompt to validate (optional)",
|
|
||||||
"default": ""
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["prompt"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_all_tools(cls) -> list[Tool]:
|
|
||||||
"""Return all tool definitions"""
|
|
||||||
return [
|
|
||||||
cls.get_generate_image_tool(),
|
|
||||||
cls.get_regenerate_from_json_tool(),
|
|
||||||
cls.get_generate_random_seed_tool(),
|
|
||||||
cls.get_validate_prompt_tool()
|
|
||||||
]
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
"""
|
|
||||||
Safe Get Task Result Handler - Debug Version
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
|
|
||||||
from mcp.types import TextContent
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def safe_get_task_result(task_manager, task_id: str) -> List[TextContent]:
|
|
||||||
"""Minimal safe version of get_task_result without image processing"""
|
|
||||||
try:
|
|
||||||
logger.info(f"Safe get_task_result called for: {task_id}")
|
|
||||||
|
|
||||||
# Step 1: Try to get raw result
|
|
||||||
try:
|
|
||||||
raw_result = task_manager.worker_pool.get_task_result(task_id)
|
|
||||||
logger.info(f"Raw result type: {type(raw_result)}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get raw result: {str(e)}")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error accessing task data: {str(e)}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
if not raw_result:
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"❌ Task '{task_id}' not found."
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Step 2: Check status safely
|
|
||||||
try:
|
|
||||||
status = raw_result.status
|
|
||||||
logger.info(f"Task status: {status}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get status: {str(e)}")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Error reading task status: {str(e)}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
# Step 3: Return basic info without touching result.result
|
|
||||||
try:
|
|
||||||
duration = raw_result.duration if hasattr(raw_result, 'duration') else None
|
|
||||||
created = raw_result.created_at.isoformat() if hasattr(raw_result, 'created_at') and raw_result.created_at else "unknown"
|
|
||||||
|
|
||||||
message = f"✅ Task '{task_id}' Status Report:\n"
|
|
||||||
message += f"Status: {status.value}\n"
|
|
||||||
message += f"Created: {created}\n"
|
|
||||||
|
|
||||||
if duration:
|
|
||||||
message += f"Duration: {duration:.2f}s\n"
|
|
||||||
|
|
||||||
if hasattr(raw_result, 'error') and raw_result.error:
|
|
||||||
message += f"Error: {raw_result.error}\n"
|
|
||||||
|
|
||||||
message += "\nNote: This is a safe diagnostic version."
|
|
||||||
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=message
|
|
||||||
)]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to format response: {str(e)}")
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Task exists but data formatting failed: {str(e)}"
|
|
||||||
)]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Critical error in safe_get_task_result: {str(e)}", exc_info=True)
|
|
||||||
return [TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"Critical error: {str(e)}"
|
|
||||||
)]
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
"""
|
|
||||||
Image processing utilities for preview generation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def create_preview_image_b64(image_data: bytes, target_size: int = 512, quality: int = 85) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Convert PNG image data to JPEG preview with specified size and return as base64
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_data: Original PNG image data in bytes
|
|
||||||
target_size: Target size for the preview (default: 512x512)
|
|
||||||
quality: JPEG quality (1-100, default: 85)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Base64 encoded JPEG image string, or None if conversion fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Open image from bytes
|
|
||||||
with Image.open(io.BytesIO(image_data)) as img:
|
|
||||||
# Convert to RGB if necessary (PNG might have alpha channel)
|
|
||||||
if img.mode in ('RGBA', 'LA', 'P'):
|
|
||||||
# Create white background for transparent images
|
|
||||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
|
||||||
if img.mode == 'P':
|
|
||||||
img = img.convert('RGBA')
|
|
||||||
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
|
|
||||||
img = background
|
|
||||||
elif img.mode != 'RGB':
|
|
||||||
img = img.convert('RGB')
|
|
||||||
|
|
||||||
# Resize to target size maintaining aspect ratio
|
|
||||||
img.thumbnail((target_size, target_size), Image.Resampling.LANCZOS)
|
|
||||||
|
|
||||||
# If image is smaller than target size, pad it to exact size
|
|
||||||
if img.size != (target_size, target_size):
|
|
||||||
# Create new image with target size and white background
|
|
||||||
new_img = Image.new('RGB', (target_size, target_size), (255, 255, 255))
|
|
||||||
# Center the resized image
|
|
||||||
x = (target_size - img.size[0]) // 2
|
|
||||||
y = (target_size - img.size[1]) // 2
|
|
||||||
new_img.paste(img, (x, y))
|
|
||||||
img = new_img
|
|
||||||
|
|
||||||
# Convert to JPEG and encode as base64
|
|
||||||
output_buffer = io.BytesIO()
|
|
||||||
img.save(output_buffer, format='JPEG', quality=quality, optimize=True)
|
|
||||||
jpeg_data = output_buffer.getvalue()
|
|
||||||
|
|
||||||
# Encode to base64
|
|
||||||
base64_data = base64.b64encode(jpeg_data).decode('utf-8')
|
|
||||||
|
|
||||||
logger.info(f"Preview image created: {target_size}x{target_size} JPEG, {len(jpeg_data)} bytes, base64 length: {len(base64_data)}")
|
|
||||||
return base64_data
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to create preview image: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def validate_image_data(image_data: bytes) -> bool:
|
|
||||||
"""
|
|
||||||
Validate if image data is a valid image
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_data: Image data in bytes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if valid image, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with Image.open(io.BytesIO(image_data)) as img:
|
|
||||||
img.verify() # Verify image integrity
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_info(image_data: bytes) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Get image information
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_data: Image data in bytes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with image info (format, size, mode) or None if invalid
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with Image.open(io.BytesIO(image_data)) as img:
|
|
||||||
return {
|
|
||||||
'format': img.format,
|
|
||||||
'size': img.size,
|
|
||||||
'mode': img.mode,
|
|
||||||
'bytes': len(image_data)
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get image info: {str(e)}")
|
|
||||||
return None
|
|
||||||
@@ -1,200 +0,0 @@
|
|||||||
"""
|
|
||||||
Token counting utilities for Imagen4 prompts
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 기본 토큰 추정 상수
|
|
||||||
AVERAGE_CHARS_PER_TOKEN = 4 # 영어 기준 평균값
|
|
||||||
KOREAN_CHARS_PER_TOKEN = 2 # 한글 기준 평균값
|
|
||||||
MAX_PROMPT_TOKENS = 480 # 최대 프롬프트 토큰 수
|
|
||||||
|
|
||||||
|
|
||||||
def estimate_token_count(text: str) -> int:
|
|
||||||
"""
|
|
||||||
텍스트의 토큰 수를 추정합니다.
|
|
||||||
|
|
||||||
정확한 토큰 계산을 위해서는 실제 토크나이저가 필요하지만,
|
|
||||||
API 호출 전 빠른 검증을 위해 추정값을 사용합니다.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: 토큰 수를 계산할 텍스트
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 추정된 토큰 수
|
|
||||||
"""
|
|
||||||
if not text:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# 텍스트 정리
|
|
||||||
text = text.strip()
|
|
||||||
if not text:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# 한글과 영어 문자 분리
|
|
||||||
korean_chars = len(re.findall(r'[가-힣]', text))
|
|
||||||
english_chars = len(re.findall(r'[a-zA-Z]', text))
|
|
||||||
other_chars = len(text) - korean_chars - english_chars
|
|
||||||
|
|
||||||
# 토큰 수 추정
|
|
||||||
korean_tokens = korean_chars / KOREAN_CHARS_PER_TOKEN
|
|
||||||
english_tokens = english_chars / AVERAGE_CHARS_PER_TOKEN
|
|
||||||
other_tokens = other_chars / AVERAGE_CHARS_PER_TOKEN
|
|
||||||
|
|
||||||
estimated_tokens = int(korean_tokens + english_tokens + other_tokens)
|
|
||||||
|
|
||||||
# 최소 1토큰은 보장
|
|
||||||
return max(1, estimated_tokens)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_prompt_length(prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> tuple[bool, int, Optional[str]]:
|
|
||||||
"""
|
|
||||||
프롬프트 길이를 검증합니다.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: 검증할 프롬프트
|
|
||||||
max_tokens: 최대 허용 토큰 수
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (유효성, 토큰 수, 오류 메시지)
|
|
||||||
"""
|
|
||||||
if not prompt:
|
|
||||||
return False, 0, "프롬프트가 비어있습니다."
|
|
||||||
|
|
||||||
token_count = estimate_token_count(prompt)
|
|
||||||
|
|
||||||
if token_count > max_tokens:
|
|
||||||
error_msg = (
|
|
||||||
f"프롬프트가 너무 깁니다. "
|
|
||||||
f"현재: {token_count}토큰, 최대: {max_tokens}토큰. "
|
|
||||||
f"프롬프트를 {token_count - max_tokens}토큰 줄여주세요."
|
|
||||||
)
|
|
||||||
return False, token_count, error_msg
|
|
||||||
|
|
||||||
return True, token_count, None
|
|
||||||
|
|
||||||
|
|
||||||
def truncate_prompt(prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> str:
|
|
||||||
"""
|
|
||||||
프롬프트를 지정된 토큰 수로 자릅니다.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: 자를 프롬프트
|
|
||||||
max_tokens: 최대 토큰 수
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 잘린 프롬프트
|
|
||||||
"""
|
|
||||||
if not prompt:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
current_tokens = estimate_token_count(prompt)
|
|
||||||
if current_tokens <= max_tokens:
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
# 대략적인 비율로 텍스트 자르기
|
|
||||||
ratio = max_tokens / current_tokens
|
|
||||||
target_length = int(len(prompt) * ratio * 0.9) # 여유분 10%
|
|
||||||
|
|
||||||
truncated = prompt[:target_length]
|
|
||||||
|
|
||||||
# 단어/문장 경계에서 자르기
|
|
||||||
if len(truncated) < len(prompt):
|
|
||||||
# 마지막 완전한 단어까지만 유지
|
|
||||||
last_space = truncated.rfind(' ')
|
|
||||||
last_korean = truncated.rfind('다') # 한글 어미
|
|
||||||
last_punct = max(truncated.rfind('.'), truncated.rfind(','), truncated.rfind('!'))
|
|
||||||
|
|
||||||
cut_point = max(last_space, last_korean, last_punct)
|
|
||||||
if cut_point > target_length * 0.8: # 너무 많이 잘리지 않도록
|
|
||||||
truncated = truncated[:cut_point]
|
|
||||||
|
|
||||||
# 최종 검증
|
|
||||||
final_tokens = estimate_token_count(truncated)
|
|
||||||
if final_tokens > max_tokens:
|
|
||||||
# 강제로 문자 단위로 자르기
|
|
||||||
chars_per_token = len(truncated) / final_tokens
|
|
||||||
target_chars = int(max_tokens * chars_per_token * 0.95)
|
|
||||||
truncated = truncated[:target_chars]
|
|
||||||
|
|
||||||
return truncated.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_stats(prompt: str) -> dict:
|
|
||||||
"""
|
|
||||||
프롬프트 통계 정보를 반환합니다.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: 분석할 프롬프트
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 프롬프트 통계
|
|
||||||
"""
|
|
||||||
if not prompt:
|
|
||||||
return {
|
|
||||||
"character_count": 0,
|
|
||||||
"estimated_tokens": 0,
|
|
||||||
"korean_chars": 0,
|
|
||||||
"english_chars": 0,
|
|
||||||
"other_chars": 0,
|
|
||||||
"is_valid": False,
|
|
||||||
"remaining_tokens": MAX_PROMPT_TOKENS
|
|
||||||
}
|
|
||||||
|
|
||||||
char_count = len(prompt)
|
|
||||||
korean_chars = len(re.findall(r'[가-힣]', prompt))
|
|
||||||
english_chars = len(re.findall(r'[a-zA-Z]', prompt))
|
|
||||||
other_chars = char_count - korean_chars - english_chars
|
|
||||||
estimated_tokens = estimate_token_count(prompt)
|
|
||||||
is_valid = estimated_tokens <= MAX_PROMPT_TOKENS
|
|
||||||
remaining_tokens = MAX_PROMPT_TOKENS - estimated_tokens
|
|
||||||
|
|
||||||
return {
|
|
||||||
"character_count": char_count,
|
|
||||||
"estimated_tokens": estimated_tokens,
|
|
||||||
"korean_chars": korean_chars,
|
|
||||||
"english_chars": english_chars,
|
|
||||||
"other_chars": other_chars,
|
|
||||||
"is_valid": is_valid,
|
|
||||||
"remaining_tokens": remaining_tokens,
|
|
||||||
"max_tokens": MAX_PROMPT_TOKENS
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# 실제 토크나이저 사용 시 대체할 수 있는 인터페이스
|
|
||||||
class TokenCounter:
|
|
||||||
"""토큰 카운터 인터페이스"""
|
|
||||||
|
|
||||||
def __init__(self, tokenizer_name: Optional[str] = None):
|
|
||||||
"""
|
|
||||||
토큰 카운터 초기화
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokenizer_name: 사용할 토크나이저 이름 (향후 확장용)
|
|
||||||
"""
|
|
||||||
self.tokenizer_name = tokenizer_name or "estimate"
|
|
||||||
logger.info(f"토큰 카운터 초기화: {self.tokenizer_name}")
|
|
||||||
|
|
||||||
def count_tokens(self, text: str) -> int:
|
|
||||||
"""텍스트의 토큰 수 계산"""
|
|
||||||
if self.tokenizer_name == "estimate":
|
|
||||||
return estimate_token_count(text)
|
|
||||||
else:
|
|
||||||
# 향후 실제 토크나이저 구현
|
|
||||||
raise NotImplementedError(f"토크나이저 '{self.tokenizer_name}'는 아직 구현되지 않았습니다.")
|
|
||||||
|
|
||||||
def validate_prompt(self, prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> tuple[bool, int, Optional[str]]:
|
|
||||||
"""프롬프트 검증"""
|
|
||||||
return validate_prompt_length(prompt, max_tokens)
|
|
||||||
|
|
||||||
def truncate_prompt(self, prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> str:
|
|
||||||
"""프롬프트 자르기"""
|
|
||||||
return truncate_prompt(prompt, max_tokens)
|
|
||||||
|
|
||||||
|
|
||||||
# 전역 토큰 카운터 인스턴스
|
|
||||||
default_token_counter = TokenCounter()
|
|
||||||
40
test_mcp_compatibility.py
Normal file
40
test_mcp_compatibility.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script to verify MCP protocol compatibility
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add current directory to path
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
def test_json_output():
|
||||||
|
"""Test that we can output valid JSON without interference"""
|
||||||
|
# This should work without any encoding issues
|
||||||
|
test_data = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "notifications/initialized",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Output JSON to stdout (this is what MCP protocol expects)
|
||||||
|
print(json.dumps(test_data))
|
||||||
|
|
||||||
|
def test_unicode_handling():
|
||||||
|
"""Test Unicode handling in stderr only"""
|
||||||
|
# This should go to stderr only, not interfere with stdout JSON
|
||||||
|
test_unicode = "Test UTF-8: 한글 테스트 ✓"
|
||||||
|
print(f"[UTF8-TEST] {test_unicode}", file=sys.stderr)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("[MCP-TEST] Testing JSON output compatibility...", file=sys.stderr)
|
||||||
|
test_json_output()
|
||||||
|
print("[MCP-TEST] JSON output test completed", file=sys.stderr)
|
||||||
|
|
||||||
|
print("[MCP-TEST] Testing Unicode handling...", file=sys.stderr)
|
||||||
|
test_unicode_handling()
|
||||||
|
print("[MCP-TEST] Unicode test completed", file=sys.stderr)
|
||||||
|
|
||||||
|
print("[MCP-TEST] All tests passed!", file=sys.stderr)
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for Imagen4 Connector
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import asyncio
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
from src.connector import Config, Imagen4Client
|
|
||||||
from src.connector.imagen4_client import ImageGenerationRequest
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfig:
|
|
||||||
"""Config 클래스 테스트"""
|
|
||||||
|
|
||||||
def test_config_creation(self):
|
|
||||||
"""설정 생성 테스트"""
|
|
||||||
config = Config(
|
|
||||||
project_id="test-project",
|
|
||||||
location="us-central1",
|
|
||||||
output_path="./test_images"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert config.project_id == "test-project"
|
|
||||||
assert config.location == "us-central1"
|
|
||||||
assert config.output_path == "./test_images"
|
|
||||||
|
|
||||||
def test_config_validation(self):
|
|
||||||
"""설정 유효성 검사 테스트"""
|
|
||||||
# 유효한 설정
|
|
||||||
valid_config = Config(project_id="test-project")
|
|
||||||
assert valid_config.validate() is True
|
|
||||||
|
|
||||||
# 무효한 설정
|
|
||||||
invalid_config = Config(project_id="")
|
|
||||||
assert invalid_config.validate() is False
|
|
||||||
|
|
||||||
@patch.dict('os.environ', {
|
|
||||||
'GOOGLE_CLOUD_PROJECT_ID': 'test-project',
|
|
||||||
'GOOGLE_CLOUD_LOCATION': 'us-west1',
|
|
||||||
'GENERATED_IMAGES_PATH': './custom_path'
|
|
||||||
})
|
|
||||||
def test_config_from_env(self):
|
|
||||||
"""환경 변수에서 설정 로드 테스트"""
|
|
||||||
config = Config.from_env()
|
|
||||||
|
|
||||||
assert config.project_id == "test-project"
|
|
||||||
assert config.location == "us-west1"
|
|
||||||
assert config.output_path == "./custom_path"
|
|
||||||
|
|
||||||
@patch.dict('os.environ', {}, clear=True)
|
|
||||||
def test_config_from_env_missing_project_id(self):
|
|
||||||
"""필수 환경 변수 누락 테스트"""
|
|
||||||
with pytest.raises(ValueError, match="GOOGLE_CLOUD_PROJECT_ID environment variable is required"):
|
|
||||||
Config.from_env()
|
|
||||||
|
|
||||||
|
|
||||||
class TestImageGenerationRequest:
|
|
||||||
"""ImageGenerationRequest 테스트"""
|
|
||||||
|
|
||||||
def test_valid_request(self):
|
|
||||||
"""유효한 요청 테스트"""
|
|
||||||
request = ImageGenerationRequest(
|
|
||||||
prompt="A beautiful landscape",
|
|
||||||
seed=12345
|
|
||||||
)
|
|
||||||
|
|
||||||
# 유효성 검사가 예외 없이 완료되어야 함
|
|
||||||
request.validate()
|
|
||||||
|
|
||||||
def test_invalid_prompt(self):
|
|
||||||
"""무효한 프롬프트 테스트"""
|
|
||||||
request = ImageGenerationRequest(
|
|
||||||
prompt="",
|
|
||||||
seed=12345
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Prompt is required"):
|
|
||||||
request.validate()
|
|
||||||
|
|
||||||
def test_invalid_seed(self):
|
|
||||||
"""무효한 시드값 테스트"""
|
|
||||||
request = ImageGenerationRequest(
|
|
||||||
prompt="Test prompt",
|
|
||||||
seed=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Seed value must be an integer between 0 and 4294967295"):
|
|
||||||
request.validate()
|
|
||||||
|
|
||||||
def test_invalid_number_of_images(self):
|
|
||||||
"""무효한 이미지 개수 테스트"""
|
|
||||||
request = ImageGenerationRequest(
|
|
||||||
prompt="Test prompt",
|
|
||||||
seed=12345,
|
|
||||||
number_of_images=3
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Number of images must be 1 or 2 only"):
|
|
||||||
request.validate()
|
|
||||||
|
|
||||||
|
|
||||||
class TestImagen4Client:
|
|
||||||
"""Imagen4Client 테스트"""
|
|
||||||
|
|
||||||
def test_client_initialization(self):
|
|
||||||
"""클라이언트 초기화 테스트"""
|
|
||||||
config = Config(project_id="test-project")
|
|
||||||
|
|
||||||
with patch('src.connector.imagen4_client.genai.Client'):
|
|
||||||
client = Imagen4Client(config)
|
|
||||||
assert client.config == config
|
|
||||||
|
|
||||||
def test_client_invalid_config(self):
|
|
||||||
"""무효한 설정으로 클라이언트 초기화 테스트"""
|
|
||||||
invalid_config = Config(project_id="")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Invalid configuration"):
|
|
||||||
Imagen4Client(invalid_config)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_generate_image_success(self):
|
|
||||||
"""이미지 생성 성공 테스트"""
|
|
||||||
config = Config(project_id="test-project")
|
|
||||||
|
|
||||||
# Create mock response
|
|
||||||
mock_image_data = b"fake_image_data"
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.generated_images = [Mock()]
|
|
||||||
mock_response.generated_images[0].image = Mock()
|
|
||||||
mock_response.generated_images[0].image.image_bytes = mock_image_data
|
|
||||||
|
|
||||||
with patch('src.connector.imagen4_client.genai.Client') as mock_client_class:
|
|
||||||
mock_client = Mock()
|
|
||||||
mock_client_class.return_value = mock_client
|
|
||||||
mock_client.models.generate_images = Mock(return_value=mock_response)
|
|
||||||
|
|
||||||
client = Imagen4Client(config)
|
|
||||||
|
|
||||||
request = ImageGenerationRequest(
|
|
||||||
prompt="Test prompt",
|
|
||||||
seed=12345
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch('asyncio.to_thread', return_value=mock_response):
|
|
||||||
response = await client.generate_image(request)
|
|
||||||
|
|
||||||
assert response.success is True
|
|
||||||
assert len(response.images_data) == 1
|
|
||||||
assert response.images_data[0] == mock_image_data
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_generate_image_failure(self):
|
|
||||||
"""Image generation failure test"""
|
|
||||||
config = Config(project_id="test-project")
|
|
||||||
|
|
||||||
with patch('src.connector.imagen4_client.genai.Client') as mock_client_class:
|
|
||||||
mock_client = Mock()
|
|
||||||
mock_client_class.return_value = mock_client
|
|
||||||
|
|
||||||
client = Imagen4Client(config)
|
|
||||||
|
|
||||||
request = ImageGenerationRequest(
|
|
||||||
prompt="Test prompt",
|
|
||||||
seed=12345
|
|
||||||
)
|
|
||||||
|
|
||||||
# Simulate API call failure
|
|
||||||
with patch('asyncio.to_thread', side_effect=Exception("API Error")):
|
|
||||||
response = await client.generate_image(request)
|
|
||||||
|
|
||||||
assert response.success is False
|
|
||||||
assert "API Error" in response.error_message
|
|
||||||
assert len(response.images_data) == 0
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__])
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for Imagen4 Server
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, AsyncMock, patch
|
|
||||||
|
|
||||||
from src.connector import Config
|
|
||||||
from src.server import Imagen4MCPServer, ToolHandlers, MCPToolDefinitions
|
|
||||||
from src.server.models import MCPToolDefinitions
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPToolDefinitions:
|
|
||||||
"""MCPToolDefinitions tests"""
|
|
||||||
|
|
||||||
def test_get_generate_image_tool(self):
|
|
||||||
"""Image generation tool definition test"""
|
|
||||||
tool = MCPToolDefinitions.get_generate_image_tool()
|
|
||||||
|
|
||||||
assert tool.name == "generate_image"
|
|
||||||
assert "Google Imagen 4" in tool.description
|
|
||||||
assert "prompt" in tool.inputSchema["properties"]
|
|
||||||
assert "seed" in tool.inputSchema["required"]
|
|
||||||
|
|
||||||
def test_get_regenerate_from_json_tool(self):
|
|
||||||
"""JSON regeneration tool definition test"""
|
|
||||||
tool = MCPToolDefinitions.get_regenerate_from_json_tool()
|
|
||||||
|
|
||||||
assert tool.name == "regenerate_from_json"
|
|
||||||
assert "JSON file" in tool.description
|
|
||||||
assert "json_file_path" in tool.inputSchema["properties"]
|
|
||||||
|
|
||||||
def test_get_generate_random_seed_tool(self):
|
|
||||||
"""Random seed generation tool definition test"""
|
|
||||||
tool = MCPToolDefinitions.get_generate_random_seed_tool()
|
|
||||||
|
|
||||||
assert tool.name == "generate_random_seed"
|
|
||||||
assert "random seed" in tool.description
|
|
||||||
|
|
||||||
def test_get_all_tools(self):
|
|
||||||
"""All tools return test"""
|
|
||||||
tools = MCPToolDefinitions.get_all_tools()
|
|
||||||
|
|
||||||
assert len(tools) == 3
|
|
||||||
tool_names = [tool.name for tool in tools]
|
|
||||||
assert "generate_image" in tool_names
|
|
||||||
assert "regenerate_from_json" in tool_names
|
|
||||||
assert "generate_random_seed" in tool_names
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolHandlers:
|
|
||||||
"""ToolHandlers tests"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def config(self):
|
|
||||||
"""Test configuration"""
|
|
||||||
return Config(project_id="test-project")
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def handlers(self, config):
|
|
||||||
"""Test handlers"""
|
|
||||||
with patch('src.server.handlers.Imagen4Client'):
|
|
||||||
return ToolHandlers(config)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_generate_random_seed(self, handlers):
|
|
||||||
"""Random seed generation handler test"""
|
|
||||||
result = await handlers.handle_generate_random_seed({})
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0].type == "text"
|
|
||||||
assert "Generated random seed:" in result[0].text
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_generate_image_missing_prompt(self, handlers):
|
|
||||||
"""Missing prompt test"""
|
|
||||||
arguments = {"seed": 12345}
|
|
||||||
|
|
||||||
result = await handlers.handle_generate_image(arguments)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0].type == "text"
|
|
||||||
assert "Prompt is required" in result[0].text
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_generate_image_missing_seed(self, handlers):
|
|
||||||
"""Missing seed value test"""
|
|
||||||
arguments = {"prompt": "Test prompt"}
|
|
||||||
|
|
||||||
result = await handlers.handle_generate_image(arguments)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0].type == "text"
|
|
||||||
assert "Seed value is required" in result[0].text
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_regenerate_from_json_missing_path(self, handlers):
|
|
||||||
"""Missing JSON file path test"""
|
|
||||||
arguments = {}
|
|
||||||
|
|
||||||
result = await handlers.handle_regenerate_from_json(arguments)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0].type == "text"
|
|
||||||
assert "JSON file path is required" in result[0].text
|
|
||||||
|
|
||||||
|
|
||||||
class TestImagen4MCPServer:
|
|
||||||
"""Imagen4MCPServer tests"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def config(self):
|
|
||||||
"""Test configuration"""
|
|
||||||
return Config(project_id="test-project")
|
|
||||||
|
|
||||||
def test_server_initialization(self, config):
|
|
||||||
"""Server initialization test"""
|
|
||||||
with patch('src.server.mcp_server.ToolHandlers'):
|
|
||||||
server = Imagen4MCPServer(config)
|
|
||||||
|
|
||||||
assert server.config == config
|
|
||||||
assert server.server is not None
|
|
||||||
assert server.handlers is not None
|
|
||||||
|
|
||||||
def test_get_server(self, config):
|
|
||||||
"""Server instance return test"""
|
|
||||||
with patch('src.server.mcp_server.ToolHandlers'):
|
|
||||||
mcp_server = Imagen4MCPServer(config)
|
|
||||||
server = mcp_server.get_server()
|
|
||||||
|
|
||||||
assert server is not None
|
|
||||||
assert server.name == "imagen4-mcp-server"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__])
|
|
||||||
Reference in New Issue
Block a user