From dcf2305e4bff78f71624d090f60198c352bdf8ba Mon Sep 17 00:00:00 2001 From: ened Date: Fri, 22 Aug 2025 20:03:38 +0900 Subject: [PATCH] imagen4 mcp server, mcp connector implementation --- .env.example | 13 + .gitignore | 119 +++++ README.md | 204 +++++++++ claude_desktop_config.json | 12 + main.py | 706 ++++++++++++++++++++++++++++++ requirements.txt | 7 + run.bat | 50 +++ src/__init__.py | 10 + src/async_task/__init__.py | 9 + src/async_task/models.py | 48 ++ src/async_task/task_manager.py | 324 ++++++++++++++ src/async_task/worker_pool.py | 258 +++++++++++ src/connector/__init__.py | 10 + src/connector/config.py | 101 +++++ src/connector/imagen4_client.py | 215 +++++++++ src/connector/utils.py | 70 +++ src/server/__init__.py | 26 ++ src/server/enhanced_handlers.py | 322 ++++++++++++++ src/server/enhanced_mcp_server.py | 63 +++ src/server/enhanced_models.py | 71 +++ src/server/handlers.py | 308 +++++++++++++ src/server/mcp_server.py | 57 +++ src/server/minimal_handler.py | 39 ++ src/server/models.py | 132 ++++++ src/server/safe_result_handler.py | 80 ++++ src/utils/__init__.py | 1 + src/utils/image_utils.py | 107 +++++ src/utils/token_utils.py | 200 +++++++++ test_main.py | 108 +++++ tests/test_connector.py | 178 ++++++++ tests/test_server.py | 136 ++++++ 31 files changed, 3984 insertions(+) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 README.md create mode 100644 claude_desktop_config.json create mode 100644 main.py create mode 100644 requirements.txt create mode 100644 run.bat create mode 100644 src/__init__.py create mode 100644 src/async_task/__init__.py create mode 100644 src/async_task/models.py create mode 100644 src/async_task/task_manager.py create mode 100644 src/async_task/worker_pool.py create mode 100644 src/connector/__init__.py create mode 100644 src/connector/config.py create mode 100644 src/connector/imagen4_client.py create mode 100644 src/connector/utils.py create mode 100644 src/server/__init__.py create mode 100644 src/server/enhanced_handlers.py create mode 100644 src/server/enhanced_mcp_server.py create mode 100644 src/server/enhanced_models.py create mode 100644 src/server/handlers.py create mode 100644 src/server/mcp_server.py create mode 100644 src/server/minimal_handler.py create mode 100644 src/server/models.py create mode 100644 src/server/safe_result_handler.py create mode 100644 src/utils/__init__.py create mode 100644 src/utils/image_utils.py create mode 100644 src/utils/token_utils.py create mode 100644 test_main.py create mode 100644 tests/test_connector.py create mode 100644 tests/test_server.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..a043900 --- /dev/null +++ b/.env.example @@ -0,0 +1,13 @@ +# Google Cloud 설정 +GOOGLE_CLOUD_PROJECT_ID=your-project-id +GOOGLE_CLOUD_LOCATION=us-central1 + +# 이미지 저장 설정 +GENERATED_IMAGES_PATH=./generated_images + +# Imagen4 모델 설정 +# IMAGEN4_DEFAULT_MODEL=imagen-4.0-generate-001 # 기본 모델 +# IMAGEN4_DEFAULT_MODEL=imagen-4.0-ultra-generate-001 # Ultra 모델 (더 높은 품질) + +# 선택적 설정 +# LOG_LEVEL=INFO \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..70e14d4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,119 @@ +# 바이트 컴파일된 파일 / 최적화된 파일 / DLL 파일 +__pycache__/ +*.py[cod] +*$py.class + +# C 확장 +*.so + +# 배포 +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +*.manifest +*.spec + +# 단위 테스트 / 커버리지 보고서 +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +Pipfile.lock + +# PEP 582 +__pypackages__/ + +# Celery +celerybeat-schedule +celerybeat.pid + +# SageMath +*.sage.py + +# 환경 +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder IDE +.spyderproject +.spyproject + +# Rope 프로젝트 +.ropeproject + +# mkdocs +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre +.pyre/ + +# Google Cloud 인증 키 +*.json +!claude_desktop_config.json + +# 로그 파일 +*.log + +# 생성된 이미지 +generated_images/ +*.png +*.jpg +*.jpeg + +# VS Code +.vscode/ + +# PyCharm +.idea/ + +# 임시 파일들 +temp_* +cleanup* \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..54f51c3 --- /dev/null +++ b/README.md @@ -0,0 +1,204 @@ +# Imagen4 MCP Server with Preview Images + +Google Imagen 4를 사용한 AI 이미지 생성 MCP 서버입니다. 하나의 `main.py` 파일로 모든 기능이 통합되어 있습니다. + +## 🚀 주요 기능 + +- **고품질 이미지 생성**: 2048x2048 PNG 이미지 생성 +- **미리보기 이미지**: 512x512 JPEG 미리보기를 base64로 제공 +- **파일 저장**: 원본 PNG + 메타데이터 JSON 저장 +- **재생성**: JSON 파일로부터 동일한 설정으로 재생성 +- **랜덤 시드**: 재현 가능한 결과를 위한 시드 생성 + +## 📦 설치 + +### 1. 의존성 설치 +```bash +pip install -r requirements.txt +``` + +필요한 패키지: +- `google-genai>=0.2.0` - Google Imagen 4 API +- `mcp>=1.0.0` - MCP 프로토콜 +- `python-dotenv>=1.0.0` - 환경변수 관리 +- `Pillow>=10.0.0` - 이미지 처리 + +### 2. 환경 설정 + +`.env` 파일을 생성하고 Google Cloud 자격 증명을 설정: + +```env +GOOGLE_APPLICATION_CREDENTIALS=path/to/your/service-account-key.json +PROJECT_ID=your-gcp-project-id +LOCATION=us-central1 +OUTPUT_PATH=./generated_images +``` + +## 🎮 사용법 + +### 서버 실행 + +#### Windows +```bash +run.bat +``` + +#### 직접 실행 +```bash +python main.py +``` + +### Claude Desktop 설정 + +`claude_desktop_config.json` 내용을 Claude Desktop 설정에 추가: + +```json +{ + "mcpServers": { + "imagen4": { + "command": "python", + "args": ["main.py"], + "cwd": "D:\\Project\\little-fairy\\imagen4" + } + } +} +``` + +## 🛠️ 사용 가능한 도구 + +### 1. generate_image +고품질 이미지를 생성하고 미리보기를 제공합니다. + +**파라미터:** +- `prompt` (필수): 이미지 생성 프롬프트 +- `seed` (필수): 재현 가능한 결과를 위한 시드 값 +- `negative_prompt` (선택): 제외할 요소 지정 +- `number_of_images` (선택): 생성할 이미지 수 (1-2) +- `aspect_ratio` (선택): 종횡비 ("1:1", "9:16", "16:9", "3:4", "4:3") +- `save_to_file` (선택): 파일 저장 여부 + +**응답 예시:** +``` +✅ Images have been successfully generated! + +🖼️ Preview Images Generated: 1 images (512x512 JPEG) +Preview 1 (base64 JPEG): /9j/4AAQSkZJRgABAQAAAQABAAD/2wBD...(67234 chars) + +📁 Files saved: + - ./generated_images/imagen4_20250821_143052_seed_1234567890.png + - ./generated_images/imagen4_20250821_143052_seed_1234567890.json + +⚙️ Generation Parameters: + - prompt: A beautiful sunset over mountains + - seed: 1234567890 + - aspect_ratio: 1:1 + - number_of_images: 1 +``` + +### 2. generate_random_seed +이미지 생성용 랜덤 시드를 생성합니다. + +### 3. regenerate_from_json +저장된 JSON 파라미터 파일로부터 이미지를 재생성합니다. + +## 🖼️ 미리보기 이미지 특징 + +### 변환 과정 +1. **원본**: 2048x2048 PNG (약 8-15MB) +2. **변환**: 512x512 JPEG (약 50-200KB) +3. **최적화**: 품질 85, 최적화 활성화 +4. **인코딩**: Base64 문자열 + +### 장점 +- **95% 크기 감소**: 빠른 전송 및 표시 +- **즉시 사용**: 웹 환경에서 바로 표시 가능 +- **투명도 처리**: PNG 투명도를 흰색 배경으로 변환 +- **종횡비 유지**: 원본 비율을 유지하면서 크기 조정 + +## 🧪 테스트 + +기능을 테스트하려면: + +```bash +python test_main.py +``` + +이 스크립트는 다음을 확인합니다: +- 모든 클래스 및 함수 import +- 이미지 처리 기능 +- 미리보기 생성 기능 + +## 📁 프로젝트 구조 + +``` +imagen4/ +├── main.py # 통합된 메인 서버 파일 +├── run.bat # Windows 실행 파일 +├── test_main.py # 테스트 스크립트 +├── claude_desktop_config.json # Claude Desktop 설정 +├── requirements.txt # 의존성 목록 +├── .env # 환경 변수 (사용자가 생성) +├── generated_images/ # 생성된 이미지 저장소 +└── src/ + └── connector/ # Google API 연결 모듈 + ├── __init__.py + ├── config.py + ├── imagen4_client.py + └── utils.py +``` + +## 🔧 기술적 세부사항 + +### 이미지 처리 +- **PIL (Pillow)** 사용으로 안정적인 이미지 변환 +- **LANCZOS 리샘플링**으로 고품질 크기 조정 +- **투명도 자동 처리**: RGBA → RGB 변환 시 흰색 배경 적용 +- **중앙 정렬**: 목표 크기보다 작은 이미지는 중앙에 배치 + +### MCP 프로토콜 +- **버전 3.0.0**: 미리보기 이미지 지원 +- **비동기 처리**: asyncio 기반 안정적인 비동기 작업 +- **오류 처리**: 상세한 로깅 및 사용자 친화적 오류 메시지 +- **타임아웃**: 6분 타임아웃으로 안전한 작업 보장 + +## 🐛 문제 해결 + +### 일반적인 문제 + +1. **Pillow 설치 오류** + ```bash + pip install --upgrade pip + pip install Pillow + ``` + +2. **Google Cloud 인증 오류** + - 서비스 계정 키 파일 경로 확인 + - PROJECT_ID가 올바른지 확인 + - Imagen API가 활성화되어 있는지 확인 + +3. **메모리 부족** + - 이미지 수를 1개로 제한 + - 시스템 메모리 확인 (최소 8GB 권장) + +### 로그 확인 + +서버 실행 시 자세한 로그가 stderr로 출력됩니다: +- 이미지 생성 진행 상황 +- 미리보기 생성 과정 +- 파일 저장 상태 +- 오류 및 경고 메시지 + +## 📄 라이센스 + +MIT License - 자세한 내용은 LICENSE 파일을 참조하세요. + +## 🎯 요약 + +이 통합된 `main.py`는 다음과 같은 이점을 제공합니다: + +- **단순성**: 하나의 파일로 모든 기능 제공 +- **효율성**: 512x512 JPEG 미리보기로 빠른 응답 +- **호환성**: 기존 MCP 클라이언트와 완벽 호환 +- **확장성**: 필요에 따라 쉽게 기능 추가 가능 + +Google Imagen 4의 강력한 이미지 생성 기능을 MCP 프로토콜을 통해 효율적으로 활용할 수 있습니다! diff --git a/claude_desktop_config.json b/claude_desktop_config.json new file mode 100644 index 0000000..4aba09e --- /dev/null +++ b/claude_desktop_config.json @@ -0,0 +1,12 @@ +{ + "mcpServers": { + "imagen4": { + "command": "python", + "args": ["main.py"], + "cwd": "D:\\Project\\imagen4", + "env": { + "PYTHONPATH": "D:\\Project\\imagen4" + } + } + } +} diff --git a/main.py b/main.py new file mode 100644 index 0000000..96f7340 --- /dev/null +++ b/main.py @@ -0,0 +1,706 @@ +#!/usr/bin/env python3 +""" +Imagen4 MCP Server with Preview Image Support + +Google Imagen 4를 사용한 AI 이미지 생성 MCP 서버 +- 2048x2048 PNG 원본 이미지 생성 +- 512x512 JPEG 미리보기 이미지 base64 제공 +- 향상된 응답 형식 + +Run from imagen4 root directory +""" + +import asyncio +import base64 +import json +import random +import logging +import sys +import os +import io +from typing import Dict, Any, List, Optional +from dataclasses import dataclass + +# Load environment variables +try: + from dotenv import load_dotenv + load_dotenv() +except ImportError: + print("Warning: python-dotenv not installed", file=sys.stderr) + +# Add current directory to PYTHONPATH +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, current_dir) + +# MCP imports +try: + from mcp.server.models import InitializationOptions + from mcp.server import NotificationOptions + from mcp.server.stdio import stdio_server + from mcp.server import Server + from mcp.types import Tool, TextContent +except ImportError as e: + print(f"Error importing MCP: {e}", file=sys.stderr) + print("Please install required packages: pip install mcp", file=sys.stderr) + sys.exit(1) + +# Connector imports +from src.connector import Config, Imagen4Client +from src.connector.imagen4_client import ImageGenerationRequest +from src.connector.utils import save_generated_images + +# Image processing import +try: + from PIL import Image +except ImportError as e: + print(f"Error importing Pillow: {e}", file=sys.stderr) + print("Please install Pillow: pip install Pillow", file=sys.stderr) + sys.exit(1) + +# Logging configuration +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[logging.StreamHandler(sys.stderr)] +) +logger = logging.getLogger("imagen4-mcp-server") + + +# ==================== Image Processing Utilities ==================== + +def create_preview_image_b64(image_data: bytes, target_size: int = 512, quality: int = 85) -> Optional[str]: + """ + Convert PNG image data to JPEG preview with specified size and return as base64 + + Args: + image_data: Original PNG image data in bytes + target_size: Target size for the preview (default: 512x512) + quality: JPEG quality (1-100, default: 85) + + Returns: + Base64 encoded JPEG image string, or None if conversion fails + """ + try: + # Open image from bytes + with Image.open(io.BytesIO(image_data)) as img: + # Convert to RGB if necessary (PNG might have alpha channel) + if img.mode in ('RGBA', 'LA', 'P'): + # Create white background for transparent images + background = Image.new('RGB', img.size, (255, 255, 255)) + if img.mode == 'P': + img = img.convert('RGBA') + background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None) + img = background + elif img.mode != 'RGB': + img = img.convert('RGB') + + # Resize to target size maintaining aspect ratio + img.thumbnail((target_size, target_size), Image.Resampling.LANCZOS) + + # If image is smaller than target size, pad it to exact size + if img.size != (target_size, target_size): + # Create new image with target size and white background + new_img = Image.new('RGB', (target_size, target_size), (255, 255, 255)) + # Center the resized image + x = (target_size - img.size[0]) // 2 + y = (target_size - img.size[1]) // 2 + new_img.paste(img, (x, y)) + img = new_img + + # Convert to JPEG and encode as base64 + output_buffer = io.BytesIO() + img.save(output_buffer, format='JPEG', quality=quality, optimize=True) + jpeg_data = output_buffer.getvalue() + + # Encode to base64 + base64_data = base64.b64encode(jpeg_data).decode('utf-8') + + logger.info(f"Preview image created: {target_size}x{target_size} JPEG, {len(jpeg_data)} bytes") + return base64_data + + except Exception as e: + logger.error(f"Failed to create preview image: {str(e)}") + return None + + +def get_image_info(image_data: bytes) -> Optional[dict]: + """Get image information""" + try: + with Image.open(io.BytesIO(image_data)) as img: + return { + 'format': img.format, + 'size': img.size, + 'mode': img.mode, + 'bytes': len(image_data) + } + except Exception as e: + logger.error(f"Failed to get image info: {str(e)}") + return None + + +# ==================== Response Models ==================== + +@dataclass +class ImageGenerationResult: + """Enhanced image generation result with preview support""" + success: bool + message: str + original_images_count: int + preview_images_b64: Optional[list[str]] = None # List of base64 JPEG previews (512x512) + saved_files: Optional[list[str]] = None + generation_params: Optional[Dict[str, Any]] = None + error_message: Optional[str] = None + + def to_text_content(self) -> str: + """Convert result to text format for MCP response""" + lines = [self.message] + + if self.success and self.preview_images_b64: + lines.append(f"\n🖼️ Preview Images Generated: {len(self.preview_images_b64)} images (512x512 JPEG)") + for i, preview_b64 in enumerate(self.preview_images_b64): + lines.append(f"Preview {i+1} (base64 JPEG): {preview_b64[:50]}...({len(preview_b64)} chars)") + + if self.saved_files: + lines.append(f"\n📁 Files saved:") + for filepath in self.saved_files: + lines.append(f" - {filepath}") + + if self.generation_params: + lines.append(f"\n⚙️ Generation Parameters:") + for key, value in self.generation_params.items(): + if key == 'prompt' and len(str(value)) > 100: + lines.append(f" - {key}: {str(value)[:100]}...") + else: + lines.append(f" - {key}: {value}") + + return "\n".join(lines) + + +# ==================== MCP Tool Definitions ==================== + +def get_tools() -> List[Tool]: + """Return all tool definitions""" + return [ + Tool( + name="generate_image", + description="Generate 2048x2048 PNG images with 512x512 JPEG previews using Google Imagen 4 AI.", + inputSchema={ + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "Text prompt for image generation (Korean or English)" + }, + "negative_prompt": { + "type": "string", + "description": "Negative prompt specifying elements not to generate (optional)", + "default": "" + }, + "number_of_images": { + "type": "integer", + "description": "Number of images to generate (1 or 2 only)", + "enum": [1, 2], + "default": 1 + }, + "seed": { + "type": "integer", + "description": "Seed value for reproducible results (required, 0 ~ 4294967295 range)", + "minimum": 0, + "maximum": 4294967295 + }, + "aspect_ratio": { + "type": "string", + "description": "Image aspect ratio", + "enum": ["1:1", "9:16", "16:9", "3:4", "4:3"], + "default": "1:1" + }, + "save_to_file": { + "type": "boolean", + "description": "Whether to save generated images to files (default: true)", + "default": True + } + }, + "required": ["prompt", "seed"] + } + ), + Tool( + name="regenerate_from_json", + description="Read parameters from JSON file and regenerate images with previews.", + inputSchema={ + "type": "object", + "properties": { + "json_file_path": { + "type": "string", + "description": "Path to JSON file containing saved parameters" + }, + "save_to_file": { + "type": "boolean", + "description": "Whether to save regenerated images to files (default: true)", + "default": True + } + }, + "required": ["json_file_path"] + } + ), + Tool( + name="generate_random_seed", + description="Generate random seed value for image generation.", + inputSchema={ + "type": "object", + "properties": {}, + "required": [] + } + ) + ] + + +# ==================== Utility Functions ==================== + +def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]: + """Remove or truncate sensitive data from arguments for safe logging""" + safe_args = {} + for key, value in arguments.items(): + if isinstance(value, str): + # Check if it's likely base64 image data + if (key in ['data', 'image_data', 'base64', 'preview_image_b64'] or + (len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))): + # Truncate long image data + safe_args[key] = f"" + elif len(value) > 1000: + # Truncate any very long strings + safe_args[key] = f"{value[:100]}..." + else: + safe_args[key] = value + else: + safe_args[key] = value + return safe_args + + +# ==================== Tool Handlers ==================== + +class Imagen4ToolHandlers: + """MCP tool handler class with preview image support""" + + def __init__(self, config: Config): + """Initialize handler""" + self.config = config + self.client = Imagen4Client(config) + + async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]: + """Random seed generation handler""" + try: + random_seed = random.randint(0, 2**32 - 1) + return [TextContent( + type="text", + text=f"Generated random seed: {random_seed}" + )] + except Exception as e: + logger.error(f"Random seed generation error: {str(e)}") + return [TextContent( + type="text", + text=f"Error occurred during random seed generation: {str(e)}" + )] + + async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent]: + """Image regeneration from JSON file handler with preview support""" + try: + json_file_path = arguments.get("json_file_path") + save_to_file = arguments.get("save_to_file", True) + + if not json_file_path: + return [TextContent( + type="text", + text="Error: JSON file path is required." + )] + + # Load parameters from JSON file + try: + with open(json_file_path, 'r', encoding='utf-8') as f: + params = json.load(f) + except FileNotFoundError: + return [TextContent( + type="text", + text=f"Error: Cannot find JSON file: {json_file_path}" + )] + except json.JSONDecodeError as e: + return [TextContent( + type="text", + text=f"Error: JSON parsing error: {str(e)}" + )] + + # Check required parameters + required_params = ['prompt', 'seed'] + missing_params = [p for p in required_params if p not in params] + + if missing_params: + return [TextContent( + type="text", + text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}" + )] + + # Create image generation request object + request = ImageGenerationRequest( + prompt=params.get('prompt'), + negative_prompt=params.get('negative_prompt', ''), + number_of_images=params.get('number_of_images', 1), + seed=params.get('seed'), + aspect_ratio=params.get('aspect_ratio', '1:1') + ) + + logger.info(f"Loaded parameters from JSON: {json_file_path}") + + # Generate image + response = await self.client.generate_image(request) + + if not response.success: + return [TextContent( + type="text", + text=f"Error occurred during image regeneration: {response.error_message}" + )] + + # Create preview images + preview_images_b64 = [] + for image_data in response.images_data: + preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85) + if preview_b64: + preview_images_b64.append(preview_b64) + logger.info(f"Created preview image: {len(preview_b64)} chars base64 JPEG") + + # Save files (optional) + saved_files = [] + if save_to_file: + regeneration_params = { + "prompt": request.prompt, + "negative_prompt": request.negative_prompt, + "number_of_images": request.number_of_images, + "seed": request.seed, + "aspect_ratio": request.aspect_ratio, + "regenerated_from": json_file_path, + "original_generated_at": params.get('generated_at', 'unknown') + } + + saved_files = save_generated_images( + images_data=response.images_data, + save_directory=self.config.output_path, + seed=request.seed, + generation_params=regeneration_params, + filename_prefix="imagen4_regen" + ) + + # Create result with preview images + result = ImageGenerationResult( + success=True, + message=f"✅ Images have been successfully regenerated from {json_file_path}", + original_images_count=len(response.images_data), + preview_images_b64=preview_images_b64, + saved_files=saved_files, + generation_params={ + "prompt": request.prompt, + "seed": request.seed, + "aspect_ratio": request.aspect_ratio, + "number_of_images": request.number_of_images + } + ) + + return [TextContent( + type="text", + text=result.to_text_content() + )] + + except Exception as e: + logger.error(f"Error occurred during image regeneration: {str(e)}") + return [TextContent( + type="text", + text=f"Error occurred during image regeneration: {str(e)}" + )] + + async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent]: + """Image generation handler with preview image support""" + try: + # Log arguments safely without exposing image data + safe_args = sanitize_args_for_logging(arguments) + logger.info(f"handle_generate_image called with arguments: {safe_args}") + + # Extract and validate arguments + prompt = arguments.get("prompt") + if not prompt: + logger.error("No prompt provided") + return [TextContent( + type="text", + text="Error: Prompt is required." + )] + + seed = arguments.get("seed") + if seed is None: + logger.error("No seed provided") + return [TextContent( + type="text", + text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed." + )] + + # Create image generation request object + request = ImageGenerationRequest( + prompt=prompt, + negative_prompt=arguments.get("negative_prompt", ""), + number_of_images=arguments.get("number_of_images", 1), + seed=seed, + aspect_ratio=arguments.get("aspect_ratio", "1:1") + ) + + save_to_file = arguments.get("save_to_file", True) + + logger.info(f"Starting image generation: '{prompt[:50]}...', Seed: {seed}") + + # Generate image with timeout + try: + logger.info("Calling client.generate_image()...") + response = await asyncio.wait_for( + self.client.generate_image(request), + timeout=360.0 # 6 minute timeout + ) + logger.info(f"Image generation completed. Success: {response.success}") + except asyncio.TimeoutError: + logger.error("Image generation timed out after 6 minutes") + return [TextContent( + type="text", + text="Error: Image generation timed out after 6 minutes. Please try again." + )] + + if not response.success: + logger.error(f"Image generation failed: {response.error_message}") + return [TextContent( + type="text", + text=f"Error occurred during image generation: {response.error_message}" + )] + + logger.info(f"Generated {len(response.images_data)} images successfully") + + # Create preview images (512x512 JPEG base64) + preview_images_b64 = [] + for i, image_data in enumerate(response.images_data): + # Log original image info + image_info = get_image_info(image_data) + if image_info: + logger.info(f"Original image {i+1}: {image_info}") + + # Create preview + preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85) + if preview_b64: + preview_images_b64.append(preview_b64) + logger.info(f"Created preview image {i+1}: {len(preview_b64)} chars base64 JPEG") + else: + logger.warning(f"Failed to create preview for image {i+1}") + + # Save files if requested + saved_files = [] + if save_to_file: + logger.info("Saving files to disk...") + generation_params = { + "prompt": request.prompt, + "negative_prompt": request.negative_prompt, + "number_of_images": request.number_of_images, + "seed": request.seed, + "aspect_ratio": request.aspect_ratio, + "guidance_scale": 7.5, + "safety_filter_level": "block_only_high", + "person_generation": "allow_all", + "add_watermark": False + } + + saved_files = save_generated_images( + images_data=response.images_data, + save_directory=self.config.output_path, + seed=request.seed, + generation_params=generation_params + ) + logger.info(f"Files saved: {saved_files}") + + # Verify files were created + for file_path in saved_files: + if os.path.exists(file_path): + size = os.path.getsize(file_path) + logger.info(f" ✓ Verified: {file_path} ({size} bytes)") + else: + logger.error(f" ❌ Missing: {file_path}") + + # Create enhanced result with preview images + result = ImageGenerationResult( + success=True, + message=f"✅ Images have been successfully generated!", + original_images_count=len(response.images_data), + preview_images_b64=preview_images_b64, + saved_files=saved_files, + generation_params={ + "prompt": request.prompt, + "seed": request.seed, + "aspect_ratio": request.aspect_ratio, + "number_of_images": request.number_of_images, + "negative_prompt": request.negative_prompt + } + ) + + logger.info(f"Returning response with {len(preview_images_b64)} preview images") + return [TextContent( + type="text", + text=result.to_text_content() + )] + + except Exception as e: + logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True) + return [TextContent( + type="text", + text=f"Error occurred during image generation: {str(e)}" + )] + + +# ==================== MCP Server ==================== + +class Imagen4MCPServer: + """Imagen4 MCP server class with preview image support""" + + def __init__(self, config: Config): + """Initialize server""" + self.config = config + self.server = Server("imagen4-mcp-server") + self.handlers = Imagen4ToolHandlers(config) + + # Register handlers + self._register_handlers() + + def _register_handlers(self) -> None: + """Register MCP handlers""" + + @self.server.list_tools() + async def handle_list_tools() -> List[Tool]: + """Return list of available tools""" + try: + logger.info("Listing available tools with preview image support") + return get_tools() + except Exception as e: + logger.error(f"Error listing tools: {e}") + raise + + @self.server.call_tool() + async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: + """Handle tool calls with preview image support""" + try: + # Log tool call safely without exposing sensitive data + safe_args = sanitize_args_for_logging(arguments) + logger.info(f"Tool called: {name} with arguments: {safe_args}") + + if name == "generate_random_seed": + return await self.handlers.handle_generate_random_seed(arguments) + elif name == "regenerate_from_json": + return await self.handlers.handle_regenerate_from_json(arguments) + elif name == "generate_image": + return await self.handlers.handle_generate_image(arguments) + else: + error_msg = f"Unknown tool: {name}" + logger.error(error_msg) + return [TextContent( + type="text", + text=f"Error: {error_msg}" + )] + except Exception as e: + logger.error(f"Error handling tool call '{name}': {e}") + import traceback + logger.error(f"Tool call traceback: {traceback.format_exc()}") + return [TextContent( + type="text", + text=f"Error occurred while processing tool '{name}': {str(e)}" + )] + + def get_server(self) -> Server: + """Return MCP server instance""" + return self.server + + +# ==================== Main Function ==================== + +async def main(): + """Main function""" + logger.info("Starting Imagen 4 MCP Server with Preview Image Support...") + + try: + # Load configuration + config = Config.from_env() + logger.info(f"Configuration loaded - Project: {config.project_id}, Location: {config.location}") + + # Create MCP server + mcp_server = Imagen4MCPServer(config) + server = mcp_server.get_server() + + logger.info("Imagen 4 MCP Server initialized successfully") + logger.info("Features: 512x512 JPEG preview images, base64 encoding, enhanced responses") + + # Run MCP server with better error handling + try: + async with stdio_server() as (read_stream, write_stream): + logger.info("Starting MCP server with stdio transport") + await server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="imagen4-mcp-server", + server_version="3.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={ + "preview_images": {}, + "base64_jpeg_previews": {} + }, + ) + ) + ) + except Exception as stdio_error: + logger.error(f"STDIO server error: {stdio_error}") + # Try to handle specific MCP protocol errors + if "TaskGroup" in str(stdio_error): + logger.error("TaskGroup error detected - this may be due to client disconnection") + raise + + except ValueError as e: + logger.error(f"Configuration error: {e}") + logger.error("Please check your .env file or environment variables.") + return 1 + + except KeyboardInterrupt: + logger.info("Server stopped by user (Ctrl+C)") + return 0 + + except Exception as e: + logger.error(f"Server error: {e}") + logger.error(f"Error type: {type(e).__name__}") + import traceback + logger.error(f"Traceback: {traceback.format_exc()}") + raise + + +if __name__ == "__main__": + try: + # Set up signal handling for graceful shutdown + import signal + + def signal_handler(signum, frame): + logger.info(f"Received signal {signum}, shutting down gracefully...") + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + # Run the server + exit_code = asyncio.run(main()) + sys.exit(exit_code or 0) + + except KeyboardInterrupt: + logger.info("Server stopped by user") + sys.exit(0) + except SystemExit as e: + logger.info(f"Server exiting with code {e.code}") + sys.exit(e.code) + except Exception as e: + logger.error(f"Fatal error: {e}") + logger.error(f"Error type: {type(e).__name__}") + import traceback + logger.error(f"Full traceback: {traceback.format_exc()}") + sys.exit(1) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8702aec --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +# Requirements for separated Imagen4 architecture +google-genai>=0.2.0 +mcp>=1.0.0 +python-dotenv>=1.0.0 +pytest>=7.0.0 +pytest-asyncio>=0.21.0 +Pillow>=10.0.0 diff --git a/run.bat b/run.bat new file mode 100644 index 0000000..a5f5ada --- /dev/null +++ b/run.bat @@ -0,0 +1,50 @@ +@echo off +echo Starting Imagen4 MCP Server with Preview Image Support... +echo Features: 512x512 JPEG preview images, base64 encoding +echo. + +REM Check if virtual environment exists +if not exist "venv\Scripts\activate.bat" ( + echo Creating virtual environment... + python -m venv venv + if errorlevel 1 ( + echo Failed to create virtual environment + pause + exit /b 1 + ) +) + +REM Activate virtual environment +echo Activating virtual environment... +call venv\Scripts\activate.bat + +REM Check if requirements are installed +python -c "import PIL" 2>nul +if errorlevel 1 ( + echo Installing requirements... + pip install -r requirements.txt + if errorlevel 1 ( + echo Failed to install requirements + pause + exit /b 1 + ) +) + +REM Check if .env file exists +if not exist ".env" ( + echo Warning: .env file not found + echo Please copy .env.example to .env and configure your API credentials + pause + exit /b 1 +) + +REM Run server +echo Starting MCP server... +python main.py + +REM Keep window open if there's an error +if errorlevel 1 ( + echo. + echo Server exited with error + pause +) diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..371d8b0 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,10 @@ +""" +Imagen4 Package + +분리된 아키텍처를 가진 Google Imagen 4 MCP 서버 +""" + +# 기본 컴포넌트만 export +# 자세한 import는 각 모듈에서 직접 하도록 함 + +__version__ = "2.1.0" diff --git a/src/async_task/__init__.py b/src/async_task/__init__.py new file mode 100644 index 0000000..2c7d4c0 --- /dev/null +++ b/src/async_task/__init__.py @@ -0,0 +1,9 @@ +""" +Async Task Management Module for MCP Server +""" + +from .models import TaskStatus, TaskResult, generate_task_id +from .task_manager import TaskManager +from .worker_pool import WorkerPool + +__all__ = ['TaskManager', 'TaskStatus', 'TaskResult', 'WorkerPool'] diff --git a/src/async_task/models.py b/src/async_task/models.py new file mode 100644 index 0000000..f9d06fb --- /dev/null +++ b/src/async_task/models.py @@ -0,0 +1,48 @@ +""" +Task Status and Result Models +""" + +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Optional, Dict + + +class TaskStatus(Enum): + """Task execution status""" + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +@dataclass +class TaskResult: + """Task execution result""" + task_id: str + status: TaskStatus + result: Optional[Any] = None + error: Optional[str] = None + created_at: datetime = field(default_factory=datetime.now) + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def duration(self) -> Optional[float]: + """Calculate task execution duration in seconds""" + if self.started_at and self.completed_at: + return (self.completed_at - self.started_at).total_seconds() + return None + + @property + def is_finished(self) -> bool: + """Check if task is finished (completed, failed, or cancelled)""" + return self.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] + + +def generate_task_id() -> str: + """Generate unique task ID""" + return str(uuid.uuid4()) diff --git a/src/async_task/task_manager.py b/src/async_task/task_manager.py new file mode 100644 index 0000000..3c4e590 --- /dev/null +++ b/src/async_task/task_manager.py @@ -0,0 +1,324 @@ +""" +Task Manager for MCP Server +Manages background task execution with status tracking and result retrieval +""" + +import asyncio +import logging +from typing import Any, Callable, Optional, Dict, List +from datetime import datetime, timedelta + +from .models import TaskResult, TaskStatus, generate_task_id +from .worker_pool import WorkerPool + +logger = logging.getLogger(__name__) + + +class TaskManager: + """Main task manager for MCP server""" + + def __init__( + self, + max_workers: Optional[int] = None, + use_process_pool: bool = False, + default_timeout: float = 600.0, # 10 minutes + result_retention_hours: int = 24, # Keep results for 24 hours + max_retained_results: int = 200 + ): + """ + Initialize task manager + + Args: + max_workers: Maximum number of worker threads/processes + use_process_pool: Use process pool instead of thread pool + default_timeout: Default timeout for tasks in seconds + result_retention_hours: How long to keep finished task results + max_retained_results: Maximum number of results to retain + """ + self.worker_pool = WorkerPool( + max_workers=max_workers, + use_process_pool=use_process_pool, + task_timeout=default_timeout + ) + + self.result_retention_hours = result_retention_hours + self.max_retained_results = max_retained_results + + # Start cleanup task + self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) + + logger.info(f"TaskManager initialized with {self.worker_pool.max_workers} workers") + + async def submit_task( + self, + func: Callable, + *args, + task_id: Optional[str] = None, + timeout: Optional[float] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs + ) -> str: + """ + Submit a task for background execution + + Args: + func: Function to execute + *args: Function arguments + task_id: Optional custom task ID + timeout: Task timeout in seconds + metadata: Additional metadata for the task + **kwargs: Function keyword arguments + + Returns: + str: Task ID for tracking progress + """ + task_id = await self.worker_pool.submit_task( + func, *args, task_id=task_id, timeout=timeout, **kwargs + ) + + # Add additional metadata if provided + if metadata: + result = self.worker_pool.get_task_result(task_id) + if result: + result.metadata.update(metadata) + + return task_id + + def get_task_status(self, task_id: str) -> Optional[TaskStatus]: + """Get task status by ID""" + result = self.worker_pool.get_task_result(task_id) + return result.status if result else None + + def get_task_result(self, task_id: str) -> Optional[TaskResult]: + """Get complete task result by ID with circular reference protection""" + try: + raw_result = self.worker_pool.get_task_result(task_id) + if not raw_result: + return None + + # Create a safe copy to avoid circular references + safe_result = TaskResult( + task_id=raw_result.task_id, + status=raw_result.status, + result=None, # We'll set this safely + error=raw_result.error, + created_at=raw_result.created_at, + started_at=raw_result.started_at, + completed_at=raw_result.completed_at, + metadata=raw_result.metadata.copy() if raw_result.metadata else {} + ) + + # Safely copy the result data + if raw_result.result is not None: + try: + # If it's a dict, create a clean copy + if isinstance(raw_result.result, dict): + safe_result.result = { + 'success': raw_result.result.get('success', False), + 'error': raw_result.result.get('error'), + 'images_b64': raw_result.result.get('images_b64', []), + 'saved_files': raw_result.result.get('saved_files', []), + 'request': raw_result.result.get('request', {}), + 'image_count': raw_result.result.get('image_count', 0) + } + else: + # For non-dict results, just copy directly + safe_result.result = raw_result.result + except Exception as copy_error: + logger.error(f"Error creating safe copy of result: {str(copy_error)}") + safe_result.result = { + 'success': False, + 'error': f'Result data corrupted: {str(copy_error)}' + } + + return safe_result + + except Exception as e: + logger.error(f"Error in get_task_result: {str(e)}", exc_info=True) + return None + + def get_task_progress(self, task_id: str) -> Dict[str, Any]: + """ + Get task progress information + + Returns: + Dict with task status, timing info, and metadata + """ + result = self.get_task_result(task_id) + if not result: + return {'error': 'Task not found'} + + progress = { + 'task_id': result.task_id, + 'status': result.status.value, + 'created_at': result.created_at.isoformat(), + 'metadata': result.metadata + } + + if result.started_at: + progress['started_at'] = result.started_at.isoformat() + + if result.status == TaskStatus.RUNNING: + elapsed = (datetime.now() - result.started_at).total_seconds() + progress['elapsed_seconds'] = elapsed + + if result.completed_at: + progress['completed_at'] = result.completed_at.isoformat() + progress['duration_seconds'] = result.duration + + if result.error: + progress['error'] = result.error + + return progress + + def list_tasks( + self, + status_filter: Optional[TaskStatus] = None, + limit: int = 50 + ) -> List[Dict[str, Any]]: + """ + List tasks with optional filtering + + Args: + status_filter: Filter by task status + limit: Maximum number of tasks to return + + Returns: + List of task progress dictionaries + """ + all_results = self.worker_pool.get_all_results() + + # Filter by status if requested + if status_filter: + filtered_results = { + k: v for k, v in all_results.items() + if v.status == status_filter + } + else: + filtered_results = all_results + + # Sort by creation time (most recent first) + sorted_items = sorted( + filtered_results.items(), + key=lambda x: x[1].created_at, + reverse=True + ) + + # Apply limit and return progress info + return [ + self.get_task_progress(task_id) + for task_id, _ in sorted_items[:limit] + ] + + def cancel_task(self, task_id: str) -> bool: + """Cancel a running task""" + return self.worker_pool.cancel_task(task_id) + + async def wait_for_task( + self, + task_id: str, + timeout: Optional[float] = None, + poll_interval: float = 1.0 + ) -> TaskResult: + """ + Wait for a task to complete + + Args: + task_id: Task ID to wait for + timeout: Maximum time to wait in seconds + poll_interval: How often to check task status + + Returns: + TaskResult: Final task result + + Raises: + asyncio.TimeoutError: If timeout is reached + ValueError: If task doesn't exist + """ + start_time = datetime.now() + + while True: + result = self.get_task_result(task_id) + if not result: + raise ValueError(f"Task {task_id} not found") + + if result.is_finished: + return result + + # Check timeout + if timeout: + elapsed = (datetime.now() - start_time).total_seconds() + if elapsed >= timeout: + raise asyncio.TimeoutError(f"Timeout waiting for task {task_id}") + + await asyncio.sleep(poll_interval) + + async def _periodic_cleanup(self) -> None: + """Periodic cleanup of old task results""" + while True: + try: + # Wait 1 hour between cleanups + await asyncio.sleep(3600) + + # Clean up old results + cutoff_time = datetime.now() - timedelta(hours=self.result_retention_hours) + all_results = self.worker_pool.get_all_results() + + to_remove = [] + for task_id, result in all_results.items(): + if (result.is_finished and + result.completed_at and + result.completed_at < cutoff_time): + to_remove.append(task_id) + + # Remove old results + for task_id in to_remove: + if task_id in self.worker_pool._results: + del self.worker_pool._results[task_id] + + if to_remove: + logger.info(f"Cleaned up {len(to_remove)} old task results") + + # Also clean up excess results + self.worker_pool.cleanup_finished_tasks(self.max_retained_results) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in periodic cleanup: {e}") + + async def shutdown(self, wait: bool = True) -> None: + """Shutdown task manager""" + logger.info("Shutting down task manager...") + + # Cancel cleanup task + if not self._cleanup_task.done(): + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + # Shutdown worker pool + await self.worker_pool.shutdown(wait=wait) + + logger.info("Task manager shutdown complete") + + @property + def stats(self) -> Dict[str, Any]: + """Get task manager statistics""" + all_results = self.worker_pool.get_all_results() + + status_counts = {} + for status in TaskStatus: + status_counts[status.value] = sum( + 1 for r in all_results.values() if r.status == status + ) + + return { + 'total_tasks': len(all_results), + 'active_tasks': self.worker_pool.active_task_count, + 'max_workers': self.worker_pool.max_workers, + 'worker_type': 'process' if self.worker_pool.use_process_pool else 'thread', + 'status_counts': status_counts + } diff --git a/src/async_task/worker_pool.py b/src/async_task/worker_pool.py new file mode 100644 index 0000000..c40cb61 --- /dev/null +++ b/src/async_task/worker_pool.py @@ -0,0 +1,258 @@ +""" +Worker Pool for Background Task Execution +""" + +import asyncio +import logging +import multiprocessing +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from typing import Any, Callable, Optional, Union +from datetime import datetime + +from .models import TaskResult, TaskStatus, generate_task_id + +logger = logging.getLogger(__name__) + + +class WorkerPool: + """Worker pool for executing background tasks""" + + def __init__( + self, + max_workers: Optional[int] = None, + use_process_pool: bool = False, + task_timeout: float = 600.0 # 10 minutes default + ): + """ + Initialize worker pool + + Args: + max_workers: Maximum number of worker threads/processes + use_process_pool: Use ProcessPoolExecutor instead of ThreadPoolExecutor + task_timeout: Default timeout for tasks in seconds + """ + self.max_workers = max_workers or min(32, (multiprocessing.cpu_count() or 1) + 4) + self.use_process_pool = use_process_pool + self.task_timeout = task_timeout + + # Initialize executor + if use_process_pool: + self.executor = ProcessPoolExecutor(max_workers=self.max_workers) + logger.info(f"Initialized ProcessPoolExecutor with {self.max_workers} workers") + else: + self.executor = ThreadPoolExecutor(max_workers=self.max_workers) + logger.info(f"Initialized ThreadPoolExecutor with {self.max_workers} workers") + + # Track running tasks + self._running_tasks: dict[str, asyncio.Task] = {} + self._results: dict[str, TaskResult] = {} + + async def submit_task( + self, + func: Callable, + *args, + task_id: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs + ) -> str: + """ + Submit a task for background execution + + Args: + func: Function to execute + *args: Function arguments + task_id: Optional custom task ID + timeout: Task timeout in seconds + **kwargs: Function keyword arguments + + Returns: + str: Task ID for tracking + """ + if task_id is None: + task_id = generate_task_id() + + timeout = timeout or self.task_timeout + + # Create task result + task_result = TaskResult( + task_id=task_id, + status=TaskStatus.PENDING, + metadata={ + 'function_name': func.__name__, + 'timeout': timeout, + 'worker_type': 'process' if self.use_process_pool else 'thread' + } + ) + + self._results[task_id] = task_result + + # Create and start background task + background_task = asyncio.create_task( + self._execute_task(func, args, kwargs, task_result, timeout) + ) + + self._running_tasks[task_id] = background_task + + logger.info(f"Task {task_id} submitted for execution (timeout: {timeout}s)") + return task_id + + async def _execute_task( + self, + func: Callable, + args: tuple, + kwargs: dict, + task_result: TaskResult, + timeout: float + ) -> None: + """Execute task in background""" + try: + task_result.status = TaskStatus.RUNNING + task_result.started_at = datetime.now() + + logger.info(f"Starting task {task_result.task_id}: {func.__name__}") + + # Execute function with timeout + if self.use_process_pool: + # For process pool, function must be pickleable + future = self.executor.submit(func, *args, **kwargs) + result = await asyncio.wait_for( + asyncio.wrap_future(future), + timeout=timeout + ) + else: + # For thread pool, can use lambda/closure + result = await asyncio.wait_for( + asyncio.to_thread(func, *args, **kwargs), + timeout=timeout + ) + + task_result.result = result + task_result.status = TaskStatus.COMPLETED + task_result.completed_at = datetime.now() + + logger.info(f"Task {task_result.task_id} completed successfully in {task_result.duration:.2f}s") + + except asyncio.TimeoutError: + task_result.status = TaskStatus.FAILED + task_result.error = f"Task timed out after {timeout} seconds" + task_result.completed_at = datetime.now() + logger.error(f"Task {task_result.task_id} timed out after {timeout}s") + + except asyncio.CancelledError: + task_result.status = TaskStatus.CANCELLED + task_result.error = "Task was cancelled" + task_result.completed_at = datetime.now() + logger.info(f"Task {task_result.task_id} was cancelled") + + except Exception as e: + task_result.status = TaskStatus.FAILED + task_result.error = str(e) + task_result.completed_at = datetime.now() + logger.error(f"Task {task_result.task_id} failed: {str(e)}", exc_info=True) + + finally: + # Clean up running task reference + if task_result.task_id in self._running_tasks: + del self._running_tasks[task_result.task_id] + + def get_task_result(self, task_id: str) -> Optional[TaskResult]: + """Get task result by ID""" + return self._results.get(task_id) + + def get_all_results(self) -> dict[str, TaskResult]: + """Get all task results""" + return self._results.copy() + + def cancel_task(self, task_id: str) -> bool: + """ + Cancel a running task + + Args: + task_id: Task ID to cancel + + Returns: + bool: True if task was cancelled, False if not found or already finished + """ + if task_id in self._running_tasks: + task = self._running_tasks[task_id] + if not task.done(): + task.cancel() + logger.info(f"Task {task_id} cancellation requested") + return True + return False + + def cleanup_finished_tasks(self, max_results: int = 100) -> int: + """ + Clean up finished task results to prevent memory buildup + + Args: + max_results: Maximum number of results to keep + + Returns: + int: Number of results cleaned up + """ + if len(self._results) <= max_results: + return 0 + + # Sort by completion time, keep most recent + finished_results = { + k: v for k, v in self._results.items() + if v.is_finished and v.completed_at + } + + if len(finished_results) <= max_results: + return 0 + + sorted_results = sorted( + finished_results.items(), + key=lambda x: x[1].completed_at, + reverse=True + ) + + # Keep most recent max_results + to_keep = set(k for k, _ in sorted_results[:max_results]) + + # Remove old results + to_remove = [k for k in finished_results.keys() if k not in to_keep] + for task_id in to_remove: + del self._results[task_id] + + cleaned_count = len(to_remove) + if cleaned_count > 0: + logger.info(f"Cleaned up {cleaned_count} old task results") + + return cleaned_count + + async def shutdown(self, wait: bool = True) -> None: + """ + Shutdown worker pool + + Args: + wait: Whether to wait for running tasks to complete + """ + logger.info("Shutting down worker pool...") + + if wait: + # Cancel all running tasks + for task_id, task in self._running_tasks.items(): + if not task.done(): + task.cancel() + logger.info(f"Cancelled task {task_id}") + + # Wait for tasks to complete cancellation + if self._running_tasks: + await asyncio.gather(*self._running_tasks.values(), return_exceptions=True) + + # Shutdown executor + self.executor.shutdown(wait=wait) + logger.info("Worker pool shutdown complete") + + @property + def active_task_count(self) -> int: + """Get number of currently running tasks""" + return len([t for t in self._running_tasks.values() if not t.done()]) + + @property + def total_task_count(self) -> int: + """Get total number of tracked tasks""" + return len(self._results) diff --git a/src/connector/__init__.py b/src/connector/__init__.py new file mode 100644 index 0000000..62293c8 --- /dev/null +++ b/src/connector/__init__.py @@ -0,0 +1,10 @@ +""" +Imagen4 Connector Package + +Google Imagen 4 API와의 통신을 담당하는 커넥터 모듈 +""" + +from .imagen4_client import Imagen4Client +from .config import Config + +__all__ = ['Imagen4Client', 'Config'] diff --git a/src/connector/config.py b/src/connector/config.py new file mode 100644 index 0000000..c792312 --- /dev/null +++ b/src/connector/config.py @@ -0,0 +1,101 @@ +""" +Configuration management for Imagen4 connector +""" + +import os +from typing import Optional +from dataclasses import dataclass + + +@dataclass +class Config: + """Imagen4 configuration class""" + project_id: str + location: str = "us-central1" + output_path: str = "./generated_images" + default_model: str = "imagen-4.0-generate-001" # Default model + + @classmethod + def from_env(cls, env_file: Optional[str] = None) -> 'Config': + """ + Load configuration from environment variables + + Args: + env_file: Optional path to .env file to load first + """ + # Try to load .env file if specified or if python-dotenv is available + if env_file or not os.getenv("GOOGLE_CLOUD_PROJECT_ID"): + try: + from dotenv import load_dotenv + + if env_file: + load_dotenv(env_file) + else: + # Try to find .env file in current directory or parent directories + current_dir = os.getcwd() + env_paths = [ + os.path.join(current_dir, '.env'), + os.path.join(os.path.dirname(current_dir), '.env'), + os.path.join(os.path.dirname(__file__), '..', '..', '.env') + ] + + for env_path in env_paths: + if os.path.exists(env_path): + load_dotenv(env_path) + break + + except ImportError: + pass # python-dotenv not available, continue with system env vars + + # Try multiple environment variable names for better compatibility + project_id = ( + os.getenv("GOOGLE_CLOUD_PROJECT_ID") or + os.getenv("GOOGLE_CLOUD_PROJECT") or + os.getenv("GCP_PROJECT") + ) + + if not project_id: + raise ValueError( + "Google Cloud Project ID is required. Please set one of these environment variables: " + "GOOGLE_CLOUD_PROJECT_ID, GOOGLE_CLOUD_PROJECT, or GCP_PROJECT" + ) + + # Also set GOOGLE_CLOUD_PROJECT for Google Cloud libraries if not already set + if not os.getenv("GOOGLE_CLOUD_PROJECT"): + os.environ["GOOGLE_CLOUD_PROJECT"] = project_id + + return cls( + project_id=project_id, + location=os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1"), + output_path=os.getenv("GENERATED_IMAGES_PATH", "./generated_images"), + default_model=os.getenv("IMAGEN4_DEFAULT_MODEL", "imagen-4.0-generate-001") + ) + + @classmethod + def for_testing(cls, project_id: str = "test-project-id") -> 'Config': + """Create a config instance for testing purposes""" + return cls( + project_id=project_id, + location="us-central1", + output_path="./test_generated_images", + default_model="imagen-4.0-generate-001" + ) + + def validate(self) -> bool: + """Validate configuration""" + if not self.project_id: + return False + if not self.location: + return False + if not self.output_path: + return False + return True + + def __str__(self) -> str: + """String representation of config""" + return ( + f"Config(project_id='{self.project_id}', " + f"location='{self.location}', " + f"output_path='{self.output_path}', " + f"default_model='{self.default_model}')" + ) diff --git a/src/connector/imagen4_client.py b/src/connector/imagen4_client.py new file mode 100644 index 0000000..5f3a3b8 --- /dev/null +++ b/src/connector/imagen4_client.py @@ -0,0 +1,215 @@ +""" +Google Imagen 4 API Client +""" + +import asyncio +import logging +from typing import List, Optional +from dataclasses import dataclass + +try: + from google import genai + from google.genai.types import GenerateImagesConfig +except ImportError as e: + raise ImportError(f"Google GenAI library not found: {e}. Install with: pip install google-genai") + +from .config import Config + +logger = logging.getLogger(__name__) + + +@dataclass +class ImageGenerationRequest: + """Image generation request data class""" + prompt: str + negative_prompt: str = "" + number_of_images: int = 1 + seed: int = 0 + aspect_ratio: str = "1:1" + model: str = "imagen-4.0-generate-001" # Model selection + + def validate(self) -> None: + """Validate request data""" + # Import here to avoid circular imports + from ..utils.token_utils import validate_prompt_length + + if not self.prompt: + raise ValueError("Prompt is required") + + # 프롬프트 토큰 수 검증 + is_valid, token_count, error_msg = validate_prompt_length(self.prompt) + if not is_valid: + raise ValueError(f"Prompt validation failed: {error_msg}") + + # negative_prompt도 검증 (더 관대한 제한) + if self.negative_prompt: + neg_is_valid, neg_token_count, neg_error_msg = validate_prompt_length( + self.negative_prompt, max_tokens=240 # negative prompt는 절반 제한 + ) + if not neg_is_valid: + raise ValueError(f"Negative prompt validation failed: {neg_error_msg}") + + if not isinstance(self.seed, int) or self.seed < 0 or self.seed > 4294967295: + raise ValueError("Seed value must be an integer between 0 and 4294967295") + + if self.number_of_images not in [1, 2]: + raise ValueError("Number of images must be 1 or 2 only") + + # Validate model name + valid_models = ["imagen-4.0-generate-001", "imagen-4.0-ultra-generate-001"] + if self.model not in valid_models: + raise ValueError(f"Model must be one of: {valid_models}") + + +@dataclass +class ImageGenerationResponse: + """Image generation response data class""" + images_data: List[bytes] + request: ImageGenerationRequest + success: bool = True + error_message: Optional[str] = None + + +class Imagen4Client: + """Google Imagen 4 API client""" + + def __init__(self, config: Config): + """ + Initialize client + + Args: + config: Imagen4 configuration object + """ + if not config.validate(): + raise ValueError("Invalid configuration") + + self.config = config + self._client = None + self._initialize_client() + + def _initialize_client(self) -> None: + """Initialize Google GenAI client""" + try: + self._client = genai.Client( + vertexai=True, + project=self.config.project_id, + location=self.config.location + ) + + # Log client information + if not self._client.vertexai: + logger.info("Using Gemini Developer API.") + elif self._client._api_client.project: + logger.info(f"Using Vertex AI with project: {self._client._api_client.project} in location: {self._client._api_client.location}") + elif self._client._api_client.api_key: + logger.info(f"Using Vertex AI in express mode with API key: {self._client._api_client.api_key[:5]}...{self._client._api_client.api_key[-5:]}") + + except Exception as e: + logger.error(f"Failed to initialize Imagen4Client: {e}") + raise + + async def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + """ + Process image generation request with MCP-safe execution + + Args: + request: Image generation request object + + Returns: + ImageGenerationResponse: Generation result + """ + try: + # Validate request + request.validate() + + # 토큰 수 정보 로깅 + from ..utils.token_utils import get_prompt_stats + prompt_stats = get_prompt_stats(request.prompt) + logger.info(f"Prompt token analysis: {prompt_stats['estimated_tokens']}/{prompt_stats['max_tokens']} tokens") + + if request.negative_prompt: + neg_stats = get_prompt_stats(request.negative_prompt) + logger.info(f"Negative prompt token analysis: {neg_stats['estimated_tokens']} tokens") + + logger.info(f"Starting image generation - Prompt: '{request.prompt[:50]}...', Seed: {request.seed}, Model: {request.model}") + print(f"Aspect Ratio: {request.aspect_ratio}, Number of Images: {request.number_of_images}, Model: {request.model}") + print(f"Prompt Tokens: {prompt_stats['estimated_tokens']}/{prompt_stats['max_tokens']}") + + # Create a new event loop if needed (for MCP compatibility) + def _sync_generate_images(): + """Synchronous wrapper for API call""" + return self._client.models.generate_images( + model=request.model, + prompt=request.prompt, + config=GenerateImagesConfig( + add_watermark=False, + aspect_ratio=request.aspect_ratio, + image_size="2K", # 2048x2048 + negative_prompt=request.negative_prompt, + number_of_images=request.number_of_images, + output_mime_type="image/png", + person_generation="allow_all", + safety_filter_level="block_only_high", + seed=request.seed, + ) + ) + + # Run in thread with explicit timeout + logger.info("Making API call with thread executor...") + response = await asyncio.wait_for( + asyncio.to_thread(_sync_generate_images), + timeout=300.0 # 5 minute timeout (increased from 90s) + ) + logger.info("API call completed successfully") + + print(f"Get response: {response}") + + # Extract image data from response + images_data = [] + if response.generated_images: + for i, gen_image in enumerate(response.generated_images): + if hasattr(gen_image, 'image') and hasattr(gen_image.image, 'image_bytes'): + image_bytes = gen_image.image.image_bytes + if image_bytes: + images_data.append(image_bytes) + logger.info(f"Image {i+1} extraction complete (size: {len(image_bytes)} bytes)") + else: + logger.warning(f"Image {i+1}'s image_bytes is empty.") + else: + logger.warning(f"Cannot find image_bytes in image {i+1}.") + + if not images_data: + raise Exception("Cannot find generated image data.") + + logger.info(f"Total {len(images_data)} images generated successfully") + + return ImageGenerationResponse( + images_data=images_data, + request=request, + success=True + ) + + except asyncio.TimeoutError: + error_msg = "Image generation timed out after 5 minutes" + logger.error(error_msg) + return ImageGenerationResponse( + images_data=[], + request=request, + success=False, + error_message=error_msg + ) + except Exception as e: + logger.error(f"Error occurred during image generation: {str(e)}") + return ImageGenerationResponse( + images_data=[], + request=request, + success=False, + error_message=str(e) + ) + + def health_check(self) -> bool: + """Check client status""" + try: + return self._client is not None and self.config.validate() + except Exception: + return False diff --git a/src/connector/utils.py b/src/connector/utils.py new file mode 100644 index 0000000..e6cc176 --- /dev/null +++ b/src/connector/utils.py @@ -0,0 +1,70 @@ +""" +Utility functions for Imagen4 connector +""" + +import os +import json +import logging +from datetime import datetime +from typing import List, Dict, Any, Optional + +logger = logging.getLogger(__name__) + + +def save_generated_images( + images_data: List[bytes], + save_directory: str = "./generated_images", + filename_prefix: str = "imagen4", + seed: Optional[int] = None, + generation_params: Optional[Dict[str, Any]] = None +) -> List[str]: + """Save generated images to files""" + os.makedirs(save_directory, exist_ok=True) + saved_files = [] + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + for i, image_data in enumerate(images_data): + try: + if seed is not None: + base_filename = f"{filename_prefix}_{seed}_{timestamp}_{i+1:03d}" + else: + base_filename = f"{filename_prefix}_{timestamp}_{i+1:03d}" + + # Save image file + image_filename = f"{base_filename}.png" + image_filepath = os.path.join(save_directory, image_filename) + + with open(image_filepath, 'wb') as f: + f.write(image_data) + + saved_files.append(image_filepath) + logger.info(f"Image saved successfully: {image_filepath}") + + # Save JSON parameter file + if generation_params: + json_filename = f"{base_filename}.json" + json_filepath = os.path.join(save_directory, json_filename) + + params_with_metadata = { + **generation_params, + "generated_at": datetime.now().isoformat(), + "image_filename": image_filename, + "image_index": i + 1, + "total_images": len(images_data), + "file_size_bytes": len(image_data), + "model_version": "imagen-4.0-generate-001", + "image_format": "PNG", + "image_size": "2048x2048" + } + + with open(json_filepath, 'w', encoding='utf-8') as f: + json.dump(params_with_metadata, f, ensure_ascii=False, indent=2) + + saved_files.append(json_filepath) + logger.info(f"Parameters saved successfully: {json_filepath}") + + except Exception as e: + logger.error(f"Failed to save image {i+1}: {str(e)}") + continue + + return saved_files \ No newline at end of file diff --git a/src/server/__init__.py b/src/server/__init__.py new file mode 100644 index 0000000..0aa1108 --- /dev/null +++ b/src/server/__init__.py @@ -0,0 +1,26 @@ +""" +Imagen4 Server Package + +MCP 서버 구현을 담당하는 모듈 +""" + +# Import basic server components +from src.server.mcp_server import Imagen4MCPServer +from src.server.handlers import ToolHandlers +from src.server.models import MCPToolDefinitions + +# Import enhanced components +try: + from src.server.enhanced_mcp_server import EnhancedImagen4MCPServer + from src.server.enhanced_handlers import EnhancedToolHandlers + from src.server.enhanced_models import ImageGenerationResult, PreviewImageResponse + + __all__ = [ + 'Imagen4MCPServer', 'ToolHandlers', 'MCPToolDefinitions', + 'EnhancedImagen4MCPServer', 'EnhancedToolHandlers', + 'ImageGenerationResult', 'PreviewImageResponse' + ] +except ImportError as e: + # Fallback if enhanced modules are not available + print(f"Warning: Enhanced features not available: {e}") + __all__ = ['Imagen4MCPServer', 'ToolHandlers', 'MCPToolDefinitions'] diff --git a/src/server/enhanced_handlers.py b/src/server/enhanced_handlers.py new file mode 100644 index 0000000..16d42ac --- /dev/null +++ b/src/server/enhanced_handlers.py @@ -0,0 +1,322 @@ +""" +Enhanced Tool Handlers for MCP Server with Preview Image Support +""" + +import asyncio +import base64 +import json +import random +import logging +from typing import List, Dict, Any + +from mcp.types import TextContent, ImageContent + +from src.connector import Imagen4Client, Config +from src.connector.imagen4_client import ImageGenerationRequest +from src.connector.utils import save_generated_images +from src.utils.image_utils import create_preview_image_b64, get_image_info +from src.server.enhanced_models import ImageGenerationResult, PreviewImageResponse + +logger = logging.getLogger(__name__) + + +def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]: + """Remove or truncate sensitive data from arguments for safe logging""" + safe_args = {} + for key, value in arguments.items(): + if isinstance(value, str): + # Check if it's likely base64 image data + if (key in ['data', 'image_data', 'base64', 'preview_image_b64'] or + (len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))): + # Truncate long image data + safe_args[key] = f"" + elif len(value) > 1000: + # Truncate any very long strings + safe_args[key] = f"{value[:100]}..." + else: + safe_args[key] = value + else: + safe_args[key] = value + return safe_args + + +class EnhancedToolHandlers: + """Enhanced MCP tool handler class with preview image support""" + + def __init__(self, config: Config): + """Initialize handler""" + self.config = config + self.client = Imagen4Client(config) + + async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]: + """Random seed generation handler""" + try: + random_seed = random.randint(0, 2**32 - 1) + return [TextContent( + type="text", + text=f"Generated random seed: {random_seed}" + )] + except Exception as e: + logger.error(f"Random seed generation error: {str(e)}") + return [TextContent( + type="text", + text=f"Error occurred during random seed generation: {str(e)}" + )] + + async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent]: + """Image regeneration from JSON file handler with preview support""" + try: + json_file_path = arguments.get("json_file_path") + save_to_file = arguments.get("save_to_file", True) + + if not json_file_path: + return [TextContent( + type="text", + text="Error: JSON file path is required." + )] + + # Load parameters from JSON file + try: + with open(json_file_path, 'r', encoding='utf-8') as f: + params = json.load(f) + except FileNotFoundError: + return [TextContent( + type="text", + text=f"Error: Cannot find JSON file: {json_file_path}" + )] + except json.JSONDecodeError as e: + return [TextContent( + type="text", + text=f"Error: JSON parsing error: {str(e)}" + )] + + # Check required parameters + required_params = ['prompt', 'seed'] + missing_params = [p for p in required_params if p not in params] + + if missing_params: + return [TextContent( + type="text", + text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}" + )] + + # Create image generation request object + request = ImageGenerationRequest( + prompt=params.get('prompt'), + negative_prompt=params.get('negative_prompt', ''), + number_of_images=params.get('number_of_images', 1), + seed=params.get('seed'), + aspect_ratio=params.get('aspect_ratio', '1:1'), + model=params.get('model', self.config.default_model) + ) + + logger.info(f"Loaded parameters from JSON: {json_file_path}") + + # Generate image + response = await self.client.generate_image(request) + + if not response.success: + return [TextContent( + type="text", + text=f"Error occurred during image regeneration: {response.error_message}" + )] + + # Create preview images + preview_images_b64 = [] + for image_data in response.images_data: + preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85) + if preview_b64: + preview_images_b64.append(preview_b64) + logger.info(f"Created preview image: {len(preview_b64)} chars base64 JPEG") + + # Save files (optional) + saved_files = [] + if save_to_file: + regeneration_params = { + "prompt": request.prompt, + "negative_prompt": request.negative_prompt, + "number_of_images": request.number_of_images, + "seed": request.seed, + "aspect_ratio": request.aspect_ratio, + "model": request.model, + "regenerated_from": json_file_path, + "original_generated_at": params.get('generated_at', 'unknown') + } + + saved_files = save_generated_images( + images_data=response.images_data, + save_directory=self.config.output_path, + seed=request.seed, + generation_params=regeneration_params, + filename_prefix="imagen4_regen" + ) + + # Create result with preview images + result = ImageGenerationResult( + success=True, + message=f"✅ Images have been successfully regenerated from {json_file_path}", + original_images_count=len(response.images_data), + preview_images_b64=preview_images_b64, + saved_files=saved_files, + generation_params={ + "prompt": request.prompt, + "seed": request.seed, + "aspect_ratio": request.aspect_ratio, + "number_of_images": request.number_of_images, + "model": request.model + } + ) + + return [TextContent( + type="text", + text=result.to_text_content() + )] + + except Exception as e: + logger.error(f"Error occurred during image regeneration: {str(e)}") + return [TextContent( + type="text", + text=f"Error occurred during image regeneration: {str(e)}" + )] + + async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent]: + """Enhanced image generation handler with preview image support""" + try: + # Log arguments safely without exposing image data + safe_args = sanitize_args_for_logging(arguments) + logger.info(f"handle_generate_image called with arguments: {safe_args}") + + # Extract and validate arguments + prompt = arguments.get("prompt") + if not prompt: + logger.error("No prompt provided") + return [TextContent( + type="text", + text="Error: Prompt is required." + )] + + seed = arguments.get("seed") + if seed is None: + logger.error("No seed provided") + return [TextContent( + type="text", + text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed." + )] + + # Create image generation request object + request = ImageGenerationRequest( + prompt=prompt, + negative_prompt=arguments.get("negative_prompt", ""), + number_of_images=arguments.get("number_of_images", 1), + seed=seed, + aspect_ratio=arguments.get("aspect_ratio", "1:1"), + model=arguments.get("model", self.config.default_model) + ) + + save_to_file = arguments.get("save_to_file", True) + + logger.info(f"Starting image generation: '{prompt[:50]}...', Seed: {seed}") + + # Generate image with timeout + try: + logger.info("Calling client.generate_image()...") + response = await asyncio.wait_for( + self.client.generate_image(request), + timeout=360.0 # 6 minute timeout + ) + logger.info(f"Image generation completed. Success: {response.success}") + except asyncio.TimeoutError: + logger.error("Image generation timed out after 6 minutes") + return [TextContent( + type="text", + text="Error: Image generation timed out after 6 minutes. Please try again." + )] + + if not response.success: + logger.error(f"Image generation failed: {response.error_message}") + return [TextContent( + type="text", + text=f"Error occurred during image generation: {response.error_message}" + )] + + logger.info(f"Generated {len(response.images_data)} images successfully") + + # Create preview images (512x512 JPEG base64) + preview_images_b64 = [] + for i, image_data in enumerate(response.images_data): + # Log original image info + image_info = get_image_info(image_data) + if image_info: + logger.info(f"Original image {i+1}: {image_info}") + + # Create preview + preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85) + if preview_b64: + preview_images_b64.append(preview_b64) + logger.info(f"Created preview image {i+1}: {len(preview_b64)} chars base64 JPEG") + else: + logger.warning(f"Failed to create preview for image {i+1}") + + # Save files if requested + saved_files = [] + if save_to_file: + logger.info("Saving files to disk...") + generation_params = { + "prompt": request.prompt, + "negative_prompt": request.negative_prompt, + "number_of_images": request.number_of_images, + "seed": request.seed, + "aspect_ratio": request.aspect_ratio, + "model": request.model, + "guidance_scale": 7.5, + "safety_filter_level": "block_only_high", + "person_generation": "allow_all", + "add_watermark": False + } + + saved_files = save_generated_images( + images_data=response.images_data, + save_directory=self.config.output_path, + seed=request.seed, + generation_params=generation_params + ) + logger.info(f"Files saved: {saved_files}") + + # Verify files were created + import os + for file_path in saved_files: + if os.path.exists(file_path): + size = os.path.getsize(file_path) + logger.info(f" ✓ Verified: {file_path} ({size} bytes)") + else: + logger.error(f" ❌ Missing: {file_path}") + + # Create enhanced result with preview images + result = ImageGenerationResult( + success=True, + message=f"✅ Images have been successfully generated!", + original_images_count=len(response.images_data), + preview_images_b64=preview_images_b64, + saved_files=saved_files, + generation_params={ + "prompt": request.prompt, + "seed": request.seed, + "aspect_ratio": request.aspect_ratio, + "number_of_images": request.number_of_images, + "negative_prompt": request.negative_prompt, + "model": request.model + } + ) + + logger.info(f"Returning response with {len(preview_images_b64)} preview images") + return [TextContent( + type="text", + text=result.to_text_content() + )] + + except Exception as e: + logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True) + return [TextContent( + type="text", + text=f"Error occurred during image generation: {str(e)}" + )] diff --git a/src/server/enhanced_mcp_server.py b/src/server/enhanced_mcp_server.py new file mode 100644 index 0000000..3e25689 --- /dev/null +++ b/src/server/enhanced_mcp_server.py @@ -0,0 +1,63 @@ +""" +Enhanced MCP Server implementation for Imagen4 with Preview Image Support +""" + +import logging +from typing import Dict, Any, List + +from mcp.server import Server +from mcp.types import Tool, TextContent + +from src.connector import Config +from src.server.models import MCPToolDefinitions +from src.server.enhanced_handlers import EnhancedToolHandlers, sanitize_args_for_logging + +logger = logging.getLogger(__name__) + + +class EnhancedImagen4MCPServer: + """Enhanced Imagen4 MCP server class with preview image support""" + + def __init__(self, config: Config): + """Initialize server""" + self.config = config + self.server = Server("imagen4-enhanced-mcp-server") + self.handlers = EnhancedToolHandlers(config) + + # Register handlers + self._register_handlers() + + def _register_handlers(self) -> None: + """Register MCP handlers""" + + @self.server.list_tools() + async def handle_list_tools() -> List[Tool]: + """Return list of available tools""" + logger.info("Listing available tools (enhanced version)") + return MCPToolDefinitions.get_all_tools() + + @self.server.call_tool() + async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: + """Handle tool calls with enhanced preview image support""" + # Log tool call safely without exposing sensitive data + safe_args = sanitize_args_for_logging(arguments) + logger.info(f"Enhanced tool called: {name} with arguments: {safe_args}") + + if name == "generate_random_seed": + return await self.handlers.handle_generate_random_seed(arguments) + elif name == "regenerate_from_json": + return await self.handlers.handle_regenerate_from_json(arguments) + elif name == "generate_image": + return await self.handlers.handle_generate_image(arguments) + else: + raise ValueError(f"Unknown tool: {name}") + + def get_server(self) -> Server: + """Return MCP server instance""" + return self.server + + +# Create a factory function for easier import +def create_enhanced_server(config: Config) -> EnhancedImagen4MCPServer: + """Factory function to create enhanced server""" + return EnhancedImagen4MCPServer(config) diff --git a/src/server/enhanced_models.py b/src/server/enhanced_models.py new file mode 100644 index 0000000..670a64f --- /dev/null +++ b/src/server/enhanced_models.py @@ -0,0 +1,71 @@ +""" +Enhanced MCP response models with preview image support +""" + +from typing import Optional, Dict, Any +from dataclasses import dataclass + + +@dataclass +class ImageGenerationResult: + """Enhanced image generation result with preview support""" + success: bool + message: str + original_images_count: int + preview_images_b64: Optional[list[str]] = None # List of base64 JPEG previews (512x512) + saved_files: Optional[list[str]] = None + generation_params: Optional[Dict[str, Any]] = None + error_message: Optional[str] = None + + def to_text_content(self) -> str: + """Convert result to text format for MCP response""" + lines = [self.message] + + if self.success and self.preview_images_b64: + lines.append(f"\n🖼️ Preview Images Generated: {len(self.preview_images_b64)} images (512x512 JPEG)") + for i, preview_b64 in enumerate(self.preview_images_b64): + lines.append(f"Preview {i+1} (base64 JPEG): {preview_b64[:50]}...({len(preview_b64)} chars)") + + if self.saved_files: + lines.append(f"\n📁 Files saved:") + for filepath in self.saved_files: + lines.append(f" - {filepath}") + + if self.generation_params: + lines.append(f"\n⚙️ Generation Parameters:") + for key, value in self.generation_params.items(): + if key == 'prompt' and len(str(value)) > 100: + lines.append(f" - {key}: {str(value)[:100]}...") + else: + lines.append(f" - {key}: {value}") + + return "\n".join(lines) + + +@dataclass +class PreviewImageResponse: + """Response containing preview images in base64 format""" + preview_images_b64: list[str] # Base64 encoded JPEG images (512x512) + original_count: int + message: str + + @classmethod + def from_image_data(cls, images_data: list[bytes], message: str = "") -> 'PreviewImageResponse': + """Create response from original PNG image data""" + from src.utils.image_utils import create_preview_image_b64 + + preview_images = [] + for i, image_data in enumerate(images_data): + preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85) + if preview_b64: + preview_images.append(preview_b64) + else: + # Fallback - use original if preview creation fails + import base64 + preview_images.append(base64.b64encode(image_data).decode('utf-8')) + + return cls( + preview_images_b64=preview_images, + original_count=len(images_data), + message=message or f"Generated {len(preview_images)} preview images" + ) diff --git a/src/server/handlers.py b/src/server/handlers.py new file mode 100644 index 0000000..ead0186 --- /dev/null +++ b/src/server/handlers.py @@ -0,0 +1,308 @@ +""" +Tool Handlers for MCP Server +""" + +import asyncio +import base64 +import json +import random +import logging +from typing import List, Dict, Any + +from mcp.types import TextContent, ImageContent + +from src.connector import Imagen4Client, Config +from src.connector.imagen4_client import ImageGenerationRequest +from src.connector.utils import save_generated_images + +logger = logging.getLogger(__name__) + + +def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]: + """Remove or truncate sensitive data from arguments for safe logging""" + safe_args = {} + for key, value in arguments.items(): + if isinstance(value, str): + # Check if it's likely base64 image data + if (key in ['data', 'image_data', 'base64'] or + (len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))): + # Truncate long image data + safe_args[key] = f"" + elif len(value) > 1000: + # Truncate any very long strings + safe_args[key] = f"{value[:100]}..." + else: + safe_args[key] = value + else: + safe_args[key] = value + return safe_args + + +class ToolHandlers: + """MCP tool handler class""" + + def __init__(self, config: Config): + """Initialize handler""" + self.config = config + self.client = Imagen4Client(config) + + async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]: + """Random seed generation handler""" + try: + random_seed = random.randint(0, 2**32 - 1) + return [TextContent( + type="text", + text=f"Generated random seed: {random_seed}" + )] + except Exception as e: + logger.error(f"Random seed generation error: {str(e)}") + return [TextContent( + type="text", + text=f"Error occurred during random seed generation: {str(e)}" + )] + + async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]: + """Image regeneration from JSON file handler""" + try: + json_file_path = arguments.get("json_file_path") + save_to_file = arguments.get("save_to_file", True) + + if not json_file_path: + return [TextContent( + type="text", + text="Error: JSON file path is required." + )] + + # Load parameters from JSON file + try: + with open(json_file_path, 'r', encoding='utf-8') as f: + params = json.load(f) + except FileNotFoundError: + return [TextContent( + type="text", + text=f"Error: Cannot find JSON file: {json_file_path}" + )] + except json.JSONDecodeError as e: + return [TextContent( + type="text", + text=f"Error: JSON parsing error: {str(e)}" + )] + + # Check required parameters + required_params = ['prompt', 'seed'] + missing_params = [p for p in required_params if p not in params] + + if missing_params: + return [TextContent( + type="text", + text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}" + )] + + # Create image generation request object + request = ImageGenerationRequest( + prompt=params.get('prompt'), + negative_prompt=params.get('negative_prompt', ''), + number_of_images=params.get('number_of_images', 1), + seed=params.get('seed'), + aspect_ratio=params.get('aspect_ratio', '1:1') + ) + + logger.info(f"Loaded parameters from JSON: {json_file_path}") + + # Generate image + response = await self.client.generate_image(request) + + if not response.success: + return [TextContent( + type="text", + text=f"Error occurred during image regeneration: {response.error_message}" + )] + + # Save files (optional) + saved_files = [] + if save_to_file: + regeneration_params = { + "prompt": request.prompt, + "negative_prompt": request.negative_prompt, + "number_of_images": request.number_of_images, + "seed": request.seed, + "aspect_ratio": request.aspect_ratio, + "regenerated_from": json_file_path, + "original_generated_at": params.get('generated_at', 'unknown') + } + + saved_files = save_generated_images( + images_data=response.images_data, + save_directory=self.config.output_path, + seed=request.seed, + generation_params=regeneration_params, + filename_prefix="imagen4_regen" + ) + + # Generate result message + message_parts = [ + f"Images have been successfully regenerated.", + f"JSON file: {json_file_path}", + f"Size: 2048x2048 PNG", + f"Count: {len(response.images_data)} images", + f"Seed: {request.seed}", + f"Prompt: {request.prompt}" + ] + + if saved_files: + message_parts.append(f"\nRegenerated files saved successfully:") + for filepath in saved_files: + message_parts.append(f"- {filepath}") + + # Generate result + result = [ + TextContent( + type="text", + text="\n".join(message_parts) + ) + ] + + # Add Base64 encoded images + for i, image_data in enumerate(response.images_data): + image_base64 = base64.b64encode(image_data).decode('utf-8') + result.append( + ImageContent( + type="image", + data=image_base64, + mimeType="image/png" + ) + ) + logger.info(f"Added image {i+1} to response: {len(image_base64)} chars base64 data") + + return result + + except Exception as e: + logger.error(f"Error occurred during image regeneration: {str(e)}") + return [TextContent( + type="text", + text=f"Error occurred during image regeneration: {str(e)}" + )] + + async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]: + """Image generation handler with synchronous processing""" + try: + # Log arguments safely without exposing image data + safe_args = sanitize_args_for_logging(arguments) + logger.info(f"handle_generate_image called with arguments: {safe_args}") + + # Extract and validate arguments + prompt = arguments.get("prompt") + if not prompt: + logger.error("No prompt provided") + return [TextContent( + type="text", + text="Error: Prompt is required." + )] + + seed = arguments.get("seed") + if seed is None: + logger.error("No seed provided") + return [TextContent( + type="text", + text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed." + )] + + # Create image generation request object + request = ImageGenerationRequest( + prompt=prompt, + negative_prompt=arguments.get("negative_prompt", ""), + number_of_images=arguments.get("number_of_images", 1), + seed=seed, + aspect_ratio=arguments.get("aspect_ratio", "1:1") + ) + + save_to_file = arguments.get("save_to_file", True) + + logger.info(f"Starting SYNCHRONOUS image generation: '{prompt[:50]}...', Seed: {seed}") + + # Generate image synchronously with longer timeout + try: + logger.info("Calling client.generate_image() synchronously...") + response = await asyncio.wait_for( + self.client.generate_image(request), + timeout=360.0 # 6 minute timeout + ) + logger.info(f"Image generation completed. Success: {response.success}") + except asyncio.TimeoutError: + logger.error("Image generation timed out after 6 minutes") + return [TextContent( + type="text", + text="Error: Image generation timed out after 6 minutes. Please try again." + )] + + if not response.success: + logger.error(f"Image generation failed: {response.error_message}") + return [TextContent( + type="text", + text=f"Error occurred during image generation: {response.error_message}" + )] + + logger.info(f"Generated {len(response.images_data)} images successfully") + + # Save files if requested + saved_files = [] + if save_to_file: + logger.info("Saving files to disk...") + generation_params = { + "prompt": request.prompt, + "negative_prompt": request.negative_prompt, + "number_of_images": request.number_of_images, + "seed": request.seed, + "aspect_ratio": request.aspect_ratio, + "guidance_scale": 7.5, + "safety_filter_level": "block_only_high", + "person_generation": "allow_all", + "add_watermark": False + } + + saved_files = save_generated_images( + images_data=response.images_data, + save_directory=self.config.output_path, + seed=request.seed, + generation_params=generation_params + ) + logger.info(f"Files saved: {saved_files}") + + # Verify files were created + import os + for file_path in saved_files: + if os.path.exists(file_path): + size = os.path.getsize(file_path) + logger.info(f" ✓ Verified: {file_path} ({size} bytes)") + else: + logger.error(f" ❌ Missing: {file_path}") + + # Generate result message + message_parts = [ + f"✅ Images have been successfully generated!", + f"Prompt: {request.prompt}", + f"Seed: {request.seed}", + f"Size: 2048x2048 PNG", + f"Count: {len(response.images_data)} images" + ] + + if saved_files: + message_parts.append(f"\n📁 Files saved successfully:") + for filepath in saved_files: + message_parts.append(f" - {filepath}") + else: + message_parts.append(f"\nℹ️ Images generated but not saved to file. Use save_to_file=true to save.") + + logger.info("Returning synchronous response") + return [TextContent( + type="text", + text="\n".join(message_parts) + )] + + except Exception as e: + logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True) + return [TextContent( + type="text", + text=f"Error occurred during image generation: {str(e)}" + )] + diff --git a/src/server/mcp_server.py b/src/server/mcp_server.py new file mode 100644 index 0000000..f2f884b --- /dev/null +++ b/src/server/mcp_server.py @@ -0,0 +1,57 @@ +""" +MCP Server implementation for Imagen4 +""" + +import logging +from typing import Dict, Any, List + +from mcp.server import Server +from mcp.types import Tool, TextContent, ImageContent + +from src.connector import Config +from src.server.models import MCPToolDefinitions +from src.server.handlers import ToolHandlers, sanitize_args_for_logging + +logger = logging.getLogger(__name__) + + +class Imagen4MCPServer: + """Imagen4 MCP server class""" + + def __init__(self, config: Config): + """Initialize server""" + self.config = config + self.server = Server("imagen4-mcp-server") + self.handlers = ToolHandlers(config) + + # Register handlers + self._register_handlers() + + def _register_handlers(self) -> None: + """Register MCP handlers""" + + @self.server.list_tools() + async def handle_list_tools() -> List[Tool]: + """Return list of available tools""" + logger.info("Listing available tools") + return MCPToolDefinitions.get_all_tools() + + @self.server.call_tool() + async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]: + """Handle tool calls""" + # Log tool call safely without exposing sensitive data + safe_args = sanitize_args_for_logging(arguments) + logger.info(f"Tool called: {name} with arguments: {safe_args}") + + if name == "generate_random_seed": + return await self.handlers.handle_generate_random_seed(arguments) + elif name == "regenerate_from_json": + return await self.handlers.handle_regenerate_from_json(arguments) + elif name == "generate_image": + return await self.handlers.handle_generate_image(arguments) + else: + raise ValueError(f"Unknown tool: {name}") + + def get_server(self) -> Server: + """Return MCP server instance""" + return self.server diff --git a/src/server/minimal_handler.py b/src/server/minimal_handler.py new file mode 100644 index 0000000..d3a3bd8 --- /dev/null +++ b/src/server/minimal_handler.py @@ -0,0 +1,39 @@ +""" +Minimal Task Result Handler for Emergency Use +""" + +import logging +from typing import List, Dict, Any + +from mcp.types import TextContent + +logger = logging.getLogger(__name__) + + +async def minimal_get_task_result(task_manager, task_id: str) -> List[TextContent]: + """Absolutely minimal task result handler""" + try: + logger.info(f"Minimal handler for task: {task_id}") + + # Just return the task status for now + status = task_manager.get_task_status(task_id) + + if status is None: + return [TextContent( + type="text", + text=f"❌ Task '{task_id}' not found." + )] + + return [TextContent( + type="text", + text=f"📋 Task '{task_id}' Status: {status.value}\n" + f"Note: This is a minimal emergency handler.\n" + f"If you see this message, the original handler has a serious issue." + )] + + except Exception as e: + logger.error(f"Even minimal handler failed: {str(e)}", exc_info=True) + return [TextContent( + type="text", + text=f"Complete failure: {str(e)}" + )] diff --git a/src/server/models.py b/src/server/models.py new file mode 100644 index 0000000..ddd7662 --- /dev/null +++ b/src/server/models.py @@ -0,0 +1,132 @@ +""" +MCP Tool Models and Definitions +""" + +from mcp.types import Tool + + +class MCPToolDefinitions: + """MCP tool definition class""" + + @staticmethod + def get_generate_image_tool() -> Tool: + """Image generation tool definition""" + return Tool( + name="generate_image", + description="Generate 2048x2048 PNG images from text prompts using Google Imagen 4 AI.", + inputSchema={ + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "Text prompt for image generation (Korean or English). Maximum 480 tokens allowed.", + "maxLength": 2000 # Approximate character limit for safety + }, + "negative_prompt": { + "type": "string", + "description": "Negative prompt specifying elements not to generate (optional). Maximum 240 tokens allowed.", + "default": "", + "maxLength": 1000 # Approximate character limit for safety + }, + "number_of_images": { + "type": "integer", + "description": "Number of images to generate (1 or 2 only)", + "enum": [1, 2], + "default": 1 + }, + "seed": { + "type": "integer", + "description": "Seed value for reproducible results (required, 0 ~ 4294967295 range)", + "minimum": 0, + "maximum": 4294967295 + }, + "aspect_ratio": { + "type": "string", + "description": "Image aspect ratio", + "enum": ["1:1", "9:16", "16:9", "3:4", "4:3"], + "default": "1:1" + }, + "save_to_file": { + "type": "boolean", + "description": "Whether to save generated images to files (default: true)", + "default": True + }, + "model": { + "type": "string", + "description": "Imagen 4 model to use for generation", + "enum": ["imagen-4.0-generate-001", "imagen-4.0-ultra-generate-001"], + "default": "imagen-4.0-generate-001" + } + }, + "required": ["prompt", "seed"] + } + ) + + @staticmethod + def get_regenerate_from_json_tool() -> Tool: + """Regenerate from JSON tool definition""" + return Tool( + name="regenerate_from_json", + description="Read parameters from JSON file and regenerate images with the same settings.", + inputSchema={ + "type": "object", + "properties": { + "json_file_path": { + "type": "string", + "description": "Path to JSON file containing saved parameters" + }, + "save_to_file": { + "type": "boolean", + "description": "Whether to save regenerated images to files (default: true)", + "default": True + } + }, + "required": ["json_file_path"] + } + ) + + @staticmethod + def get_generate_random_seed_tool() -> Tool: + """Random seed generation tool definition""" + return Tool( + name="generate_random_seed", + description="Generate random seed value for image generation.", + inputSchema={ + "type": "object", + "properties": {}, + "required": [] + } + ) + + @staticmethod + def get_validate_prompt_tool() -> Tool: + """Prompt validation tool definition""" + return Tool( + name="validate_prompt", + description="Validate prompt length and estimate token count before image generation.", + inputSchema={ + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "Text prompt to validate" + }, + "negative_prompt": { + "type": "string", + "description": "Negative prompt to validate (optional)", + "default": "" + } + }, + "required": ["prompt"] + } + ) + + @classmethod + def get_all_tools(cls) -> list[Tool]: + """Return all tool definitions""" + return [ + cls.get_generate_image_tool(), + cls.get_regenerate_from_json_tool(), + cls.get_generate_random_seed_tool(), + cls.get_validate_prompt_tool() + ] diff --git a/src/server/safe_result_handler.py b/src/server/safe_result_handler.py new file mode 100644 index 0000000..2c6e727 --- /dev/null +++ b/src/server/safe_result_handler.py @@ -0,0 +1,80 @@ +""" +Safe Get Task Result Handler - Debug Version +""" + +import logging +from typing import List, Dict, Any + +from mcp.types import TextContent + +logger = logging.getLogger(__name__) + + +async def safe_get_task_result(task_manager, task_id: str) -> List[TextContent]: + """Minimal safe version of get_task_result without image processing""" + try: + logger.info(f"Safe get_task_result called for: {task_id}") + + # Step 1: Try to get raw result + try: + raw_result = task_manager.worker_pool.get_task_result(task_id) + logger.info(f"Raw result type: {type(raw_result)}") + except Exception as e: + logger.error(f"Failed to get raw result: {str(e)}") + return [TextContent( + type="text", + text=f"Error accessing task data: {str(e)}" + )] + + if not raw_result: + return [TextContent( + type="text", + text=f"❌ Task '{task_id}' not found." + )] + + # Step 2: Check status safely + try: + status = raw_result.status + logger.info(f"Task status: {status}") + except Exception as e: + logger.error(f"Failed to get status: {str(e)}") + return [TextContent( + type="text", + text=f"Error reading task status: {str(e)}" + )] + + # Step 3: Return basic info without touching result.result + try: + duration = raw_result.duration if hasattr(raw_result, 'duration') else None + created = raw_result.created_at.isoformat() if hasattr(raw_result, 'created_at') and raw_result.created_at else "unknown" + + message = f"✅ Task '{task_id}' Status Report:\n" + message += f"Status: {status.value}\n" + message += f"Created: {created}\n" + + if duration: + message += f"Duration: {duration:.2f}s\n" + + if hasattr(raw_result, 'error') and raw_result.error: + message += f"Error: {raw_result.error}\n" + + message += "\nNote: This is a safe diagnostic version." + + return [TextContent( + type="text", + text=message + )] + + except Exception as e: + logger.error(f"Failed to format response: {str(e)}") + return [TextContent( + type="text", + text=f"Task exists but data formatting failed: {str(e)}" + )] + + except Exception as e: + logger.error(f"Critical error in safe_get_task_result: {str(e)}", exc_info=True) + return [TextContent( + type="text", + text=f"Critical error: {str(e)}" + )] diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..67b9db6 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1 @@ +# Utils package \ No newline at end of file diff --git a/src/utils/image_utils.py b/src/utils/image_utils.py new file mode 100644 index 0000000..1364cb3 --- /dev/null +++ b/src/utils/image_utils.py @@ -0,0 +1,107 @@ +""" +Image processing utilities for preview generation +""" + +import base64 +import io +import logging +from typing import Optional +from PIL import Image + +logger = logging.getLogger(__name__) + + +def create_preview_image_b64(image_data: bytes, target_size: int = 512, quality: int = 85) -> Optional[str]: + """ + Convert PNG image data to JPEG preview with specified size and return as base64 + + Args: + image_data: Original PNG image data in bytes + target_size: Target size for the preview (default: 512x512) + quality: JPEG quality (1-100, default: 85) + + Returns: + Base64 encoded JPEG image string, or None if conversion fails + """ + try: + # Open image from bytes + with Image.open(io.BytesIO(image_data)) as img: + # Convert to RGB if necessary (PNG might have alpha channel) + if img.mode in ('RGBA', 'LA', 'P'): + # Create white background for transparent images + background = Image.new('RGB', img.size, (255, 255, 255)) + if img.mode == 'P': + img = img.convert('RGBA') + background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None) + img = background + elif img.mode != 'RGB': + img = img.convert('RGB') + + # Resize to target size maintaining aspect ratio + img.thumbnail((target_size, target_size), Image.Resampling.LANCZOS) + + # If image is smaller than target size, pad it to exact size + if img.size != (target_size, target_size): + # Create new image with target size and white background + new_img = Image.new('RGB', (target_size, target_size), (255, 255, 255)) + # Center the resized image + x = (target_size - img.size[0]) // 2 + y = (target_size - img.size[1]) // 2 + new_img.paste(img, (x, y)) + img = new_img + + # Convert to JPEG and encode as base64 + output_buffer = io.BytesIO() + img.save(output_buffer, format='JPEG', quality=quality, optimize=True) + jpeg_data = output_buffer.getvalue() + + # Encode to base64 + base64_data = base64.b64encode(jpeg_data).decode('utf-8') + + logger.info(f"Preview image created: {target_size}x{target_size} JPEG, {len(jpeg_data)} bytes, base64 length: {len(base64_data)}") + return base64_data + + except Exception as e: + logger.error(f"Failed to create preview image: {str(e)}") + return None + + +def validate_image_data(image_data: bytes) -> bool: + """ + Validate if image data is a valid image + + Args: + image_data: Image data in bytes + + Returns: + True if valid image, False otherwise + """ + try: + with Image.open(io.BytesIO(image_data)) as img: + img.verify() # Verify image integrity + return True + except Exception: + return False + + +def get_image_info(image_data: bytes) -> Optional[dict]: + """ + Get image information + + Args: + image_data: Image data in bytes + + Returns: + Dictionary with image info (format, size, mode) or None if invalid + """ + try: + with Image.open(io.BytesIO(image_data)) as img: + return { + 'format': img.format, + 'size': img.size, + 'mode': img.mode, + 'bytes': len(image_data) + } + except Exception as e: + logger.error(f"Failed to get image info: {str(e)}") + return None diff --git a/src/utils/token_utils.py b/src/utils/token_utils.py new file mode 100644 index 0000000..1f63890 --- /dev/null +++ b/src/utils/token_utils.py @@ -0,0 +1,200 @@ +""" +Token counting utilities for Imagen4 prompts +""" + +import re +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + +# 기본 토큰 추정 상수 +AVERAGE_CHARS_PER_TOKEN = 4 # 영어 기준 평균값 +KOREAN_CHARS_PER_TOKEN = 2 # 한글 기준 평균값 +MAX_PROMPT_TOKENS = 480 # 최대 프롬프트 토큰 수 + + +def estimate_token_count(text: str) -> int: + """ + 텍스트의 토큰 수를 추정합니다. + + 정확한 토큰 계산을 위해서는 실제 토크나이저가 필요하지만, + API 호출 전 빠른 검증을 위해 추정값을 사용합니다. + + Args: + text: 토큰 수를 계산할 텍스트 + + Returns: + int: 추정된 토큰 수 + """ + if not text: + return 0 + + # 텍스트 정리 + text = text.strip() + if not text: + return 0 + + # 한글과 영어 문자 분리 + korean_chars = len(re.findall(r'[가-힣]', text)) + english_chars = len(re.findall(r'[a-zA-Z]', text)) + other_chars = len(text) - korean_chars - english_chars + + # 토큰 수 추정 + korean_tokens = korean_chars / KOREAN_CHARS_PER_TOKEN + english_tokens = english_chars / AVERAGE_CHARS_PER_TOKEN + other_tokens = other_chars / AVERAGE_CHARS_PER_TOKEN + + estimated_tokens = int(korean_tokens + english_tokens + other_tokens) + + # 최소 1토큰은 보장 + return max(1, estimated_tokens) + + +def validate_prompt_length(prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> tuple[bool, int, Optional[str]]: + """ + 프롬프트 길이를 검증합니다. + + Args: + prompt: 검증할 프롬프트 + max_tokens: 최대 허용 토큰 수 + + Returns: + tuple: (유효성, 토큰 수, 오류 메시지) + """ + if not prompt: + return False, 0, "프롬프트가 비어있습니다." + + token_count = estimate_token_count(prompt) + + if token_count > max_tokens: + error_msg = ( + f"프롬프트가 너무 깁니다. " + f"현재: {token_count}토큰, 최대: {max_tokens}토큰. " + f"프롬프트를 {token_count - max_tokens}토큰 줄여주세요." + ) + return False, token_count, error_msg + + return True, token_count, None + + +def truncate_prompt(prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> str: + """ + 프롬프트를 지정된 토큰 수로 자릅니다. + + Args: + prompt: 자를 프롬프트 + max_tokens: 최대 토큰 수 + + Returns: + str: 잘린 프롬프트 + """ + if not prompt: + return "" + + current_tokens = estimate_token_count(prompt) + if current_tokens <= max_tokens: + return prompt + + # 대략적인 비율로 텍스트 자르기 + ratio = max_tokens / current_tokens + target_length = int(len(prompt) * ratio * 0.9) # 여유분 10% + + truncated = prompt[:target_length] + + # 단어/문장 경계에서 자르기 + if len(truncated) < len(prompt): + # 마지막 완전한 단어까지만 유지 + last_space = truncated.rfind(' ') + last_korean = truncated.rfind('다') # 한글 어미 + last_punct = max(truncated.rfind('.'), truncated.rfind(','), truncated.rfind('!')) + + cut_point = max(last_space, last_korean, last_punct) + if cut_point > target_length * 0.8: # 너무 많이 잘리지 않도록 + truncated = truncated[:cut_point] + + # 최종 검증 + final_tokens = estimate_token_count(truncated) + if final_tokens > max_tokens: + # 강제로 문자 단위로 자르기 + chars_per_token = len(truncated) / final_tokens + target_chars = int(max_tokens * chars_per_token * 0.95) + truncated = truncated[:target_chars] + + return truncated.strip() + + +def get_prompt_stats(prompt: str) -> dict: + """ + 프롬프트 통계 정보를 반환합니다. + + Args: + prompt: 분석할 프롬프트 + + Returns: + dict: 프롬프트 통계 + """ + if not prompt: + return { + "character_count": 0, + "estimated_tokens": 0, + "korean_chars": 0, + "english_chars": 0, + "other_chars": 0, + "is_valid": False, + "remaining_tokens": MAX_PROMPT_TOKENS + } + + char_count = len(prompt) + korean_chars = len(re.findall(r'[가-힣]', prompt)) + english_chars = len(re.findall(r'[a-zA-Z]', prompt)) + other_chars = char_count - korean_chars - english_chars + estimated_tokens = estimate_token_count(prompt) + is_valid = estimated_tokens <= MAX_PROMPT_TOKENS + remaining_tokens = MAX_PROMPT_TOKENS - estimated_tokens + + return { + "character_count": char_count, + "estimated_tokens": estimated_tokens, + "korean_chars": korean_chars, + "english_chars": english_chars, + "other_chars": other_chars, + "is_valid": is_valid, + "remaining_tokens": remaining_tokens, + "max_tokens": MAX_PROMPT_TOKENS + } + + +# 실제 토크나이저 사용 시 대체할 수 있는 인터페이스 +class TokenCounter: + """토큰 카운터 인터페이스""" + + def __init__(self, tokenizer_name: Optional[str] = None): + """ + 토큰 카운터 초기화 + + Args: + tokenizer_name: 사용할 토크나이저 이름 (향후 확장용) + """ + self.tokenizer_name = tokenizer_name or "estimate" + logger.info(f"토큰 카운터 초기화: {self.tokenizer_name}") + + def count_tokens(self, text: str) -> int: + """텍스트의 토큰 수 계산""" + if self.tokenizer_name == "estimate": + return estimate_token_count(text) + else: + # 향후 실제 토크나이저 구현 + raise NotImplementedError(f"토크나이저 '{self.tokenizer_name}'는 아직 구현되지 않았습니다.") + + def validate_prompt(self, prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> tuple[bool, int, Optional[str]]: + """프롬프트 검증""" + return validate_prompt_length(prompt, max_tokens) + + def truncate_prompt(self, prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> str: + """프롬프트 자르기""" + return truncate_prompt(prompt, max_tokens) + + +# 전역 토큰 카운터 인스턴스 +default_token_counter = TokenCounter() diff --git a/test_main.py b/test_main.py new file mode 100644 index 0000000..2cedf43 --- /dev/null +++ b/test_main.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +""" +Simple test for the consolidated main.py +""" + +import sys +import os +import io +from PIL import Image +import base64 + +# Add current directory to PYTHONPATH +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, current_dir) + +def create_test_image(): + """Create a simple test image""" + img = Image.new('RGB', (2048, 2048), color='lightblue') + from PIL import ImageDraw + draw = ImageDraw.Draw(img) + draw.rectangle([500, 500, 1500, 1500], fill='red') + + buffer = io.BytesIO() + img.save(buffer, format='PNG') + return buffer.getvalue() + +def test_preview_function(): + """Test the preview image creation""" + # Import the function from main.py + from main import create_preview_image_b64, get_image_info + + # Create test image + test_png = create_test_image() + print(f"Created test PNG: {len(test_png)} bytes") + + # Get image info + info = get_image_info(test_png) + print(f"Image info: {info}") + + # Create preview + preview_b64 = create_preview_image_b64(test_png, target_size=512, quality=85) + + if preview_b64: + print(f"✅ Preview created: {len(preview_b64)} chars") + + # Save preview to verify + preview_bytes = base64.b64decode(preview_b64) + with open('test_preview.jpg', 'wb') as f: + f.write(preview_bytes) + + ratio = (len(preview_bytes) / len(test_png)) * 100 + print(f"Compression: {ratio:.1f}% (original: {len(test_png)}, preview: {len(preview_bytes)})") + print("Preview saved as test_preview.jpg") + return True + else: + print("❌ Failed to create preview") + return False + +def test_imports(): + """Test all necessary imports""" + try: + from main import ( + Imagen4MCPServer, + Imagen4ToolHandlers, + ImageGenerationResult, + get_tools + ) + print("✅ All classes imported successfully") + + # Test tool definitions + tools = get_tools() + print(f"✅ {len(tools)} tools defined") + + return True + except Exception as e: + print(f"❌ Import error: {e}") + return False + +def main(): + """Main test function""" + print("=== Testing Consolidated main.py ===\n") + + success = True + + # Test imports + if not test_imports(): + success = False + + print() + + # Test preview function + if not test_preview_function(): + success = False + + print(f"\n=== Test Results ===") + if success: + print("✅ All tests passed!") + print("\nTo run the server:") + print("1. Make sure .env is configured") + print("2. Run: python main.py") + print("3. Or use: run.bat") + else: + print("❌ Some tests failed") + + return 0 if success else 1 + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test_connector.py b/tests/test_connector.py new file mode 100644 index 0000000..f95dcba --- /dev/null +++ b/tests/test_connector.py @@ -0,0 +1,178 @@ +""" +Tests for Imagen4 Connector +""" + +import pytest +import asyncio +from unittest.mock import Mock, patch + +from src.connector import Config, Imagen4Client +from src.connector.imagen4_client import ImageGenerationRequest + + +class TestConfig: + """Config 클래스 테스트""" + + def test_config_creation(self): + """설정 생성 테스트""" + config = Config( + project_id="test-project", + location="us-central1", + output_path="./test_images" + ) + + assert config.project_id == "test-project" + assert config.location == "us-central1" + assert config.output_path == "./test_images" + + def test_config_validation(self): + """설정 유효성 검사 테스트""" + # 유효한 설정 + valid_config = Config(project_id="test-project") + assert valid_config.validate() is True + + # 무효한 설정 + invalid_config = Config(project_id="") + assert invalid_config.validate() is False + + @patch.dict('os.environ', { + 'GOOGLE_CLOUD_PROJECT_ID': 'test-project', + 'GOOGLE_CLOUD_LOCATION': 'us-west1', + 'GENERATED_IMAGES_PATH': './custom_path' + }) + def test_config_from_env(self): + """환경 변수에서 설정 로드 테스트""" + config = Config.from_env() + + assert config.project_id == "test-project" + assert config.location == "us-west1" + assert config.output_path == "./custom_path" + + @patch.dict('os.environ', {}, clear=True) + def test_config_from_env_missing_project_id(self): + """필수 환경 변수 누락 테스트""" + with pytest.raises(ValueError, match="GOOGLE_CLOUD_PROJECT_ID environment variable is required"): + Config.from_env() + + +class TestImageGenerationRequest: + """ImageGenerationRequest 테스트""" + + def test_valid_request(self): + """유효한 요청 테스트""" + request = ImageGenerationRequest( + prompt="A beautiful landscape", + seed=12345 + ) + + # 유효성 검사가 예외 없이 완료되어야 함 + request.validate() + + def test_invalid_prompt(self): + """무효한 프롬프트 테스트""" + request = ImageGenerationRequest( + prompt="", + seed=12345 + ) + + with pytest.raises(ValueError, match="Prompt is required"): + request.validate() + + def test_invalid_seed(self): + """무효한 시드값 테스트""" + request = ImageGenerationRequest( + prompt="Test prompt", + seed=-1 + ) + + with pytest.raises(ValueError, match="Seed value must be an integer between 0 and 4294967295"): + request.validate() + + def test_invalid_number_of_images(self): + """무효한 이미지 개수 테스트""" + request = ImageGenerationRequest( + prompt="Test prompt", + seed=12345, + number_of_images=3 + ) + + with pytest.raises(ValueError, match="Number of images must be 1 or 2 only"): + request.validate() + + +class TestImagen4Client: + """Imagen4Client 테스트""" + + def test_client_initialization(self): + """클라이언트 초기화 테스트""" + config = Config(project_id="test-project") + + with patch('src.connector.imagen4_client.genai.Client'): + client = Imagen4Client(config) + assert client.config == config + + def test_client_invalid_config(self): + """무효한 설정으로 클라이언트 초기화 테스트""" + invalid_config = Config(project_id="") + + with pytest.raises(ValueError, match="Invalid configuration"): + Imagen4Client(invalid_config) + + @pytest.mark.asyncio + async def test_generate_image_success(self): + """이미지 생성 성공 테스트""" + config = Config(project_id="test-project") + + # Create mock response + mock_image_data = b"fake_image_data" + mock_response = Mock() + mock_response.generated_images = [Mock()] + mock_response.generated_images[0].image = Mock() + mock_response.generated_images[0].image.image_bytes = mock_image_data + + with patch('src.connector.imagen4_client.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + mock_client.models.generate_images = Mock(return_value=mock_response) + + client = Imagen4Client(config) + + request = ImageGenerationRequest( + prompt="Test prompt", + seed=12345 + ) + + with patch('asyncio.to_thread', return_value=mock_response): + response = await client.generate_image(request) + + assert response.success is True + assert len(response.images_data) == 1 + assert response.images_data[0] == mock_image_data + + @pytest.mark.asyncio + async def test_generate_image_failure(self): + """Image generation failure test""" + config = Config(project_id="test-project") + + with patch('src.connector.imagen4_client.genai.Client') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + client = Imagen4Client(config) + + request = ImageGenerationRequest( + prompt="Test prompt", + seed=12345 + ) + + # Simulate API call failure + with patch('asyncio.to_thread', side_effect=Exception("API Error")): + response = await client.generate_image(request) + + assert response.success is False + assert "API Error" in response.error_message + assert len(response.images_data) == 0 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..2c488ad --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,136 @@ +""" +Tests for Imagen4 Server +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch + +from src.connector import Config +from src.server import Imagen4MCPServer, ToolHandlers, MCPToolDefinitions +from src.server.models import MCPToolDefinitions + + +class TestMCPToolDefinitions: + """MCPToolDefinitions tests""" + + def test_get_generate_image_tool(self): + """Image generation tool definition test""" + tool = MCPToolDefinitions.get_generate_image_tool() + + assert tool.name == "generate_image" + assert "Google Imagen 4" in tool.description + assert "prompt" in tool.inputSchema["properties"] + assert "seed" in tool.inputSchema["required"] + + def test_get_regenerate_from_json_tool(self): + """JSON regeneration tool definition test""" + tool = MCPToolDefinitions.get_regenerate_from_json_tool() + + assert tool.name == "regenerate_from_json" + assert "JSON file" in tool.description + assert "json_file_path" in tool.inputSchema["properties"] + + def test_get_generate_random_seed_tool(self): + """Random seed generation tool definition test""" + tool = MCPToolDefinitions.get_generate_random_seed_tool() + + assert tool.name == "generate_random_seed" + assert "random seed" in tool.description + + def test_get_all_tools(self): + """All tools return test""" + tools = MCPToolDefinitions.get_all_tools() + + assert len(tools) == 3 + tool_names = [tool.name for tool in tools] + assert "generate_image" in tool_names + assert "regenerate_from_json" in tool_names + assert "generate_random_seed" in tool_names + + +class TestToolHandlers: + """ToolHandlers tests""" + + @pytest.fixture + def config(self): + """Test configuration""" + return Config(project_id="test-project") + + @pytest.fixture + def handlers(self, config): + """Test handlers""" + with patch('src.server.handlers.Imagen4Client'): + return ToolHandlers(config) + + @pytest.mark.asyncio + async def test_handle_generate_random_seed(self, handlers): + """Random seed generation handler test""" + result = await handlers.handle_generate_random_seed({}) + + assert len(result) == 1 + assert result[0].type == "text" + assert "Generated random seed:" in result[0].text + + @pytest.mark.asyncio + async def test_handle_generate_image_missing_prompt(self, handlers): + """Missing prompt test""" + arguments = {"seed": 12345} + + result = await handlers.handle_generate_image(arguments) + + assert len(result) == 1 + assert result[0].type == "text" + assert "Prompt is required" in result[0].text + + @pytest.mark.asyncio + async def test_handle_generate_image_missing_seed(self, handlers): + """Missing seed value test""" + arguments = {"prompt": "Test prompt"} + + result = await handlers.handle_generate_image(arguments) + + assert len(result) == 1 + assert result[0].type == "text" + assert "Seed value is required" in result[0].text + + @pytest.mark.asyncio + async def test_handle_regenerate_from_json_missing_path(self, handlers): + """Missing JSON file path test""" + arguments = {} + + result = await handlers.handle_regenerate_from_json(arguments) + + assert len(result) == 1 + assert result[0].type == "text" + assert "JSON file path is required" in result[0].text + + +class TestImagen4MCPServer: + """Imagen4MCPServer tests""" + + @pytest.fixture + def config(self): + """Test configuration""" + return Config(project_id="test-project") + + def test_server_initialization(self, config): + """Server initialization test""" + with patch('src.server.mcp_server.ToolHandlers'): + server = Imagen4MCPServer(config) + + assert server.config == config + assert server.server is not None + assert server.handlers is not None + + def test_get_server(self, config): + """Server instance return test""" + with patch('src.server.mcp_server.ToolHandlers'): + mcp_server = Imagen4MCPServer(config) + server = mcp_server.get_server() + + assert server is not None + assert server.name == "imagen4-mcp-server" + + +if __name__ == "__main__": + pytest.main([__file__])