Files
imagen4/main.py
2025-08-26 02:40:23 +09:00

975 lines
40 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Imagen4 MCP Server with Preview Image Support
Google Imagen 4를 사용한 AI 이미지 생성 MCP 서버
- 2048x2048 PNG 원본 이미지 생성
- 512x512 JPEG 미리보기 이미지 base64 제공
- 향상된 응답 형식
Run from imagen4 root directory
"""
# ==================== CRITICAL: UTF-8 SETUP MUST BE FIRST ====================
# This must be done before ANY other imports to prevent cp949 codec errors
import os
import sys
import locale
# Force UTF-8 environment variables - set immediately
os.environ['PYTHONIOENCODING'] = 'utf-8'
os.environ['PYTHONUTF8'] = '1'
os.environ['LC_ALL'] = 'C.UTF-8'
# Windows-specific UTF-8 setup
if sys.platform.startswith('win'):
# Set console code page to UTF-8
try:
os.system('chcp 65001 >nul 2>&1')
except Exception:
pass
# Force locale to UTF-8
try:
locale.setlocale(locale.LC_ALL, 'C.UTF-8')
except locale.Error:
try:
locale.setlocale(locale.LC_ALL, '')
except locale.Error:
pass
# Reconfigure stdout/stderr with UTF-8 encoding
import codecs
import io
# Method 1: Try reconfigure (Python 3.7+)
if hasattr(sys.stdout, 'reconfigure'):
try:
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
except Exception:
# Method 2: Replace streams with UTF-8 writers
try:
sys.stdout = codecs.getwriter('utf-8')(sys.stdout.detach(), errors='replace')
sys.stderr = codecs.getwriter('utf-8')(sys.stderr.detach(), errors='replace')
except Exception:
# Method 3: Create new UTF-8 streams
try:
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
except Exception:
# Final fallback: continue with existing streams
pass
else:
# For older Python versions
try:
sys.stdout = codecs.getwriter('utf-8')(sys.stdout.detach(), errors='replace')
sys.stderr = codecs.getwriter('utf-8')(sys.stderr.detach(), errors='replace')
except Exception:
try:
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
except Exception:
pass
# Verify UTF-8 setup (disabled for MCP compatibility)
# Note: UTF-8 test output disabled to prevent MCP protocol interference
try:
test_unicode = "Test UTF-8: 한글 테스트 ✓"
# print(f"[UTF8-TEST] {test_unicode}") # Disabled: interferes with MCP JSON protocol
except UnicodeEncodeError as e:
# Only log errors to stderr, not stdout
import sys
print(f"[UTF8-ERROR] Unicode test failed: {e}", file=sys.stderr)
# ==================== Regular Imports (after UTF-8 setup) ====================
import asyncio
import base64
import json
import random
import logging
import io
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
# Load environment variables
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
# Use logger for dotenv warning instead of direct print
import logging
logging.getLogger("imagen4-mcp-server").warning("python-dotenv not installed")
# Add current directory to PYTHONPATH
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
# MCP imports
try:
from mcp.server.models import InitializationOptions
from mcp.server import NotificationOptions
from mcp.server.stdio import stdio_server
from mcp.server import Server
from mcp.types import Tool, TextContent
except ImportError as e:
# Use logger for MCP import errors instead of direct print
import logging
logger = logging.getLogger("imagen4-mcp-server")
logger.error(f"Error importing MCP: {e}")
logger.error("Please install required packages: pip install mcp")
sys.exit(1)
# Connector imports
from src.connector import Config, Imagen4Client
from src.connector.imagen4_client import ImageGenerationRequest
from src.connector.utils import save_generated_images
# Image processing import
try:
from PIL import Image
except ImportError as e:
# Use logger for Pillow import errors instead of direct print
import logging
logger = logging.getLogger("imagen4-mcp-server")
logger.error(f"Error importing Pillow: {e}")
logger.error("Please install Pillow: pip install Pillow")
sys.exit(1)
# ==================== Unicode-Safe Logging Setup ====================
class SafeUnicodeHandler(logging.StreamHandler):
"""Ultra-safe Unicode stream handler that prevents all encoding issues"""
def __init__(self, stream=None):
super().__init__(stream)
self.encoding = 'utf-8'
def emit(self, record):
try:
# Format the record
msg = self.format(record)
# Ultra-safe character replacement for Windows cp949 issues
if sys.platform.startswith('win'):
# Replace ALL potentially problematic Unicode characters
emoji_replacements = {
'💭': '[THOUGHT]', '📊': '[STATS]', '⚠️': '[WARNING]', '💡': '[TIP]',
'': '[SUCCESS]', '': '[ERROR]', '🔄': '[RETRY]', '': '[WAIT]',
'🚀': '[START]', '📦': '[RESPONSE]', '💥': '[FAILED]', '🖼️': '[IMAGE]',
'📁': '[FILES]', '⚙️': '[PARAMS]', '🎨': '[GENERATE]', '📏': '[SIZE]',
'🎯': '[SETTINGS]', '⏱️': '[TIME]', '': '[TIMEOUT]', '🎉': '[COMPLETE]',
'🔍': '[DEBUG]', '🚨': '[CIRCUIT]', '📊': '[STATUS]'
}
for emoji, replacement in emoji_replacements.items():
msg = msg.replace(emoji, replacement)
# Additional safety: encode/decode to remove any remaining problematic characters
msg = msg.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
# Final safety: ASCII-safe fallback for any remaining issues
try:
msg.encode('cp949')
except UnicodeEncodeError:
# If still can't encode to cp949, make it ASCII-safe
msg = msg.encode('ascii', errors='replace').decode('ascii')
# Write the message
stream = self.stream
terminator = getattr(self, 'terminator', '\n')
# Try multiple methods to write safely
try:
if hasattr(stream, 'buffer'):
# Method 1: Write to buffer with UTF-8 encoding
stream.buffer.write((msg + terminator).encode('utf-8', errors='replace'))
stream.buffer.flush()
else:
# Method 2: Write directly to stream
stream.write(msg + terminator)
if hasattr(stream, 'flush'):
stream.flush()
except (UnicodeEncodeError, UnicodeDecodeError):
# Emergency fallback: write ASCII version only
try:
safe_msg = msg.encode('ascii', errors='replace').decode('ascii')
if hasattr(stream, 'buffer'):
stream.buffer.write((safe_msg + terminator).encode('ascii'))
stream.buffer.flush()
else:
stream.write(safe_msg + terminator)
if hasattr(stream, 'flush'):
stream.flush()
except Exception:
# Absolute last resort: just skip this log message
pass
except Exception:
# If all else fails, call the default error handler
self.handleError(record)
# Emoji-safe formatter
class SafeFormatter(logging.Formatter):
"""Formatter that ensures safe handling of all Unicode characters"""
def format(self, record):
try:
# Safely handle record arguments
if hasattr(record, 'args') and record.args:
safe_args = []
for arg in record.args:
if isinstance(arg, str):
# Ensure safe encoding for string arguments
safe_arg = arg.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
safe_args.append(safe_arg)
else:
safe_args.append(str(arg))
record.args = tuple(safe_args)
return super().format(record)
except Exception:
# Fallback: return a safe version of the log message
return f"[LOG-ERROR] Could not format log message safely: {record.levelname}"
# Set up ultra-safe logging
safe_handler = SafeUnicodeHandler(sys.stderr)
safe_handler.setFormatter(SafeFormatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
))
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
root_logger.handlers.clear() # Remove any existing handlers
root_logger.addHandler(safe_handler)
logger = logging.getLogger("imagen4-mcp-server")
# Test logging
logger.info("[INIT] Imagen4 MCP Server initializing with Unicode-safe logging")
# ==================== Image Processing Utilities ====================
def create_preview_image_b64(image_data: bytes, target_size: int = 512, quality: int = 85) -> Optional[str]:
"""
Convert PNG image data to JPEG preview with specified size and return as base64
Args:
image_data: Original PNG image data in bytes
target_size: Target size for the preview (default: 512x512)
quality: JPEG quality (1-100, default: 85)
Returns:
Base64 encoded JPEG image string, or None if conversion fails
"""
try:
# Open image from bytes
with Image.open(io.BytesIO(image_data)) as img:
# Convert to RGB if necessary (PNG might have alpha channel)
if img.mode in ('RGBA', 'LA', 'P'):
# Create white background for transparent images
background = Image.new('RGB', img.size, (255, 255, 255))
if img.mode == 'P':
img = img.convert('RGBA')
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
img = background
elif img.mode != 'RGB':
img = img.convert('RGB')
# Resize to target size maintaining aspect ratio
img.thumbnail((target_size, target_size), Image.Resampling.LANCZOS)
# If image is smaller than target size, pad it to exact size
if img.size != (target_size, target_size):
# Create new image with target size and white background
new_img = Image.new('RGB', (target_size, target_size), (255, 255, 255))
# Center the resized image
x = (target_size - img.size[0]) // 2
y = (target_size - img.size[1]) // 2
new_img.paste(img, (x, y))
img = new_img
# Convert to JPEG and encode as base64
output_buffer = io.BytesIO()
img.save(output_buffer, format='JPEG', quality=quality, optimize=True)
jpeg_data = output_buffer.getvalue()
# Encode to base64
base64_data = base64.b64encode(jpeg_data).decode('utf-8')
logger.info(f"Preview image created: {target_size}x{target_size} JPEG, {len(jpeg_data)} bytes")
return base64_data
except Exception as e:
logger.error(f"Failed to create preview image: {str(e)}")
return None
def get_image_info(image_data: bytes) -> Optional[dict]:
"""Get image information"""
try:
with Image.open(io.BytesIO(image_data)) as img:
return {
'format': img.format,
'size': img.size,
'mode': img.mode,
'bytes': len(image_data)
}
except Exception as e:
logger.error(f"Failed to get image info: {str(e)}")
return None
# ==================== Response Models ====================
@dataclass
class ImageGenerationResult:
"""Enhanced image generation result with preview support"""
success: bool
message: str
original_images_count: int
preview_images_b64: Optional[list[str]] = None # List of base64 JPEG previews (512x512)
saved_files: Optional[list[str]] = None
generation_params: Optional[Dict[str, Any]] = None
error_message: Optional[str] = None
def to_text_content(self) -> str:
"""Convert result to text format for MCP response"""
lines = [self.message]
if self.success and self.preview_images_b64:
lines.append(f"\n[IMAGE] Preview Images Generated: {len(self.preview_images_b64)} images (512x512 JPEG)")
for i, preview_b64 in enumerate(self.preview_images_b64):
lines.append(f"Preview {i+1} (base64 JPEG): {preview_b64[:50]}...({len(preview_b64)} chars)")
if self.saved_files:
lines.append(f"\n[FILES] Files saved:")
for filepath in self.saved_files:
lines.append(f" - {filepath}")
if self.generation_params:
lines.append(f"\n[PARAMS] Generation Parameters:")
for key, value in self.generation_params.items():
if key == 'prompt' and len(str(value)) > 100:
lines.append(f" - {key}: {str(value)[:100]}...")
else:
lines.append(f" - {key}: {value}")
return "\n".join(lines)
# ==================== MCP Tool Definitions ====================
def get_tools() -> List[Tool]:
"""Return all tool definitions"""
return [
Tool(
name="generate_image",
description="Generate 2048x2048 PNG images with 512x512 JPEG previews using Google Imagen 4 AI.",
inputSchema={
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Text prompt for image generation (Korean or English)"
},
"negative_prompt": {
"type": "string",
"description": "Negative prompt specifying elements not to generate (optional)",
"default": ""
},
"number_of_images": {
"type": "integer",
"description": "Number of images to generate (1 or 2 only)",
"enum": [1, 2],
"default": 1
},
"seed": {
"type": "integer",
"description": "Seed value for reproducible results (required, 0 ~ 4294967295 range)",
"minimum": 0,
"maximum": 4294967295
},
"aspect_ratio": {
"type": "string",
"description": "Image aspect ratio",
"enum": ["1:1", "9:16", "16:9", "3:4", "4:3"],
"default": "1:1"
},
"save_to_file": {
"type": "boolean",
"description": "Whether to save generated images to files (default: true)",
"default": True
}
},
"required": ["prompt", "seed"]
}
),
Tool(
name="regenerate_from_json",
description="Read parameters from JSON file and regenerate images with previews.",
inputSchema={
"type": "object",
"properties": {
"json_file_path": {
"type": "string",
"description": "Path to JSON file containing saved parameters"
},
"save_to_file": {
"type": "boolean",
"description": "Whether to save regenerated images to files (default: true)",
"default": True
}
},
"required": ["json_file_path"]
}
),
Tool(
name="generate_random_seed",
description="Generate random seed value for image generation.",
inputSchema={
"type": "object",
"properties": {},
"required": []
}
)
]
# ==================== Utility Functions ====================
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Remove or truncate sensitive data from arguments for safe logging with Unicode support"""
safe_args = {}
for key, value in arguments.items():
if isinstance(value, str):
# Check if it's likely base64 image data
if (key in ['data', 'image_data', 'base64', 'preview_image_b64'] or
(len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))):
# Truncate long image data
safe_args[key] = f"<image_data:{len(value)} chars>"
elif len(value) > 1000:
# Truncate any very long strings but preserve Unicode
truncated = value[:100]
safe_args[key] = f"{truncated}...<truncated:{len(value)} total chars>"
else:
# Keep the original string (including Korean characters)
safe_args[key] = value
else:
safe_args[key] = value
return safe_args
# ==================== Tool Handlers ====================
class Imagen4ToolHandlers:
"""MCP tool handler class with preview image support"""
def __init__(self, config: Config):
"""Initialize handler"""
self.config = config
self.client = Imagen4Client(config)
async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Random seed generation handler"""
try:
random_seed = random.randint(0, 2**32 - 1)
return [TextContent(
type="text",
text=f"Generated random seed: {random_seed}"
)]
except Exception as e:
logger.error(f"Random seed generation error: {str(e)}")
return [TextContent(
type="text",
text=f"Error occurred during random seed generation: {str(e)}"
)]
async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Image regeneration from JSON file handler with preview support"""
try:
json_file_path = arguments.get("json_file_path")
save_to_file = arguments.get("save_to_file", True)
if not json_file_path:
return [TextContent(
type="text",
text="Error: JSON file path is required."
)]
# Load parameters from JSON file with proper UTF-8 encoding
try:
with open(json_file_path, 'r', encoding='utf-8') as f:
params = json.load(f)
except FileNotFoundError:
return [TextContent(
type="text",
text=f"Error: Cannot find JSON file: {json_file_path}"
)]
except json.JSONDecodeError as e:
return [TextContent(
type="text",
text=f"Error: JSON parsing error: {str(e)}"
)]
# Check required parameters
required_params = ['prompt', 'seed']
missing_params = [p for p in required_params if p not in params]
if missing_params:
return [TextContent(
type="text",
text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}"
)]
# Create image generation request object
request = ImageGenerationRequest(
prompt=params.get('prompt'),
negative_prompt=params.get('negative_prompt', ''),
number_of_images=params.get('number_of_images', 1),
seed=params.get('seed'),
aspect_ratio=params.get('aspect_ratio', '1:1')
)
logger.info(f"Loaded parameters from JSON: {json_file_path}")
# Generate image
response = await self.client.generate_image(request)
if not response.success:
return [TextContent(
type="text",
text=f"Error occurred during image regeneration: {response.error_message}"
)]
# Create preview images
preview_images_b64 = []
for image_data in response.images_data:
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
if preview_b64:
preview_images_b64.append(preview_b64)
logger.info(f"Created preview image: {len(preview_b64)} chars base64 JPEG")
# Save files (optional)
saved_files = []
if save_to_file:
regeneration_params = {
"prompt": request.prompt,
"negative_prompt": request.negative_prompt,
"number_of_images": request.number_of_images,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"regenerated_from": json_file_path,
"original_generated_at": params.get('generated_at', 'unknown')
}
saved_files = save_generated_images(
images_data=response.images_data,
save_directory=self.config.output_path,
seed=request.seed,
generation_params=regeneration_params,
filename_prefix="imagen4_regen"
)
# Create result with preview images
result = ImageGenerationResult(
success=True,
message=f"[SUCCESS] Images have been successfully regenerated from {json_file_path}",
original_images_count=len(response.images_data),
preview_images_b64=preview_images_b64,
saved_files=saved_files,
generation_params={
"prompt": request.prompt,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"number_of_images": request.number_of_images
}
)
return [TextContent(
type="text",
text=result.to_text_content()
)]
except Exception as e:
logger.error(f"Error occurred during image regeneration: {str(e)}")
return [TextContent(
type="text",
text=f"Error occurred during image regeneration: {str(e)}"
)]
async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Image generation handler with preview image support and proper Unicode logging"""
try:
# Log arguments safely without exposing image data, but preserve Unicode
safe_args = sanitize_args_for_logging(arguments)
logger.info(f"handle_generate_image called with arguments: {safe_args}")
# Extract and validate arguments
prompt = arguments.get("prompt")
if not prompt:
logger.error("No prompt provided")
return [TextContent(
type="text",
text="Error: Prompt is required."
)]
seed = arguments.get("seed")
if seed is None:
logger.error("No seed provided")
return [TextContent(
type="text",
text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed."
)]
# Create image generation request object
request = ImageGenerationRequest(
prompt=prompt,
negative_prompt=arguments.get("negative_prompt", ""),
number_of_images=arguments.get("number_of_images", 1),
seed=seed,
aspect_ratio=arguments.get("aspect_ratio", "1:1")
)
save_to_file = arguments.get("save_to_file", True)
# Log with proper Unicode handling for Korean text
prompt_preview = prompt[:50] + "..." if len(prompt) > 50 else prompt
logger.info(f"Starting image generation: '{prompt_preview}', Seed: {seed}")
# Import error recovery utilities
import time
try:
from src.utils.error_recovery import health_monitor, circuit_breaker, default_retry_config, retry_with_backoff
error_recovery_available = True
except ImportError:
logger.warning("Error recovery utilities not available, using basic error handling")
error_recovery_available = False
# Generate image with enhanced error handling
logger.info("Starting image generation with enhanced error handling...")
if error_recovery_available:
# Check API health status
if not health_monitor.is_healthy():
stats = health_monitor.get_stats()
logger.warning(f"[WARNING] API status unstable: Success rate {stats['success_rate']:.1%}, Consecutive errors {stats['consecutive_errors']}")
# Safe API call with circuit breaker and retry
async def safe_generate():
return await circuit_breaker.call(self.client.generate_image, request)
try:
start_time = time.time()
response = await retry_with_backoff(safe_generate, default_retry_config)
execution_time = time.time() - start_time
# Record success
health_monitor.record_success(execution_time)
logger.info(f"Image generation completed successfully. Execution time: {execution_time:.1f}s")
except Exception as e:
# Record error
from src.connector.imagen4_client import classify_api_error, APIErrorType
error_type, user_message = classify_api_error(e)
health_monitor.record_error(error_type.value)
# Convert error to response object
execution_time = time.time() - start_time if 'start_time' in locals() else 0
response = type('ErrorResponse', (), {
'success': False,
'error_message': user_message,
'error_type': error_type,
'execution_time': execution_time,
'images_data': []
})()
logger.error(f"Image generation failed after retries: {user_message}")
else:
# Basic error handling fallback
logger.info("Calling client.generate_image()...")
response = await self.client.generate_image(request)
logger.info(f"Image generation completed. Success: {response.success}")
if not response.success:
error_type_str = getattr(response.error_type, 'value', 'unknown') if hasattr(response, 'error_type') and response.error_type else 'unknown'
execution_time = getattr(response, 'execution_time', 0) or 0
logger.error(f"Image generation failed: {response.error_message} (Type: {error_type_str}, Time: {execution_time:.1f}s)")
# Provide additional information based on error type
additional_info = ""
if hasattr(response, 'error_type') and response.error_type:
if error_type_str == "quota_exceeded":
additional_info = "\n[TIP] Please check your API quota or payment information."
elif error_type_str == "safety_violation":
additional_info = "\n[TIP] Try removing sensitive content from your prompt or use different expressions."
elif error_type_str == "timeout":
additional_info = f"\n[TIP] Try simplifying your prompt or reducing the number of images. (Took: {execution_time:.1f}s)"
elif error_type_str == "network":
additional_info = "\n[TIP] Please check your network connection and try again later."
elif error_type_str == "service_unavailable":
additional_info = "\n[TIP] Google service may be temporarily unstable. Please try again in a few minutes."
# Add API status information
if error_recovery_available:
try:
stats = health_monitor.get_stats()
if stats['total_requests'] > 0:
additional_info += f"\n[API STATUS] Success rate: {stats['success_rate']:.1%} ({stats['success_count']}/{stats['total_requests']})"
except:
pass
return [TextContent(
type="text",
text=f"Error occurred during image generation: {response.error_message}{additional_info}"
)]
logger.info(f"Generated {len(response.images_data)} images successfully")
# Create preview images (512x512 JPEG base64)
preview_images_b64 = []
for i, image_data in enumerate(response.images_data):
# Log original image info
image_info = get_image_info(image_data)
if image_info:
logger.info(f"Original image {i+1}: {image_info}")
# Create preview
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
if preview_b64:
preview_images_b64.append(preview_b64)
logger.info(f"Created preview image {i+1}: {len(preview_b64)} chars base64 JPEG")
else:
logger.warning(f"Failed to create preview for image {i+1}")
# Save files if requested
saved_files = []
if save_to_file:
logger.info("Saving files to disk...")
generation_params = {
"prompt": request.prompt,
"negative_prompt": request.negative_prompt,
"number_of_images": request.number_of_images,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"guidance_scale": 7.5,
"safety_filter_level": "block_only_high",
"person_generation": "allow_all",
"add_watermark": False
}
saved_files = save_generated_images(
images_data=response.images_data,
save_directory=self.config.output_path,
seed=request.seed,
generation_params=generation_params
)
logger.info(f"Files saved: {saved_files}")
# Verify files were created
for file_path in saved_files:
if os.path.exists(file_path):
size = os.path.getsize(file_path)
logger.info(f" [OK] Verified: {file_path} ({size} bytes)")
else:
logger.error(f" [MISSING] File not found: {file_path}")
# Create enhanced result with preview images
result = ImageGenerationResult(
success=True,
message=f"[SUCCESS] Images have been successfully generated! (Took {response.execution_time:.1f}s)",
original_images_count=len(response.images_data),
preview_images_b64=preview_images_b64,
saved_files=saved_files,
generation_params={
"prompt": request.prompt,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"number_of_images": request.number_of_images,
"negative_prompt": request.negative_prompt,
"execution_time": response.execution_time,
"model": request.model
}
)
logger.info(f"Returning response with {len(preview_images_b64)} preview images")
return [TextContent(
type="text",
text=result.to_text_content()
)]
except Exception as e:
logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True)
return [TextContent(
type="text",
text=f"Error occurred during image generation: {str(e)}"
)]
# ==================== MCP Server ====================
class Imagen4MCPServer:
"""Imagen4 MCP server class with preview image support"""
def __init__(self, config: Config):
"""Initialize server"""
self.config = config
self.server = Server("imagen4-mcp-server")
self.handlers = Imagen4ToolHandlers(config)
# Register handlers
self._register_handlers()
def _register_handlers(self) -> None:
"""Register MCP handlers"""
@self.server.list_tools()
async def handle_list_tools() -> List[Tool]:
"""Return list of available tools"""
try:
logger.info("Listing available tools with preview image support")
return get_tools()
except Exception as e:
logger.error(f"Error listing tools: {e}")
raise
@self.server.call_tool()
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
"""Handle tool calls with preview image support"""
try:
# Log tool call safely without exposing sensitive data
safe_args = sanitize_args_for_logging(arguments)
logger.info(f"Tool called: {name} with arguments: {safe_args}")
if name == "generate_random_seed":
return await self.handlers.handle_generate_random_seed(arguments)
elif name == "regenerate_from_json":
return await self.handlers.handle_regenerate_from_json(arguments)
elif name == "generate_image":
return await self.handlers.handle_generate_image(arguments)
else:
error_msg = f"Unknown tool: {name}"
logger.error(error_msg)
return [TextContent(
type="text",
text=f"Error: {error_msg}"
)]
except Exception as e:
logger.error(f"Error handling tool call '{name}': {e}")
import traceback
logger.error(f"Tool call traceback: {traceback.format_exc()}")
return [TextContent(
type="text",
text=f"Error occurred while processing tool '{name}': {str(e)}"
)]
def get_server(self) -> Server:
"""Return MCP server instance"""
return self.server
# ==================== Main Function ====================
async def main():
"""Main function"""
logger.info("Starting Imagen 4 MCP Server with Preview Image Support...")
try:
# Load configuration
config = Config.from_env()
logger.info(f"Configuration loaded - Project: {config.project_id}, Location: {config.location}")
# Create MCP server
mcp_server = Imagen4MCPServer(config)
server = mcp_server.get_server()
# Log client timeout information
try:
temp_client = Imagen4Client(config)
timeout_settings = temp_client.get_timeout_settings()
config_timeout_info = f"API: {timeout_settings['api_timeout']}s, Progress: {timeout_settings['progress_interval']}s"
except:
config_timeout_info = "Unknown"
logger.info("Imagen 4 MCP Server initialized successfully")
logger.info("Features: 512x512 JPEG preview images, base64 encoding, enhanced responses, Unicode support")
logger.info(f"Timeout settings: {config_timeout_info}")
# Run MCP server with better error handling
try:
async with stdio_server() as (read_stream, write_stream):
logger.info("Starting MCP server with stdio transport")
await server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="imagen4-mcp-server",
server_version="3.1.0",
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={
"preview_images": {},
"base64_jpeg_previews": {},
"unicode_logging": {}
},
)
)
)
except Exception as stdio_error:
logger.error(f"STDIO server error: {stdio_error}")
# Try to handle specific MCP protocol errors
if "TaskGroup" in str(stdio_error):
logger.error("TaskGroup error detected - this may be due to client disconnection")
raise
except ValueError as e:
logger.error(f"Configuration error: {e}")
logger.error("Please check your .env file or environment variables.")
return 1
except KeyboardInterrupt:
logger.info("Server stopped by user (Ctrl+C)")
return 0
except Exception as e:
logger.error(f"Server error: {e}")
logger.error(f"Error type: {type(e).__name__}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
raise
if __name__ == "__main__":
try:
# Set up signal handling for graceful shutdown
import signal
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, shutting down gracefully...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Run the server
exit_code = asyncio.run(main())
sys.exit(exit_code or 0)
except KeyboardInterrupt:
logger.info("Server stopped by user")
sys.exit(0)
except SystemExit as e:
logger.info(f"Server exiting with code {e.code}")
sys.exit(e.code)
except Exception as e:
logger.error(f"Fatal error: {e}")
logger.error(f"Error type: {type(e).__name__}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
sys.exit(1)