Files
flux1-edit/main.py
2025-08-26 04:33:33 +09:00

407 lines
16 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
FLUX.1 Edit MCP Server - Fixed Version
AI image editing MCP server using FLUX.1 Kontext
- Enhanced error handling and UTF-8 support
- MCP protocol compliance
- Based on imagen4 server structure
"""
# ==================== CRITICAL: UTF-8 SETUP MUST BE FIRST ====================
import os
import sys
import locale
# Force UTF-8 environment variables - set immediately
os.environ['PYTHONIOENCODING'] = 'utf-8'
os.environ['PYTHONUTF8'] = '1'
os.environ['LC_ALL'] = 'C.UTF-8'
# Windows-specific UTF-8 setup
if sys.platform.startswith('win'):
try:
os.system('chcp 65001 >nul 2>&1')
except Exception:
pass
try:
locale.setlocale(locale.LC_ALL, 'C.UTF-8')
except locale.Error:
try:
locale.setlocale(locale.LC_ALL, '')
except locale.Error:
pass
# Add src to path for imports
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, 'src'))
sys.path.insert(0, current_dir)
# ==================== Imports ====================
import asyncio
import logging
import json
from typing import Dict, Any, List
from pathlib import Path
# Safe logging setup
class SafeUnicodeHandler(logging.StreamHandler):
"""Ultra-safe Unicode stream handler that prevents all encoding issues"""
def __init__(self, stream=None):
super().__init__(stream)
self.encoding = 'utf-8'
def emit(self, record):
try:
msg = self.format(record)
# Windows safety: replace problematic Unicode characters
if sys.platform.startswith('win'):
emoji_replacements = {
'': '[SUCCESS]', '': '[ERROR]', '⚠️': '[WARNING]',
'🔄': '[RETRY]', '': '[WAIT]', '🖼️': '[IMAGE]',
'📁': '[FILES]', '⚙️': '[PARAMS]', '🎨': '[GENERATE]'
}
for emoji, replacement in emoji_replacements.items():
msg = msg.replace(emoji, replacement)
# Ensure safe encoding
msg = msg.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
# Write safely
stream = self.stream
terminator = getattr(self, 'terminator', '\n')
try:
if hasattr(stream, 'buffer'):
stream.buffer.write((msg + terminator).encode('utf-8', errors='replace'))
stream.buffer.flush()
else:
stream.write(msg + terminator)
if hasattr(stream, 'flush'):
stream.flush()
except (UnicodeEncodeError, UnicodeDecodeError):
try:
safe_msg = msg.encode('ascii', errors='replace').decode('ascii')
if hasattr(stream, 'buffer'):
stream.buffer.write((safe_msg + terminator).encode('ascii'))
stream.buffer.flush()
else:
stream.write(safe_msg + terminator)
if hasattr(stream, 'flush'):
stream.flush()
except Exception:
pass
except Exception:
self.handleError(record)
# Setup safe logging
safe_handler = SafeUnicodeHandler(sys.stderr)
safe_handler.setFormatter(logging.Formatter(
'%(asctime)s [%(name)s] [%(levelname)s] %(message)s'
))
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
root_logger.handlers.clear()
root_logger.addHandler(safe_handler)
logger = logging.getLogger("flux1-edit-mcp")
# ==================== MCP Imports with Error Handling ====================
try:
import mcp.types as types
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.server.models import InitializationOptions
from mcp.server import NotificationOptions
logger.info("MCP imports successful")
except ImportError as e:
logger.error(f"Failed to import MCP: {e}")
logger.error("Please install MCP: pip install mcp")
sys.exit(1)
# ==================== Local Imports with Error Handling ====================
try:
from src.connector.config import Config
from src.server.models import TOOL_DEFINITIONS, ToolName
from src.server.handlers import ToolHandlers
logger.info("Local imports successful")
except ImportError as e:
logger.error(f"Failed to import local modules: {e}")
logger.error(f"Current directory: {current_dir}")
logger.error(f"Python path: {sys.path}")
# List available modules for debugging
try:
src_path = os.path.join(current_dir, 'src')
if os.path.exists(src_path):
logger.error(f"Available in src: {os.listdir(src_path)}")
for subdir in ['connector', 'server', 'utils']:
subdir_path = os.path.join(src_path, subdir)
if os.path.exists(subdir_path):
logger.error(f"Available in src/{subdir}: {os.listdir(subdir_path)}")
except Exception:
pass
sys.exit(1)
# ==================== Utility Functions ====================
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Remove or truncate sensitive data from arguments for safe logging"""
safe_args = {}
for key, value in arguments.items():
if isinstance(value, str):
if (key.endswith('_b64') or key.endswith('_data') or
(len(value) > 100 and any(value.startswith(prefix) for prefix in ['iVBORw0KGgo', '/9j/', 'R0lGOD']))):
safe_args[key] = f"<image_data:{len(value)} chars>"
elif len(value) > 1000:
safe_args[key] = f"{value[:100]}...<truncated:{len(value)} chars>"
else:
safe_args[key] = value
else:
safe_args[key] = value
return safe_args
# ==================== MCP Server Class ====================
class FluxEditMCPServer:
"""FLUX.1 Edit MCP server with enhanced error handling"""
def __init__(self):
"""Initialize server with comprehensive error handling"""
logger.info("Initializing FLUX.1 Edit MCP Server...")
try:
# Create configuration
self.config = Config()
logger.info("Configuration created successfully")
# Validate configuration
if not self.config.validate():
raise RuntimeError("Configuration validation failed")
logger.info("Configuration validated successfully")
# Create MCP server
self.server = Server("flux1-edit")
logger.info("MCP server instance created")
# Create tool handlers
self.handlers = ToolHandlers(self.config)
logger.info("Tool handlers created successfully")
# Register handlers
self._register_handlers()
logger.info("Handlers registered successfully")
logger.info("FLUX.1 Edit MCP Server initialization complete")
except Exception as e:
logger.error(f"Failed to initialize server: {e}", exc_info=True)
raise
def _register_handlers(self):
"""Register MCP handlers with comprehensive error handling"""
@self.server.list_tools()
async def handle_list_tools() -> List[types.Tool]:
"""List available tools"""
try:
logger.info("Listing available tools")
tools = []
for tool_name, tool_def in TOOL_DEFINITIONS.items():
# Build properties for parameters
properties = {}
required = []
for param in tool_def.parameters:
prop_def = {
"type": param.type,
"description": param.description
}
# Add enum if specified
if param.enum:
prop_def["enum"] = param.enum
# Add default if specified
if param.default is not None:
prop_def["default"] = param.default
properties[param.name] = prop_def
if param.required:
required.append(param.name)
# Build tool schema
tool = types.Tool(
name=tool_def.name,
description=tool_def.description,
inputSchema={
"type": "object",
"properties": properties,
"required": required
}
)
tools.append(tool)
logger.info(f"Listed {len(tools)} tools successfully")
return tools
except Exception as e:
logger.error(f"Error listing tools: {e}", exc_info=True)
raise
@self.server.list_prompts()
async def handle_list_prompts() -> List[types.Prompt]:
"""List available prompts (empty for this server)"""
try:
logger.info("Listing available prompts")
# This server doesn't provide prompt templates
return []
except Exception as e:
logger.error(f"Error listing prompts: {e}", exc_info=True)
raise
@self.server.list_resources()
async def handle_list_resources() -> List[types.Resource]:
"""List available resources (empty for this server)"""
try:
logger.info("Listing available resources")
# This server doesn't provide static resources
return []
except Exception as e:
logger.error(f"Error listing resources: {e}", exc_info=True)
raise
@self.server.call_tool()
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[types.TextContent | types.ImageContent]:
"""Handle tool calls with comprehensive error handling"""
try:
# Log tool call safely
safe_args = sanitize_args_for_logging(arguments)
logger.info(f"Tool call: {name} with args: {safe_args}")
# Route to appropriate handler
if name == ToolName.FLUX_EDIT_IMAGE:
return await self.handlers.handle_flux_edit_image(arguments)
elif name == ToolName.FLUX_EDIT_IMAGE_FROM_FILE:
return await self.handlers.handle_flux_edit_image_from_file(arguments)
elif name == ToolName.VALIDATE_IMAGE:
return await self.handlers.handle_validate_image(arguments)
else:
error_msg = f"Unknown tool: {name}"
logger.error(error_msg)
return [types.TextContent(
type="text",
text=f"[ERROR] {error_msg}"
)]
except Exception as e:
logger.error(f"Error calling tool {name}: {e}", exc_info=True)
return [types.TextContent(
type="text",
text=f"[ERROR] Tool execution error: {str(e)}"
)]
# Add cancellation notification handler
try:
# This handles request cancellations gracefully
@self.server.set_logging_level()
async def handle_set_logging_level(level: types.LoggingLevel):
"""Handle logging level changes"""
try:
logger.info(f"Setting logging level to: {level}")
# Update logger level based on MCP level
if level == types.LoggingLevel.DEBUG:
logging.getLogger().setLevel(logging.DEBUG)
elif level == types.LoggingLevel.INFO:
logging.getLogger().setLevel(logging.INFO)
elif level == types.LoggingLevel.WARNING:
logging.getLogger().setLevel(logging.WARNING)
elif level == types.LoggingLevel.ERROR:
logging.getLogger().setLevel(logging.ERROR)
except Exception as e:
logger.error(f"Error setting logging level: {e}", exc_info=True)
except AttributeError:
# If set_logging_level is not available, skip
logger.info("set_logging_level not available in this MCP version")
async def run(self):
"""Run the MCP server"""
try:
logger.info("Starting FLUX.1 Edit MCP Server...")
logger.info(f"API key configured: {'Yes' if self.config.api_key else 'No'}")
logger.info(f"Input directory: {self.config.input_path}")
logger.info(f"Output directory: {self.config.generated_images_path}")
# Run server using stdio
async with stdio_server() as (read_stream, write_stream):
logger.info("MCP server started with stdio transport")
await self.server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="flux1-edit",
server_version="1.0.0",
capabilities=self.server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={}
)
)
)
except Exception as e:
logger.error(f"Server run error: {e}", exc_info=True)
raise
# ==================== Main Function ====================
async def main():
"""Main entry point with comprehensive error handling"""
try:
logger.info("Starting FLUX.1 Edit MCP Server main function")
# Create and run server
server = FluxEditMCPServer()
await server.run()
except KeyboardInterrupt:
logger.info("Server stopped by user (Ctrl+C)")
return 0
except Exception as e:
logger.error(f"Fatal server error: {e}", exc_info=True)
return 1
if __name__ == "__main__":
try:
# Set up signal handling
import signal
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, shutting down gracefully...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Run the server
logger.info("Starting FLUX.1 Edit MCP Server from main")
exit_code = asyncio.run(main())
sys.exit(exit_code or 0)
except KeyboardInterrupt:
logger.info("Server stopped by user")
sys.exit(0)
except SystemExit as e:
logger.info(f"Server exiting with code {e.code}")
sys.exit(e.code)
except Exception as e:
logger.error(f"Fatal error in main: {e}", exc_info=True)
sys.exit(1)