imagen4 mcp server, mcp connector implementation

This commit is contained in:
2025-08-22 20:03:38 +09:00
commit dcf2305e4b
31 changed files with 3984 additions and 0 deletions

13
.env.example Normal file
View File

@@ -0,0 +1,13 @@
# Google Cloud 설정
GOOGLE_CLOUD_PROJECT_ID=your-project-id
GOOGLE_CLOUD_LOCATION=us-central1
# 이미지 저장 설정
GENERATED_IMAGES_PATH=./generated_images
# Imagen4 모델 설정
# IMAGEN4_DEFAULT_MODEL=imagen-4.0-generate-001 # 기본 모델
# IMAGEN4_DEFAULT_MODEL=imagen-4.0-ultra-generate-001 # Ultra 모델 (더 높은 품질)
# 선택적 설정
# LOG_LEVEL=INFO

119
.gitignore vendored Normal file
View File

@@ -0,0 +1,119 @@
# 바이트 컴파일된 파일 / 최적화된 파일 / DLL 파일
__pycache__/
*.py[cod]
*$py.class
# C 확장
*.so
# 배포
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
*.manifest
*.spec
# 단위 테스트 / 커버리지 보고서
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
Pipfile.lock
# PEP 582
__pypackages__/
# Celery
celerybeat-schedule
celerybeat.pid
# SageMath
*.sage.py
# 환경
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder IDE
.spyderproject
.spyproject
# Rope 프로젝트
.ropeproject
# mkdocs
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre
.pyre/
# Google Cloud 인증 키
*.json
!claude_desktop_config.json
# 로그 파일
*.log
# 생성된 이미지
generated_images/
*.png
*.jpg
*.jpeg
# VS Code
.vscode/
# PyCharm
.idea/
# 임시 파일들
temp_*
cleanup*

204
README.md Normal file
View File

@@ -0,0 +1,204 @@
# Imagen4 MCP Server with Preview Images
Google Imagen 4를 사용한 AI 이미지 생성 MCP 서버입니다. 하나의 `main.py` 파일로 모든 기능이 통합되어 있습니다.
## 🚀 주요 기능
- **고품질 이미지 생성**: 2048x2048 PNG 이미지 생성
- **미리보기 이미지**: 512x512 JPEG 미리보기를 base64로 제공
- **파일 저장**: 원본 PNG + 메타데이터 JSON 저장
- **재생성**: JSON 파일로부터 동일한 설정으로 재생성
- **랜덤 시드**: 재현 가능한 결과를 위한 시드 생성
## 📦 설치
### 1. 의존성 설치
```bash
pip install -r requirements.txt
```
필요한 패키지:
- `google-genai>=0.2.0` - Google Imagen 4 API
- `mcp>=1.0.0` - MCP 프로토콜
- `python-dotenv>=1.0.0` - 환경변수 관리
- `Pillow>=10.0.0` - 이미지 처리
### 2. 환경 설정
`.env` 파일을 생성하고 Google Cloud 자격 증명을 설정:
```env
GOOGLE_APPLICATION_CREDENTIALS=path/to/your/service-account-key.json
PROJECT_ID=your-gcp-project-id
LOCATION=us-central1
OUTPUT_PATH=./generated_images
```
## 🎮 사용법
### 서버 실행
#### Windows
```bash
run.bat
```
#### 직접 실행
```bash
python main.py
```
### Claude Desktop 설정
`claude_desktop_config.json` 내용을 Claude Desktop 설정에 추가:
```json
{
"mcpServers": {
"imagen4": {
"command": "python",
"args": ["main.py"],
"cwd": "D:\\Project\\little-fairy\\imagen4"
}
}
}
```
## 🛠️ 사용 가능한 도구
### 1. generate_image
고품질 이미지를 생성하고 미리보기를 제공합니다.
**파라미터:**
- `prompt` (필수): 이미지 생성 프롬프트
- `seed` (필수): 재현 가능한 결과를 위한 시드 값
- `negative_prompt` (선택): 제외할 요소 지정
- `number_of_images` (선택): 생성할 이미지 수 (1-2)
- `aspect_ratio` (선택): 종횡비 ("1:1", "9:16", "16:9", "3:4", "4:3")
- `save_to_file` (선택): 파일 저장 여부
**응답 예시:**
```
✅ Images have been successfully generated!
🖼️ Preview Images Generated: 1 images (512x512 JPEG)
Preview 1 (base64 JPEG): /9j/4AAQSkZJRgABAQAAAQABAAD/2wBD...(67234 chars)
📁 Files saved:
- ./generated_images/imagen4_20250821_143052_seed_1234567890.png
- ./generated_images/imagen4_20250821_143052_seed_1234567890.json
⚙️ Generation Parameters:
- prompt: A beautiful sunset over mountains
- seed: 1234567890
- aspect_ratio: 1:1
- number_of_images: 1
```
### 2. generate_random_seed
이미지 생성용 랜덤 시드를 생성합니다.
### 3. regenerate_from_json
저장된 JSON 파라미터 파일로부터 이미지를 재생성합니다.
## 🖼️ 미리보기 이미지 특징
### 변환 과정
1. **원본**: 2048x2048 PNG (약 8-15MB)
2. **변환**: 512x512 JPEG (약 50-200KB)
3. **최적화**: 품질 85, 최적화 활성화
4. **인코딩**: Base64 문자열
### 장점
- **95% 크기 감소**: 빠른 전송 및 표시
- **즉시 사용**: 웹 환경에서 바로 표시 가능
- **투명도 처리**: PNG 투명도를 흰색 배경으로 변환
- **종횡비 유지**: 원본 비율을 유지하면서 크기 조정
## 🧪 테스트
기능을 테스트하려면:
```bash
python test_main.py
```
이 스크립트는 다음을 확인합니다:
- 모든 클래스 및 함수 import
- 이미지 처리 기능
- 미리보기 생성 기능
## 📁 프로젝트 구조
```
imagen4/
├── main.py # 통합된 메인 서버 파일
├── run.bat # Windows 실행 파일
├── test_main.py # 테스트 스크립트
├── claude_desktop_config.json # Claude Desktop 설정
├── requirements.txt # 의존성 목록
├── .env # 환경 변수 (사용자가 생성)
├── generated_images/ # 생성된 이미지 저장소
└── src/
└── connector/ # Google API 연결 모듈
├── __init__.py
├── config.py
├── imagen4_client.py
└── utils.py
```
## 🔧 기술적 세부사항
### 이미지 처리
- **PIL (Pillow)** 사용으로 안정적인 이미지 변환
- **LANCZOS 리샘플링**으로 고품질 크기 조정
- **투명도 자동 처리**: RGBA → RGB 변환 시 흰색 배경 적용
- **중앙 정렬**: 목표 크기보다 작은 이미지는 중앙에 배치
### MCP 프로토콜
- **버전 3.0.0**: 미리보기 이미지 지원
- **비동기 처리**: asyncio 기반 안정적인 비동기 작업
- **오류 처리**: 상세한 로깅 및 사용자 친화적 오류 메시지
- **타임아웃**: 6분 타임아웃으로 안전한 작업 보장
## 🐛 문제 해결
### 일반적인 문제
1. **Pillow 설치 오류**
```bash
pip install --upgrade pip
pip install Pillow
```
2. **Google Cloud 인증 오류**
- 서비스 계정 키 파일 경로 확인
- PROJECT_ID가 올바른지 확인
- Imagen API가 활성화되어 있는지 확인
3. **메모리 부족**
- 이미지 수를 1개로 제한
- 시스템 메모리 확인 (최소 8GB 권장)
### 로그 확인
서버 실행 시 자세한 로그가 stderr로 출력됩니다:
- 이미지 생성 진행 상황
- 미리보기 생성 과정
- 파일 저장 상태
- 오류 및 경고 메시지
## 📄 라이센스
MIT License - 자세한 내용은 LICENSE 파일을 참조하세요.
## 🎯 요약
이 통합된 `main.py`는 다음과 같은 이점을 제공합니다:
- **단순성**: 하나의 파일로 모든 기능 제공
- **효율성**: 512x512 JPEG 미리보기로 빠른 응답
- **호환성**: 기존 MCP 클라이언트와 완벽 호환
- **확장성**: 필요에 따라 쉽게 기능 추가 가능
Google Imagen 4의 강력한 이미지 생성 기능을 MCP 프로토콜을 통해 효율적으로 활용할 수 있습니다!

View File

@@ -0,0 +1,12 @@
{
"mcpServers": {
"imagen4": {
"command": "python",
"args": ["main.py"],
"cwd": "D:\\Project\\imagen4",
"env": {
"PYTHONPATH": "D:\\Project\\imagen4"
}
}
}
}

706
main.py Normal file
View File

@@ -0,0 +1,706 @@
#!/usr/bin/env python3
"""
Imagen4 MCP Server with Preview Image Support
Google Imagen 4를 사용한 AI 이미지 생성 MCP 서버
- 2048x2048 PNG 원본 이미지 생성
- 512x512 JPEG 미리보기 이미지 base64 제공
- 향상된 응답 형식
Run from imagen4 root directory
"""
import asyncio
import base64
import json
import random
import logging
import sys
import os
import io
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
# Load environment variables
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
print("Warning: python-dotenv not installed", file=sys.stderr)
# Add current directory to PYTHONPATH
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
# MCP imports
try:
from mcp.server.models import InitializationOptions
from mcp.server import NotificationOptions
from mcp.server.stdio import stdio_server
from mcp.server import Server
from mcp.types import Tool, TextContent
except ImportError as e:
print(f"Error importing MCP: {e}", file=sys.stderr)
print("Please install required packages: pip install mcp", file=sys.stderr)
sys.exit(1)
# Connector imports
from src.connector import Config, Imagen4Client
from src.connector.imagen4_client import ImageGenerationRequest
from src.connector.utils import save_generated_images
# Image processing import
try:
from PIL import Image
except ImportError as e:
print(f"Error importing Pillow: {e}", file=sys.stderr)
print("Please install Pillow: pip install Pillow", file=sys.stderr)
sys.exit(1)
# Logging configuration
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stderr)]
)
logger = logging.getLogger("imagen4-mcp-server")
# ==================== Image Processing Utilities ====================
def create_preview_image_b64(image_data: bytes, target_size: int = 512, quality: int = 85) -> Optional[str]:
"""
Convert PNG image data to JPEG preview with specified size and return as base64
Args:
image_data: Original PNG image data in bytes
target_size: Target size for the preview (default: 512x512)
quality: JPEG quality (1-100, default: 85)
Returns:
Base64 encoded JPEG image string, or None if conversion fails
"""
try:
# Open image from bytes
with Image.open(io.BytesIO(image_data)) as img:
# Convert to RGB if necessary (PNG might have alpha channel)
if img.mode in ('RGBA', 'LA', 'P'):
# Create white background for transparent images
background = Image.new('RGB', img.size, (255, 255, 255))
if img.mode == 'P':
img = img.convert('RGBA')
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
img = background
elif img.mode != 'RGB':
img = img.convert('RGB')
# Resize to target size maintaining aspect ratio
img.thumbnail((target_size, target_size), Image.Resampling.LANCZOS)
# If image is smaller than target size, pad it to exact size
if img.size != (target_size, target_size):
# Create new image with target size and white background
new_img = Image.new('RGB', (target_size, target_size), (255, 255, 255))
# Center the resized image
x = (target_size - img.size[0]) // 2
y = (target_size - img.size[1]) // 2
new_img.paste(img, (x, y))
img = new_img
# Convert to JPEG and encode as base64
output_buffer = io.BytesIO()
img.save(output_buffer, format='JPEG', quality=quality, optimize=True)
jpeg_data = output_buffer.getvalue()
# Encode to base64
base64_data = base64.b64encode(jpeg_data).decode('utf-8')
logger.info(f"Preview image created: {target_size}x{target_size} JPEG, {len(jpeg_data)} bytes")
return base64_data
except Exception as e:
logger.error(f"Failed to create preview image: {str(e)}")
return None
def get_image_info(image_data: bytes) -> Optional[dict]:
"""Get image information"""
try:
with Image.open(io.BytesIO(image_data)) as img:
return {
'format': img.format,
'size': img.size,
'mode': img.mode,
'bytes': len(image_data)
}
except Exception as e:
logger.error(f"Failed to get image info: {str(e)}")
return None
# ==================== Response Models ====================
@dataclass
class ImageGenerationResult:
"""Enhanced image generation result with preview support"""
success: bool
message: str
original_images_count: int
preview_images_b64: Optional[list[str]] = None # List of base64 JPEG previews (512x512)
saved_files: Optional[list[str]] = None
generation_params: Optional[Dict[str, Any]] = None
error_message: Optional[str] = None
def to_text_content(self) -> str:
"""Convert result to text format for MCP response"""
lines = [self.message]
if self.success and self.preview_images_b64:
lines.append(f"\n🖼️ Preview Images Generated: {len(self.preview_images_b64)} images (512x512 JPEG)")
for i, preview_b64 in enumerate(self.preview_images_b64):
lines.append(f"Preview {i+1} (base64 JPEG): {preview_b64[:50]}...({len(preview_b64)} chars)")
if self.saved_files:
lines.append(f"\n📁 Files saved:")
for filepath in self.saved_files:
lines.append(f" - {filepath}")
if self.generation_params:
lines.append(f"\n⚙️ Generation Parameters:")
for key, value in self.generation_params.items():
if key == 'prompt' and len(str(value)) > 100:
lines.append(f" - {key}: {str(value)[:100]}...")
else:
lines.append(f" - {key}: {value}")
return "\n".join(lines)
# ==================== MCP Tool Definitions ====================
def get_tools() -> List[Tool]:
"""Return all tool definitions"""
return [
Tool(
name="generate_image",
description="Generate 2048x2048 PNG images with 512x512 JPEG previews using Google Imagen 4 AI.",
inputSchema={
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Text prompt for image generation (Korean or English)"
},
"negative_prompt": {
"type": "string",
"description": "Negative prompt specifying elements not to generate (optional)",
"default": ""
},
"number_of_images": {
"type": "integer",
"description": "Number of images to generate (1 or 2 only)",
"enum": [1, 2],
"default": 1
},
"seed": {
"type": "integer",
"description": "Seed value for reproducible results (required, 0 ~ 4294967295 range)",
"minimum": 0,
"maximum": 4294967295
},
"aspect_ratio": {
"type": "string",
"description": "Image aspect ratio",
"enum": ["1:1", "9:16", "16:9", "3:4", "4:3"],
"default": "1:1"
},
"save_to_file": {
"type": "boolean",
"description": "Whether to save generated images to files (default: true)",
"default": True
}
},
"required": ["prompt", "seed"]
}
),
Tool(
name="regenerate_from_json",
description="Read parameters from JSON file and regenerate images with previews.",
inputSchema={
"type": "object",
"properties": {
"json_file_path": {
"type": "string",
"description": "Path to JSON file containing saved parameters"
},
"save_to_file": {
"type": "boolean",
"description": "Whether to save regenerated images to files (default: true)",
"default": True
}
},
"required": ["json_file_path"]
}
),
Tool(
name="generate_random_seed",
description="Generate random seed value for image generation.",
inputSchema={
"type": "object",
"properties": {},
"required": []
}
)
]
# ==================== Utility Functions ====================
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Remove or truncate sensitive data from arguments for safe logging"""
safe_args = {}
for key, value in arguments.items():
if isinstance(value, str):
# Check if it's likely base64 image data
if (key in ['data', 'image_data', 'base64', 'preview_image_b64'] or
(len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))):
# Truncate long image data
safe_args[key] = f"<image_data:{len(value)} chars>"
elif len(value) > 1000:
# Truncate any very long strings
safe_args[key] = f"{value[:100]}...<truncated:{len(value)} total chars>"
else:
safe_args[key] = value
else:
safe_args[key] = value
return safe_args
# ==================== Tool Handlers ====================
class Imagen4ToolHandlers:
"""MCP tool handler class with preview image support"""
def __init__(self, config: Config):
"""Initialize handler"""
self.config = config
self.client = Imagen4Client(config)
async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Random seed generation handler"""
try:
random_seed = random.randint(0, 2**32 - 1)
return [TextContent(
type="text",
text=f"Generated random seed: {random_seed}"
)]
except Exception as e:
logger.error(f"Random seed generation error: {str(e)}")
return [TextContent(
type="text",
text=f"Error occurred during random seed generation: {str(e)}"
)]
async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Image regeneration from JSON file handler with preview support"""
try:
json_file_path = arguments.get("json_file_path")
save_to_file = arguments.get("save_to_file", True)
if not json_file_path:
return [TextContent(
type="text",
text="Error: JSON file path is required."
)]
# Load parameters from JSON file
try:
with open(json_file_path, 'r', encoding='utf-8') as f:
params = json.load(f)
except FileNotFoundError:
return [TextContent(
type="text",
text=f"Error: Cannot find JSON file: {json_file_path}"
)]
except json.JSONDecodeError as e:
return [TextContent(
type="text",
text=f"Error: JSON parsing error: {str(e)}"
)]
# Check required parameters
required_params = ['prompt', 'seed']
missing_params = [p for p in required_params if p not in params]
if missing_params:
return [TextContent(
type="text",
text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}"
)]
# Create image generation request object
request = ImageGenerationRequest(
prompt=params.get('prompt'),
negative_prompt=params.get('negative_prompt', ''),
number_of_images=params.get('number_of_images', 1),
seed=params.get('seed'),
aspect_ratio=params.get('aspect_ratio', '1:1')
)
logger.info(f"Loaded parameters from JSON: {json_file_path}")
# Generate image
response = await self.client.generate_image(request)
if not response.success:
return [TextContent(
type="text",
text=f"Error occurred during image regeneration: {response.error_message}"
)]
# Create preview images
preview_images_b64 = []
for image_data in response.images_data:
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
if preview_b64:
preview_images_b64.append(preview_b64)
logger.info(f"Created preview image: {len(preview_b64)} chars base64 JPEG")
# Save files (optional)
saved_files = []
if save_to_file:
regeneration_params = {
"prompt": request.prompt,
"negative_prompt": request.negative_prompt,
"number_of_images": request.number_of_images,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"regenerated_from": json_file_path,
"original_generated_at": params.get('generated_at', 'unknown')
}
saved_files = save_generated_images(
images_data=response.images_data,
save_directory=self.config.output_path,
seed=request.seed,
generation_params=regeneration_params,
filename_prefix="imagen4_regen"
)
# Create result with preview images
result = ImageGenerationResult(
success=True,
message=f"✅ Images have been successfully regenerated from {json_file_path}",
original_images_count=len(response.images_data),
preview_images_b64=preview_images_b64,
saved_files=saved_files,
generation_params={
"prompt": request.prompt,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"number_of_images": request.number_of_images
}
)
return [TextContent(
type="text",
text=result.to_text_content()
)]
except Exception as e:
logger.error(f"Error occurred during image regeneration: {str(e)}")
return [TextContent(
type="text",
text=f"Error occurred during image regeneration: {str(e)}"
)]
async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Image generation handler with preview image support"""
try:
# Log arguments safely without exposing image data
safe_args = sanitize_args_for_logging(arguments)
logger.info(f"handle_generate_image called with arguments: {safe_args}")
# Extract and validate arguments
prompt = arguments.get("prompt")
if not prompt:
logger.error("No prompt provided")
return [TextContent(
type="text",
text="Error: Prompt is required."
)]
seed = arguments.get("seed")
if seed is None:
logger.error("No seed provided")
return [TextContent(
type="text",
text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed."
)]
# Create image generation request object
request = ImageGenerationRequest(
prompt=prompt,
negative_prompt=arguments.get("negative_prompt", ""),
number_of_images=arguments.get("number_of_images", 1),
seed=seed,
aspect_ratio=arguments.get("aspect_ratio", "1:1")
)
save_to_file = arguments.get("save_to_file", True)
logger.info(f"Starting image generation: '{prompt[:50]}...', Seed: {seed}")
# Generate image with timeout
try:
logger.info("Calling client.generate_image()...")
response = await asyncio.wait_for(
self.client.generate_image(request),
timeout=360.0 # 6 minute timeout
)
logger.info(f"Image generation completed. Success: {response.success}")
except asyncio.TimeoutError:
logger.error("Image generation timed out after 6 minutes")
return [TextContent(
type="text",
text="Error: Image generation timed out after 6 minutes. Please try again."
)]
if not response.success:
logger.error(f"Image generation failed: {response.error_message}")
return [TextContent(
type="text",
text=f"Error occurred during image generation: {response.error_message}"
)]
logger.info(f"Generated {len(response.images_data)} images successfully")
# Create preview images (512x512 JPEG base64)
preview_images_b64 = []
for i, image_data in enumerate(response.images_data):
# Log original image info
image_info = get_image_info(image_data)
if image_info:
logger.info(f"Original image {i+1}: {image_info}")
# Create preview
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
if preview_b64:
preview_images_b64.append(preview_b64)
logger.info(f"Created preview image {i+1}: {len(preview_b64)} chars base64 JPEG")
else:
logger.warning(f"Failed to create preview for image {i+1}")
# Save files if requested
saved_files = []
if save_to_file:
logger.info("Saving files to disk...")
generation_params = {
"prompt": request.prompt,
"negative_prompt": request.negative_prompt,
"number_of_images": request.number_of_images,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"guidance_scale": 7.5,
"safety_filter_level": "block_only_high",
"person_generation": "allow_all",
"add_watermark": False
}
saved_files = save_generated_images(
images_data=response.images_data,
save_directory=self.config.output_path,
seed=request.seed,
generation_params=generation_params
)
logger.info(f"Files saved: {saved_files}")
# Verify files were created
for file_path in saved_files:
if os.path.exists(file_path):
size = os.path.getsize(file_path)
logger.info(f" ✓ Verified: {file_path} ({size} bytes)")
else:
logger.error(f" ❌ Missing: {file_path}")
# Create enhanced result with preview images
result = ImageGenerationResult(
success=True,
message=f"✅ Images have been successfully generated!",
original_images_count=len(response.images_data),
preview_images_b64=preview_images_b64,
saved_files=saved_files,
generation_params={
"prompt": request.prompt,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"number_of_images": request.number_of_images,
"negative_prompt": request.negative_prompt
}
)
logger.info(f"Returning response with {len(preview_images_b64)} preview images")
return [TextContent(
type="text",
text=result.to_text_content()
)]
except Exception as e:
logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True)
return [TextContent(
type="text",
text=f"Error occurred during image generation: {str(e)}"
)]
# ==================== MCP Server ====================
class Imagen4MCPServer:
"""Imagen4 MCP server class with preview image support"""
def __init__(self, config: Config):
"""Initialize server"""
self.config = config
self.server = Server("imagen4-mcp-server")
self.handlers = Imagen4ToolHandlers(config)
# Register handlers
self._register_handlers()
def _register_handlers(self) -> None:
"""Register MCP handlers"""
@self.server.list_tools()
async def handle_list_tools() -> List[Tool]:
"""Return list of available tools"""
try:
logger.info("Listing available tools with preview image support")
return get_tools()
except Exception as e:
logger.error(f"Error listing tools: {e}")
raise
@self.server.call_tool()
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
"""Handle tool calls with preview image support"""
try:
# Log tool call safely without exposing sensitive data
safe_args = sanitize_args_for_logging(arguments)
logger.info(f"Tool called: {name} with arguments: {safe_args}")
if name == "generate_random_seed":
return await self.handlers.handle_generate_random_seed(arguments)
elif name == "regenerate_from_json":
return await self.handlers.handle_regenerate_from_json(arguments)
elif name == "generate_image":
return await self.handlers.handle_generate_image(arguments)
else:
error_msg = f"Unknown tool: {name}"
logger.error(error_msg)
return [TextContent(
type="text",
text=f"Error: {error_msg}"
)]
except Exception as e:
logger.error(f"Error handling tool call '{name}': {e}")
import traceback
logger.error(f"Tool call traceback: {traceback.format_exc()}")
return [TextContent(
type="text",
text=f"Error occurred while processing tool '{name}': {str(e)}"
)]
def get_server(self) -> Server:
"""Return MCP server instance"""
return self.server
# ==================== Main Function ====================
async def main():
"""Main function"""
logger.info("Starting Imagen 4 MCP Server with Preview Image Support...")
try:
# Load configuration
config = Config.from_env()
logger.info(f"Configuration loaded - Project: {config.project_id}, Location: {config.location}")
# Create MCP server
mcp_server = Imagen4MCPServer(config)
server = mcp_server.get_server()
logger.info("Imagen 4 MCP Server initialized successfully")
logger.info("Features: 512x512 JPEG preview images, base64 encoding, enhanced responses")
# Run MCP server with better error handling
try:
async with stdio_server() as (read_stream, write_stream):
logger.info("Starting MCP server with stdio transport")
await server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="imagen4-mcp-server",
server_version="3.0.0",
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={
"preview_images": {},
"base64_jpeg_previews": {}
},
)
)
)
except Exception as stdio_error:
logger.error(f"STDIO server error: {stdio_error}")
# Try to handle specific MCP protocol errors
if "TaskGroup" in str(stdio_error):
logger.error("TaskGroup error detected - this may be due to client disconnection")
raise
except ValueError as e:
logger.error(f"Configuration error: {e}")
logger.error("Please check your .env file or environment variables.")
return 1
except KeyboardInterrupt:
logger.info("Server stopped by user (Ctrl+C)")
return 0
except Exception as e:
logger.error(f"Server error: {e}")
logger.error(f"Error type: {type(e).__name__}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
raise
if __name__ == "__main__":
try:
# Set up signal handling for graceful shutdown
import signal
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, shutting down gracefully...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Run the server
exit_code = asyncio.run(main())
sys.exit(exit_code or 0)
except KeyboardInterrupt:
logger.info("Server stopped by user")
sys.exit(0)
except SystemExit as e:
logger.info(f"Server exiting with code {e.code}")
sys.exit(e.code)
except Exception as e:
logger.error(f"Fatal error: {e}")
logger.error(f"Error type: {type(e).__name__}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
sys.exit(1)

7
requirements.txt Normal file
View File

@@ -0,0 +1,7 @@
# Requirements for separated Imagen4 architecture
google-genai>=0.2.0
mcp>=1.0.0
python-dotenv>=1.0.0
pytest>=7.0.0
pytest-asyncio>=0.21.0
Pillow>=10.0.0

50
run.bat Normal file
View File

@@ -0,0 +1,50 @@
@echo off
echo Starting Imagen4 MCP Server with Preview Image Support...
echo Features: 512x512 JPEG preview images, base64 encoding
echo.
REM Check if virtual environment exists
if not exist "venv\Scripts\activate.bat" (
echo Creating virtual environment...
python -m venv venv
if errorlevel 1 (
echo Failed to create virtual environment
pause
exit /b 1
)
)
REM Activate virtual environment
echo Activating virtual environment...
call venv\Scripts\activate.bat
REM Check if requirements are installed
python -c "import PIL" 2>nul
if errorlevel 1 (
echo Installing requirements...
pip install -r requirements.txt
if errorlevel 1 (
echo Failed to install requirements
pause
exit /b 1
)
)
REM Check if .env file exists
if not exist ".env" (
echo Warning: .env file not found
echo Please copy .env.example to .env and configure your API credentials
pause
exit /b 1
)
REM Run server
echo Starting MCP server...
python main.py
REM Keep window open if there's an error
if errorlevel 1 (
echo.
echo Server exited with error
pause
)

10
src/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
"""
Imagen4 Package
분리된 아키텍처를 가진 Google Imagen 4 MCP 서버
"""
# 기본 컴포넌트만 export
# 자세한 import는 각 모듈에서 직접 하도록 함
__version__ = "2.1.0"

View File

@@ -0,0 +1,9 @@
"""
Async Task Management Module for MCP Server
"""
from .models import TaskStatus, TaskResult, generate_task_id
from .task_manager import TaskManager
from .worker_pool import WorkerPool
__all__ = ['TaskManager', 'TaskStatus', 'TaskResult', 'WorkerPool']

48
src/async_task/models.py Normal file
View File

@@ -0,0 +1,48 @@
"""
Task Status and Result Models
"""
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Optional, Dict
class TaskStatus(Enum):
"""Task execution status"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class TaskResult:
"""Task execution result"""
task_id: str
status: TaskStatus
result: Optional[Any] = None
error: Optional[str] = None
created_at: datetime = field(default_factory=datetime.now)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def duration(self) -> Optional[float]:
"""Calculate task execution duration in seconds"""
if self.started_at and self.completed_at:
return (self.completed_at - self.started_at).total_seconds()
return None
@property
def is_finished(self) -> bool:
"""Check if task is finished (completed, failed, or cancelled)"""
return self.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]
def generate_task_id() -> str:
"""Generate unique task ID"""
return str(uuid.uuid4())

View File

@@ -0,0 +1,324 @@
"""
Task Manager for MCP Server
Manages background task execution with status tracking and result retrieval
"""
import asyncio
import logging
from typing import Any, Callable, Optional, Dict, List
from datetime import datetime, timedelta
from .models import TaskResult, TaskStatus, generate_task_id
from .worker_pool import WorkerPool
logger = logging.getLogger(__name__)
class TaskManager:
"""Main task manager for MCP server"""
def __init__(
self,
max_workers: Optional[int] = None,
use_process_pool: bool = False,
default_timeout: float = 600.0, # 10 minutes
result_retention_hours: int = 24, # Keep results for 24 hours
max_retained_results: int = 200
):
"""
Initialize task manager
Args:
max_workers: Maximum number of worker threads/processes
use_process_pool: Use process pool instead of thread pool
default_timeout: Default timeout for tasks in seconds
result_retention_hours: How long to keep finished task results
max_retained_results: Maximum number of results to retain
"""
self.worker_pool = WorkerPool(
max_workers=max_workers,
use_process_pool=use_process_pool,
task_timeout=default_timeout
)
self.result_retention_hours = result_retention_hours
self.max_retained_results = max_retained_results
# Start cleanup task
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
logger.info(f"TaskManager initialized with {self.worker_pool.max_workers} workers")
async def submit_task(
self,
func: Callable,
*args,
task_id: Optional[str] = None,
timeout: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs
) -> str:
"""
Submit a task for background execution
Args:
func: Function to execute
*args: Function arguments
task_id: Optional custom task ID
timeout: Task timeout in seconds
metadata: Additional metadata for the task
**kwargs: Function keyword arguments
Returns:
str: Task ID for tracking progress
"""
task_id = await self.worker_pool.submit_task(
func, *args, task_id=task_id, timeout=timeout, **kwargs
)
# Add additional metadata if provided
if metadata:
result = self.worker_pool.get_task_result(task_id)
if result:
result.metadata.update(metadata)
return task_id
def get_task_status(self, task_id: str) -> Optional[TaskStatus]:
"""Get task status by ID"""
result = self.worker_pool.get_task_result(task_id)
return result.status if result else None
def get_task_result(self, task_id: str) -> Optional[TaskResult]:
"""Get complete task result by ID with circular reference protection"""
try:
raw_result = self.worker_pool.get_task_result(task_id)
if not raw_result:
return None
# Create a safe copy to avoid circular references
safe_result = TaskResult(
task_id=raw_result.task_id,
status=raw_result.status,
result=None, # We'll set this safely
error=raw_result.error,
created_at=raw_result.created_at,
started_at=raw_result.started_at,
completed_at=raw_result.completed_at,
metadata=raw_result.metadata.copy() if raw_result.metadata else {}
)
# Safely copy the result data
if raw_result.result is not None:
try:
# If it's a dict, create a clean copy
if isinstance(raw_result.result, dict):
safe_result.result = {
'success': raw_result.result.get('success', False),
'error': raw_result.result.get('error'),
'images_b64': raw_result.result.get('images_b64', []),
'saved_files': raw_result.result.get('saved_files', []),
'request': raw_result.result.get('request', {}),
'image_count': raw_result.result.get('image_count', 0)
}
else:
# For non-dict results, just copy directly
safe_result.result = raw_result.result
except Exception as copy_error:
logger.error(f"Error creating safe copy of result: {str(copy_error)}")
safe_result.result = {
'success': False,
'error': f'Result data corrupted: {str(copy_error)}'
}
return safe_result
except Exception as e:
logger.error(f"Error in get_task_result: {str(e)}", exc_info=True)
return None
def get_task_progress(self, task_id: str) -> Dict[str, Any]:
"""
Get task progress information
Returns:
Dict with task status, timing info, and metadata
"""
result = self.get_task_result(task_id)
if not result:
return {'error': 'Task not found'}
progress = {
'task_id': result.task_id,
'status': result.status.value,
'created_at': result.created_at.isoformat(),
'metadata': result.metadata
}
if result.started_at:
progress['started_at'] = result.started_at.isoformat()
if result.status == TaskStatus.RUNNING:
elapsed = (datetime.now() - result.started_at).total_seconds()
progress['elapsed_seconds'] = elapsed
if result.completed_at:
progress['completed_at'] = result.completed_at.isoformat()
progress['duration_seconds'] = result.duration
if result.error:
progress['error'] = result.error
return progress
def list_tasks(
self,
status_filter: Optional[TaskStatus] = None,
limit: int = 50
) -> List[Dict[str, Any]]:
"""
List tasks with optional filtering
Args:
status_filter: Filter by task status
limit: Maximum number of tasks to return
Returns:
List of task progress dictionaries
"""
all_results = self.worker_pool.get_all_results()
# Filter by status if requested
if status_filter:
filtered_results = {
k: v for k, v in all_results.items()
if v.status == status_filter
}
else:
filtered_results = all_results
# Sort by creation time (most recent first)
sorted_items = sorted(
filtered_results.items(),
key=lambda x: x[1].created_at,
reverse=True
)
# Apply limit and return progress info
return [
self.get_task_progress(task_id)
for task_id, _ in sorted_items[:limit]
]
def cancel_task(self, task_id: str) -> bool:
"""Cancel a running task"""
return self.worker_pool.cancel_task(task_id)
async def wait_for_task(
self,
task_id: str,
timeout: Optional[float] = None,
poll_interval: float = 1.0
) -> TaskResult:
"""
Wait for a task to complete
Args:
task_id: Task ID to wait for
timeout: Maximum time to wait in seconds
poll_interval: How often to check task status
Returns:
TaskResult: Final task result
Raises:
asyncio.TimeoutError: If timeout is reached
ValueError: If task doesn't exist
"""
start_time = datetime.now()
while True:
result = self.get_task_result(task_id)
if not result:
raise ValueError(f"Task {task_id} not found")
if result.is_finished:
return result
# Check timeout
if timeout:
elapsed = (datetime.now() - start_time).total_seconds()
if elapsed >= timeout:
raise asyncio.TimeoutError(f"Timeout waiting for task {task_id}")
await asyncio.sleep(poll_interval)
async def _periodic_cleanup(self) -> None:
"""Periodic cleanup of old task results"""
while True:
try:
# Wait 1 hour between cleanups
await asyncio.sleep(3600)
# Clean up old results
cutoff_time = datetime.now() - timedelta(hours=self.result_retention_hours)
all_results = self.worker_pool.get_all_results()
to_remove = []
for task_id, result in all_results.items():
if (result.is_finished and
result.completed_at and
result.completed_at < cutoff_time):
to_remove.append(task_id)
# Remove old results
for task_id in to_remove:
if task_id in self.worker_pool._results:
del self.worker_pool._results[task_id]
if to_remove:
logger.info(f"Cleaned up {len(to_remove)} old task results")
# Also clean up excess results
self.worker_pool.cleanup_finished_tasks(self.max_retained_results)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in periodic cleanup: {e}")
async def shutdown(self, wait: bool = True) -> None:
"""Shutdown task manager"""
logger.info("Shutting down task manager...")
# Cancel cleanup task
if not self._cleanup_task.done():
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
# Shutdown worker pool
await self.worker_pool.shutdown(wait=wait)
logger.info("Task manager shutdown complete")
@property
def stats(self) -> Dict[str, Any]:
"""Get task manager statistics"""
all_results = self.worker_pool.get_all_results()
status_counts = {}
for status in TaskStatus:
status_counts[status.value] = sum(
1 for r in all_results.values() if r.status == status
)
return {
'total_tasks': len(all_results),
'active_tasks': self.worker_pool.active_task_count,
'max_workers': self.worker_pool.max_workers,
'worker_type': 'process' if self.worker_pool.use_process_pool else 'thread',
'status_counts': status_counts
}

View File

@@ -0,0 +1,258 @@
"""
Worker Pool for Background Task Execution
"""
import asyncio
import logging
import multiprocessing
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from typing import Any, Callable, Optional, Union
from datetime import datetime
from .models import TaskResult, TaskStatus, generate_task_id
logger = logging.getLogger(__name__)
class WorkerPool:
"""Worker pool for executing background tasks"""
def __init__(
self,
max_workers: Optional[int] = None,
use_process_pool: bool = False,
task_timeout: float = 600.0 # 10 minutes default
):
"""
Initialize worker pool
Args:
max_workers: Maximum number of worker threads/processes
use_process_pool: Use ProcessPoolExecutor instead of ThreadPoolExecutor
task_timeout: Default timeout for tasks in seconds
"""
self.max_workers = max_workers or min(32, (multiprocessing.cpu_count() or 1) + 4)
self.use_process_pool = use_process_pool
self.task_timeout = task_timeout
# Initialize executor
if use_process_pool:
self.executor = ProcessPoolExecutor(max_workers=self.max_workers)
logger.info(f"Initialized ProcessPoolExecutor with {self.max_workers} workers")
else:
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
logger.info(f"Initialized ThreadPoolExecutor with {self.max_workers} workers")
# Track running tasks
self._running_tasks: dict[str, asyncio.Task] = {}
self._results: dict[str, TaskResult] = {}
async def submit_task(
self,
func: Callable,
*args,
task_id: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs
) -> str:
"""
Submit a task for background execution
Args:
func: Function to execute
*args: Function arguments
task_id: Optional custom task ID
timeout: Task timeout in seconds
**kwargs: Function keyword arguments
Returns:
str: Task ID for tracking
"""
if task_id is None:
task_id = generate_task_id()
timeout = timeout or self.task_timeout
# Create task result
task_result = TaskResult(
task_id=task_id,
status=TaskStatus.PENDING,
metadata={
'function_name': func.__name__,
'timeout': timeout,
'worker_type': 'process' if self.use_process_pool else 'thread'
}
)
self._results[task_id] = task_result
# Create and start background task
background_task = asyncio.create_task(
self._execute_task(func, args, kwargs, task_result, timeout)
)
self._running_tasks[task_id] = background_task
logger.info(f"Task {task_id} submitted for execution (timeout: {timeout}s)")
return task_id
async def _execute_task(
self,
func: Callable,
args: tuple,
kwargs: dict,
task_result: TaskResult,
timeout: float
) -> None:
"""Execute task in background"""
try:
task_result.status = TaskStatus.RUNNING
task_result.started_at = datetime.now()
logger.info(f"Starting task {task_result.task_id}: {func.__name__}")
# Execute function with timeout
if self.use_process_pool:
# For process pool, function must be pickleable
future = self.executor.submit(func, *args, **kwargs)
result = await asyncio.wait_for(
asyncio.wrap_future(future),
timeout=timeout
)
else:
# For thread pool, can use lambda/closure
result = await asyncio.wait_for(
asyncio.to_thread(func, *args, **kwargs),
timeout=timeout
)
task_result.result = result
task_result.status = TaskStatus.COMPLETED
task_result.completed_at = datetime.now()
logger.info(f"Task {task_result.task_id} completed successfully in {task_result.duration:.2f}s")
except asyncio.TimeoutError:
task_result.status = TaskStatus.FAILED
task_result.error = f"Task timed out after {timeout} seconds"
task_result.completed_at = datetime.now()
logger.error(f"Task {task_result.task_id} timed out after {timeout}s")
except asyncio.CancelledError:
task_result.status = TaskStatus.CANCELLED
task_result.error = "Task was cancelled"
task_result.completed_at = datetime.now()
logger.info(f"Task {task_result.task_id} was cancelled")
except Exception as e:
task_result.status = TaskStatus.FAILED
task_result.error = str(e)
task_result.completed_at = datetime.now()
logger.error(f"Task {task_result.task_id} failed: {str(e)}", exc_info=True)
finally:
# Clean up running task reference
if task_result.task_id in self._running_tasks:
del self._running_tasks[task_result.task_id]
def get_task_result(self, task_id: str) -> Optional[TaskResult]:
"""Get task result by ID"""
return self._results.get(task_id)
def get_all_results(self) -> dict[str, TaskResult]:
"""Get all task results"""
return self._results.copy()
def cancel_task(self, task_id: str) -> bool:
"""
Cancel a running task
Args:
task_id: Task ID to cancel
Returns:
bool: True if task was cancelled, False if not found or already finished
"""
if task_id in self._running_tasks:
task = self._running_tasks[task_id]
if not task.done():
task.cancel()
logger.info(f"Task {task_id} cancellation requested")
return True
return False
def cleanup_finished_tasks(self, max_results: int = 100) -> int:
"""
Clean up finished task results to prevent memory buildup
Args:
max_results: Maximum number of results to keep
Returns:
int: Number of results cleaned up
"""
if len(self._results) <= max_results:
return 0
# Sort by completion time, keep most recent
finished_results = {
k: v for k, v in self._results.items()
if v.is_finished and v.completed_at
}
if len(finished_results) <= max_results:
return 0
sorted_results = sorted(
finished_results.items(),
key=lambda x: x[1].completed_at,
reverse=True
)
# Keep most recent max_results
to_keep = set(k for k, _ in sorted_results[:max_results])
# Remove old results
to_remove = [k for k in finished_results.keys() if k not in to_keep]
for task_id in to_remove:
del self._results[task_id]
cleaned_count = len(to_remove)
if cleaned_count > 0:
logger.info(f"Cleaned up {cleaned_count} old task results")
return cleaned_count
async def shutdown(self, wait: bool = True) -> None:
"""
Shutdown worker pool
Args:
wait: Whether to wait for running tasks to complete
"""
logger.info("Shutting down worker pool...")
if wait:
# Cancel all running tasks
for task_id, task in self._running_tasks.items():
if not task.done():
task.cancel()
logger.info(f"Cancelled task {task_id}")
# Wait for tasks to complete cancellation
if self._running_tasks:
await asyncio.gather(*self._running_tasks.values(), return_exceptions=True)
# Shutdown executor
self.executor.shutdown(wait=wait)
logger.info("Worker pool shutdown complete")
@property
def active_task_count(self) -> int:
"""Get number of currently running tasks"""
return len([t for t in self._running_tasks.values() if not t.done()])
@property
def total_task_count(self) -> int:
"""Get total number of tracked tasks"""
return len(self._results)

10
src/connector/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
"""
Imagen4 Connector Package
Google Imagen 4 API와의 통신을 담당하는 커넥터 모듈
"""
from .imagen4_client import Imagen4Client
from .config import Config
__all__ = ['Imagen4Client', 'Config']

101
src/connector/config.py Normal file
View File

@@ -0,0 +1,101 @@
"""
Configuration management for Imagen4 connector
"""
import os
from typing import Optional
from dataclasses import dataclass
@dataclass
class Config:
"""Imagen4 configuration class"""
project_id: str
location: str = "us-central1"
output_path: str = "./generated_images"
default_model: str = "imagen-4.0-generate-001" # Default model
@classmethod
def from_env(cls, env_file: Optional[str] = None) -> 'Config':
"""
Load configuration from environment variables
Args:
env_file: Optional path to .env file to load first
"""
# Try to load .env file if specified or if python-dotenv is available
if env_file or not os.getenv("GOOGLE_CLOUD_PROJECT_ID"):
try:
from dotenv import load_dotenv
if env_file:
load_dotenv(env_file)
else:
# Try to find .env file in current directory or parent directories
current_dir = os.getcwd()
env_paths = [
os.path.join(current_dir, '.env'),
os.path.join(os.path.dirname(current_dir), '.env'),
os.path.join(os.path.dirname(__file__), '..', '..', '.env')
]
for env_path in env_paths:
if os.path.exists(env_path):
load_dotenv(env_path)
break
except ImportError:
pass # python-dotenv not available, continue with system env vars
# Try multiple environment variable names for better compatibility
project_id = (
os.getenv("GOOGLE_CLOUD_PROJECT_ID") or
os.getenv("GOOGLE_CLOUD_PROJECT") or
os.getenv("GCP_PROJECT")
)
if not project_id:
raise ValueError(
"Google Cloud Project ID is required. Please set one of these environment variables: "
"GOOGLE_CLOUD_PROJECT_ID, GOOGLE_CLOUD_PROJECT, or GCP_PROJECT"
)
# Also set GOOGLE_CLOUD_PROJECT for Google Cloud libraries if not already set
if not os.getenv("GOOGLE_CLOUD_PROJECT"):
os.environ["GOOGLE_CLOUD_PROJECT"] = project_id
return cls(
project_id=project_id,
location=os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1"),
output_path=os.getenv("GENERATED_IMAGES_PATH", "./generated_images"),
default_model=os.getenv("IMAGEN4_DEFAULT_MODEL", "imagen-4.0-generate-001")
)
@classmethod
def for_testing(cls, project_id: str = "test-project-id") -> 'Config':
"""Create a config instance for testing purposes"""
return cls(
project_id=project_id,
location="us-central1",
output_path="./test_generated_images",
default_model="imagen-4.0-generate-001"
)
def validate(self) -> bool:
"""Validate configuration"""
if not self.project_id:
return False
if not self.location:
return False
if not self.output_path:
return False
return True
def __str__(self) -> str:
"""String representation of config"""
return (
f"Config(project_id='{self.project_id}', "
f"location='{self.location}', "
f"output_path='{self.output_path}', "
f"default_model='{self.default_model}')"
)

View File

@@ -0,0 +1,215 @@
"""
Google Imagen 4 API Client
"""
import asyncio
import logging
from typing import List, Optional
from dataclasses import dataclass
try:
from google import genai
from google.genai.types import GenerateImagesConfig
except ImportError as e:
raise ImportError(f"Google GenAI library not found: {e}. Install with: pip install google-genai")
from .config import Config
logger = logging.getLogger(__name__)
@dataclass
class ImageGenerationRequest:
"""Image generation request data class"""
prompt: str
negative_prompt: str = ""
number_of_images: int = 1
seed: int = 0
aspect_ratio: str = "1:1"
model: str = "imagen-4.0-generate-001" # Model selection
def validate(self) -> None:
"""Validate request data"""
# Import here to avoid circular imports
from ..utils.token_utils import validate_prompt_length
if not self.prompt:
raise ValueError("Prompt is required")
# 프롬프트 토큰 수 검증
is_valid, token_count, error_msg = validate_prompt_length(self.prompt)
if not is_valid:
raise ValueError(f"Prompt validation failed: {error_msg}")
# negative_prompt도 검증 (더 관대한 제한)
if self.negative_prompt:
neg_is_valid, neg_token_count, neg_error_msg = validate_prompt_length(
self.negative_prompt, max_tokens=240 # negative prompt는 절반 제한
)
if not neg_is_valid:
raise ValueError(f"Negative prompt validation failed: {neg_error_msg}")
if not isinstance(self.seed, int) or self.seed < 0 or self.seed > 4294967295:
raise ValueError("Seed value must be an integer between 0 and 4294967295")
if self.number_of_images not in [1, 2]:
raise ValueError("Number of images must be 1 or 2 only")
# Validate model name
valid_models = ["imagen-4.0-generate-001", "imagen-4.0-ultra-generate-001"]
if self.model not in valid_models:
raise ValueError(f"Model must be one of: {valid_models}")
@dataclass
class ImageGenerationResponse:
"""Image generation response data class"""
images_data: List[bytes]
request: ImageGenerationRequest
success: bool = True
error_message: Optional[str] = None
class Imagen4Client:
"""Google Imagen 4 API client"""
def __init__(self, config: Config):
"""
Initialize client
Args:
config: Imagen4 configuration object
"""
if not config.validate():
raise ValueError("Invalid configuration")
self.config = config
self._client = None
self._initialize_client()
def _initialize_client(self) -> None:
"""Initialize Google GenAI client"""
try:
self._client = genai.Client(
vertexai=True,
project=self.config.project_id,
location=self.config.location
)
# Log client information
if not self._client.vertexai:
logger.info("Using Gemini Developer API.")
elif self._client._api_client.project:
logger.info(f"Using Vertex AI with project: {self._client._api_client.project} in location: {self._client._api_client.location}")
elif self._client._api_client.api_key:
logger.info(f"Using Vertex AI in express mode with API key: {self._client._api_client.api_key[:5]}...{self._client._api_client.api_key[-5:]}")
except Exception as e:
logger.error(f"Failed to initialize Imagen4Client: {e}")
raise
async def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse:
"""
Process image generation request with MCP-safe execution
Args:
request: Image generation request object
Returns:
ImageGenerationResponse: Generation result
"""
try:
# Validate request
request.validate()
# 토큰 수 정보 로깅
from ..utils.token_utils import get_prompt_stats
prompt_stats = get_prompt_stats(request.prompt)
logger.info(f"Prompt token analysis: {prompt_stats['estimated_tokens']}/{prompt_stats['max_tokens']} tokens")
if request.negative_prompt:
neg_stats = get_prompt_stats(request.negative_prompt)
logger.info(f"Negative prompt token analysis: {neg_stats['estimated_tokens']} tokens")
logger.info(f"Starting image generation - Prompt: '{request.prompt[:50]}...', Seed: {request.seed}, Model: {request.model}")
print(f"Aspect Ratio: {request.aspect_ratio}, Number of Images: {request.number_of_images}, Model: {request.model}")
print(f"Prompt Tokens: {prompt_stats['estimated_tokens']}/{prompt_stats['max_tokens']}")
# Create a new event loop if needed (for MCP compatibility)
def _sync_generate_images():
"""Synchronous wrapper for API call"""
return self._client.models.generate_images(
model=request.model,
prompt=request.prompt,
config=GenerateImagesConfig(
add_watermark=False,
aspect_ratio=request.aspect_ratio,
image_size="2K", # 2048x2048
negative_prompt=request.negative_prompt,
number_of_images=request.number_of_images,
output_mime_type="image/png",
person_generation="allow_all",
safety_filter_level="block_only_high",
seed=request.seed,
)
)
# Run in thread with explicit timeout
logger.info("Making API call with thread executor...")
response = await asyncio.wait_for(
asyncio.to_thread(_sync_generate_images),
timeout=300.0 # 5 minute timeout (increased from 90s)
)
logger.info("API call completed successfully")
print(f"Get response: {response}")
# Extract image data from response
images_data = []
if response.generated_images:
for i, gen_image in enumerate(response.generated_images):
if hasattr(gen_image, 'image') and hasattr(gen_image.image, 'image_bytes'):
image_bytes = gen_image.image.image_bytes
if image_bytes:
images_data.append(image_bytes)
logger.info(f"Image {i+1} extraction complete (size: {len(image_bytes)} bytes)")
else:
logger.warning(f"Image {i+1}'s image_bytes is empty.")
else:
logger.warning(f"Cannot find image_bytes in image {i+1}.")
if not images_data:
raise Exception("Cannot find generated image data.")
logger.info(f"Total {len(images_data)} images generated successfully")
return ImageGenerationResponse(
images_data=images_data,
request=request,
success=True
)
except asyncio.TimeoutError:
error_msg = "Image generation timed out after 5 minutes"
logger.error(error_msg)
return ImageGenerationResponse(
images_data=[],
request=request,
success=False,
error_message=error_msg
)
except Exception as e:
logger.error(f"Error occurred during image generation: {str(e)}")
return ImageGenerationResponse(
images_data=[],
request=request,
success=False,
error_message=str(e)
)
def health_check(self) -> bool:
"""Check client status"""
try:
return self._client is not None and self.config.validate()
except Exception:
return False

70
src/connector/utils.py Normal file
View File

@@ -0,0 +1,70 @@
"""
Utility functions for Imagen4 connector
"""
import os
import json
import logging
from datetime import datetime
from typing import List, Dict, Any, Optional
logger = logging.getLogger(__name__)
def save_generated_images(
images_data: List[bytes],
save_directory: str = "./generated_images",
filename_prefix: str = "imagen4",
seed: Optional[int] = None,
generation_params: Optional[Dict[str, Any]] = None
) -> List[str]:
"""Save generated images to files"""
os.makedirs(save_directory, exist_ok=True)
saved_files = []
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
for i, image_data in enumerate(images_data):
try:
if seed is not None:
base_filename = f"{filename_prefix}_{seed}_{timestamp}_{i+1:03d}"
else:
base_filename = f"{filename_prefix}_{timestamp}_{i+1:03d}"
# Save image file
image_filename = f"{base_filename}.png"
image_filepath = os.path.join(save_directory, image_filename)
with open(image_filepath, 'wb') as f:
f.write(image_data)
saved_files.append(image_filepath)
logger.info(f"Image saved successfully: {image_filepath}")
# Save JSON parameter file
if generation_params:
json_filename = f"{base_filename}.json"
json_filepath = os.path.join(save_directory, json_filename)
params_with_metadata = {
**generation_params,
"generated_at": datetime.now().isoformat(),
"image_filename": image_filename,
"image_index": i + 1,
"total_images": len(images_data),
"file_size_bytes": len(image_data),
"model_version": "imagen-4.0-generate-001",
"image_format": "PNG",
"image_size": "2048x2048"
}
with open(json_filepath, 'w', encoding='utf-8') as f:
json.dump(params_with_metadata, f, ensure_ascii=False, indent=2)
saved_files.append(json_filepath)
logger.info(f"Parameters saved successfully: {json_filepath}")
except Exception as e:
logger.error(f"Failed to save image {i+1}: {str(e)}")
continue
return saved_files

26
src/server/__init__.py Normal file
View File

@@ -0,0 +1,26 @@
"""
Imagen4 Server Package
MCP 서버 구현을 담당하는 모듈
"""
# Import basic server components
from src.server.mcp_server import Imagen4MCPServer
from src.server.handlers import ToolHandlers
from src.server.models import MCPToolDefinitions
# Import enhanced components
try:
from src.server.enhanced_mcp_server import EnhancedImagen4MCPServer
from src.server.enhanced_handlers import EnhancedToolHandlers
from src.server.enhanced_models import ImageGenerationResult, PreviewImageResponse
__all__ = [
'Imagen4MCPServer', 'ToolHandlers', 'MCPToolDefinitions',
'EnhancedImagen4MCPServer', 'EnhancedToolHandlers',
'ImageGenerationResult', 'PreviewImageResponse'
]
except ImportError as e:
# Fallback if enhanced modules are not available
print(f"Warning: Enhanced features not available: {e}")
__all__ = ['Imagen4MCPServer', 'ToolHandlers', 'MCPToolDefinitions']

View File

@@ -0,0 +1,322 @@
"""
Enhanced Tool Handlers for MCP Server with Preview Image Support
"""
import asyncio
import base64
import json
import random
import logging
from typing import List, Dict, Any
from mcp.types import TextContent, ImageContent
from src.connector import Imagen4Client, Config
from src.connector.imagen4_client import ImageGenerationRequest
from src.connector.utils import save_generated_images
from src.utils.image_utils import create_preview_image_b64, get_image_info
from src.server.enhanced_models import ImageGenerationResult, PreviewImageResponse
logger = logging.getLogger(__name__)
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Remove or truncate sensitive data from arguments for safe logging"""
safe_args = {}
for key, value in arguments.items():
if isinstance(value, str):
# Check if it's likely base64 image data
if (key in ['data', 'image_data', 'base64', 'preview_image_b64'] or
(len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))):
# Truncate long image data
safe_args[key] = f"<image_data:{len(value)} chars>"
elif len(value) > 1000:
# Truncate any very long strings
safe_args[key] = f"{value[:100]}...<truncated:{len(value)} total chars>"
else:
safe_args[key] = value
else:
safe_args[key] = value
return safe_args
class EnhancedToolHandlers:
"""Enhanced MCP tool handler class with preview image support"""
def __init__(self, config: Config):
"""Initialize handler"""
self.config = config
self.client = Imagen4Client(config)
async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Random seed generation handler"""
try:
random_seed = random.randint(0, 2**32 - 1)
return [TextContent(
type="text",
text=f"Generated random seed: {random_seed}"
)]
except Exception as e:
logger.error(f"Random seed generation error: {str(e)}")
return [TextContent(
type="text",
text=f"Error occurred during random seed generation: {str(e)}"
)]
async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Image regeneration from JSON file handler with preview support"""
try:
json_file_path = arguments.get("json_file_path")
save_to_file = arguments.get("save_to_file", True)
if not json_file_path:
return [TextContent(
type="text",
text="Error: JSON file path is required."
)]
# Load parameters from JSON file
try:
with open(json_file_path, 'r', encoding='utf-8') as f:
params = json.load(f)
except FileNotFoundError:
return [TextContent(
type="text",
text=f"Error: Cannot find JSON file: {json_file_path}"
)]
except json.JSONDecodeError as e:
return [TextContent(
type="text",
text=f"Error: JSON parsing error: {str(e)}"
)]
# Check required parameters
required_params = ['prompt', 'seed']
missing_params = [p for p in required_params if p not in params]
if missing_params:
return [TextContent(
type="text",
text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}"
)]
# Create image generation request object
request = ImageGenerationRequest(
prompt=params.get('prompt'),
negative_prompt=params.get('negative_prompt', ''),
number_of_images=params.get('number_of_images', 1),
seed=params.get('seed'),
aspect_ratio=params.get('aspect_ratio', '1:1'),
model=params.get('model', self.config.default_model)
)
logger.info(f"Loaded parameters from JSON: {json_file_path}")
# Generate image
response = await self.client.generate_image(request)
if not response.success:
return [TextContent(
type="text",
text=f"Error occurred during image regeneration: {response.error_message}"
)]
# Create preview images
preview_images_b64 = []
for image_data in response.images_data:
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
if preview_b64:
preview_images_b64.append(preview_b64)
logger.info(f"Created preview image: {len(preview_b64)} chars base64 JPEG")
# Save files (optional)
saved_files = []
if save_to_file:
regeneration_params = {
"prompt": request.prompt,
"negative_prompt": request.negative_prompt,
"number_of_images": request.number_of_images,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"model": request.model,
"regenerated_from": json_file_path,
"original_generated_at": params.get('generated_at', 'unknown')
}
saved_files = save_generated_images(
images_data=response.images_data,
save_directory=self.config.output_path,
seed=request.seed,
generation_params=regeneration_params,
filename_prefix="imagen4_regen"
)
# Create result with preview images
result = ImageGenerationResult(
success=True,
message=f"✅ Images have been successfully regenerated from {json_file_path}",
original_images_count=len(response.images_data),
preview_images_b64=preview_images_b64,
saved_files=saved_files,
generation_params={
"prompt": request.prompt,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"number_of_images": request.number_of_images,
"model": request.model
}
)
return [TextContent(
type="text",
text=result.to_text_content()
)]
except Exception as e:
logger.error(f"Error occurred during image regeneration: {str(e)}")
return [TextContent(
type="text",
text=f"Error occurred during image regeneration: {str(e)}"
)]
async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Enhanced image generation handler with preview image support"""
try:
# Log arguments safely without exposing image data
safe_args = sanitize_args_for_logging(arguments)
logger.info(f"handle_generate_image called with arguments: {safe_args}")
# Extract and validate arguments
prompt = arguments.get("prompt")
if not prompt:
logger.error("No prompt provided")
return [TextContent(
type="text",
text="Error: Prompt is required."
)]
seed = arguments.get("seed")
if seed is None:
logger.error("No seed provided")
return [TextContent(
type="text",
text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed."
)]
# Create image generation request object
request = ImageGenerationRequest(
prompt=prompt,
negative_prompt=arguments.get("negative_prompt", ""),
number_of_images=arguments.get("number_of_images", 1),
seed=seed,
aspect_ratio=arguments.get("aspect_ratio", "1:1"),
model=arguments.get("model", self.config.default_model)
)
save_to_file = arguments.get("save_to_file", True)
logger.info(f"Starting image generation: '{prompt[:50]}...', Seed: {seed}")
# Generate image with timeout
try:
logger.info("Calling client.generate_image()...")
response = await asyncio.wait_for(
self.client.generate_image(request),
timeout=360.0 # 6 minute timeout
)
logger.info(f"Image generation completed. Success: {response.success}")
except asyncio.TimeoutError:
logger.error("Image generation timed out after 6 minutes")
return [TextContent(
type="text",
text="Error: Image generation timed out after 6 minutes. Please try again."
)]
if not response.success:
logger.error(f"Image generation failed: {response.error_message}")
return [TextContent(
type="text",
text=f"Error occurred during image generation: {response.error_message}"
)]
logger.info(f"Generated {len(response.images_data)} images successfully")
# Create preview images (512x512 JPEG base64)
preview_images_b64 = []
for i, image_data in enumerate(response.images_data):
# Log original image info
image_info = get_image_info(image_data)
if image_info:
logger.info(f"Original image {i+1}: {image_info}")
# Create preview
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
if preview_b64:
preview_images_b64.append(preview_b64)
logger.info(f"Created preview image {i+1}: {len(preview_b64)} chars base64 JPEG")
else:
logger.warning(f"Failed to create preview for image {i+1}")
# Save files if requested
saved_files = []
if save_to_file:
logger.info("Saving files to disk...")
generation_params = {
"prompt": request.prompt,
"negative_prompt": request.negative_prompt,
"number_of_images": request.number_of_images,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"model": request.model,
"guidance_scale": 7.5,
"safety_filter_level": "block_only_high",
"person_generation": "allow_all",
"add_watermark": False
}
saved_files = save_generated_images(
images_data=response.images_data,
save_directory=self.config.output_path,
seed=request.seed,
generation_params=generation_params
)
logger.info(f"Files saved: {saved_files}")
# Verify files were created
import os
for file_path in saved_files:
if os.path.exists(file_path):
size = os.path.getsize(file_path)
logger.info(f" ✓ Verified: {file_path} ({size} bytes)")
else:
logger.error(f" ❌ Missing: {file_path}")
# Create enhanced result with preview images
result = ImageGenerationResult(
success=True,
message=f"✅ Images have been successfully generated!",
original_images_count=len(response.images_data),
preview_images_b64=preview_images_b64,
saved_files=saved_files,
generation_params={
"prompt": request.prompt,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"number_of_images": request.number_of_images,
"negative_prompt": request.negative_prompt,
"model": request.model
}
)
logger.info(f"Returning response with {len(preview_images_b64)} preview images")
return [TextContent(
type="text",
text=result.to_text_content()
)]
except Exception as e:
logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True)
return [TextContent(
type="text",
text=f"Error occurred during image generation: {str(e)}"
)]

View File

@@ -0,0 +1,63 @@
"""
Enhanced MCP Server implementation for Imagen4 with Preview Image Support
"""
import logging
from typing import Dict, Any, List
from mcp.server import Server
from mcp.types import Tool, TextContent
from src.connector import Config
from src.server.models import MCPToolDefinitions
from src.server.enhanced_handlers import EnhancedToolHandlers, sanitize_args_for_logging
logger = logging.getLogger(__name__)
class EnhancedImagen4MCPServer:
"""Enhanced Imagen4 MCP server class with preview image support"""
def __init__(self, config: Config):
"""Initialize server"""
self.config = config
self.server = Server("imagen4-enhanced-mcp-server")
self.handlers = EnhancedToolHandlers(config)
# Register handlers
self._register_handlers()
def _register_handlers(self) -> None:
"""Register MCP handlers"""
@self.server.list_tools()
async def handle_list_tools() -> List[Tool]:
"""Return list of available tools"""
logger.info("Listing available tools (enhanced version)")
return MCPToolDefinitions.get_all_tools()
@self.server.call_tool()
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
"""Handle tool calls with enhanced preview image support"""
# Log tool call safely without exposing sensitive data
safe_args = sanitize_args_for_logging(arguments)
logger.info(f"Enhanced tool called: {name} with arguments: {safe_args}")
if name == "generate_random_seed":
return await self.handlers.handle_generate_random_seed(arguments)
elif name == "regenerate_from_json":
return await self.handlers.handle_regenerate_from_json(arguments)
elif name == "generate_image":
return await self.handlers.handle_generate_image(arguments)
else:
raise ValueError(f"Unknown tool: {name}")
def get_server(self) -> Server:
"""Return MCP server instance"""
return self.server
# Create a factory function for easier import
def create_enhanced_server(config: Config) -> EnhancedImagen4MCPServer:
"""Factory function to create enhanced server"""
return EnhancedImagen4MCPServer(config)

View File

@@ -0,0 +1,71 @@
"""
Enhanced MCP response models with preview image support
"""
from typing import Optional, Dict, Any
from dataclasses import dataclass
@dataclass
class ImageGenerationResult:
"""Enhanced image generation result with preview support"""
success: bool
message: str
original_images_count: int
preview_images_b64: Optional[list[str]] = None # List of base64 JPEG previews (512x512)
saved_files: Optional[list[str]] = None
generation_params: Optional[Dict[str, Any]] = None
error_message: Optional[str] = None
def to_text_content(self) -> str:
"""Convert result to text format for MCP response"""
lines = [self.message]
if self.success and self.preview_images_b64:
lines.append(f"\n🖼️ Preview Images Generated: {len(self.preview_images_b64)} images (512x512 JPEG)")
for i, preview_b64 in enumerate(self.preview_images_b64):
lines.append(f"Preview {i+1} (base64 JPEG): {preview_b64[:50]}...({len(preview_b64)} chars)")
if self.saved_files:
lines.append(f"\n📁 Files saved:")
for filepath in self.saved_files:
lines.append(f" - {filepath}")
if self.generation_params:
lines.append(f"\n⚙️ Generation Parameters:")
for key, value in self.generation_params.items():
if key == 'prompt' and len(str(value)) > 100:
lines.append(f" - {key}: {str(value)[:100]}...")
else:
lines.append(f" - {key}: {value}")
return "\n".join(lines)
@dataclass
class PreviewImageResponse:
"""Response containing preview images in base64 format"""
preview_images_b64: list[str] # Base64 encoded JPEG images (512x512)
original_count: int
message: str
@classmethod
def from_image_data(cls, images_data: list[bytes], message: str = "") -> 'PreviewImageResponse':
"""Create response from original PNG image data"""
from src.utils.image_utils import create_preview_image_b64
preview_images = []
for i, image_data in enumerate(images_data):
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
if preview_b64:
preview_images.append(preview_b64)
else:
# Fallback - use original if preview creation fails
import base64
preview_images.append(base64.b64encode(image_data).decode('utf-8'))
return cls(
preview_images_b64=preview_images,
original_count=len(images_data),
message=message or f"Generated {len(preview_images)} preview images"
)

308
src/server/handlers.py Normal file
View File

@@ -0,0 +1,308 @@
"""
Tool Handlers for MCP Server
"""
import asyncio
import base64
import json
import random
import logging
from typing import List, Dict, Any
from mcp.types import TextContent, ImageContent
from src.connector import Imagen4Client, Config
from src.connector.imagen4_client import ImageGenerationRequest
from src.connector.utils import save_generated_images
logger = logging.getLogger(__name__)
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Remove or truncate sensitive data from arguments for safe logging"""
safe_args = {}
for key, value in arguments.items():
if isinstance(value, str):
# Check if it's likely base64 image data
if (key in ['data', 'image_data', 'base64'] or
(len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))):
# Truncate long image data
safe_args[key] = f"<image_data:{len(value)} chars>"
elif len(value) > 1000:
# Truncate any very long strings
safe_args[key] = f"{value[:100]}...<truncated:{len(value)} total chars>"
else:
safe_args[key] = value
else:
safe_args[key] = value
return safe_args
class ToolHandlers:
"""MCP tool handler class"""
def __init__(self, config: Config):
"""Initialize handler"""
self.config = config
self.client = Imagen4Client(config)
async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Random seed generation handler"""
try:
random_seed = random.randint(0, 2**32 - 1)
return [TextContent(
type="text",
text=f"Generated random seed: {random_seed}"
)]
except Exception as e:
logger.error(f"Random seed generation error: {str(e)}")
return [TextContent(
type="text",
text=f"Error occurred during random seed generation: {str(e)}"
)]
async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
"""Image regeneration from JSON file handler"""
try:
json_file_path = arguments.get("json_file_path")
save_to_file = arguments.get("save_to_file", True)
if not json_file_path:
return [TextContent(
type="text",
text="Error: JSON file path is required."
)]
# Load parameters from JSON file
try:
with open(json_file_path, 'r', encoding='utf-8') as f:
params = json.load(f)
except FileNotFoundError:
return [TextContent(
type="text",
text=f"Error: Cannot find JSON file: {json_file_path}"
)]
except json.JSONDecodeError as e:
return [TextContent(
type="text",
text=f"Error: JSON parsing error: {str(e)}"
)]
# Check required parameters
required_params = ['prompt', 'seed']
missing_params = [p for p in required_params if p not in params]
if missing_params:
return [TextContent(
type="text",
text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}"
)]
# Create image generation request object
request = ImageGenerationRequest(
prompt=params.get('prompt'),
negative_prompt=params.get('negative_prompt', ''),
number_of_images=params.get('number_of_images', 1),
seed=params.get('seed'),
aspect_ratio=params.get('aspect_ratio', '1:1')
)
logger.info(f"Loaded parameters from JSON: {json_file_path}")
# Generate image
response = await self.client.generate_image(request)
if not response.success:
return [TextContent(
type="text",
text=f"Error occurred during image regeneration: {response.error_message}"
)]
# Save files (optional)
saved_files = []
if save_to_file:
regeneration_params = {
"prompt": request.prompt,
"negative_prompt": request.negative_prompt,
"number_of_images": request.number_of_images,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"regenerated_from": json_file_path,
"original_generated_at": params.get('generated_at', 'unknown')
}
saved_files = save_generated_images(
images_data=response.images_data,
save_directory=self.config.output_path,
seed=request.seed,
generation_params=regeneration_params,
filename_prefix="imagen4_regen"
)
# Generate result message
message_parts = [
f"Images have been successfully regenerated.",
f"JSON file: {json_file_path}",
f"Size: 2048x2048 PNG",
f"Count: {len(response.images_data)} images",
f"Seed: {request.seed}",
f"Prompt: {request.prompt}"
]
if saved_files:
message_parts.append(f"\nRegenerated files saved successfully:")
for filepath in saved_files:
message_parts.append(f"- {filepath}")
# Generate result
result = [
TextContent(
type="text",
text="\n".join(message_parts)
)
]
# Add Base64 encoded images
for i, image_data in enumerate(response.images_data):
image_base64 = base64.b64encode(image_data).decode('utf-8')
result.append(
ImageContent(
type="image",
data=image_base64,
mimeType="image/png"
)
)
logger.info(f"Added image {i+1} to response: {len(image_base64)} chars base64 data")
return result
except Exception as e:
logger.error(f"Error occurred during image regeneration: {str(e)}")
return [TextContent(
type="text",
text=f"Error occurred during image regeneration: {str(e)}"
)]
async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
"""Image generation handler with synchronous processing"""
try:
# Log arguments safely without exposing image data
safe_args = sanitize_args_for_logging(arguments)
logger.info(f"handle_generate_image called with arguments: {safe_args}")
# Extract and validate arguments
prompt = arguments.get("prompt")
if not prompt:
logger.error("No prompt provided")
return [TextContent(
type="text",
text="Error: Prompt is required."
)]
seed = arguments.get("seed")
if seed is None:
logger.error("No seed provided")
return [TextContent(
type="text",
text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed."
)]
# Create image generation request object
request = ImageGenerationRequest(
prompt=prompt,
negative_prompt=arguments.get("negative_prompt", ""),
number_of_images=arguments.get("number_of_images", 1),
seed=seed,
aspect_ratio=arguments.get("aspect_ratio", "1:1")
)
save_to_file = arguments.get("save_to_file", True)
logger.info(f"Starting SYNCHRONOUS image generation: '{prompt[:50]}...', Seed: {seed}")
# Generate image synchronously with longer timeout
try:
logger.info("Calling client.generate_image() synchronously...")
response = await asyncio.wait_for(
self.client.generate_image(request),
timeout=360.0 # 6 minute timeout
)
logger.info(f"Image generation completed. Success: {response.success}")
except asyncio.TimeoutError:
logger.error("Image generation timed out after 6 minutes")
return [TextContent(
type="text",
text="Error: Image generation timed out after 6 minutes. Please try again."
)]
if not response.success:
logger.error(f"Image generation failed: {response.error_message}")
return [TextContent(
type="text",
text=f"Error occurred during image generation: {response.error_message}"
)]
logger.info(f"Generated {len(response.images_data)} images successfully")
# Save files if requested
saved_files = []
if save_to_file:
logger.info("Saving files to disk...")
generation_params = {
"prompt": request.prompt,
"negative_prompt": request.negative_prompt,
"number_of_images": request.number_of_images,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"guidance_scale": 7.5,
"safety_filter_level": "block_only_high",
"person_generation": "allow_all",
"add_watermark": False
}
saved_files = save_generated_images(
images_data=response.images_data,
save_directory=self.config.output_path,
seed=request.seed,
generation_params=generation_params
)
logger.info(f"Files saved: {saved_files}")
# Verify files were created
import os
for file_path in saved_files:
if os.path.exists(file_path):
size = os.path.getsize(file_path)
logger.info(f" ✓ Verified: {file_path} ({size} bytes)")
else:
logger.error(f" ❌ Missing: {file_path}")
# Generate result message
message_parts = [
f"✅ Images have been successfully generated!",
f"Prompt: {request.prompt}",
f"Seed: {request.seed}",
f"Size: 2048x2048 PNG",
f"Count: {len(response.images_data)} images"
]
if saved_files:
message_parts.append(f"\n📁 Files saved successfully:")
for filepath in saved_files:
message_parts.append(f" - {filepath}")
else:
message_parts.append(f"\n Images generated but not saved to file. Use save_to_file=true to save.")
logger.info("Returning synchronous response")
return [TextContent(
type="text",
text="\n".join(message_parts)
)]
except Exception as e:
logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True)
return [TextContent(
type="text",
text=f"Error occurred during image generation: {str(e)}"
)]

57
src/server/mcp_server.py Normal file
View File

@@ -0,0 +1,57 @@
"""
MCP Server implementation for Imagen4
"""
import logging
from typing import Dict, Any, List
from mcp.server import Server
from mcp.types import Tool, TextContent, ImageContent
from src.connector import Config
from src.server.models import MCPToolDefinitions
from src.server.handlers import ToolHandlers, sanitize_args_for_logging
logger = logging.getLogger(__name__)
class Imagen4MCPServer:
"""Imagen4 MCP server class"""
def __init__(self, config: Config):
"""Initialize server"""
self.config = config
self.server = Server("imagen4-mcp-server")
self.handlers = ToolHandlers(config)
# Register handlers
self._register_handlers()
def _register_handlers(self) -> None:
"""Register MCP handlers"""
@self.server.list_tools()
async def handle_list_tools() -> List[Tool]:
"""Return list of available tools"""
logger.info("Listing available tools")
return MCPToolDefinitions.get_all_tools()
@self.server.call_tool()
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
"""Handle tool calls"""
# Log tool call safely without exposing sensitive data
safe_args = sanitize_args_for_logging(arguments)
logger.info(f"Tool called: {name} with arguments: {safe_args}")
if name == "generate_random_seed":
return await self.handlers.handle_generate_random_seed(arguments)
elif name == "regenerate_from_json":
return await self.handlers.handle_regenerate_from_json(arguments)
elif name == "generate_image":
return await self.handlers.handle_generate_image(arguments)
else:
raise ValueError(f"Unknown tool: {name}")
def get_server(self) -> Server:
"""Return MCP server instance"""
return self.server

View File

@@ -0,0 +1,39 @@
"""
Minimal Task Result Handler for Emergency Use
"""
import logging
from typing import List, Dict, Any
from mcp.types import TextContent
logger = logging.getLogger(__name__)
async def minimal_get_task_result(task_manager, task_id: str) -> List[TextContent]:
"""Absolutely minimal task result handler"""
try:
logger.info(f"Minimal handler for task: {task_id}")
# Just return the task status for now
status = task_manager.get_task_status(task_id)
if status is None:
return [TextContent(
type="text",
text=f"❌ Task '{task_id}' not found."
)]
return [TextContent(
type="text",
text=f"📋 Task '{task_id}' Status: {status.value}\n"
f"Note: This is a minimal emergency handler.\n"
f"If you see this message, the original handler has a serious issue."
)]
except Exception as e:
logger.error(f"Even minimal handler failed: {str(e)}", exc_info=True)
return [TextContent(
type="text",
text=f"Complete failure: {str(e)}"
)]

132
src/server/models.py Normal file
View File

@@ -0,0 +1,132 @@
"""
MCP Tool Models and Definitions
"""
from mcp.types import Tool
class MCPToolDefinitions:
"""MCP tool definition class"""
@staticmethod
def get_generate_image_tool() -> Tool:
"""Image generation tool definition"""
return Tool(
name="generate_image",
description="Generate 2048x2048 PNG images from text prompts using Google Imagen 4 AI.",
inputSchema={
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Text prompt for image generation (Korean or English). Maximum 480 tokens allowed.",
"maxLength": 2000 # Approximate character limit for safety
},
"negative_prompt": {
"type": "string",
"description": "Negative prompt specifying elements not to generate (optional). Maximum 240 tokens allowed.",
"default": "",
"maxLength": 1000 # Approximate character limit for safety
},
"number_of_images": {
"type": "integer",
"description": "Number of images to generate (1 or 2 only)",
"enum": [1, 2],
"default": 1
},
"seed": {
"type": "integer",
"description": "Seed value for reproducible results (required, 0 ~ 4294967295 range)",
"minimum": 0,
"maximum": 4294967295
},
"aspect_ratio": {
"type": "string",
"description": "Image aspect ratio",
"enum": ["1:1", "9:16", "16:9", "3:4", "4:3"],
"default": "1:1"
},
"save_to_file": {
"type": "boolean",
"description": "Whether to save generated images to files (default: true)",
"default": True
},
"model": {
"type": "string",
"description": "Imagen 4 model to use for generation",
"enum": ["imagen-4.0-generate-001", "imagen-4.0-ultra-generate-001"],
"default": "imagen-4.0-generate-001"
}
},
"required": ["prompt", "seed"]
}
)
@staticmethod
def get_regenerate_from_json_tool() -> Tool:
"""Regenerate from JSON tool definition"""
return Tool(
name="regenerate_from_json",
description="Read parameters from JSON file and regenerate images with the same settings.",
inputSchema={
"type": "object",
"properties": {
"json_file_path": {
"type": "string",
"description": "Path to JSON file containing saved parameters"
},
"save_to_file": {
"type": "boolean",
"description": "Whether to save regenerated images to files (default: true)",
"default": True
}
},
"required": ["json_file_path"]
}
)
@staticmethod
def get_generate_random_seed_tool() -> Tool:
"""Random seed generation tool definition"""
return Tool(
name="generate_random_seed",
description="Generate random seed value for image generation.",
inputSchema={
"type": "object",
"properties": {},
"required": []
}
)
@staticmethod
def get_validate_prompt_tool() -> Tool:
"""Prompt validation tool definition"""
return Tool(
name="validate_prompt",
description="Validate prompt length and estimate token count before image generation.",
inputSchema={
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Text prompt to validate"
},
"negative_prompt": {
"type": "string",
"description": "Negative prompt to validate (optional)",
"default": ""
}
},
"required": ["prompt"]
}
)
@classmethod
def get_all_tools(cls) -> list[Tool]:
"""Return all tool definitions"""
return [
cls.get_generate_image_tool(),
cls.get_regenerate_from_json_tool(),
cls.get_generate_random_seed_tool(),
cls.get_validate_prompt_tool()
]

View File

@@ -0,0 +1,80 @@
"""
Safe Get Task Result Handler - Debug Version
"""
import logging
from typing import List, Dict, Any
from mcp.types import TextContent
logger = logging.getLogger(__name__)
async def safe_get_task_result(task_manager, task_id: str) -> List[TextContent]:
"""Minimal safe version of get_task_result without image processing"""
try:
logger.info(f"Safe get_task_result called for: {task_id}")
# Step 1: Try to get raw result
try:
raw_result = task_manager.worker_pool.get_task_result(task_id)
logger.info(f"Raw result type: {type(raw_result)}")
except Exception as e:
logger.error(f"Failed to get raw result: {str(e)}")
return [TextContent(
type="text",
text=f"Error accessing task data: {str(e)}"
)]
if not raw_result:
return [TextContent(
type="text",
text=f"❌ Task '{task_id}' not found."
)]
# Step 2: Check status safely
try:
status = raw_result.status
logger.info(f"Task status: {status}")
except Exception as e:
logger.error(f"Failed to get status: {str(e)}")
return [TextContent(
type="text",
text=f"Error reading task status: {str(e)}"
)]
# Step 3: Return basic info without touching result.result
try:
duration = raw_result.duration if hasattr(raw_result, 'duration') else None
created = raw_result.created_at.isoformat() if hasattr(raw_result, 'created_at') and raw_result.created_at else "unknown"
message = f"✅ Task '{task_id}' Status Report:\n"
message += f"Status: {status.value}\n"
message += f"Created: {created}\n"
if duration:
message += f"Duration: {duration:.2f}s\n"
if hasattr(raw_result, 'error') and raw_result.error:
message += f"Error: {raw_result.error}\n"
message += "\nNote: This is a safe diagnostic version."
return [TextContent(
type="text",
text=message
)]
except Exception as e:
logger.error(f"Failed to format response: {str(e)}")
return [TextContent(
type="text",
text=f"Task exists but data formatting failed: {str(e)}"
)]
except Exception as e:
logger.error(f"Critical error in safe_get_task_result: {str(e)}", exc_info=True)
return [TextContent(
type="text",
text=f"Critical error: {str(e)}"
)]

1
src/utils/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Utils package

107
src/utils/image_utils.py Normal file
View File

@@ -0,0 +1,107 @@
"""
Image processing utilities for preview generation
"""
import base64
import io
import logging
from typing import Optional
from PIL import Image
logger = logging.getLogger(__name__)
def create_preview_image_b64(image_data: bytes, target_size: int = 512, quality: int = 85) -> Optional[str]:
"""
Convert PNG image data to JPEG preview with specified size and return as base64
Args:
image_data: Original PNG image data in bytes
target_size: Target size for the preview (default: 512x512)
quality: JPEG quality (1-100, default: 85)
Returns:
Base64 encoded JPEG image string, or None if conversion fails
"""
try:
# Open image from bytes
with Image.open(io.BytesIO(image_data)) as img:
# Convert to RGB if necessary (PNG might have alpha channel)
if img.mode in ('RGBA', 'LA', 'P'):
# Create white background for transparent images
background = Image.new('RGB', img.size, (255, 255, 255))
if img.mode == 'P':
img = img.convert('RGBA')
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
img = background
elif img.mode != 'RGB':
img = img.convert('RGB')
# Resize to target size maintaining aspect ratio
img.thumbnail((target_size, target_size), Image.Resampling.LANCZOS)
# If image is smaller than target size, pad it to exact size
if img.size != (target_size, target_size):
# Create new image with target size and white background
new_img = Image.new('RGB', (target_size, target_size), (255, 255, 255))
# Center the resized image
x = (target_size - img.size[0]) // 2
y = (target_size - img.size[1]) // 2
new_img.paste(img, (x, y))
img = new_img
# Convert to JPEG and encode as base64
output_buffer = io.BytesIO()
img.save(output_buffer, format='JPEG', quality=quality, optimize=True)
jpeg_data = output_buffer.getvalue()
# Encode to base64
base64_data = base64.b64encode(jpeg_data).decode('utf-8')
logger.info(f"Preview image created: {target_size}x{target_size} JPEG, {len(jpeg_data)} bytes, base64 length: {len(base64_data)}")
return base64_data
except Exception as e:
logger.error(f"Failed to create preview image: {str(e)}")
return None
def validate_image_data(image_data: bytes) -> bool:
"""
Validate if image data is a valid image
Args:
image_data: Image data in bytes
Returns:
True if valid image, False otherwise
"""
try:
with Image.open(io.BytesIO(image_data)) as img:
img.verify() # Verify image integrity
return True
except Exception:
return False
def get_image_info(image_data: bytes) -> Optional[dict]:
"""
Get image information
Args:
image_data: Image data in bytes
Returns:
Dictionary with image info (format, size, mode) or None if invalid
"""
try:
with Image.open(io.BytesIO(image_data)) as img:
return {
'format': img.format,
'size': img.size,
'mode': img.mode,
'bytes': len(image_data)
}
except Exception as e:
logger.error(f"Failed to get image info: {str(e)}")
return None

200
src/utils/token_utils.py Normal file
View File

@@ -0,0 +1,200 @@
"""
Token counting utilities for Imagen4 prompts
"""
import re
import logging
from typing import Optional
logger = logging.getLogger(__name__)
# 기본 토큰 추정 상수
AVERAGE_CHARS_PER_TOKEN = 4 # 영어 기준 평균값
KOREAN_CHARS_PER_TOKEN = 2 # 한글 기준 평균값
MAX_PROMPT_TOKENS = 480 # 최대 프롬프트 토큰 수
def estimate_token_count(text: str) -> int:
"""
텍스트의 토큰 수를 추정합니다.
정확한 토큰 계산을 위해서는 실제 토크나이저가 필요하지만,
API 호출 전 빠른 검증을 위해 추정값을 사용합니다.
Args:
text: 토큰 수를 계산할 텍스트
Returns:
int: 추정된 토큰 수
"""
if not text:
return 0
# 텍스트 정리
text = text.strip()
if not text:
return 0
# 한글과 영어 문자 분리
korean_chars = len(re.findall(r'[가-힣]', text))
english_chars = len(re.findall(r'[a-zA-Z]', text))
other_chars = len(text) - korean_chars - english_chars
# 토큰 수 추정
korean_tokens = korean_chars / KOREAN_CHARS_PER_TOKEN
english_tokens = english_chars / AVERAGE_CHARS_PER_TOKEN
other_tokens = other_chars / AVERAGE_CHARS_PER_TOKEN
estimated_tokens = int(korean_tokens + english_tokens + other_tokens)
# 최소 1토큰은 보장
return max(1, estimated_tokens)
def validate_prompt_length(prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> tuple[bool, int, Optional[str]]:
"""
프롬프트 길이를 검증합니다.
Args:
prompt: 검증할 프롬프트
max_tokens: 최대 허용 토큰 수
Returns:
tuple: (유효성, 토큰 수, 오류 메시지)
"""
if not prompt:
return False, 0, "프롬프트가 비어있습니다."
token_count = estimate_token_count(prompt)
if token_count > max_tokens:
error_msg = (
f"프롬프트가 너무 깁니다. "
f"현재: {token_count}토큰, 최대: {max_tokens}토큰. "
f"프롬프트를 {token_count - max_tokens}토큰 줄여주세요."
)
return False, token_count, error_msg
return True, token_count, None
def truncate_prompt(prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> str:
"""
프롬프트를 지정된 토큰 수로 자릅니다.
Args:
prompt: 자를 프롬프트
max_tokens: 최대 토큰 수
Returns:
str: 잘린 프롬프트
"""
if not prompt:
return ""
current_tokens = estimate_token_count(prompt)
if current_tokens <= max_tokens:
return prompt
# 대략적인 비율로 텍스트 자르기
ratio = max_tokens / current_tokens
target_length = int(len(prompt) * ratio * 0.9) # 여유분 10%
truncated = prompt[:target_length]
# 단어/문장 경계에서 자르기
if len(truncated) < len(prompt):
# 마지막 완전한 단어까지만 유지
last_space = truncated.rfind(' ')
last_korean = truncated.rfind('') # 한글 어미
last_punct = max(truncated.rfind('.'), truncated.rfind(','), truncated.rfind('!'))
cut_point = max(last_space, last_korean, last_punct)
if cut_point > target_length * 0.8: # 너무 많이 잘리지 않도록
truncated = truncated[:cut_point]
# 최종 검증
final_tokens = estimate_token_count(truncated)
if final_tokens > max_tokens:
# 강제로 문자 단위로 자르기
chars_per_token = len(truncated) / final_tokens
target_chars = int(max_tokens * chars_per_token * 0.95)
truncated = truncated[:target_chars]
return truncated.strip()
def get_prompt_stats(prompt: str) -> dict:
"""
프롬프트 통계 정보를 반환합니다.
Args:
prompt: 분석할 프롬프트
Returns:
dict: 프롬프트 통계
"""
if not prompt:
return {
"character_count": 0,
"estimated_tokens": 0,
"korean_chars": 0,
"english_chars": 0,
"other_chars": 0,
"is_valid": False,
"remaining_tokens": MAX_PROMPT_TOKENS
}
char_count = len(prompt)
korean_chars = len(re.findall(r'[가-힣]', prompt))
english_chars = len(re.findall(r'[a-zA-Z]', prompt))
other_chars = char_count - korean_chars - english_chars
estimated_tokens = estimate_token_count(prompt)
is_valid = estimated_tokens <= MAX_PROMPT_TOKENS
remaining_tokens = MAX_PROMPT_TOKENS - estimated_tokens
return {
"character_count": char_count,
"estimated_tokens": estimated_tokens,
"korean_chars": korean_chars,
"english_chars": english_chars,
"other_chars": other_chars,
"is_valid": is_valid,
"remaining_tokens": remaining_tokens,
"max_tokens": MAX_PROMPT_TOKENS
}
# 실제 토크나이저 사용 시 대체할 수 있는 인터페이스
class TokenCounter:
"""토큰 카운터 인터페이스"""
def __init__(self, tokenizer_name: Optional[str] = None):
"""
토큰 카운터 초기화
Args:
tokenizer_name: 사용할 토크나이저 이름 (향후 확장용)
"""
self.tokenizer_name = tokenizer_name or "estimate"
logger.info(f"토큰 카운터 초기화: {self.tokenizer_name}")
def count_tokens(self, text: str) -> int:
"""텍스트의 토큰 수 계산"""
if self.tokenizer_name == "estimate":
return estimate_token_count(text)
else:
# 향후 실제 토크나이저 구현
raise NotImplementedError(f"토크나이저 '{self.tokenizer_name}'는 아직 구현되지 않았습니다.")
def validate_prompt(self, prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> tuple[bool, int, Optional[str]]:
"""프롬프트 검증"""
return validate_prompt_length(prompt, max_tokens)
def truncate_prompt(self, prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> str:
"""프롬프트 자르기"""
return truncate_prompt(prompt, max_tokens)
# 전역 토큰 카운터 인스턴스
default_token_counter = TokenCounter()

108
test_main.py Normal file
View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python3
"""
Simple test for the consolidated main.py
"""
import sys
import os
import io
from PIL import Image
import base64
# Add current directory to PYTHONPATH
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
def create_test_image():
"""Create a simple test image"""
img = Image.new('RGB', (2048, 2048), color='lightblue')
from PIL import ImageDraw
draw = ImageDraw.Draw(img)
draw.rectangle([500, 500, 1500, 1500], fill='red')
buffer = io.BytesIO()
img.save(buffer, format='PNG')
return buffer.getvalue()
def test_preview_function():
"""Test the preview image creation"""
# Import the function from main.py
from main import create_preview_image_b64, get_image_info
# Create test image
test_png = create_test_image()
print(f"Created test PNG: {len(test_png)} bytes")
# Get image info
info = get_image_info(test_png)
print(f"Image info: {info}")
# Create preview
preview_b64 = create_preview_image_b64(test_png, target_size=512, quality=85)
if preview_b64:
print(f"✅ Preview created: {len(preview_b64)} chars")
# Save preview to verify
preview_bytes = base64.b64decode(preview_b64)
with open('test_preview.jpg', 'wb') as f:
f.write(preview_bytes)
ratio = (len(preview_bytes) / len(test_png)) * 100
print(f"Compression: {ratio:.1f}% (original: {len(test_png)}, preview: {len(preview_bytes)})")
print("Preview saved as test_preview.jpg")
return True
else:
print("❌ Failed to create preview")
return False
def test_imports():
"""Test all necessary imports"""
try:
from main import (
Imagen4MCPServer,
Imagen4ToolHandlers,
ImageGenerationResult,
get_tools
)
print("✅ All classes imported successfully")
# Test tool definitions
tools = get_tools()
print(f"{len(tools)} tools defined")
return True
except Exception as e:
print(f"❌ Import error: {e}")
return False
def main():
"""Main test function"""
print("=== Testing Consolidated main.py ===\n")
success = True
# Test imports
if not test_imports():
success = False
print()
# Test preview function
if not test_preview_function():
success = False
print(f"\n=== Test Results ===")
if success:
print("✅ All tests passed!")
print("\nTo run the server:")
print("1. Make sure .env is configured")
print("2. Run: python main.py")
print("3. Or use: run.bat")
else:
print("❌ Some tests failed")
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())

178
tests/test_connector.py Normal file
View File

@@ -0,0 +1,178 @@
"""
Tests for Imagen4 Connector
"""
import pytest
import asyncio
from unittest.mock import Mock, patch
from src.connector import Config, Imagen4Client
from src.connector.imagen4_client import ImageGenerationRequest
class TestConfig:
"""Config 클래스 테스트"""
def test_config_creation(self):
"""설정 생성 테스트"""
config = Config(
project_id="test-project",
location="us-central1",
output_path="./test_images"
)
assert config.project_id == "test-project"
assert config.location == "us-central1"
assert config.output_path == "./test_images"
def test_config_validation(self):
"""설정 유효성 검사 테스트"""
# 유효한 설정
valid_config = Config(project_id="test-project")
assert valid_config.validate() is True
# 무효한 설정
invalid_config = Config(project_id="")
assert invalid_config.validate() is False
@patch.dict('os.environ', {
'GOOGLE_CLOUD_PROJECT_ID': 'test-project',
'GOOGLE_CLOUD_LOCATION': 'us-west1',
'GENERATED_IMAGES_PATH': './custom_path'
})
def test_config_from_env(self):
"""환경 변수에서 설정 로드 테스트"""
config = Config.from_env()
assert config.project_id == "test-project"
assert config.location == "us-west1"
assert config.output_path == "./custom_path"
@patch.dict('os.environ', {}, clear=True)
def test_config_from_env_missing_project_id(self):
"""필수 환경 변수 누락 테스트"""
with pytest.raises(ValueError, match="GOOGLE_CLOUD_PROJECT_ID environment variable is required"):
Config.from_env()
class TestImageGenerationRequest:
"""ImageGenerationRequest 테스트"""
def test_valid_request(self):
"""유효한 요청 테스트"""
request = ImageGenerationRequest(
prompt="A beautiful landscape",
seed=12345
)
# 유효성 검사가 예외 없이 완료되어야 함
request.validate()
def test_invalid_prompt(self):
"""무효한 프롬프트 테스트"""
request = ImageGenerationRequest(
prompt="",
seed=12345
)
with pytest.raises(ValueError, match="Prompt is required"):
request.validate()
def test_invalid_seed(self):
"""무효한 시드값 테스트"""
request = ImageGenerationRequest(
prompt="Test prompt",
seed=-1
)
with pytest.raises(ValueError, match="Seed value must be an integer between 0 and 4294967295"):
request.validate()
def test_invalid_number_of_images(self):
"""무효한 이미지 개수 테스트"""
request = ImageGenerationRequest(
prompt="Test prompt",
seed=12345,
number_of_images=3
)
with pytest.raises(ValueError, match="Number of images must be 1 or 2 only"):
request.validate()
class TestImagen4Client:
"""Imagen4Client 테스트"""
def test_client_initialization(self):
"""클라이언트 초기화 테스트"""
config = Config(project_id="test-project")
with patch('src.connector.imagen4_client.genai.Client'):
client = Imagen4Client(config)
assert client.config == config
def test_client_invalid_config(self):
"""무효한 설정으로 클라이언트 초기화 테스트"""
invalid_config = Config(project_id="")
with pytest.raises(ValueError, match="Invalid configuration"):
Imagen4Client(invalid_config)
@pytest.mark.asyncio
async def test_generate_image_success(self):
"""이미지 생성 성공 테스트"""
config = Config(project_id="test-project")
# Create mock response
mock_image_data = b"fake_image_data"
mock_response = Mock()
mock_response.generated_images = [Mock()]
mock_response.generated_images[0].image = Mock()
mock_response.generated_images[0].image.image_bytes = mock_image_data
with patch('src.connector.imagen4_client.genai.Client') as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
mock_client.models.generate_images = Mock(return_value=mock_response)
client = Imagen4Client(config)
request = ImageGenerationRequest(
prompt="Test prompt",
seed=12345
)
with patch('asyncio.to_thread', return_value=mock_response):
response = await client.generate_image(request)
assert response.success is True
assert len(response.images_data) == 1
assert response.images_data[0] == mock_image_data
@pytest.mark.asyncio
async def test_generate_image_failure(self):
"""Image generation failure test"""
config = Config(project_id="test-project")
with patch('src.connector.imagen4_client.genai.Client') as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
client = Imagen4Client(config)
request = ImageGenerationRequest(
prompt="Test prompt",
seed=12345
)
# Simulate API call failure
with patch('asyncio.to_thread', side_effect=Exception("API Error")):
response = await client.generate_image(request)
assert response.success is False
assert "API Error" in response.error_message
assert len(response.images_data) == 0
if __name__ == "__main__":
pytest.main([__file__])

136
tests/test_server.py Normal file
View File

@@ -0,0 +1,136 @@
"""
Tests for Imagen4 Server
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from src.connector import Config
from src.server import Imagen4MCPServer, ToolHandlers, MCPToolDefinitions
from src.server.models import MCPToolDefinitions
class TestMCPToolDefinitions:
"""MCPToolDefinitions tests"""
def test_get_generate_image_tool(self):
"""Image generation tool definition test"""
tool = MCPToolDefinitions.get_generate_image_tool()
assert tool.name == "generate_image"
assert "Google Imagen 4" in tool.description
assert "prompt" in tool.inputSchema["properties"]
assert "seed" in tool.inputSchema["required"]
def test_get_regenerate_from_json_tool(self):
"""JSON regeneration tool definition test"""
tool = MCPToolDefinitions.get_regenerate_from_json_tool()
assert tool.name == "regenerate_from_json"
assert "JSON file" in tool.description
assert "json_file_path" in tool.inputSchema["properties"]
def test_get_generate_random_seed_tool(self):
"""Random seed generation tool definition test"""
tool = MCPToolDefinitions.get_generate_random_seed_tool()
assert tool.name == "generate_random_seed"
assert "random seed" in tool.description
def test_get_all_tools(self):
"""All tools return test"""
tools = MCPToolDefinitions.get_all_tools()
assert len(tools) == 3
tool_names = [tool.name for tool in tools]
assert "generate_image" in tool_names
assert "regenerate_from_json" in tool_names
assert "generate_random_seed" in tool_names
class TestToolHandlers:
"""ToolHandlers tests"""
@pytest.fixture
def config(self):
"""Test configuration"""
return Config(project_id="test-project")
@pytest.fixture
def handlers(self, config):
"""Test handlers"""
with patch('src.server.handlers.Imagen4Client'):
return ToolHandlers(config)
@pytest.mark.asyncio
async def test_handle_generate_random_seed(self, handlers):
"""Random seed generation handler test"""
result = await handlers.handle_generate_random_seed({})
assert len(result) == 1
assert result[0].type == "text"
assert "Generated random seed:" in result[0].text
@pytest.mark.asyncio
async def test_handle_generate_image_missing_prompt(self, handlers):
"""Missing prompt test"""
arguments = {"seed": 12345}
result = await handlers.handle_generate_image(arguments)
assert len(result) == 1
assert result[0].type == "text"
assert "Prompt is required" in result[0].text
@pytest.mark.asyncio
async def test_handle_generate_image_missing_seed(self, handlers):
"""Missing seed value test"""
arguments = {"prompt": "Test prompt"}
result = await handlers.handle_generate_image(arguments)
assert len(result) == 1
assert result[0].type == "text"
assert "Seed value is required" in result[0].text
@pytest.mark.asyncio
async def test_handle_regenerate_from_json_missing_path(self, handlers):
"""Missing JSON file path test"""
arguments = {}
result = await handlers.handle_regenerate_from_json(arguments)
assert len(result) == 1
assert result[0].type == "text"
assert "JSON file path is required" in result[0].text
class TestImagen4MCPServer:
"""Imagen4MCPServer tests"""
@pytest.fixture
def config(self):
"""Test configuration"""
return Config(project_id="test-project")
def test_server_initialization(self, config):
"""Server initialization test"""
with patch('src.server.mcp_server.ToolHandlers'):
server = Imagen4MCPServer(config)
assert server.config == config
assert server.server is not None
assert server.handlers is not None
def test_get_server(self, config):
"""Server instance return test"""
with patch('src.server.mcp_server.ToolHandlers'):
mcp_server = Imagen4MCPServer(config)
server = mcp_server.get_server()
assert server is not None
assert server.name == "imagen4-mcp-server"
if __name__ == "__main__":
pytest.main([__file__])