786 lines
31 KiB
Python
786 lines
31 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Imagen4 MCP Server with Preview Image Support
|
|
|
|
Google Imagen 4를 사용한 AI 이미지 생성 MCP 서버
|
|
- 2048x2048 PNG 원본 이미지 생성
|
|
- 512x512 JPEG 미리보기 이미지 base64 제공
|
|
- 향상된 응답 형식
|
|
|
|
Run from imagen4 root directory
|
|
"""
|
|
|
|
import asyncio
|
|
import base64
|
|
import json
|
|
import random
|
|
import logging
|
|
import sys
|
|
import os
|
|
import io
|
|
from typing import Dict, Any, List, Optional
|
|
from dataclasses import dataclass
|
|
|
|
# Load environment variables
|
|
try:
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
except ImportError:
|
|
print("Warning: python-dotenv not installed", file=sys.stderr)
|
|
|
|
# Add current directory to PYTHONPATH
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.insert(0, current_dir)
|
|
|
|
# MCP imports
|
|
try:
|
|
from mcp.server.models import InitializationOptions
|
|
from mcp.server import NotificationOptions
|
|
from mcp.server.stdio import stdio_server
|
|
from mcp.server import Server
|
|
from mcp.types import Tool, TextContent
|
|
except ImportError as e:
|
|
print(f"Error importing MCP: {e}", file=sys.stderr)
|
|
print("Please install required packages: pip install mcp", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# Connector imports
|
|
from src.connector import Config, Imagen4Client
|
|
from src.connector.imagen4_client import ImageGenerationRequest
|
|
from src.connector.utils import save_generated_images
|
|
|
|
# Image processing import
|
|
try:
|
|
from PIL import Image
|
|
except ImportError as e:
|
|
print(f"Error importing Pillow: {e}", file=sys.stderr)
|
|
print("Please install Pillow: pip install Pillow", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# ==================== Unicode-Safe Logging Setup ====================
|
|
|
|
class UnicodeStreamHandler(logging.StreamHandler):
|
|
"""Custom stream handler that ensures proper Unicode handling"""
|
|
|
|
def __init__(self, stream=None):
|
|
super().__init__(stream)
|
|
# Force UTF-8 encoding for Windows compatibility
|
|
if hasattr(self.stream, 'reconfigure'):
|
|
try:
|
|
self.stream.reconfigure(encoding='utf-8', errors='replace')
|
|
except:
|
|
pass
|
|
|
|
def emit(self, record):
|
|
try:
|
|
msg = self.format(record)
|
|
# Ensure the message is properly encoded as UTF-8
|
|
if isinstance(msg, str):
|
|
# For Windows, ensure proper UTF-8 handling
|
|
if sys.platform.startswith('win'):
|
|
msg = msg.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
|
|
|
|
stream = self.stream
|
|
# Write with explicit UTF-8 encoding
|
|
if hasattr(stream, 'buffer'):
|
|
stream.buffer.write((msg + self.terminator).encode('utf-8', errors='replace'))
|
|
stream.buffer.flush()
|
|
else:
|
|
stream.write(msg + self.terminator)
|
|
if hasattr(stream, 'flush'):
|
|
stream.flush()
|
|
except Exception:
|
|
self.handleError(record)
|
|
|
|
# Custom formatter for proper Unicode handling
|
|
class UnicodeFormatter(logging.Formatter):
|
|
"""Custom formatter that ensures proper Unicode handling in log messages"""
|
|
|
|
def format(self, record):
|
|
# Ensure all arguments are properly handled for Unicode
|
|
if hasattr(record, 'args') and record.args:
|
|
safe_args = []
|
|
for arg in record.args:
|
|
if isinstance(arg, (dict, list)):
|
|
# Convert complex objects to string safely
|
|
safe_args.append(str(arg))
|
|
elif isinstance(arg, str):
|
|
# Ensure string is properly encoded
|
|
safe_args.append(arg)
|
|
else:
|
|
safe_args.append(str(arg))
|
|
record.args = tuple(safe_args)
|
|
|
|
return super().format(record)
|
|
|
|
# Set up UTF-8 encoding for stdout/stderr on Windows
|
|
if sys.platform.startswith('win'):
|
|
# Force UTF-8 encoding for Windows console
|
|
import locale
|
|
try:
|
|
# Try to set console to UTF-8
|
|
os.system('chcp 65001 >nul 2>&1')
|
|
# Set environment variables for Python UTF-8 mode
|
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|
except:
|
|
pass
|
|
|
|
# Logging configuration with Unicode support
|
|
unicode_handler = UnicodeStreamHandler(sys.stderr)
|
|
unicode_handler.setFormatter(UnicodeFormatter(
|
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
))
|
|
|
|
# Configure root logger
|
|
root_logger = logging.getLogger()
|
|
root_logger.setLevel(logging.DEBUG)
|
|
root_logger.handlers.clear() # Remove default handlers
|
|
root_logger.addHandler(unicode_handler)
|
|
|
|
logger = logging.getLogger("imagen4-mcp-server")
|
|
|
|
|
|
# ==================== Image Processing Utilities ====================
|
|
|
|
def create_preview_image_b64(image_data: bytes, target_size: int = 512, quality: int = 85) -> Optional[str]:
|
|
"""
|
|
Convert PNG image data to JPEG preview with specified size and return as base64
|
|
|
|
Args:
|
|
image_data: Original PNG image data in bytes
|
|
target_size: Target size for the preview (default: 512x512)
|
|
quality: JPEG quality (1-100, default: 85)
|
|
|
|
Returns:
|
|
Base64 encoded JPEG image string, or None if conversion fails
|
|
"""
|
|
try:
|
|
# Open image from bytes
|
|
with Image.open(io.BytesIO(image_data)) as img:
|
|
# Convert to RGB if necessary (PNG might have alpha channel)
|
|
if img.mode in ('RGBA', 'LA', 'P'):
|
|
# Create white background for transparent images
|
|
background = Image.new('RGB', img.size, (255, 255, 255))
|
|
if img.mode == 'P':
|
|
img = img.convert('RGBA')
|
|
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
|
|
img = background
|
|
elif img.mode != 'RGB':
|
|
img = img.convert('RGB')
|
|
|
|
# Resize to target size maintaining aspect ratio
|
|
img.thumbnail((target_size, target_size), Image.Resampling.LANCZOS)
|
|
|
|
# If image is smaller than target size, pad it to exact size
|
|
if img.size != (target_size, target_size):
|
|
# Create new image with target size and white background
|
|
new_img = Image.new('RGB', (target_size, target_size), (255, 255, 255))
|
|
# Center the resized image
|
|
x = (target_size - img.size[0]) // 2
|
|
y = (target_size - img.size[1]) // 2
|
|
new_img.paste(img, (x, y))
|
|
img = new_img
|
|
|
|
# Convert to JPEG and encode as base64
|
|
output_buffer = io.BytesIO()
|
|
img.save(output_buffer, format='JPEG', quality=quality, optimize=True)
|
|
jpeg_data = output_buffer.getvalue()
|
|
|
|
# Encode to base64
|
|
base64_data = base64.b64encode(jpeg_data).decode('utf-8')
|
|
|
|
logger.info(f"Preview image created: {target_size}x{target_size} JPEG, {len(jpeg_data)} bytes")
|
|
return base64_data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create preview image: {str(e)}")
|
|
return None
|
|
|
|
|
|
def get_image_info(image_data: bytes) -> Optional[dict]:
|
|
"""Get image information"""
|
|
try:
|
|
with Image.open(io.BytesIO(image_data)) as img:
|
|
return {
|
|
'format': img.format,
|
|
'size': img.size,
|
|
'mode': img.mode,
|
|
'bytes': len(image_data)
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Failed to get image info: {str(e)}")
|
|
return None
|
|
|
|
|
|
# ==================== Response Models ====================
|
|
|
|
@dataclass
|
|
class ImageGenerationResult:
|
|
"""Enhanced image generation result with preview support"""
|
|
success: bool
|
|
message: str
|
|
original_images_count: int
|
|
preview_images_b64: Optional[list[str]] = None # List of base64 JPEG previews (512x512)
|
|
saved_files: Optional[list[str]] = None
|
|
generation_params: Optional[Dict[str, Any]] = None
|
|
error_message: Optional[str] = None
|
|
|
|
def to_text_content(self) -> str:
|
|
"""Convert result to text format for MCP response"""
|
|
lines = [self.message]
|
|
|
|
if self.success and self.preview_images_b64:
|
|
lines.append(f"\n🖼️ Preview Images Generated: {len(self.preview_images_b64)} images (512x512 JPEG)")
|
|
for i, preview_b64 in enumerate(self.preview_images_b64):
|
|
lines.append(f"Preview {i+1} (base64 JPEG): {preview_b64[:50]}...({len(preview_b64)} chars)")
|
|
|
|
if self.saved_files:
|
|
lines.append(f"\n📁 Files saved:")
|
|
for filepath in self.saved_files:
|
|
lines.append(f" - {filepath}")
|
|
|
|
if self.generation_params:
|
|
lines.append(f"\n⚙️ Generation Parameters:")
|
|
for key, value in self.generation_params.items():
|
|
if key == 'prompt' and len(str(value)) > 100:
|
|
lines.append(f" - {key}: {str(value)[:100]}...")
|
|
else:
|
|
lines.append(f" - {key}: {value}")
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
# ==================== MCP Tool Definitions ====================
|
|
|
|
def get_tools() -> List[Tool]:
|
|
"""Return all tool definitions"""
|
|
return [
|
|
Tool(
|
|
name="generate_image",
|
|
description="Generate 2048x2048 PNG images with 512x512 JPEG previews using Google Imagen 4 AI.",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"prompt": {
|
|
"type": "string",
|
|
"description": "Text prompt for image generation (Korean or English)"
|
|
},
|
|
"negative_prompt": {
|
|
"type": "string",
|
|
"description": "Negative prompt specifying elements not to generate (optional)",
|
|
"default": ""
|
|
},
|
|
"number_of_images": {
|
|
"type": "integer",
|
|
"description": "Number of images to generate (1 or 2 only)",
|
|
"enum": [1, 2],
|
|
"default": 1
|
|
},
|
|
"seed": {
|
|
"type": "integer",
|
|
"description": "Seed value for reproducible results (required, 0 ~ 4294967295 range)",
|
|
"minimum": 0,
|
|
"maximum": 4294967295
|
|
},
|
|
"aspect_ratio": {
|
|
"type": "string",
|
|
"description": "Image aspect ratio",
|
|
"enum": ["1:1", "9:16", "16:9", "3:4", "4:3"],
|
|
"default": "1:1"
|
|
},
|
|
"save_to_file": {
|
|
"type": "boolean",
|
|
"description": "Whether to save generated images to files (default: true)",
|
|
"default": True
|
|
}
|
|
},
|
|
"required": ["prompt", "seed"]
|
|
}
|
|
),
|
|
Tool(
|
|
name="regenerate_from_json",
|
|
description="Read parameters from JSON file and regenerate images with previews.",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"json_file_path": {
|
|
"type": "string",
|
|
"description": "Path to JSON file containing saved parameters"
|
|
},
|
|
"save_to_file": {
|
|
"type": "boolean",
|
|
"description": "Whether to save regenerated images to files (default: true)",
|
|
"default": True
|
|
}
|
|
},
|
|
"required": ["json_file_path"]
|
|
}
|
|
),
|
|
Tool(
|
|
name="generate_random_seed",
|
|
description="Generate random seed value for image generation.",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {},
|
|
"required": []
|
|
}
|
|
)
|
|
]
|
|
|
|
|
|
# ==================== Utility Functions ====================
|
|
|
|
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Remove or truncate sensitive data from arguments for safe logging with Unicode support"""
|
|
safe_args = {}
|
|
for key, value in arguments.items():
|
|
if isinstance(value, str):
|
|
# Check if it's likely base64 image data
|
|
if (key in ['data', 'image_data', 'base64', 'preview_image_b64'] or
|
|
(len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))):
|
|
# Truncate long image data
|
|
safe_args[key] = f"<image_data:{len(value)} chars>"
|
|
elif len(value) > 1000:
|
|
# Truncate any very long strings but preserve Unicode
|
|
truncated = value[:100]
|
|
safe_args[key] = f"{truncated}...<truncated:{len(value)} total chars>"
|
|
else:
|
|
# Keep the original string (including Korean characters)
|
|
safe_args[key] = value
|
|
else:
|
|
safe_args[key] = value
|
|
return safe_args
|
|
|
|
|
|
# ==================== Tool Handlers ====================
|
|
|
|
class Imagen4ToolHandlers:
|
|
"""MCP tool handler class with preview image support"""
|
|
|
|
def __init__(self, config: Config):
|
|
"""Initialize handler"""
|
|
self.config = config
|
|
self.client = Imagen4Client(config)
|
|
|
|
async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
|
"""Random seed generation handler"""
|
|
try:
|
|
random_seed = random.randint(0, 2**32 - 1)
|
|
return [TextContent(
|
|
type="text",
|
|
text=f"Generated random seed: {random_seed}"
|
|
)]
|
|
except Exception as e:
|
|
logger.error(f"Random seed generation error: {str(e)}")
|
|
return [TextContent(
|
|
type="text",
|
|
text=f"Error occurred during random seed generation: {str(e)}"
|
|
)]
|
|
|
|
async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
|
"""Image regeneration from JSON file handler with preview support"""
|
|
try:
|
|
json_file_path = arguments.get("json_file_path")
|
|
save_to_file = arguments.get("save_to_file", True)
|
|
|
|
if not json_file_path:
|
|
return [TextContent(
|
|
type="text",
|
|
text="Error: JSON file path is required."
|
|
)]
|
|
|
|
# Load parameters from JSON file with proper UTF-8 encoding
|
|
try:
|
|
with open(json_file_path, 'r', encoding='utf-8') as f:
|
|
params = json.load(f)
|
|
except FileNotFoundError:
|
|
return [TextContent(
|
|
type="text",
|
|
text=f"Error: Cannot find JSON file: {json_file_path}"
|
|
)]
|
|
except json.JSONDecodeError as e:
|
|
return [TextContent(
|
|
type="text",
|
|
text=f"Error: JSON parsing error: {str(e)}"
|
|
)]
|
|
|
|
# Check required parameters
|
|
required_params = ['prompt', 'seed']
|
|
missing_params = [p for p in required_params if p not in params]
|
|
|
|
if missing_params:
|
|
return [TextContent(
|
|
type="text",
|
|
text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}"
|
|
)]
|
|
|
|
# Create image generation request object
|
|
request = ImageGenerationRequest(
|
|
prompt=params.get('prompt'),
|
|
negative_prompt=params.get('negative_prompt', ''),
|
|
number_of_images=params.get('number_of_images', 1),
|
|
seed=params.get('seed'),
|
|
aspect_ratio=params.get('aspect_ratio', '1:1')
|
|
)
|
|
|
|
logger.info(f"Loaded parameters from JSON: {json_file_path}")
|
|
|
|
# Generate image
|
|
response = await self.client.generate_image(request)
|
|
|
|
if not response.success:
|
|
return [TextContent(
|
|
type="text",
|
|
text=f"Error occurred during image regeneration: {response.error_message}"
|
|
)]
|
|
|
|
# Create preview images
|
|
preview_images_b64 = []
|
|
for image_data in response.images_data:
|
|
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
|
|
if preview_b64:
|
|
preview_images_b64.append(preview_b64)
|
|
logger.info(f"Created preview image: {len(preview_b64)} chars base64 JPEG")
|
|
|
|
# Save files (optional)
|
|
saved_files = []
|
|
if save_to_file:
|
|
regeneration_params = {
|
|
"prompt": request.prompt,
|
|
"negative_prompt": request.negative_prompt,
|
|
"number_of_images": request.number_of_images,
|
|
"seed": request.seed,
|
|
"aspect_ratio": request.aspect_ratio,
|
|
"regenerated_from": json_file_path,
|
|
"original_generated_at": params.get('generated_at', 'unknown')
|
|
}
|
|
|
|
saved_files = save_generated_images(
|
|
images_data=response.images_data,
|
|
save_directory=self.config.output_path,
|
|
seed=request.seed,
|
|
generation_params=regeneration_params,
|
|
filename_prefix="imagen4_regen"
|
|
)
|
|
|
|
# Create result with preview images
|
|
result = ImageGenerationResult(
|
|
success=True,
|
|
message=f"✅ Images have been successfully regenerated from {json_file_path}",
|
|
original_images_count=len(response.images_data),
|
|
preview_images_b64=preview_images_b64,
|
|
saved_files=saved_files,
|
|
generation_params={
|
|
"prompt": request.prompt,
|
|
"seed": request.seed,
|
|
"aspect_ratio": request.aspect_ratio,
|
|
"number_of_images": request.number_of_images
|
|
}
|
|
)
|
|
|
|
return [TextContent(
|
|
type="text",
|
|
text=result.to_text_content()
|
|
)]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error occurred during image regeneration: {str(e)}")
|
|
return [TextContent(
|
|
type="text",
|
|
text=f"Error occurred during image regeneration: {str(e)}"
|
|
)]
|
|
|
|
async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
|
"""Image generation handler with preview image support and proper Unicode logging"""
|
|
try:
|
|
# Log arguments safely without exposing image data, but preserve Unicode
|
|
safe_args = sanitize_args_for_logging(arguments)
|
|
logger.info(f"handle_generate_image called with arguments: {safe_args}")
|
|
|
|
# Extract and validate arguments
|
|
prompt = arguments.get("prompt")
|
|
if not prompt:
|
|
logger.error("No prompt provided")
|
|
return [TextContent(
|
|
type="text",
|
|
text="Error: Prompt is required."
|
|
)]
|
|
|
|
seed = arguments.get("seed")
|
|
if seed is None:
|
|
logger.error("No seed provided")
|
|
return [TextContent(
|
|
type="text",
|
|
text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed."
|
|
)]
|
|
|
|
# Create image generation request object
|
|
request = ImageGenerationRequest(
|
|
prompt=prompt,
|
|
negative_prompt=arguments.get("negative_prompt", ""),
|
|
number_of_images=arguments.get("number_of_images", 1),
|
|
seed=seed,
|
|
aspect_ratio=arguments.get("aspect_ratio", "1:1")
|
|
)
|
|
|
|
save_to_file = arguments.get("save_to_file", True)
|
|
|
|
# Log with proper Unicode handling for Korean text
|
|
prompt_preview = prompt[:50] + "..." if len(prompt) > 50 else prompt
|
|
logger.info(f"Starting image generation: '{prompt_preview}', Seed: {seed}")
|
|
|
|
# Generate image with timeout
|
|
try:
|
|
logger.info("Calling client.generate_image()...")
|
|
response = await asyncio.wait_for(
|
|
self.client.generate_image(request),
|
|
timeout=360.0 # 6 minute timeout
|
|
)
|
|
logger.info(f"Image generation completed. Success: {response.success}")
|
|
except asyncio.TimeoutError:
|
|
logger.error("Image generation timed out after 6 minutes")
|
|
return [TextContent(
|
|
type="text",
|
|
text="Error: Image generation timed out after 6 minutes. Please try again."
|
|
)]
|
|
|
|
if not response.success:
|
|
logger.error(f"Image generation failed: {response.error_message}")
|
|
return [TextContent(
|
|
type="text",
|
|
text=f"Error occurred during image generation: {response.error_message}"
|
|
)]
|
|
|
|
logger.info(f"Generated {len(response.images_data)} images successfully")
|
|
|
|
# Create preview images (512x512 JPEG base64)
|
|
preview_images_b64 = []
|
|
for i, image_data in enumerate(response.images_data):
|
|
# Log original image info
|
|
image_info = get_image_info(image_data)
|
|
if image_info:
|
|
logger.info(f"Original image {i+1}: {image_info}")
|
|
|
|
# Create preview
|
|
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
|
|
if preview_b64:
|
|
preview_images_b64.append(preview_b64)
|
|
logger.info(f"Created preview image {i+1}: {len(preview_b64)} chars base64 JPEG")
|
|
else:
|
|
logger.warning(f"Failed to create preview for image {i+1}")
|
|
|
|
# Save files if requested
|
|
saved_files = []
|
|
if save_to_file:
|
|
logger.info("Saving files to disk...")
|
|
generation_params = {
|
|
"prompt": request.prompt,
|
|
"negative_prompt": request.negative_prompt,
|
|
"number_of_images": request.number_of_images,
|
|
"seed": request.seed,
|
|
"aspect_ratio": request.aspect_ratio,
|
|
"guidance_scale": 7.5,
|
|
"safety_filter_level": "block_only_high",
|
|
"person_generation": "allow_all",
|
|
"add_watermark": False
|
|
}
|
|
|
|
saved_files = save_generated_images(
|
|
images_data=response.images_data,
|
|
save_directory=self.config.output_path,
|
|
seed=request.seed,
|
|
generation_params=generation_params
|
|
)
|
|
logger.info(f"Files saved: {saved_files}")
|
|
|
|
# Verify files were created
|
|
for file_path in saved_files:
|
|
if os.path.exists(file_path):
|
|
size = os.path.getsize(file_path)
|
|
logger.info(f" ✓ Verified: {file_path} ({size} bytes)")
|
|
else:
|
|
logger.error(f" ❌ Missing: {file_path}")
|
|
|
|
# Create enhanced result with preview images
|
|
result = ImageGenerationResult(
|
|
success=True,
|
|
message=f"✅ Images have been successfully generated!",
|
|
original_images_count=len(response.images_data),
|
|
preview_images_b64=preview_images_b64,
|
|
saved_files=saved_files,
|
|
generation_params={
|
|
"prompt": request.prompt,
|
|
"seed": request.seed,
|
|
"aspect_ratio": request.aspect_ratio,
|
|
"number_of_images": request.number_of_images,
|
|
"negative_prompt": request.negative_prompt
|
|
}
|
|
)
|
|
|
|
logger.info(f"Returning response with {len(preview_images_b64)} preview images")
|
|
return [TextContent(
|
|
type="text",
|
|
text=result.to_text_content()
|
|
)]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True)
|
|
return [TextContent(
|
|
type="text",
|
|
text=f"Error occurred during image generation: {str(e)}"
|
|
)]
|
|
|
|
|
|
# ==================== MCP Server ====================
|
|
|
|
class Imagen4MCPServer:
|
|
"""Imagen4 MCP server class with preview image support"""
|
|
|
|
def __init__(self, config: Config):
|
|
"""Initialize server"""
|
|
self.config = config
|
|
self.server = Server("imagen4-mcp-server")
|
|
self.handlers = Imagen4ToolHandlers(config)
|
|
|
|
# Register handlers
|
|
self._register_handlers()
|
|
|
|
def _register_handlers(self) -> None:
|
|
"""Register MCP handlers"""
|
|
|
|
@self.server.list_tools()
|
|
async def handle_list_tools() -> List[Tool]:
|
|
"""Return list of available tools"""
|
|
try:
|
|
logger.info("Listing available tools with preview image support")
|
|
return get_tools()
|
|
except Exception as e:
|
|
logger.error(f"Error listing tools: {e}")
|
|
raise
|
|
|
|
@self.server.call_tool()
|
|
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
|
|
"""Handle tool calls with preview image support"""
|
|
try:
|
|
# Log tool call safely without exposing sensitive data
|
|
safe_args = sanitize_args_for_logging(arguments)
|
|
logger.info(f"Tool called: {name} with arguments: {safe_args}")
|
|
|
|
if name == "generate_random_seed":
|
|
return await self.handlers.handle_generate_random_seed(arguments)
|
|
elif name == "regenerate_from_json":
|
|
return await self.handlers.handle_regenerate_from_json(arguments)
|
|
elif name == "generate_image":
|
|
return await self.handlers.handle_generate_image(arguments)
|
|
else:
|
|
error_msg = f"Unknown tool: {name}"
|
|
logger.error(error_msg)
|
|
return [TextContent(
|
|
type="text",
|
|
text=f"Error: {error_msg}"
|
|
)]
|
|
except Exception as e:
|
|
logger.error(f"Error handling tool call '{name}': {e}")
|
|
import traceback
|
|
logger.error(f"Tool call traceback: {traceback.format_exc()}")
|
|
return [TextContent(
|
|
type="text",
|
|
text=f"Error occurred while processing tool '{name}': {str(e)}"
|
|
)]
|
|
|
|
def get_server(self) -> Server:
|
|
"""Return MCP server instance"""
|
|
return self.server
|
|
|
|
|
|
# ==================== Main Function ====================
|
|
|
|
async def main():
|
|
"""Main function"""
|
|
logger.info("Starting Imagen 4 MCP Server with Preview Image Support...")
|
|
|
|
try:
|
|
# Load configuration
|
|
config = Config.from_env()
|
|
logger.info(f"Configuration loaded - Project: {config.project_id}, Location: {config.location}")
|
|
|
|
# Create MCP server
|
|
mcp_server = Imagen4MCPServer(config)
|
|
server = mcp_server.get_server()
|
|
|
|
logger.info("Imagen 4 MCP Server initialized successfully")
|
|
logger.info("Features: 512x512 JPEG preview images, base64 encoding, enhanced responses, Unicode support")
|
|
|
|
# Run MCP server with better error handling
|
|
try:
|
|
async with stdio_server() as (read_stream, write_stream):
|
|
logger.info("Starting MCP server with stdio transport")
|
|
await server.run(
|
|
read_stream,
|
|
write_stream,
|
|
InitializationOptions(
|
|
server_name="imagen4-mcp-server",
|
|
server_version="3.1.0",
|
|
capabilities=server.get_capabilities(
|
|
notification_options=NotificationOptions(),
|
|
experimental_capabilities={
|
|
"preview_images": {},
|
|
"base64_jpeg_previews": {},
|
|
"unicode_logging": {}
|
|
},
|
|
)
|
|
)
|
|
)
|
|
except Exception as stdio_error:
|
|
logger.error(f"STDIO server error: {stdio_error}")
|
|
# Try to handle specific MCP protocol errors
|
|
if "TaskGroup" in str(stdio_error):
|
|
logger.error("TaskGroup error detected - this may be due to client disconnection")
|
|
raise
|
|
|
|
except ValueError as e:
|
|
logger.error(f"Configuration error: {e}")
|
|
logger.error("Please check your .env file or environment variables.")
|
|
return 1
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Server stopped by user (Ctrl+C)")
|
|
return 0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Server error: {e}")
|
|
logger.error(f"Error type: {type(e).__name__}")
|
|
import traceback
|
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
|
raise
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
# Set up signal handling for graceful shutdown
|
|
import signal
|
|
|
|
def signal_handler(signum, frame):
|
|
logger.info(f"Received signal {signum}, shutting down gracefully...")
|
|
sys.exit(0)
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
# Run the server
|
|
exit_code = asyncio.run(main())
|
|
sys.exit(exit_code or 0)
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Server stopped by user")
|
|
sys.exit(0)
|
|
except SystemExit as e:
|
|
logger.info(f"Server exiting with code {e.code}")
|
|
sys.exit(e.code)
|
|
except Exception as e:
|
|
logger.error(f"Fatal error: {e}")
|
|
logger.error(f"Error type: {type(e).__name__}")
|
|
import traceback
|
|
logger.error(f"Full traceback: {traceback.format_exc()}")
|
|
sys.exit(1)
|