Files
imagen4/main.py

786 lines
31 KiB
Python

#!/usr/bin/env python3
"""
Imagen4 MCP Server with Preview Image Support
Google Imagen 4를 사용한 AI 이미지 생성 MCP 서버
- 2048x2048 PNG 원본 이미지 생성
- 512x512 JPEG 미리보기 이미지 base64 제공
- 향상된 응답 형식
Run from imagen4 root directory
"""
import asyncio
import base64
import json
import random
import logging
import sys
import os
import io
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
# Load environment variables
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
print("Warning: python-dotenv not installed", file=sys.stderr)
# Add current directory to PYTHONPATH
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
# MCP imports
try:
from mcp.server.models import InitializationOptions
from mcp.server import NotificationOptions
from mcp.server.stdio import stdio_server
from mcp.server import Server
from mcp.types import Tool, TextContent
except ImportError as e:
print(f"Error importing MCP: {e}", file=sys.stderr)
print("Please install required packages: pip install mcp", file=sys.stderr)
sys.exit(1)
# Connector imports
from src.connector import Config, Imagen4Client
from src.connector.imagen4_client import ImageGenerationRequest
from src.connector.utils import save_generated_images
# Image processing import
try:
from PIL import Image
except ImportError as e:
print(f"Error importing Pillow: {e}", file=sys.stderr)
print("Please install Pillow: pip install Pillow", file=sys.stderr)
sys.exit(1)
# ==================== Unicode-Safe Logging Setup ====================
class UnicodeStreamHandler(logging.StreamHandler):
"""Custom stream handler that ensures proper Unicode handling"""
def __init__(self, stream=None):
super().__init__(stream)
# Force UTF-8 encoding for Windows compatibility
if hasattr(self.stream, 'reconfigure'):
try:
self.stream.reconfigure(encoding='utf-8', errors='replace')
except:
pass
def emit(self, record):
try:
msg = self.format(record)
# Ensure the message is properly encoded as UTF-8
if isinstance(msg, str):
# For Windows, ensure proper UTF-8 handling
if sys.platform.startswith('win'):
msg = msg.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
stream = self.stream
# Write with explicit UTF-8 encoding
if hasattr(stream, 'buffer'):
stream.buffer.write((msg + self.terminator).encode('utf-8', errors='replace'))
stream.buffer.flush()
else:
stream.write(msg + self.terminator)
if hasattr(stream, 'flush'):
stream.flush()
except Exception:
self.handleError(record)
# Custom formatter for proper Unicode handling
class UnicodeFormatter(logging.Formatter):
"""Custom formatter that ensures proper Unicode handling in log messages"""
def format(self, record):
# Ensure all arguments are properly handled for Unicode
if hasattr(record, 'args') and record.args:
safe_args = []
for arg in record.args:
if isinstance(arg, (dict, list)):
# Convert complex objects to string safely
safe_args.append(str(arg))
elif isinstance(arg, str):
# Ensure string is properly encoded
safe_args.append(arg)
else:
safe_args.append(str(arg))
record.args = tuple(safe_args)
return super().format(record)
# Set up UTF-8 encoding for stdout/stderr on Windows
if sys.platform.startswith('win'):
# Force UTF-8 encoding for Windows console
import locale
try:
# Try to set console to UTF-8
os.system('chcp 65001 >nul 2>&1')
# Set environment variables for Python UTF-8 mode
os.environ['PYTHONIOENCODING'] = 'utf-8'
except:
pass
# Logging configuration with Unicode support
unicode_handler = UnicodeStreamHandler(sys.stderr)
unicode_handler.setFormatter(UnicodeFormatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
))
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
root_logger.handlers.clear() # Remove default handlers
root_logger.addHandler(unicode_handler)
logger = logging.getLogger("imagen4-mcp-server")
# ==================== Image Processing Utilities ====================
def create_preview_image_b64(image_data: bytes, target_size: int = 512, quality: int = 85) -> Optional[str]:
"""
Convert PNG image data to JPEG preview with specified size and return as base64
Args:
image_data: Original PNG image data in bytes
target_size: Target size for the preview (default: 512x512)
quality: JPEG quality (1-100, default: 85)
Returns:
Base64 encoded JPEG image string, or None if conversion fails
"""
try:
# Open image from bytes
with Image.open(io.BytesIO(image_data)) as img:
# Convert to RGB if necessary (PNG might have alpha channel)
if img.mode in ('RGBA', 'LA', 'P'):
# Create white background for transparent images
background = Image.new('RGB', img.size, (255, 255, 255))
if img.mode == 'P':
img = img.convert('RGBA')
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
img = background
elif img.mode != 'RGB':
img = img.convert('RGB')
# Resize to target size maintaining aspect ratio
img.thumbnail((target_size, target_size), Image.Resampling.LANCZOS)
# If image is smaller than target size, pad it to exact size
if img.size != (target_size, target_size):
# Create new image with target size and white background
new_img = Image.new('RGB', (target_size, target_size), (255, 255, 255))
# Center the resized image
x = (target_size - img.size[0]) // 2
y = (target_size - img.size[1]) // 2
new_img.paste(img, (x, y))
img = new_img
# Convert to JPEG and encode as base64
output_buffer = io.BytesIO()
img.save(output_buffer, format='JPEG', quality=quality, optimize=True)
jpeg_data = output_buffer.getvalue()
# Encode to base64
base64_data = base64.b64encode(jpeg_data).decode('utf-8')
logger.info(f"Preview image created: {target_size}x{target_size} JPEG, {len(jpeg_data)} bytes")
return base64_data
except Exception as e:
logger.error(f"Failed to create preview image: {str(e)}")
return None
def get_image_info(image_data: bytes) -> Optional[dict]:
"""Get image information"""
try:
with Image.open(io.BytesIO(image_data)) as img:
return {
'format': img.format,
'size': img.size,
'mode': img.mode,
'bytes': len(image_data)
}
except Exception as e:
logger.error(f"Failed to get image info: {str(e)}")
return None
# ==================== Response Models ====================
@dataclass
class ImageGenerationResult:
"""Enhanced image generation result with preview support"""
success: bool
message: str
original_images_count: int
preview_images_b64: Optional[list[str]] = None # List of base64 JPEG previews (512x512)
saved_files: Optional[list[str]] = None
generation_params: Optional[Dict[str, Any]] = None
error_message: Optional[str] = None
def to_text_content(self) -> str:
"""Convert result to text format for MCP response"""
lines = [self.message]
if self.success and self.preview_images_b64:
lines.append(f"\n🖼️ Preview Images Generated: {len(self.preview_images_b64)} images (512x512 JPEG)")
for i, preview_b64 in enumerate(self.preview_images_b64):
lines.append(f"Preview {i+1} (base64 JPEG): {preview_b64[:50]}...({len(preview_b64)} chars)")
if self.saved_files:
lines.append(f"\n📁 Files saved:")
for filepath in self.saved_files:
lines.append(f" - {filepath}")
if self.generation_params:
lines.append(f"\n⚙️ Generation Parameters:")
for key, value in self.generation_params.items():
if key == 'prompt' and len(str(value)) > 100:
lines.append(f" - {key}: {str(value)[:100]}...")
else:
lines.append(f" - {key}: {value}")
return "\n".join(lines)
# ==================== MCP Tool Definitions ====================
def get_tools() -> List[Tool]:
"""Return all tool definitions"""
return [
Tool(
name="generate_image",
description="Generate 2048x2048 PNG images with 512x512 JPEG previews using Google Imagen 4 AI.",
inputSchema={
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Text prompt for image generation (Korean or English)"
},
"negative_prompt": {
"type": "string",
"description": "Negative prompt specifying elements not to generate (optional)",
"default": ""
},
"number_of_images": {
"type": "integer",
"description": "Number of images to generate (1 or 2 only)",
"enum": [1, 2],
"default": 1
},
"seed": {
"type": "integer",
"description": "Seed value for reproducible results (required, 0 ~ 4294967295 range)",
"minimum": 0,
"maximum": 4294967295
},
"aspect_ratio": {
"type": "string",
"description": "Image aspect ratio",
"enum": ["1:1", "9:16", "16:9", "3:4", "4:3"],
"default": "1:1"
},
"save_to_file": {
"type": "boolean",
"description": "Whether to save generated images to files (default: true)",
"default": True
}
},
"required": ["prompt", "seed"]
}
),
Tool(
name="regenerate_from_json",
description="Read parameters from JSON file and regenerate images with previews.",
inputSchema={
"type": "object",
"properties": {
"json_file_path": {
"type": "string",
"description": "Path to JSON file containing saved parameters"
},
"save_to_file": {
"type": "boolean",
"description": "Whether to save regenerated images to files (default: true)",
"default": True
}
},
"required": ["json_file_path"]
}
),
Tool(
name="generate_random_seed",
description="Generate random seed value for image generation.",
inputSchema={
"type": "object",
"properties": {},
"required": []
}
)
]
# ==================== Utility Functions ====================
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Remove or truncate sensitive data from arguments for safe logging with Unicode support"""
safe_args = {}
for key, value in arguments.items():
if isinstance(value, str):
# Check if it's likely base64 image data
if (key in ['data', 'image_data', 'base64', 'preview_image_b64'] or
(len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))):
# Truncate long image data
safe_args[key] = f"<image_data:{len(value)} chars>"
elif len(value) > 1000:
# Truncate any very long strings but preserve Unicode
truncated = value[:100]
safe_args[key] = f"{truncated}...<truncated:{len(value)} total chars>"
else:
# Keep the original string (including Korean characters)
safe_args[key] = value
else:
safe_args[key] = value
return safe_args
# ==================== Tool Handlers ====================
class Imagen4ToolHandlers:
"""MCP tool handler class with preview image support"""
def __init__(self, config: Config):
"""Initialize handler"""
self.config = config
self.client = Imagen4Client(config)
async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Random seed generation handler"""
try:
random_seed = random.randint(0, 2**32 - 1)
return [TextContent(
type="text",
text=f"Generated random seed: {random_seed}"
)]
except Exception as e:
logger.error(f"Random seed generation error: {str(e)}")
return [TextContent(
type="text",
text=f"Error occurred during random seed generation: {str(e)}"
)]
async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Image regeneration from JSON file handler with preview support"""
try:
json_file_path = arguments.get("json_file_path")
save_to_file = arguments.get("save_to_file", True)
if not json_file_path:
return [TextContent(
type="text",
text="Error: JSON file path is required."
)]
# Load parameters from JSON file with proper UTF-8 encoding
try:
with open(json_file_path, 'r', encoding='utf-8') as f:
params = json.load(f)
except FileNotFoundError:
return [TextContent(
type="text",
text=f"Error: Cannot find JSON file: {json_file_path}"
)]
except json.JSONDecodeError as e:
return [TextContent(
type="text",
text=f"Error: JSON parsing error: {str(e)}"
)]
# Check required parameters
required_params = ['prompt', 'seed']
missing_params = [p for p in required_params if p not in params]
if missing_params:
return [TextContent(
type="text",
text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}"
)]
# Create image generation request object
request = ImageGenerationRequest(
prompt=params.get('prompt'),
negative_prompt=params.get('negative_prompt', ''),
number_of_images=params.get('number_of_images', 1),
seed=params.get('seed'),
aspect_ratio=params.get('aspect_ratio', '1:1')
)
logger.info(f"Loaded parameters from JSON: {json_file_path}")
# Generate image
response = await self.client.generate_image(request)
if not response.success:
return [TextContent(
type="text",
text=f"Error occurred during image regeneration: {response.error_message}"
)]
# Create preview images
preview_images_b64 = []
for image_data in response.images_data:
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
if preview_b64:
preview_images_b64.append(preview_b64)
logger.info(f"Created preview image: {len(preview_b64)} chars base64 JPEG")
# Save files (optional)
saved_files = []
if save_to_file:
regeneration_params = {
"prompt": request.prompt,
"negative_prompt": request.negative_prompt,
"number_of_images": request.number_of_images,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"regenerated_from": json_file_path,
"original_generated_at": params.get('generated_at', 'unknown')
}
saved_files = save_generated_images(
images_data=response.images_data,
save_directory=self.config.output_path,
seed=request.seed,
generation_params=regeneration_params,
filename_prefix="imagen4_regen"
)
# Create result with preview images
result = ImageGenerationResult(
success=True,
message=f"✅ Images have been successfully regenerated from {json_file_path}",
original_images_count=len(response.images_data),
preview_images_b64=preview_images_b64,
saved_files=saved_files,
generation_params={
"prompt": request.prompt,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"number_of_images": request.number_of_images
}
)
return [TextContent(
type="text",
text=result.to_text_content()
)]
except Exception as e:
logger.error(f"Error occurred during image regeneration: {str(e)}")
return [TextContent(
type="text",
text=f"Error occurred during image regeneration: {str(e)}"
)]
async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent]:
"""Image generation handler with preview image support and proper Unicode logging"""
try:
# Log arguments safely without exposing image data, but preserve Unicode
safe_args = sanitize_args_for_logging(arguments)
logger.info(f"handle_generate_image called with arguments: {safe_args}")
# Extract and validate arguments
prompt = arguments.get("prompt")
if not prompt:
logger.error("No prompt provided")
return [TextContent(
type="text",
text="Error: Prompt is required."
)]
seed = arguments.get("seed")
if seed is None:
logger.error("No seed provided")
return [TextContent(
type="text",
text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed."
)]
# Create image generation request object
request = ImageGenerationRequest(
prompt=prompt,
negative_prompt=arguments.get("negative_prompt", ""),
number_of_images=arguments.get("number_of_images", 1),
seed=seed,
aspect_ratio=arguments.get("aspect_ratio", "1:1")
)
save_to_file = arguments.get("save_to_file", True)
# Log with proper Unicode handling for Korean text
prompt_preview = prompt[:50] + "..." if len(prompt) > 50 else prompt
logger.info(f"Starting image generation: '{prompt_preview}', Seed: {seed}")
# Generate image with timeout
try:
logger.info("Calling client.generate_image()...")
response = await asyncio.wait_for(
self.client.generate_image(request),
timeout=360.0 # 6 minute timeout
)
logger.info(f"Image generation completed. Success: {response.success}")
except asyncio.TimeoutError:
logger.error("Image generation timed out after 6 minutes")
return [TextContent(
type="text",
text="Error: Image generation timed out after 6 minutes. Please try again."
)]
if not response.success:
logger.error(f"Image generation failed: {response.error_message}")
return [TextContent(
type="text",
text=f"Error occurred during image generation: {response.error_message}"
)]
logger.info(f"Generated {len(response.images_data)} images successfully")
# Create preview images (512x512 JPEG base64)
preview_images_b64 = []
for i, image_data in enumerate(response.images_data):
# Log original image info
image_info = get_image_info(image_data)
if image_info:
logger.info(f"Original image {i+1}: {image_info}")
# Create preview
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
if preview_b64:
preview_images_b64.append(preview_b64)
logger.info(f"Created preview image {i+1}: {len(preview_b64)} chars base64 JPEG")
else:
logger.warning(f"Failed to create preview for image {i+1}")
# Save files if requested
saved_files = []
if save_to_file:
logger.info("Saving files to disk...")
generation_params = {
"prompt": request.prompt,
"negative_prompt": request.negative_prompt,
"number_of_images": request.number_of_images,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"guidance_scale": 7.5,
"safety_filter_level": "block_only_high",
"person_generation": "allow_all",
"add_watermark": False
}
saved_files = save_generated_images(
images_data=response.images_data,
save_directory=self.config.output_path,
seed=request.seed,
generation_params=generation_params
)
logger.info(f"Files saved: {saved_files}")
# Verify files were created
for file_path in saved_files:
if os.path.exists(file_path):
size = os.path.getsize(file_path)
logger.info(f" ✓ Verified: {file_path} ({size} bytes)")
else:
logger.error(f" ❌ Missing: {file_path}")
# Create enhanced result with preview images
result = ImageGenerationResult(
success=True,
message=f"✅ Images have been successfully generated!",
original_images_count=len(response.images_data),
preview_images_b64=preview_images_b64,
saved_files=saved_files,
generation_params={
"prompt": request.prompt,
"seed": request.seed,
"aspect_ratio": request.aspect_ratio,
"number_of_images": request.number_of_images,
"negative_prompt": request.negative_prompt
}
)
logger.info(f"Returning response with {len(preview_images_b64)} preview images")
return [TextContent(
type="text",
text=result.to_text_content()
)]
except Exception as e:
logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True)
return [TextContent(
type="text",
text=f"Error occurred during image generation: {str(e)}"
)]
# ==================== MCP Server ====================
class Imagen4MCPServer:
"""Imagen4 MCP server class with preview image support"""
def __init__(self, config: Config):
"""Initialize server"""
self.config = config
self.server = Server("imagen4-mcp-server")
self.handlers = Imagen4ToolHandlers(config)
# Register handlers
self._register_handlers()
def _register_handlers(self) -> None:
"""Register MCP handlers"""
@self.server.list_tools()
async def handle_list_tools() -> List[Tool]:
"""Return list of available tools"""
try:
logger.info("Listing available tools with preview image support")
return get_tools()
except Exception as e:
logger.error(f"Error listing tools: {e}")
raise
@self.server.call_tool()
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
"""Handle tool calls with preview image support"""
try:
# Log tool call safely without exposing sensitive data
safe_args = sanitize_args_for_logging(arguments)
logger.info(f"Tool called: {name} with arguments: {safe_args}")
if name == "generate_random_seed":
return await self.handlers.handle_generate_random_seed(arguments)
elif name == "regenerate_from_json":
return await self.handlers.handle_regenerate_from_json(arguments)
elif name == "generate_image":
return await self.handlers.handle_generate_image(arguments)
else:
error_msg = f"Unknown tool: {name}"
logger.error(error_msg)
return [TextContent(
type="text",
text=f"Error: {error_msg}"
)]
except Exception as e:
logger.error(f"Error handling tool call '{name}': {e}")
import traceback
logger.error(f"Tool call traceback: {traceback.format_exc()}")
return [TextContent(
type="text",
text=f"Error occurred while processing tool '{name}': {str(e)}"
)]
def get_server(self) -> Server:
"""Return MCP server instance"""
return self.server
# ==================== Main Function ====================
async def main():
"""Main function"""
logger.info("Starting Imagen 4 MCP Server with Preview Image Support...")
try:
# Load configuration
config = Config.from_env()
logger.info(f"Configuration loaded - Project: {config.project_id}, Location: {config.location}")
# Create MCP server
mcp_server = Imagen4MCPServer(config)
server = mcp_server.get_server()
logger.info("Imagen 4 MCP Server initialized successfully")
logger.info("Features: 512x512 JPEG preview images, base64 encoding, enhanced responses, Unicode support")
# Run MCP server with better error handling
try:
async with stdio_server() as (read_stream, write_stream):
logger.info("Starting MCP server with stdio transport")
await server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="imagen4-mcp-server",
server_version="3.1.0",
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={
"preview_images": {},
"base64_jpeg_previews": {},
"unicode_logging": {}
},
)
)
)
except Exception as stdio_error:
logger.error(f"STDIO server error: {stdio_error}")
# Try to handle specific MCP protocol errors
if "TaskGroup" in str(stdio_error):
logger.error("TaskGroup error detected - this may be due to client disconnection")
raise
except ValueError as e:
logger.error(f"Configuration error: {e}")
logger.error("Please check your .env file or environment variables.")
return 1
except KeyboardInterrupt:
logger.info("Server stopped by user (Ctrl+C)")
return 0
except Exception as e:
logger.error(f"Server error: {e}")
logger.error(f"Error type: {type(e).__name__}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
raise
if __name__ == "__main__":
try:
# Set up signal handling for graceful shutdown
import signal
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, shutting down gracefully...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Run the server
exit_code = asyncio.run(main())
sys.exit(exit_code or 0)
except KeyboardInterrupt:
logger.info("Server stopped by user")
sys.exit(0)
except SystemExit as e:
logger.info(f"Server exiting with code {e.code}")
sys.exit(e.code)
except Exception as e:
logger.error(f"Fatal error: {e}")
logger.error(f"Error type: {type(e).__name__}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
sys.exit(1)