From 60535b1aa35020b08842acdee3e8485173f1381a Mon Sep 17 00:00:00 2001 From: ened Date: Tue, 26 Aug 2025 03:27:45 +0900 Subject: [PATCH] clean up code --- .claude/settings.local.json | 9 + check_dependencies.py | 150 --------- debug_test.py | 189 +++++++++++ diagnostic.py | 249 -------------- main.py | 368 ++++++++++++++++++++- requirements.txt | 15 +- simple_test.py | 126 ------- src/connector/__init__.py | 6 +- src/server/__init__.py | 3 +- src/server/handlers_backup.py | 605 ---------------------------------- src/server/mcp_server.py | 172 ---------- src/utils/__init__.py | 2 - src/utils/image_utils.py | 88 +---- start.bat | 35 -- test_fixes.py | 119 ------- tests/__init__.py | 0 tests/run_tests.py | 106 ------ tests/test_config.py | 159 --------- tests/test_flux_client.py | 308 ----------------- tests/test_handlers.py | 360 -------------------- tests/test_image_utils.py | 197 ----------- tests/test_validation.py | 218 ------------ troubleshoot.bat | 150 --------- 23 files changed, 559 insertions(+), 3075 deletions(-) create mode 100644 .claude/settings.local.json delete mode 100644 check_dependencies.py create mode 100644 debug_test.py delete mode 100644 diagnostic.py delete mode 100644 simple_test.py delete mode 100644 src/server/handlers_backup.py delete mode 100644 src/server/mcp_server.py delete mode 100644 start.bat delete mode 100644 test_fixes.py delete mode 100644 tests/__init__.py delete mode 100644 tests/run_tests.py delete mode 100644 tests/test_config.py delete mode 100644 tests/test_flux_client.py delete mode 100644 tests/test_handlers.py delete mode 100644 tests/test_image_utils.py delete mode 100644 tests/test_validation.py delete mode 100644 troubleshoot.bat diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..ad4dc53 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,9 @@ +{ + "permissions": { + "allow": [ + "Bash(py debug_test.py)" + ], + "deny": [], + "ask": [] + } +} \ No newline at end of file diff --git a/check_dependencies.py b/check_dependencies.py deleted file mode 100644 index 1e3548e..0000000 --- a/check_dependencies.py +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env python3 -""" -FLUX.1 Edit MCP Server - Dependency Check Script -This script checks if all required dependencies are properly installed -""" - -import sys -import importlib.util -from typing import List, Tuple - -def check_dependency(module_name: str, package_name: str = None) -> Tuple[bool, str]: - """ - Check if a dependency is installed and importable - - Args: - module_name: Name of the module to import - package_name: Name of the package to install (if different from module) - - Returns: - Tuple of (success, message) - """ - if package_name is None: - package_name = module_name - - try: - spec = importlib.util.find_spec(module_name) - if spec is None: - return False, f"[MISSING] {module_name} not found - install with: pip install {package_name}" - - # Try to actually import it - module = importlib.import_module(module_name) - - # Get version if available - version = getattr(module, '__version__', 'unknown') - return True, f"[OK] {module_name} {version}" - - except ImportError as e: - return False, f"[ERROR] {module_name} import failed: {e}" - except Exception as e: - return False, f"[ERROR] {module_name} error: {e}" - -def check_local_modules() -> List[Tuple[bool, str]]: - """Check local project modules""" - results = [] - - try: - # Add src to path temporarily - import os - from pathlib import Path - - src_path = Path(__file__).parent / 'src' - if src_path.exists(): - sys.path.insert(0, str(src_path)) - - # Test local imports - try: - from connector.config import Config - results.append((True, "[OK] Local Config module")) - except Exception as e: - results.append((False, f"[ERROR] Local Config module: {e}")) - - try: - from connector.flux_client import FluxEditClient - results.append((True, "[OK] Local FluxEditClient module")) - except Exception as e: - results.append((False, f"[ERROR] Local FluxEditClient module: {e}")) - - try: - from server.mcp_server import FluxEditMCPServer - results.append((True, "[OK] Local MCP Server module")) - except Exception as e: - results.append((False, f"[ERROR] Local MCP Server module: {e}")) - - except Exception as e: - results.append((False, f"[ERROR] Local module check failed: {e}")) - - return results - -def main(): - """Main dependency check function""" - print("FLUX.1 Edit MCP Server - Dependency Check") - print("=========================================") - print(f"Python version: {sys.version}") - print(f"Python executable: {sys.executable}") - print() - - # Required dependencies with their install names - dependencies = [ - ("aiohttp", "aiohttp==3.11.7"), - ("httpx", "httpx==0.28.1"), - ("mcp", "mcp==1.1.0"), - ("PIL", "Pillow==11.0.0"), - ("dotenv", "python-dotenv==1.0.1"), - ("pydantic", "pydantic==2.10.3"), - ("structlog", "structlog==24.4.0"), - ] - - # Optional dependencies for development - optional_dependencies = [ - ("pytest", "pytest==8.3.4"), - ("black", "black==24.10.0"), - ] - - all_good = True - - print("Checking required dependencies...") - print("-" * 50) - - for module_name, package_name in dependencies: - success, message = check_dependency(module_name, package_name) - print(message) - if not success: - all_good = False - - print() - print("Checking optional dependencies...") - print("-" * 50) - - for module_name, package_name in optional_dependencies: - success, message = check_dependency(module_name, package_name) - print(message) - - print() - print("Checking local modules...") - print("-" * 50) - - local_results = check_local_modules() - for success, message in local_results: - print(message) - if not success: - all_good = False - - print() - print("=" * 50) - - if all_good: - print("[SUCCESS] All required dependencies are installed and working!") - print("You can now run: python main.py") - return 0 - else: - print("[FAILED] Some dependencies are missing or broken.") - print() - print("To fix this, try:") - print("1. Run: install_dependencies.bat (Windows)") - print("2. Or run: pip install -r requirements.txt") - print("3. Or run individual pip install commands shown above") - return 1 - -if __name__ == "__main__": - sys.exit(main()) diff --git a/debug_test.py b/debug_test.py new file mode 100644 index 0000000..70f1f17 --- /dev/null +++ b/debug_test.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +FLUX.1 Edit MCP Server Debug Test + +This script helps diagnose issues with the MCP server +""" + +# Force UTF-8 encoding setup +import os +import sys +import locale + +# Force UTF-8 environment variables +os.environ['PYTHONIOENCODING'] = 'utf-8' +os.environ['PYTHONUTF8'] = '1' +os.environ['LC_ALL'] = 'C.UTF-8' + +# Windows-specific UTF-8 setup +if sys.platform.startswith('win'): + try: + os.system('chcp 65001 >nul 2>&1') + except Exception: + pass + + try: + locale.setlocale(locale.LC_ALL, 'C.UTF-8') + except locale.Error: + try: + locale.setlocale(locale.LC_ALL, '') + except locale.Error: + pass + +import traceback + +# Add paths +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.join(current_dir, 'src')) +sys.path.insert(0, current_dir) + +def test_imports(): + """Test all required imports""" + print("=== Import Tests ===") + + # Test MCP imports + try: + import mcp.types as types + from mcp.server import Server + from mcp.server.stdio import stdio_server + print("[SUCCESS] MCP imports successful") + except ImportError as e: + print(f"[ERROR] MCP import failed: {e}") + return False + + # Test local imports + try: + from src.connector.config import Config + print("[SUCCESS] Config import successful") + except ImportError as e: + print(f"[ERROR] Config import failed: {e}") + print(f"Error details: {traceback.format_exc()}") + return False + + try: + from src.server.models import TOOL_DEFINITIONS, ToolName + print("[SUCCESS] Models import successful") + except ImportError as e: + print(f"[ERROR] Models import failed: {e}") + print(f"Error details: {traceback.format_exc()}") + return False + + try: + from src.server.handlers import ToolHandlers + print("[SUCCESS] Handlers import successful") + except ImportError as e: + print(f"[ERROR] Handlers import failed: {e}") + print(f"Error details: {traceback.format_exc()}") + return False + + # Test aiohttp + try: + import aiohttp + print(f"[SUCCESS] aiohttp version: {aiohttp.__version__}") + except ImportError as e: + print(f"[ERROR] aiohttp import failed: {e}") + return False + + return True + +def test_config(): + """Test configuration loading""" + print("\n=== Configuration Tests ===") + + try: + from src.connector.config import Config + config = Config() + print("[SUCCESS] Config created successfully") + + # Check API key + if config.api_key: + print(f"[SUCCESS] API key configured: ***{config.api_key[-4:]}") + else: + print("[WARNING] API key not configured") + + # Check paths + print(f"[SUCCESS] Input path: {config.input_path}") + print(f"[SUCCESS] Output path: {config.generated_images_path}") + + # Test validation + if config.validate(): + print("[SUCCESS] Configuration validation passed") + else: + print("[ERROR] Configuration validation failed") + + return True + + except Exception as e: + print(f"[ERROR] Configuration test failed: {e}") + print(f"Error details: {traceback.format_exc()}") + return False + +def test_handlers(): + """Test handler creation""" + print("\n=== Handler Tests ===") + + try: + from src.connector.config import Config + from src.server.handlers import ToolHandlers + + config = Config() + handlers = ToolHandlers(config) + print("[SUCCESS] Handlers created successfully") + + return True + + except Exception as e: + print(f"[ERROR] Handler test failed: {e}") + print(f"Error details: {traceback.format_exc()}") + return False + +def test_server_creation(): + """Test MCP server creation""" + print("\n=== Server Creation Tests ===") + + try: + from mcp.server import Server + + server = Server("flux1-edit-test") + print("[SUCCESS] MCP server created successfully") + + return True + + except Exception as e: + print(f"[ERROR] Server creation failed: {e}") + print(f"Error details: {traceback.format_exc()}") + return False + +def main(): + """Run all tests""" + print("FLUX.1 Edit MCP Server Debug Test") + print("=" * 50) + + tests = [ + test_imports, + test_config, + test_handlers, + test_server_creation + ] + + passed = 0 + for test in tests: + try: + if test(): + passed += 1 + except Exception as e: + print(f"[ERROR] Test crashed: {e}") + + print(f"\n=== Results ===") + print(f"Passed: {passed}/{len(tests)}") + + if passed == len(tests): + print("[SUCCESS] All tests passed! The server should work.") + return 0 + else: + print("[ERROR] Some tests failed. Check the errors above.") + return 1 + +if __name__ == "__main__": + sys.exit(main()) diff --git a/diagnostic.py b/diagnostic.py deleted file mode 100644 index 51fad6f..0000000 --- a/diagnostic.py +++ /dev/null @@ -1,249 +0,0 @@ -""" -Test script to diagnose FLUX.1 Edit MCP Server issues -""" - -import sys -import os -import asyncio -import logging -from pathlib import Path - -# Add src to path -sys.path.insert(0, str(Path(__file__).parent / 'src')) - -def test_imports(): - """Test all required imports""" - print("πŸ” Testing imports...") - - # Test standard library imports - try: - import json - import logging - import asyncio - print("βœ… Standard library imports OK") - except ImportError as e: - print(f"❌ Standard library import failed: {e}") - return False - - # Test third-party dependencies - deps = { - 'aiohttp': '3.11.7', - 'httpx': '0.28.1', - 'mcp': '1.1.0+', - 'PIL': 'Pillow 11.0.0', - 'dotenv': 'python-dotenv 1.0.1', - 'pydantic': '2.10.3' - } - - for module, expected in deps.items(): - try: - __import__(module) - print(f"βœ… {module} ({expected}) OK") - except ImportError as e: - print(f"❌ {module} missing: {e}") - return False - - # Test MCP specific imports - try: - import mcp.types as types - import mcp.server - from mcp.server import Server - from mcp.server.stdio import stdio_server - print("βœ… MCP imports OK") - except ImportError as e: - print(f"❌ MCP import failed: {e}") - return False - - # Test local module imports - try: - from src.connector import Config - print("βœ… Config import OK") - except ImportError as e: - print(f"❌ Config import failed: {e}") - return False - - try: - from src.server.models import TOOL_DEFINITIONS, ToolName - print("βœ… Models import OK") - except ImportError as e: - print(f"❌ Models import failed: {e}") - return False - - try: - from src.server.handlers import ToolHandlers - print("βœ… Handlers import OK") - except ImportError as e: - print(f"❌ Handlers import failed: {e}") - return False - - try: - from src.utils import validate_edit_parameters - print("βœ… Utils import OK") - except ImportError as e: - print(f"❌ Utils import failed: {e}") - return False - - return True - - -def test_config(): - """Test configuration loading""" - print("\nπŸ” Testing configuration...") - - try: - from src.connector import Config - config = Config() - - print(f"βœ… Config loaded") - print(f" - API key: {'***' + config.api_key[-4:] if config.api_key else 'Not Set'}") - print(f" - Input path: {config.input_path}") - print(f" - Output path: {config.generated_images_path}") - - # Test config validation - if config.validate(): - print("βœ… Config validation passed") - else: - print("❌ Config validation failed") - return False - - return True - - except Exception as e: - print(f"❌ Config test failed: {e}") - return False - - -def test_mcp_server_creation(): - """Test MCP server creation""" - print("\nπŸ” Testing MCP server creation...") - - try: - from src.server.mcp_server import FluxEditMCPServer - - server = FluxEditMCPServer() - print("βœ… MCP server created successfully") - - # Test tool definitions - from src.server.models import TOOL_DEFINITIONS - print(f"βœ… Tool definitions loaded: {len(TOOL_DEFINITIONS)} tools") - for tool_name in TOOL_DEFINITIONS: - print(f" - {tool_name}") - - return True - - except Exception as e: - print(f"❌ MCP server creation failed: {e}") - import traceback - traceback.print_exc() - return False - - -async def test_mcp_server_init(): - """Test MCP server initialization""" - print("\nπŸ” Testing MCP server initialization...") - - try: - from src.server.mcp_server import create_server - - server = create_server() - print("βœ… MCP server initialized") - - # Test configuration validation - if server.validate_config(): - print("βœ… Server config validation passed") - else: - print("❌ Server config validation failed") - return False - - return True - - except Exception as e: - print(f"❌ MCP server initialization failed: {e}") - import traceback - traceback.print_exc() - return False - - -def test_directories(): - """Test directory structure""" - print("\nπŸ” Testing directory structure...") - - base_path = Path(__file__).parent - - required_dirs = [ - 'src', - 'src/connector', - 'src/server', - 'src/utils', - 'input_images', - 'generated_images' - ] - - for dir_path in required_dirs: - full_path = base_path / dir_path - if full_path.exists(): - print(f"βœ… {dir_path}/") - else: - print(f"❌ {dir_path}/ (missing)") - try: - full_path.mkdir(parents=True, exist_ok=True) - print(f"βœ… Created {dir_path}/") - except Exception as e: - print(f"❌ Failed to create {dir_path}/: {e}") - return False - - return True - - -def main(): - """Main test function""" - print("πŸš€ FLUX.1 Edit MCP Server Diagnostic Tool") - print("=" * 50) - - tests = [ - ("Directory Structure", test_directories), - ("Import Dependencies", test_imports), - ("Configuration", test_config), - ("MCP Server Creation", test_mcp_server_creation), - ("MCP Server Initialization", lambda: asyncio.run(test_mcp_server_init())), - ] - - results = [] - for test_name, test_func in tests: - print(f"\nπŸ“‹ Running {test_name}...") - try: - result = test_func() - results.append((test_name, result)) - if result: - print(f"βœ… {test_name} PASSED") - else: - print(f"❌ {test_name} FAILED") - except Exception as e: - print(f"❌ {test_name} FAILED with exception: {e}") - results.append((test_name, False)) - - # Summary - print("\n" + "=" * 50) - print("πŸ“Š DIAGNOSTIC SUMMARY") - print("=" * 50) - - passed = sum(1 for _, result in results if result) - total = len(results) - - for test_name, result in results: - status = "βœ… PASS" if result else "❌ FAIL" - print(f"{status} {test_name}") - - print(f"\n🎯 Results: {passed}/{total} tests passed") - - if passed == total: - print("πŸŽ‰ All tests passed! Server should work correctly.") - else: - print("πŸ”§ Some tests failed. Please fix the issues above.") - - return passed == total - - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) diff --git a/main.py b/main.py index 2804160..2e6d0ec 100644 --- a/main.py +++ b/main.py @@ -1,30 +1,362 @@ -"""Main entry point for FLUX.1 Edit MCP Server""" +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +FLUX.1 Edit MCP Server - Fixed Version -import asyncio +FLUX.1 Kontextλ₯Ό μ‚¬μš©ν•œ AI 이미지 νŽΈμ§‘ MCP μ„œλ²„ +- Enhanced error handling and UTF-8 support +- MCP protocol compliance +- Based on imagen4 server structure +""" + +# ==================== CRITICAL: UTF-8 SETUP MUST BE FIRST ==================== +import os import sys -from pathlib import Path +import locale + +# Force UTF-8 environment variables - set immediately +os.environ['PYTHONIOENCODING'] = 'utf-8' +os.environ['PYTHONUTF8'] = '1' +os.environ['LC_ALL'] = 'C.UTF-8' + +# Windows-specific UTF-8 setup +if sys.platform.startswith('win'): + try: + os.system('chcp 65001 >nul 2>&1') + except Exception: + pass + + try: + locale.setlocale(locale.LC_ALL, 'C.UTF-8') + except locale.Error: + try: + locale.setlocale(locale.LC_ALL, '') + except locale.Error: + pass # Add src to path for imports -sys.path.insert(0, str(Path(__file__).parent / 'src')) +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.join(current_dir, 'src')) +sys.path.insert(0, current_dir) +# ==================== Imports ==================== +import asyncio +import logging +import json +from typing import Dict, Any, List +from pathlib import Path + +# Safe logging setup +class SafeUnicodeHandler(logging.StreamHandler): + """Ultra-safe Unicode stream handler that prevents all encoding issues""" + + def __init__(self, stream=None): + super().__init__(stream) + self.encoding = 'utf-8' + + def emit(self, record): + try: + msg = self.format(record) + + # Windows safety: replace problematic Unicode characters + if sys.platform.startswith('win'): + emoji_replacements = { + 'βœ…': '[SUCCESS]', '❌': '[ERROR]', '⚠️': '[WARNING]', + 'πŸ”„': '[RETRY]', '⏳': '[WAIT]', 'πŸ–ΌοΈ': '[IMAGE]', + 'πŸ“': '[FILES]', 'βš™οΈ': '[PARAMS]', '🎨': '[GENERATE]' + } + + for emoji, replacement in emoji_replacements.items(): + msg = msg.replace(emoji, replacement) + + # Ensure safe encoding + msg = msg.encode('utf-8', errors='replace').decode('utf-8', errors='replace') + + # Write safely + stream = self.stream + terminator = getattr(self, 'terminator', '\n') + + try: + if hasattr(stream, 'buffer'): + stream.buffer.write((msg + terminator).encode('utf-8', errors='replace')) + stream.buffer.flush() + else: + stream.write(msg + terminator) + if hasattr(stream, 'flush'): + stream.flush() + except (UnicodeEncodeError, UnicodeDecodeError): + try: + safe_msg = msg.encode('ascii', errors='replace').decode('ascii') + if hasattr(stream, 'buffer'): + stream.buffer.write((safe_msg + terminator).encode('ascii')) + stream.buffer.flush() + else: + stream.write(safe_msg + terminator) + if hasattr(stream, 'flush'): + stream.flush() + except Exception: + pass + except Exception: + self.handleError(record) + +# Setup safe logging +safe_handler = SafeUnicodeHandler(sys.stderr) +safe_handler.setFormatter(logging.Formatter( + '%(asctime)s [%(name)s] [%(levelname)s] %(message)s' +)) + +root_logger = logging.getLogger() +root_logger.setLevel(logging.INFO) +root_logger.handlers.clear() +root_logger.addHandler(safe_handler) + +logger = logging.getLogger("flux1-edit-mcp") + +# ==================== MCP Imports with Error Handling ==================== +try: + import mcp.types as types + from mcp.server import Server + from mcp.server.stdio import stdio_server + from mcp.server.models import InitializationOptions + from mcp.server import NotificationOptions + logger.info("MCP imports successful") +except ImportError as e: + logger.error(f"Failed to import MCP: {e}") + logger.error("Please install MCP: pip install mcp") + sys.exit(1) + +# ==================== Local Imports with Error Handling ==================== +try: + from src.connector.config import Config + from src.server.models import TOOL_DEFINITIONS, ToolName + from src.server.handlers import ToolHandlers + logger.info("Local imports successful") +except ImportError as e: + logger.error(f"Failed to import local modules: {e}") + logger.error(f"Current directory: {current_dir}") + logger.error(f"Python path: {sys.path}") + + # List available modules for debugging + try: + src_path = os.path.join(current_dir, 'src') + if os.path.exists(src_path): + logger.error(f"Available in src: {os.listdir(src_path)}") + for subdir in ['connector', 'server', 'utils']: + subdir_path = os.path.join(src_path, subdir) + if os.path.exists(subdir_path): + logger.error(f"Available in src/{subdir}: {os.listdir(subdir_path)}") + except Exception: + pass + + sys.exit(1) + +# ==================== Utility Functions ==================== +def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]: + """Remove or truncate sensitive data from arguments for safe logging""" + safe_args = {} + for key, value in arguments.items(): + if isinstance(value, str): + if (key.endswith('_b64') or key.endswith('_data') or + (len(value) > 100 and any(value.startswith(prefix) for prefix in ['iVBORw0KGgo', '/9j/', 'R0lGOD']))): + safe_args[key] = f"" + elif len(value) > 1000: + safe_args[key] = f"{value[:100]}..." + else: + safe_args[key] = value + else: + safe_args[key] = value + return safe_args + +# ==================== MCP Server Class ==================== +class FluxEditMCPServer: + """FLUX.1 Edit MCP server with enhanced error handling""" + + def __init__(self): + """Initialize server with comprehensive error handling""" + logger.info("Initializing FLUX.1 Edit MCP Server...") + + try: + # Create configuration + self.config = Config() + logger.info("Configuration created successfully") + + # Validate configuration + if not self.config.validate(): + raise RuntimeError("Configuration validation failed") + logger.info("Configuration validated successfully") + + # Create MCP server + self.server = Server("flux1-edit") + logger.info("MCP server instance created") + + # Create tool handlers + self.handlers = ToolHandlers(self.config) + logger.info("Tool handlers created successfully") + + # Register handlers + self._register_handlers() + logger.info("Handlers registered successfully") + + logger.info("FLUX.1 Edit MCP Server initialization complete") + + except Exception as e: + logger.error(f"Failed to initialize server: {e}", exc_info=True) + raise + + def _register_handlers(self): + """Register MCP handlers with comprehensive error handling""" + + @self.server.list_tools() + async def handle_list_tools() -> List[types.Tool]: + """List available tools""" + try: + logger.info("Listing available tools") + tools = [] + + for tool_name, tool_def in TOOL_DEFINITIONS.items(): + # Build properties for parameters + properties = {} + required = [] + + for param in tool_def.parameters: + prop_def = { + "type": param.type, + "description": param.description + } + + # Add enum if specified + if param.enum: + prop_def["enum"] = param.enum + + # Add default if specified + if param.default is not None: + prop_def["default"] = param.default + + properties[param.name] = prop_def + + if param.required: + required.append(param.name) + + # Build tool schema + tool = types.Tool( + name=tool_def.name, + description=tool_def.description, + inputSchema={ + "type": "object", + "properties": properties, + "required": required + } + ) + tools.append(tool) + + logger.info(f"Listed {len(tools)} tools successfully") + return tools + + except Exception as e: + logger.error(f"Error listing tools: {e}", exc_info=True) + raise + + @self.server.call_tool() + async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[types.TextContent | types.ImageContent]: + """Handle tool calls with comprehensive error handling""" + try: + # Log tool call safely + safe_args = sanitize_args_for_logging(arguments) + logger.info(f"Tool call: {name} with args: {safe_args}") + + # Route to appropriate handler + if name == ToolName.FLUX_EDIT_IMAGE: + return await self.handlers.handle_flux_edit_image(arguments) + elif name == ToolName.FLUX_EDIT_IMAGE_FROM_FILE: + return await self.handlers.handle_flux_edit_image_from_file(arguments) + elif name == ToolName.VALIDATE_IMAGE: + return await self.handlers.handle_validate_image(arguments) + elif name == ToolName.MOVE_TEMP_TO_OUTPUT: + return await self.handlers.handle_move_temp_to_output(arguments) + else: + error_msg = f"Unknown tool: {name}" + logger.error(error_msg) + return [types.TextContent( + type="text", + text=f"[ERROR] {error_msg}" + )] + + except Exception as e: + logger.error(f"Error calling tool {name}: {e}", exc_info=True) + return [types.TextContent( + type="text", + text=f"[ERROR] Tool execution error: {str(e)}" + )] + + async def run(self): + """Run the MCP server""" + try: + logger.info("Starting FLUX.1 Edit MCP Server...") + logger.info(f"API key configured: {'Yes' if self.config.api_key else 'No'}") + logger.info(f"Input directory: {self.config.input_path}") + logger.info(f"Output directory: {self.config.generated_images_path}") + + # Run server using stdio with proper initialization + async with stdio_server() as (read_stream, write_stream): + logger.info("MCP server started with stdio transport") + + await self.server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="flux1-edit", + server_version="1.0.0", + capabilities=self.server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={} + ) + ) + ) + + except Exception as e: + logger.error(f"Server run error: {e}", exc_info=True) + raise + +# ==================== Main Function ==================== +async def main(): + """Main entry point with comprehensive error handling""" + try: + logger.info("Starting FLUX.1 Edit MCP Server main function") + + # Create and run server + server = FluxEditMCPServer() + await server.run() + + except KeyboardInterrupt: + logger.info("Server stopped by user (Ctrl+C)") + return 0 + except Exception as e: + logger.error(f"Fatal server error: {e}", exc_info=True) + return 1 if __name__ == "__main__": try: - # Import and run the main server function - from src.server.mcp_server import main - asyncio.run(main()) + # Set up signal handling + import signal + + def signal_handler(signum, frame): + logger.info(f"Received signal {signum}, shutting down gracefully...") + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + # Run the server + logger.info("Starting FLUX.1 Edit MCP Server from main") + exit_code = asyncio.run(main()) + sys.exit(exit_code or 0) + except KeyboardInterrupt: + logger.info("Server stopped by user") sys.exit(0) + except SystemExit as e: + logger.info(f"Server exiting with code {e.code}") + sys.exit(e.code) except Exception as e: - # Log to stderr for debugging, but avoid stdout pollution for MCP - import logging - logging.basicConfig( - level=logging.ERROR, - format='%(asctime)s [%(name)s] %(levelname)s: %(message)s', - handlers=[ - logging.FileHandler('flux1-edit.log', mode='a', encoding='utf-8') - ] - ) - logger = logging.getLogger(__name__) - logger.error(f"Fatal error: {e}", exc_info=True) + logger.error(f"Fatal error in main: {e}", exc_info=True) sys.exit(1) diff --git a/requirements.txt b/requirements.txt index aa71b9f..bbbed94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,10 @@ # FLUX.1 Edit MCP Server Dependencies -# Core MCP Server - Updated version +# Core MCP Server mcp==1.2.0 -# HTTP Client for FLUX API -httpx==0.28.1 -aiohttp==3.11.7 +# HTTP Client for FLUX API - using stable version +aiohttp==3.9.5 # Image Processing Pillow==11.0.0 @@ -16,13 +15,7 @@ python-dotenv==1.0.1 # Data Validation pydantic==2.10.3 -# Async utilities - asyncio is built into Python 3.7+ -# asyncio-compat package not needed for modern Python versions - -# Logging -structlog==24.4.0 - -# Development and Testing (optional) +# Development and Testing pytest==8.3.4 pytest-asyncio==0.25.0 pytest-mock==3.14.0 diff --git a/simple_test.py b/simple_test.py deleted file mode 100644 index c4b06bc..0000000 --- a/simple_test.py +++ /dev/null @@ -1,126 +0,0 @@ -""" -Simple test script for FLUX.1 Edit MCP Server -""" - -import sys -import asyncio -import json -from pathlib import Path - -# Add src to path -sys.path.insert(0, str(Path(__file__).parent / 'src')) - -async def test_server_import(): - """Test if the server can be imported and initialized""" - print("πŸ” Testing server import and initialization...") - - try: - # Test imports - from src.server.mcp_server import create_server - from src.connector import Config - - print("βœ… Successfully imported server components") - - # Test configuration - config = Config() - if config.validate(): - print("βœ… Configuration validation passed") - else: - print("❌ Configuration validation failed") - return False - - # Test server creation - server = create_server() - print("βœ… Server created successfully") - - # Test server validation - if server.validate_config(): - print("βœ… Server configuration validated") - else: - print("❌ Server configuration validation failed") - return False - - return True - - except Exception as e: - print(f"❌ Error: {e}") - import traceback - traceback.print_exc() - return False - -async def test_mcp_protocol(): - """Test MCP protocol basics""" - print("\nπŸ” Testing MCP protocol basics...") - - try: - import mcp.types as types - from mcp.server import Server - from mcp.server.stdio import stdio_server - - print("βœ… MCP imports successful") - - # Create a minimal server for testing - server = Server("test-server") - - @server.list_tools() - async def list_tools(): - return [ - types.Tool( - name="test_tool", - description="Test tool", - inputSchema={ - "type": "object", - "properties": { - "test_param": { - "type": "string", - "description": "Test parameter" - } - }, - "required": ["test_param"] - } - ) - ] - - print("βœ… MCP server handlers configured") - return True - - except Exception as e: - print(f"❌ MCP protocol test failed: {e}") - import traceback - traceback.print_exc() - return False - -def main(): - """Main test function""" - print("πŸš€ FLUX.1 Edit MCP Server - Simple Test") - print("=" * 50) - - # Run async tests - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - test1_result = loop.run_until_complete(test_server_import()) - test2_result = loop.run_until_complete(test_mcp_protocol()) - - print("\n" + "=" * 50) - print("πŸ“Š TEST RESULTS") - print("=" * 50) - - print(f"Server Import & Init: {'βœ… PASS' if test1_result else '❌ FAIL'}") - print(f"MCP Protocol Basic: {'βœ… PASS' if test2_result else '❌ FAIL'}") - - if test1_result and test2_result: - print("\nπŸŽ‰ All tests passed! Server should be ready.") - print("\nπŸ’‘ Try running: python main.py") - return True - else: - print("\nπŸ”§ Some tests failed. Check the errors above.") - return False - - finally: - loop.close() - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) diff --git a/src/connector/__init__.py b/src/connector/__init__.py index 5ac8209..1bacbe7 100644 --- a/src/connector/__init__.py +++ b/src/connector/__init__.py @@ -21,12 +21,12 @@ except ImportError as e: try: import aiohttp except ImportError: - logger.error("aiohttp is not installed. Please run: pip install aiohttp==3.11.7") - print("\n❌ Missing dependency: aiohttp") + logger.error("aiohttp is not installed. Please run: pip install aiohttp==3.9.5") + print("\n[ERROR] Missing dependency: aiohttp") print("Please run one of the following commands:") print(" - install_dependencies.bat (on Windows)") print(" - pip install -r requirements.txt") - print(" - pip install aiohttp==3.11.7") + print(" - pip install aiohttp==3.9.5") sys.exit(1) raise diff --git a/src/server/__init__.py b/src/server/__init__.py index ea6bd16..75bcb74 100644 --- a/src/server/__init__.py +++ b/src/server/__init__.py @@ -1,7 +1,6 @@ """Server package for FLUX.1 Edit""" -from .mcp_server import FluxEditMCPServer, create_server, main from .handlers import ToolHandlers from .models import TOOL_DEFINITIONS, ToolName -__all__ = ['FluxEditMCPServer', 'create_server', 'main', 'ToolHandlers', 'TOOL_DEFINITIONS', 'ToolName'] +__all__ = ['ToolHandlers', 'TOOL_DEFINITIONS', 'ToolName'] diff --git a/src/server/handlers_backup.py b/src/server/handlers_backup.py deleted file mode 100644 index 87ab4bd..0000000 --- a/src/server/handlers_backup.py +++ /dev/null @@ -1,605 +0,0 @@ -"""MCP Tool Handlers for FLUX.1 Edit MCP Server""" - -import json -import logging -import random -from datetime import datetime -from pathlib import Path -from typing import Dict, Any, List, Optional - -from mcp.types import TextContent, ImageContent - -from ..connector import Config, FluxEditClient, FluxEditRequest -from ..utils import ( - validate_edit_parameters, - validate_file_parameters, - validate_image_path_parameter, - validate_move_file_parameters, - validate_image_file, - save_image, - encode_image_base64, - decode_image_base64, - sanitize_prompt, - get_image_dimensions, - convert_image_to_base64, - get_file_size_mb -) - -logger = logging.getLogger(__name__) - - -class ToolHandlers: - """Handler class for FLUX.1 Edit MCP tools""" - - def __init__(self, config: Config): - """Initialize handlers with configuration""" - self.config = config - self.current_seed = None # Track current seed for session - - def _get_or_create_seed(self) -> int: - """Get current seed or create new one""" - if self.current_seed is None: - self.current_seed = random.randint(0, 999999) - return self.current_seed - - def _reset_seed(self): - """Reset seed for new session""" - self.current_seed = None - - def _save_b64_to_temp_file(self, b64_data: str, filename: str) -> str: - """Save base64 data to a temporary file with specified filename - - Args: - b64_data: Base64 encoded image data - filename: Desired filename for the file - - Returns: - str: Path to saved file - """ - try: - # Decode base64 data - image_data = decode_image_base64(b64_data) - - # Save to local temp directory for processing - temp_dir = self.config.base_path / 'temp' - temp_dir.mkdir(exist_ok=True) - file_path = temp_dir / filename - - if not save_image(image_data, str(file_path)): - raise RuntimeError(f"Failed to save image to temp file: {filename}") - - logger.info(f"Saved temp file: {filename} ({len(image_data) / 1024:.1f} KB)") - - return str(file_path) - except Exception as e: - logger.error(f"Error saving b64 to temp file: {e}") - raise - - def _move_temp_to_generated(self, temp_file_path: str, base_name: str, index: int, extension: str = None) -> str: - """ - Move file from temp directory to generated_images directory - - Args: - temp_file_path: Path to temporary file - base_name: Base name for the destination file - index: Index for the file (0 for input, 1+ for output) - extension: File extension (will detect from temp file if not provided) - - Returns: - str: Path to moved file in generated_images directory - """ - try: - # Ensure output directory exists - self.config.ensure_output_directory() - - temp_path = Path(temp_file_path) - - # Verify source file exists - if not temp_path.exists(): - raise FileNotFoundError(f"Temp file not found: {temp_file_path}") - - # Detect extension from temp file if not provided - if extension is None: - extension = temp_path.suffix[1:] if temp_path.suffix else 'png' - - # Generate destination filename - dest_filename = self.config.generate_filename(base_name, index, extension) - dest_path = self.config.generated_images_path / dest_filename - - # Copy file (preserve original in temp for potential reuse) - import shutil - try: - shutil.copy2(temp_file_path, dest_path) - - # Verify copy was successful - if not dest_path.exists(): - raise RuntimeError(f"File copy verification failed: {dest_path}") - - # Check file sizes match - if temp_path.stat().st_size != dest_path.stat().st_size: - raise RuntimeError(f"File copy size mismatch: {temp_path.stat().st_size} != {dest_path.stat().st_size}") - - except PermissionError as e: - raise RuntimeError(f"Permission denied copying file to {dest_path}: {e}") - except shutil.Error as e: - raise RuntimeError(f"Copy operation failed: {e}") - - logger.info(f"Moved temp file to generated_images: {temp_path.name} β†’ {dest_filename}") - - return str(dest_path) - - except Exception as e: - logger.error(f"Error moving temp file to generated_images: {e}") - raise - - async def handle_flux_edit_image(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]: - """ - Handle flux_edit_image tool call - - Args: - arguments: Tool arguments - - Returns: - List of content items - """ - try: - # Validate parameters - is_valid, error_msg = validate_edit_parameters(arguments) - if not is_valid: - return [TextContent( - type="text", - text=f"❌ Parameter validation failed: {error_msg}" - )] - - # Extract parameters - input_image_b64 = arguments['input_image_b64'] - prompt = sanitize_prompt(arguments['prompt']) - seed = arguments['seed'] - aspect_ratio = arguments.get('aspect_ratio', self.config.default_aspect_ratio) - save_to_file = arguments.get('save_to_file', True) - - logger.info(f"Starting FLUX edit with seed {seed}") - - # Generate base name - base_name = self.config.generate_base_name(seed) - - # Save input image to temp and then to generated_images as 000 - temp_image_name = f'temp_input_{random.randint(1000, 9999)}.png' - temp_image_path = self._save_b64_to_temp_file(input_image_b64, temp_image_name) - - # Copy to generated_images as input (000) - input_generated_path = self._move_temp_to_generated(temp_image_path, base_name, 0) - logger.info(f"Input file saved: {Path(input_generated_path).name}") - - # Create FLUX edit request - request = FluxEditRequest( - input_image_b64=input_image_b64, - prompt=prompt, - seed=seed, - aspect_ratio=aspect_ratio, - safety_tolerance=self.config.safety_tolerance, - output_format=self.config.OUTPUT_FORMAT, - prompt_upsampling=self.config.prompt_upsampling - ) - - # Process edit using FLUX API - async with FluxEditClient(self.config) as client: - response = await client.edit_image(request) - - if not response.success: - return [TextContent( - type="text", - text=f"❌ FLUX edit failed: {response.error_message}" - )] - - # Save output image and metadata - saved_path = None - json_path = None - - if save_to_file: - output_path = self.config.get_output_path(base_name, 1, 'png') - - if save_image(response.edited_image_data, str(output_path)): - saved_path = str(output_path) - - # Save parameters as JSON - if self.config.save_parameters: - params_dict = { - "base_name": base_name, - "timestamp": datetime.now().isoformat(), - "model": self.config.MODEL_NAME, - "prompt": prompt, - "seed": seed, - "aspect_ratio": aspect_ratio, - "safety_tolerance": self.config.safety_tolerance, - "output_format": self.config.OUTPUT_FORMAT, - "prompt_upsampling": self.config.prompt_upsampling, - "input_image_temp": temp_image_name, - "input_generated_path": input_generated_path, - "output_size": response.image_size, - "execution_time": response.execution_time, - "request_id": response.request_id, - "metadata": response.metadata - } - - json_path = self.config.get_output_path(base_name, 1, 'json') - with open(json_path, 'w', encoding='utf-8') as f: - json.dump(params_dict, f, indent=2, ensure_ascii=False) - logger.info(f"Parameters saved to: {json_path}") - - # Prepare response - contents = [] - - # Add text description - text = f"βœ… Image edited successfully with FLUX.1 Kontext!\n" - text += f"🎲 Seed: {seed}\n" - text += f"πŸ“ Base name: {base_name}\n" - if response.image_size: - text += f"πŸ“ Size: {response.image_size[0]}x{response.image_size[1]}\n" - text += f"πŸ“ Aspect ratio: {aspect_ratio}\n" - text += f"⏱️ Processing time: {response.execution_time:.1f}s\n" - - if saved_path: - text += f"\nπŸ’Ύ Output: {Path(saved_path).name}" - text += f"\nπŸ“ Input: {Path(input_generated_path).name}" - if json_path: - text += f"\nπŸ“‹ Parameters: {Path(json_path).name}" - - contents.append(TextContent(type="text", text=text)) - - # Add image preview - if response.edited_image_data: - image_b64 = encode_image_base64(response.edited_image_data) - contents.append(ImageContent( - type="image", - data=image_b64, - mimeType="image/png" - )) - - # Reset seed for next session - self._reset_seed() - - return contents - - except Exception as e: - logger.error(f"Error in handle_flux_edit_image: {e}", exc_info=True) - return [TextContent( - type="text", - text=f"❌ Unexpected error: {str(e)}" - )] - - async def handle_flux_edit_image_from_file(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]: - """ - Handle flux_edit_image_from_file tool call - - Args: - arguments: Tool arguments - - Returns: - List of content items - """ - try: - # Validate parameters - is_valid, error_msg = validate_file_parameters(arguments) - if not is_valid: - return [TextContent( - type="text", - text=f"❌ Parameter validation failed: {error_msg}" - )] - - # Extract parameters - input_image_name = arguments['input_image_name'] - prompt = sanitize_prompt(arguments['prompt']) - seed = arguments['seed'] - aspect_ratio = arguments.get('aspect_ratio', self.config.default_aspect_ratio) - save_to_file = arguments.get('save_to_file', True) - - # Check if file exists in input directory - input_file_path = self.config.input_path / input_image_name - - if not input_file_path.exists(): - # Enhanced error message with debug info - error_text = f"❌ File not found in input directory: {input_image_name}\n" - error_text += f"πŸ“ Looking in: {self.config.input_path}\n" - error_text += f"πŸ” Full path: {input_file_path}\n" - error_text += f"πŸ“‚ Input directory exists: {self.config.input_path.exists()}\n" - - # List available files in input directory - if self.config.input_path.exists(): - files = [f.name for f in self.config.input_path.iterdir() if f.is_file()] - if files: - error_text += f"πŸ“‹ Available files: {', '.join(files[:10])}" - if len(files) > 10: - error_text += f" and {len(files) - 10} more..." - else: - error_text += "πŸ“‹ No files found in input directory" - else: - error_text += "⚠️ Input directory does not exist" - - return [TextContent(type="text", text=error_text)] - - # Validate the image file - is_valid, size_mb, validation_error = validate_image_file( - str(input_file_path), - self.config.max_image_size_mb - ) - if not is_valid: - return [TextContent( - type="text", - text=f"❌ Image validation failed: {validation_error}" - )] - - logger.info(f"Starting FLUX edit from file: {input_image_name} ({size_mb:.2f}MB)") - - # Convert image to base64 - try: - input_image_b64 = convert_image_to_base64(str(input_file_path)) - except Exception as e: - return [TextContent( - type="text", - text=f"❌ Failed to convert image to base64: {str(e)}" - )] - - # Generate base name - base_name = self.config.generate_base_name(seed) - - # Copy original file to generated_images as input (000) - try: - with open(input_file_path, 'rb') as f: - image_data = f.read() - - input_generated_path = self.config.get_output_path(base_name, 0, 'png') - if not save_image(image_data, str(input_generated_path)): - raise RuntimeError("Failed to save input to generated_images") - - logger.info(f"Input file copied: {Path(input_generated_path).name}") - - except Exception as e: - return [TextContent( - type="text", - text=f"❌ Failed to copy input file: {str(e)}" - )] - - # Create FLUX edit request - request = FluxEditRequest( - input_image_b64=input_image_b64, - prompt=prompt, - seed=seed, - aspect_ratio=aspect_ratio, - safety_tolerance=self.config.safety_tolerance, - output_format=self.config.OUTPUT_FORMAT, - prompt_upsampling=self.config.prompt_upsampling - ) - - # Process edit using FLUX API - async with FluxEditClient(self.config) as client: - response = await client.edit_image(request) - - if not response.success: - return [TextContent( - type="text", - text=f"❌ FLUX edit failed: {response.error_message}" - )] - - # Save output image and metadata - saved_path = None - json_path = None - - if save_to_file: - output_path = self.config.get_output_path(base_name, 1, 'png') - - if save_image(response.edited_image_data, str(output_path)): - saved_path = str(output_path) - - # Save parameters as JSON - if self.config.save_parameters: - params_dict = { - "base_name": base_name, - "timestamp": datetime.now().isoformat(), - "model": self.config.MODEL_NAME, - "prompt": prompt, - "seed": seed, - "aspect_ratio": aspect_ratio, - "safety_tolerance": self.config.safety_tolerance, - "output_format": self.config.OUTPUT_FORMAT, - "prompt_upsampling": self.config.prompt_upsampling, - "input_image_name": input_image_name, - "input_file_path": str(input_file_path), - "input_size": get_image_dimensions(str(input_file_path)), - "input_size_mb": size_mb, - "output_size": response.image_size, - "execution_time": response.execution_time, - "request_id": response.request_id, - "metadata": response.metadata - } - - json_path = self.config.get_output_path(base_name, 1, 'json') - with open(json_path, 'w', encoding='utf-8') as f: - json.dump(params_dict, f, indent=2, ensure_ascii=False) - logger.info(f"Parameters saved to: {json_path}") - - # Prepare response - contents = [] - - # Add text description - text = f"βœ… Image edited successfully from file with FLUX.1 Kontext!\n" - text += f"πŸ“ Input: {input_image_name} ({size_mb:.2f}MB)\n" - text += f"🎲 Seed: {seed}\n" - text += f"πŸ“ Base name: {base_name}\n" - if response.image_size: - text += f"πŸ“ Size: {response.image_size[0]}x{response.image_size[1]}\n" - text += f"πŸ“ Aspect ratio: {aspect_ratio}\n" - text += f"⏱️ Processing time: {response.execution_time:.1f}s\n" - - if saved_path: - text += f"\nπŸ’Ύ Output: {Path(saved_path).name}" - text += f"\nπŸ“ Input copy: {Path(input_generated_path).name}" - if json_path: - text += f"\nπŸ“‹ Parameters: {Path(json_path).name}" - - contents.append(TextContent(type="text", text=text)) - - # Add image preview - if response.edited_image_data: - image_b64 = encode_image_base64(response.edited_image_data) - contents.append(ImageContent( - type="image", - data=image_b64, - mimeType="image/png" - )) - - # Reset seed for next session - self._reset_seed() - - return contents - - except Exception as e: - logger.error(f"Error in handle_flux_edit_image_from_file: {e}", exc_info=True) - return [TextContent( - type="text", - text=f"❌ File-based edit error: {str(e)}" - )] - - async def handle_validate_image(self, arguments: Dict[str, Any]) -> List[TextContent]: - """ - Handle validate_image tool call - - Args: - arguments: Tool arguments - - Returns: - List of content items - """ - try: - # Validate parameters - is_valid, error_msg = validate_image_path_parameter(arguments) - if not is_valid: - return [TextContent( - type="text", - text=f"❌ Parameter validation failed: {error_msg}" - )] - - image_path = arguments['image_path'] - - # Validate image - is_valid, size_mb, error_msg = validate_image_file( - image_path, - self.config.max_image_size_mb - ) - - # Get additional info if valid - if is_valid: - width, height = get_image_dimensions(image_path) - - text = f"βœ… Image validation passed!\n" - text += f"πŸ“ File: {Path(image_path).name}\n" - text += f"πŸ“ Dimensions: {width}x{height}\n" - text += f"πŸ’Ύ Size: {size_mb:.2f}MB\n" - text += f"🎯 Max allowed: {self.config.max_image_size_mb}MB\n" - - # Check aspect ratio compatibility - from ..utils import get_optimal_aspect_ratio - optimal_ratio = get_optimal_aspect_ratio(width, height) - text += f"πŸ“ Optimal aspect ratio: {optimal_ratio}" - else: - text = f"❌ Image validation failed!\n" - text += f"πŸ“ File: {Path(image_path).name}\n" - text += f"⚠️ Issue: {error_msg}" - - return [TextContent(type="text", text=text)] - - except Exception as e: - logger.error(f"Error in handle_validate_image: {e}", exc_info=True) - return [TextContent( - type="text", - text=f"❌ Validation error: {str(e)}" - )] - - async def handle_move_temp_to_output(self, arguments: Dict[str, Any]) -> List[TextContent]: - """ - Handle move_temp_to_output tool call - - Args: - arguments: Tool arguments - - Returns: - List of content items - """ - try: - # Validate parameters - is_valid, error_msg = validate_move_file_parameters(arguments) - if not is_valid: - return [TextContent( - type="text", - text=f"❌ Parameter validation failed: {error_msg}" - )] - - temp_file_name = arguments['temp_file_name'] - output_file_name = arguments.get('output_file_name') - copy_only = arguments.get('copy_only', False) - - # Get temp file path - temp_file_path = self.config.base_path / 'temp' / temp_file_name - - # Check if temp file exists - if not temp_file_path.exists(): - return [TextContent( - type="text", - text=f"❌ Temp file not found: {temp_file_name}" - )] - - # Generate output file name if not provided - if not output_file_name: - base_name = self.config.generate_base_name_simple() - file_ext = Path(temp_file_name).suffix[1:] or 'png' - output_file_name = f"{base_name}_001.{file_ext}" - - # Ensure output directory exists - self.config.ensure_output_directory() - - # Get output path - output_path = self.config.generated_images_path / output_file_name - - # Move or copy file - try: - import shutil - if copy_only: - shutil.copy2(temp_file_path, output_path) - operation = "copied" - else: - shutil.move(str(temp_file_path), str(output_path)) - operation = "moved" - - # Verify operation was successful - if not output_path.exists(): - raise RuntimeError(f"File {operation} verification failed") - - logger.info(f"πŸ“ File {operation}: {temp_file_name} -> {output_file_name}") - - # Get file size for reporting - file_size_mb = output_path.stat().st_size / (1024 * 1024) - - text = f"βœ… File {operation} successfully!\n" - text += f"πŸ“ From temp: {temp_file_name}\n" - text += f"πŸ“ To output: {output_file_name}\n" - text += f"πŸ’Ύ Size: {file_size_mb:.2f}MB" - - return [TextContent(type="text", text=text)] - - except PermissionError as e: - return [TextContent( - type="text", - text=f"❌ Permission denied: {str(e)}" - )] - except Exception as e: - return [TextContent( - type="text", - text=f"❌ File operation failed: {str(e)}" - )] - - except Exception as e: - logger.error(f"Error in handle_move_temp_to_output: {e}", exc_info=True) - return [TextContent( - type="text", - text=f"❌ File move error: {str(e)}" - )] diff --git a/src/server/mcp_server.py b/src/server/mcp_server.py deleted file mode 100644 index e266e61..0000000 --- a/src/server/mcp_server.py +++ /dev/null @@ -1,172 +0,0 @@ -"""MCP Server for FLUX.1 Edit""" - -import logging -import asyncio -from typing import Dict, Any, List - -import mcp.types as types -from mcp.server import Server -from mcp.server.stdio import stdio_server - -from ..connector import Config -from .models import TOOL_DEFINITIONS, ToolName -from .handlers import ToolHandlers - -logger = logging.getLogger(__name__) - - -async def main(): - """Main entry point for the MCP server""" - - # Setup logging with minimal output for MCP compatibility - logging.basicConfig( - level=logging.WARNING, # Only warnings and errors - format='%(asctime)s [%(name)s] %(levelname)s: %(message)s', - handlers=[ - logging.FileHandler('flux1-edit.log', mode='a', encoding='utf-8') - ] - ) - - # Silence noisy loggers completely - logging.getLogger('aiohttp').setLevel(logging.ERROR) - logging.getLogger('PIL').setLevel(logging.ERROR) - logging.getLogger('httpx').setLevel(logging.ERROR) - - logger = logging.getLogger(__name__) - - try: - # Create configuration - config = Config() - - # Validate configuration - if not config.validate(): - logger.error("Configuration validation failed") - raise RuntimeError("Configuration validation failed") - - # Create MCP server - server = Server("flux1-edit") - - # Create tool handlers - handlers = ToolHandlers(config) - - logger.info("Setting up MCP server handlers...") - - # Set up list_tools handler - @server.list_tools() - async def list_tools() -> List[types.Tool]: - """List available tools""" - tools = [] - - for tool_name, tool_def in TOOL_DEFINITIONS.items(): - # Build properties for parameters - properties = {} - required = [] - - for param in tool_def.parameters: - prop_def = { - "type": param.type, - "description": param.description - } - - # Add enum if specified - if param.enum: - prop_def["enum"] = param.enum - - # Add default if specified - if param.default is not None: - prop_def["default"] = param.default - - properties[param.name] = prop_def - - if param.required: - required.append(param.name) - - # Build tool schema - tool = types.Tool( - name=tool_def.name, - description=tool_def.description, - inputSchema={ - "type": "object", - "properties": properties, - "required": required - } - ) - tools.append(tool) - - logger.debug(f"Listed {len(tools)} tools") - return tools - - # Set up call_tool handler - @server.call_tool() - async def call_tool(name: str, arguments: Dict[str, Any]) -> List[types.TextContent | types.ImageContent]: - """Handle tool calls""" - try: - logger.info(f"Tool call: {name}") - - # Sanitize arguments for logging - safe_args = arguments.copy() - if 'input_image_b64' in safe_args: - b64_data = safe_args['input_image_b64'] - safe_args['input_image_b64'] = f"" - if 'prompt' in safe_args and len(safe_args['prompt']) > 100: - safe_args['prompt'] = safe_args['prompt'][:100] + '...' - - logger.debug(f"Arguments: {safe_args}") - - # Route to appropriate handler - if name == ToolName.FLUX_EDIT_IMAGE: - return await handlers.handle_flux_edit_image(arguments) - elif name == ToolName.FLUX_EDIT_IMAGE_FROM_FILE: - return await handlers.handle_flux_edit_image_from_file(arguments) - elif name == ToolName.VALIDATE_IMAGE: - return await handlers.handle_validate_image(arguments) - elif name == ToolName.MOVE_TEMP_TO_OUTPUT: - return await handlers.handle_move_temp_to_output(arguments) - else: - return [types.TextContent( - type="text", - text=f"[ERROR] Unknown tool: {name}" - )] - - except Exception as e: - logger.error(f"Error calling tool {name}: {e}", exc_info=True) - return [types.TextContent( - type="text", - text=f"[ERROR] Tool execution error: {str(e)}" - )] - - logger.info("Starting FLUX.1 Edit MCP Server...") - logger.info(f"API key configured: {'Yes' if config.api_key else 'No'}") - logger.info(f"Input directory: {config.input_path}") - logger.info(f"Output directory: {config.generated_images_path}") - - # Run the server using stdio - await stdio_server(server) - - except Exception as e: - logger.error(f"Server startup failed: {e}", exc_info=True) - raise - - -class FluxEditMCPServer: - """Legacy wrapper class (kept for compatibility)""" - - def __init__(self): - """Initialize the MCP server""" - self.config = Config() - self.server = Server("flux1-edit") - self.handlers = ToolHandlers(self.config) - logger.info("FLUX.1 Edit MCP Server initialized") - - def validate_config(self) -> bool: - """Validate server configuration""" - return self.config.validate() - - -def create_server() -> FluxEditMCPServer: - """Create and return a FLUX.1 Edit MCP Server instance""" - return FluxEditMCPServer() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 1caca02..86fcf36 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -9,7 +9,6 @@ from .image_utils import ( save_image, encode_image_base64, decode_image_base64, - optimize_image_for_flux, convert_image_to_base64, validate_aspect_ratio, get_optimal_aspect_ratio @@ -37,7 +36,6 @@ __all__ = [ 'save_image', 'encode_image_base64', 'decode_image_base64', - 'optimize_image_for_flux', 'convert_image_to_base64', 'validate_aspect_ratio', 'get_optimal_aspect_ratio', diff --git a/src/utils/image_utils.py b/src/utils/image_utils.py index d198a54..d6e7a89 100644 --- a/src/utils/image_utils.py +++ b/src/utils/image_utils.py @@ -215,81 +215,6 @@ def decode_image_base64(base64_str: str) -> bytes: raise ValueError(f"Failed to decode base64 data: {e}") -def optimize_image_for_flux(image_path: str, max_size_mb: float = 20.0) -> bytes: - """ - Optimize image for FLUX.1 Kontext API (20MB limit) - - Args: - image_path: Path to input image - max_size_mb: Maximum size in MB (default: 20 for FLUX) - - Returns: - bytes: Optimized image data - """ - max_size_bytes = max_size_mb * 1024 * 1024 - - try: - with Image.open(image_path) as img: - # For FLUX, we want to preserve quality as much as possible - # since 20MB is quite generous - - # Convert to RGB if needed (FLUX typically prefers RGB) - if img.mode != 'RGB': - if img.mode == 'RGBA': - # Create white background for transparent images - background = Image.new('RGB', img.size, (255, 255, 255)) - background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None) - img = background - else: - img = img.convert('RGB') - - # Try PNG first (lossless) - buffer = io.BytesIO() - img.save(buffer, format='PNG', optimize=True) - png_data = buffer.getvalue() - - if len(png_data) <= max_size_bytes: - logger.info(f"Image optimized as PNG: {len(png_data) / (1024*1024):.2f}MB") - return png_data - - # PNG too large, try JPEG with high quality - for quality in [95, 90, 85, 80]: - buffer = io.BytesIO() - img.save(buffer, format='JPEG', quality=quality, optimize=True) - jpeg_data = buffer.getvalue() - - if len(jpeg_data) <= max_size_bytes: - size_mb = len(jpeg_data) / (1024 * 1024) - logger.info(f"Image optimized as JPEG (quality {quality}): {size_mb:.2f}MB") - return jpeg_data - - # Still too large, try resizing (preserve aspect ratio) - logger.warning("Image still too large, attempting resize...") - - scale = 0.95 - while scale > 0.5: - new_width = int(img.width * scale) - new_height = int(img.height * scale) - - resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS) - - buffer = io.BytesIO() - resized.save(buffer, format='JPEG', quality=85, optimize=True) - data = buffer.getvalue() - - if len(data) <= max_size_bytes: - size_mb = len(data) / (1024 * 1024) - logger.warning(f"Image resized to {new_width}x{new_height} ({scale*100:.0f}%): {size_mb:.2f}MB") - return data - - scale -= 0.05 - - raise ValueError(f"Cannot optimize image to under {max_size_mb}MB") - - except Exception as e: - logger.error(f"Error optimizing image: {e}") - raise - def convert_image_to_base64(image_path: str) -> str: """ @@ -302,16 +227,9 @@ def convert_image_to_base64(image_path: str) -> str: str: Base64 encoded image data """ try: - # Check if optimization is needed - current_size_mb = get_file_size_mb(image_path) - - if current_size_mb <= 20.0: - # Read directly if under limit - with open(image_path, 'rb') as f: - image_data = f.read() - else: - # Optimize if over limit - image_data = optimize_image_for_flux(image_path) + # Read image file directly (validation should handle size limits) + with open(image_path, 'rb') as f: + image_data = f.read() return encode_image_base64(image_data) diff --git a/start.bat b/start.bat deleted file mode 100644 index ed1400e..0000000 --- a/start.bat +++ /dev/null @@ -1,35 +0,0 @@ -@echo off -echo FLUX.1 Edit MCP Server - Quick Start -echo ==================================== - -REM Set UTF-8 encoding to prevent Unicode errors -chcp 65001 >nul 2>&1 -set PYTHONIOENCODING=utf-8 -set PYTHONUTF8=1 - -echo Running simple test first... -python simple_test.py -if errorlevel 1 ( - echo. - echo Simple test failed. Please check the errors above. - pause - exit /b 1 -) - -echo. -echo Simple test passed! Starting MCP server... -echo. - -REM Start the main MCP server -python main.py - -if errorlevel 1 ( - echo. - echo Server failed to start. Check flux1-edit.log for details. - pause - exit /b 1 -) - -echo. -echo Server stopped. -pause diff --git a/test_fixes.py b/test_fixes.py deleted file mode 100644 index f0261b1..0000000 --- a/test_fixes.py +++ /dev/null @@ -1,119 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify FLUX.1 Edit MCP Server fixes -This script tests if the Unicode and JSON parsing issues are resolved -""" - -import sys -import os -from pathlib import Path - -# Add src to path for imports -sys.path.insert(0, str(Path(__file__).parent / 'src')) - -def test_console_encoding(): - """Test console encoding with safe characters""" - print("Testing console output with safe characters...") - print("[OK] ASCII characters work fine") - print("[ERROR] Error messages use brackets instead of Unicode") - print("[SUCCESS] Success messages use brackets instead of Unicode") - print("[INFO] Info messages work properly") - return True - -def test_dependency_imports(): - """Test importing dependencies silently""" - print("Testing dependency imports...") - - missing_deps = [] - - try: - import aiohttp - # Silent check - except ImportError: - missing_deps.append("aiohttp") - - try: - import mcp - # Silent check - except ImportError: - missing_deps.append("mcp") - - try: - from src.connector import Config - # Silent check - except ImportError: - missing_deps.append("src.connector.Config") - - try: - from src.server import main - # Silent check - except ImportError: - missing_deps.append("src.server.main") - - if missing_deps: - print(f"[ERROR] Missing dependencies: {', '.join(missing_deps)}") - return False - else: - print("[SUCCESS] All imports successful") - return True - -def test_server_creation(): - """Test creating server instance without starting it""" - print("Testing server creation...") - - try: - from src.server import create_server - server = create_server() - print("[SUCCESS] Server instance created successfully") - return True - except Exception as e: - print(f"[ERROR] Server creation failed: {e}") - return False - -def main(): - """Main test function""" - print("FLUX.1 Edit MCP Server - Fix Verification") - print("=" * 50) - - # Set UTF-8 encoding environment variables - os.environ['PYTHONIOENCODING'] = 'utf-8' - os.environ['PYTHONUTF8'] = '1' - - tests = [ - ("Console encoding", test_console_encoding), - ("Dependency imports", test_dependency_imports), - ("Server creation", test_server_creation), - ] - - passed = 0 - total = len(tests) - - for test_name, test_func in tests: - print(f"\nRunning test: {test_name}") - print("-" * 30) - - try: - if test_func(): - passed += 1 - print(f"[PASSED] {test_name}") - else: - print(f"[FAILED] {test_name}") - except Exception as e: - print(f"[FAILED] {test_name}: {e}") - - print("\n" + "=" * 50) - print(f"Test Results: {passed}/{total} tests passed") - - if passed == total: - print("[SUCCESS] All tests passed! The fixes should work.") - print("\nTo run the server:") - print("1. Make sure your .env file is configured") - print("2. Run: start.bat or run.bat") - print("3. The server should start without Unicode errors") - return 0 - else: - print(f"[FAILED] {total - passed} tests failed. Check the errors above.") - return 1 - -if __name__ == "__main__": - sys.exit(main()) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/run_tests.py b/tests/run_tests.py deleted file mode 100644 index b2a2988..0000000 --- a/tests/run_tests.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Test runner for all FLUX.1 Edit MCP Server tests""" - -import unittest -import sys -from pathlib import Path - -# Add src to path -sys.path.insert(0, str(Path(__file__).parent.parent / 'src')) - -# Import all test modules -from test_config import TestConfig -from test_image_utils import TestImageUtils -from test_validation import TestValidation -from test_flux_client import TestFluxEditClient, TestFluxEditClientAsync -from test_handlers import TestToolHandlers - - -def create_test_suite(): - """Create and return the complete test suite""" - suite = unittest.TestSuite() - - # Add all test cases - suite.addTest(unittest.makeSuite(TestConfig)) - suite.addTest(unittest.makeSuite(TestImageUtils)) - suite.addTest(unittest.makeSuite(TestValidation)) - suite.addTest(unittest.makeSuite(TestFluxEditClient)) - suite.addTest(unittest.makeSuite(TestFluxEditClientAsync)) - suite.addTest(unittest.makeSuite(TestToolHandlers)) - - return suite - - -def run_tests(): - """Run all tests and return results""" - # Setup test runner - runner = unittest.TextTestRunner( - verbosity=2, - stream=sys.stdout, - descriptions=True, - failfast=False - ) - - # Create and run test suite - suite = create_test_suite() - result = runner.run(suite) - - # Print summary - print(f"\n{'='*70}") - print("TEST SUMMARY") - print(f"{'='*70}") - print(f"Tests run: {result.testsRun}") - print(f"Failures: {len(result.failures)}") - print(f"Errors: {len(result.errors)}") - print(f"Skipped: {len(result.skipped) if hasattr(result, 'skipped') else 0}") - - if result.failures: - print(f"\nFAILURES:") - for test, traceback in result.failures: - print(f" - {test}: {traceback.split('AssertionError: ')[-1].split()[0] if 'AssertionError:' in traceback else 'Unknown failure'}") - - if result.errors: - print(f"\nERRORS:") - for test, traceback in result.errors: - error_msg = traceback.split('\n')[-2] if traceback.split('\n')[-2] else 'Unknown error' - print(f" - {test}: {error_msg}") - - success = len(result.failures) == 0 and len(result.errors) == 0 - print(f"\nOVERALL: {'βœ… PASSED' if success else '❌ FAILED'}") - - return success - - -def run_specific_test(test_name: str): - """Run a specific test module""" - test_modules = { - 'config': TestConfig, - 'image_utils': TestImageUtils, - 'validation': TestValidation, - 'flux_client': TestFluxEditClient, - 'flux_client_async': TestFluxEditClientAsync, - 'handlers': TestToolHandlers - } - - if test_name not in test_modules: - print(f"❌ Unknown test module: {test_name}") - print(f"Available modules: {', '.join(test_modules.keys())}") - return False - - suite = unittest.makeSuite(test_modules[test_name]) - runner = unittest.TextTestRunner(verbosity=2) - result = runner.run(suite) - - return len(result.failures) == 0 and len(result.errors) == 0 - - -if __name__ == '__main__': - # Check for specific test argument - if len(sys.argv) > 1: - test_name = sys.argv[1] - success = run_specific_test(test_name) - else: - # Run all tests - success = run_tests() - - # Exit with appropriate code - sys.exit(0 if success else 1) diff --git a/tests/test_config.py b/tests/test_config.py deleted file mode 100644 index 29d4df6..0000000 --- a/tests/test_config.py +++ /dev/null @@ -1,159 +0,0 @@ -"""Unit tests for Config class""" - -import unittest -import tempfile -import os -from pathlib import Path -from unittest.mock import patch, MagicMock - -# Add src to path -import sys -sys.path.insert(0, str(Path(__file__).parent.parent / 'src')) - -from src.connector.config import Config - - -class TestConfig(unittest.TestCase): - """Test cases for Config class""" - - def setUp(self): - """Set up test fixtures""" - self.temp_dir = tempfile.mkdtemp() - self.temp_path = Path(self.temp_dir) - - def tearDown(self): - """Clean up test fixtures""" - import shutil - if self.temp_path.exists(): - shutil.rmtree(self.temp_path) - - @patch.dict(os.environ, { - 'FLUX_API_KEY': 'test_api_key_12345', - 'LOG_LEVEL': 'DEBUG', - 'MAX_IMAGE_SIZE_MB': '25', - 'DEFAULT_TIMEOUT': '600' - }) - def test_config_initialization(self): - """Test config initialization with environment variables""" - config = Config() - - self.assertEqual(config.api_key, 'test_api_key_12345') - self.assertEqual(config.log_level, 'DEBUG') - self.assertEqual(config.max_image_size_mb, 25) - self.assertEqual(config.default_timeout, 600) - - def test_config_defaults(self): - """Test config defaults when env vars not set""" - # Clear environment - env_vars_to_clear = [ - 'FLUX_API_KEY', 'LOG_LEVEL', 'MAX_IMAGE_SIZE_MB', - 'DEFAULT_TIMEOUT', 'POLLING_INTERVAL', 'MAX_POLLING_ATTEMPTS' - ] - - with patch.dict(os.environ, {}, clear=True): - config = Config() - - self.assertEqual(config.api_key, '') - self.assertEqual(config.log_level, 'INFO') - self.assertEqual(config.max_image_size_mb, 20) - self.assertEqual(config.default_timeout, 300) - self.assertEqual(config.polling_interval, 2) - self.assertEqual(config.max_polling_attempts, 150) - - def test_api_url_generation(self): - """Test API URL generation""" - config = Config() - - edit_url = config.get_api_url(config.EDIT_ENDPOINT) - result_url = config.get_api_url(config.RESULT_ENDPOINT) - - expected_edit = f"{config.api_base_url}/flux-kontext-pro" - expected_result = f"{config.api_base_url}/v1/get_result" - - self.assertEqual(edit_url, expected_edit) - self.assertEqual(result_url, expected_result) - - def test_filename_generation(self): - """Test filename generation""" - config = Config() - - base_name = "fluxedit_123456_20250826_143022" - - # Test different file numbers and extensions - filename_000 = config.generate_filename(base_name, 0, 'png') - filename_001 = config.generate_filename(base_name, 1, 'png') - filename_json = config.generate_filename(base_name, 1, 'json') - - self.assertEqual(filename_000, "fluxedit_123456_20250826_143022_000.png") - self.assertEqual(filename_001, "fluxedit_123456_20250826_143022_001.png") - self.assertEqual(filename_json, "fluxedit_123456_20250826_143022_001.json") - - def test_base_name_generation(self): - """Test base name generation""" - config = Config() - - # Test with seed - base_name_with_seed = config.generate_base_name(12345) - self.assertIn("fluxedit_12345_", base_name_with_seed) - - # Test simple generation - base_name_simple = config.generate_base_name_simple() - self.assertIn("fluxedit_", base_name_simple) - self.assertNotIn("_12345_", base_name_simple) # No seed in simple - - @patch.dict(os.environ, { - 'FLUX_API_KEY': 'valid_key', - 'MAX_IMAGE_SIZE_MB': '20', - 'DEFAULT_TIMEOUT': '300' - }) - def test_validation_success(self): - """Test successful validation""" - config = Config() - self.assertTrue(config.validate()) - - def test_validation_failures(self): - """Test validation failures""" - # Test missing API key - with patch.dict(os.environ, {'FLUX_API_KEY': ''}, clear=True): - config = Config() - self.assertFalse(config.validate()) - - # Test invalid image size - with patch.dict(os.environ, { - 'FLUX_API_KEY': 'valid_key', - 'MAX_IMAGE_SIZE_MB': '0' - }, clear=True): - config = Config() - self.assertFalse(config.validate()) - - # Test invalid timeout - with patch.dict(os.environ, { - 'FLUX_API_KEY': 'valid_key', - 'DEFAULT_TIMEOUT': '-1' - }, clear=True): - config = Config() - self.assertFalse(config.validate()) - - def test_max_image_size_bytes(self): - """Test max image size in bytes calculation""" - config = Config() - config.max_image_size_mb = 20 - - expected_bytes = 20 * 1024 * 1024 - self.assertEqual(config.get_max_image_size_bytes(), expected_bytes) - - @patch('pathlib.Path.mkdir') - @patch('pathlib.Path.exists') - def test_directory_creation(self, mock_exists, mock_mkdir): - """Test directory creation logic""" - mock_exists.return_value = True - - # This should not raise an exception - config = Config() - - # Verify mkdir was called - mock_mkdir.assert_called() - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_flux_client.py b/tests/test_flux_client.py deleted file mode 100644 index 18b15af..0000000 --- a/tests/test_flux_client.py +++ /dev/null @@ -1,308 +0,0 @@ -"""Unit tests for FLUX API client""" - -import unittest -import asyncio -from unittest.mock import AsyncMock, MagicMock, patch -from pathlib import Path - -# Add src to path -import sys -sys.path.insert(0, str(Path(__file__).parent.parent / 'src')) - -from src.connector.flux_client import FluxEditClient, FluxEditRequest, FluxEditResponse -from src.connector.config import Config - - -class TestFluxEditClient(unittest.TestCase): - """Test cases for FLUX API client""" - - def setUp(self): - """Set up test fixtures""" - # Mock config - self.config = MagicMock(spec=Config) - self.config.api_key = 'test_api_key' - self.config.default_timeout = 30 - self.config.polling_interval = 1 - self.config.max_polling_attempts = 5 - self.config.get_api_url.side_effect = lambda endpoint: f"https://api.test.com{endpoint}" - - self.client = FluxEditClient(self.config) - - # Sample request - self.sample_request = FluxEditRequest( - input_image_b64='test_base64_data', - prompt='Make the sky blue', - seed=12345, - aspect_ratio='16:9' - ) - - def tearDown(self): - """Clean up test fixtures""" - # Ensure client session is closed - if hasattr(self.client, 'session') and self.client.session: - asyncio.create_task(self.client.close()) - - @patch('aiohttp.ClientSession') - async def test_create_edit_request_success(self, mock_session_class): - """Test successful edit request creation""" - # Mock session and response - mock_session = AsyncMock() - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json.return_value = {'id': 'request_12345'} - - mock_session.post.return_value.__aenter__.return_value = mock_response - mock_session_class.return_value = mock_session - - # Test - request_id = await self.client._create_edit_request(self.sample_request) - - # Verify - self.assertEqual(request_id, 'request_12345') - mock_session.post.assert_called_once() - - # Check payload structure - call_args = mock_session.post.call_args - payload = call_args[1]['json'] - self.assertEqual(payload['prompt'], 'Make the sky blue') - self.assertEqual(payload['seed'], 12345) - self.assertEqual(payload['input_image'], 'test_base64_data') - - @patch('aiohttp.ClientSession') - async def test_create_edit_request_failure(self, mock_session_class): - """Test edit request creation failure""" - # Mock session and response - mock_session = AsyncMock() - mock_response = AsyncMock() - mock_response.status = 400 - mock_response.text.return_value = 'Bad Request' - - mock_session.post.return_value.__aenter__.return_value = mock_response - mock_session_class.return_value = mock_session - - # Test - request_id = await self.client._create_edit_request(self.sample_request) - - # Verify - self.assertIsNone(request_id) - - @patch('aiohttp.ClientSession') - async def test_poll_result_success(self, mock_session_class): - """Test successful result polling""" - # Mock session and responses - mock_session = AsyncMock() - - # First response: processing - mock_response_processing = AsyncMock() - mock_response_processing.status = 200 - mock_response_processing.json.return_value = {'status': 'processing'} - - # Second response: ready - mock_response_ready = AsyncMock() - mock_response_ready.status = 200 - mock_response_ready.json.return_value = { - 'status': 'ready', - 'result': {'sample': 'https://example.com/image.png'} - } - - # Mock to return processing first, then ready - mock_session.get.return_value.__aenter__.side_effect = [ - mock_response_processing, - mock_response_ready - ] - mock_session_class.return_value = mock_session - - # Test - result = await self.client._poll_result('test_request_id') - - # Verify - self.assertIsNotNone(result) - self.assertEqual(result['status'], 'ready') - self.assertIn('result', result) - self.assertEqual(mock_session.get.call_count, 2) - - @patch('aiohttp.ClientSession') - async def test_poll_result_timeout(self, mock_session_class): - """Test polling timeout""" - # Mock session to always return processing - mock_session = AsyncMock() - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json.return_value = {'status': 'processing'} - - mock_session.get.return_value.__aenter__.return_value = mock_response - mock_session_class.return_value = mock_session - - # Test - result = await self.client._poll_result('test_request_id') - - # Verify - should timeout after max attempts - self.assertIsNone(result) - self.assertEqual(mock_session.get.call_count, self.config.max_polling_attempts) - - @patch('aiohttp.ClientSession') - async def test_download_result_image_success(self, mock_session_class): - """Test successful image download""" - # Mock session and response - mock_session = AsyncMock() - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.read.return_value = b'fake_image_data' - - mock_session.get.return_value.__aenter__.return_value = mock_response - mock_session_class.return_value = mock_session - - # Test - image_data = await self.client._download_result_image('https://example.com/image.png') - - # Verify - self.assertEqual(image_data, b'fake_image_data') - mock_session.get.assert_called_once_with('https://example.com/image.png') - - @patch('aiohttp.ClientSession') - async def test_download_result_image_failure(self, mock_session_class): - """Test image download failure""" - # Mock session and response - mock_session = AsyncMock() - mock_response = AsyncMock() - mock_response.status = 404 - - mock_session.get.return_value.__aenter__.return_value = mock_response - mock_session_class.return_value = mock_session - - # Test - image_data = await self.client._download_result_image('https://example.com/image.png') - - # Verify - self.assertIsNone(image_data) - - def test_get_image_size(self): - """Test image size detection from bytes""" - # Create a small test image in memory - from PIL import Image - import io - - # Create 10x20 test image - img = Image.new('RGB', (10, 20), color='red') - buffer = io.BytesIO() - img.save(buffer, format='PNG') - image_data = buffer.getvalue() - - # Test - size = self.client._get_image_size(image_data) - - # Verify - self.assertEqual(size, (10, 20)) - - def test_get_image_size_invalid_data(self): - """Test image size detection with invalid data""" - size = self.client._get_image_size(b'invalid_image_data') - self.assertIsNone(size) - - @patch.object(FluxEditClient, '_create_edit_request') - @patch.object(FluxEditClient, '_poll_result') - @patch.object(FluxEditClient, '_download_result_image') - async def test_edit_image_success(self, mock_download, mock_poll, mock_create): - """Test complete successful edit flow""" - # Setup mocks - mock_create.return_value = 'request_123' - mock_poll.return_value = { - 'status': 'ready', - 'result': {'sample': 'https://example.com/result.png'} - } - mock_download.return_value = b'edited_image_data' - - # Test - response = await self.client.edit_image(self.sample_request) - - # Verify - self.assertTrue(response.success) - self.assertEqual(response.edited_image_data, b'edited_image_data') - self.assertEqual(response.request_id, 'request_123') - self.assertEqual(response.result_url, 'https://example.com/result.png') - self.assertGreater(response.execution_time, 0) - - @patch.object(FluxEditClient, '_create_edit_request') - async def test_edit_image_create_failure(self, mock_create): - """Test edit flow with creation failure""" - # Setup mock - mock_create.return_value = None - - # Test - response = await self.client.edit_image(self.sample_request) - - # Verify - self.assertFalse(response.success) - self.assertIn('Failed to create edit request', response.error_message) - - @patch.object(FluxEditClient, '_create_edit_request') - @patch.object(FluxEditClient, '_poll_result') - async def test_edit_image_poll_failure(self, mock_poll, mock_create): - """Test edit flow with polling failure""" - # Setup mocks - mock_create.return_value = 'request_123' - mock_poll.return_value = None - - # Test - response = await self.client.edit_image(self.sample_request) - - # Verify - self.assertFalse(response.success) - self.assertIn('Failed to get edit result', response.error_message) - - async def test_context_manager(self): - """Test async context manager functionality""" - async with FluxEditClient(self.config) as client: - self.assertIsInstance(client, FluxEditClient) - - # Session should be closed after context - if hasattr(client, 'session'): - self.assertTrue(client.session is None or client.session.closed) - - -# Test helper to run async tests -class AsyncTestCase(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() - - def async_test(self, coro): - return self.loop.run_until_complete(coro) - - -class TestFluxEditClientAsync(AsyncTestCase): - """Async test runner for FLUX client tests""" - - def setUp(self): - super().setUp() - self.config = MagicMock(spec=Config) - self.config.api_key = 'test_api_key' - self.config.default_timeout = 30 - self.config.polling_interval = 0.1 # Faster for tests - self.config.max_polling_attempts = 3 - self.config.get_api_url.side_effect = lambda endpoint: f"https://api.test.com{endpoint}" - - self.client = FluxEditClient(self.config) - self.sample_request = FluxEditRequest( - input_image_b64='test_base64_data', - prompt='Make the sky blue', - seed=12345, - aspect_ratio='16:9' - ) - - def test_async_context_manager(self): - """Test async context manager""" - async def run_test(): - async with FluxEditClient(self.config) as client: - self.assertIsInstance(client, FluxEditClient) - return True - - result = self.async_test(run_test()) - self.assertTrue(result) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_handlers.py b/tests/test_handlers.py deleted file mode 100644 index eecf77c..0000000 --- a/tests/test_handlers.py +++ /dev/null @@ -1,360 +0,0 @@ -"""Unit tests for MCP tool handlers""" - -import unittest -import tempfile -import asyncio -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch -from PIL import Image -import io - -# Add src to path -import sys -sys.path.insert(0, str(Path(__file__).parent.parent / 'src')) - -from src.server.handlers import ToolHandlers -from src.connector.config import Config -from src.connector.flux_client import FluxEditResponse -from mcp.types import TextContent, ImageContent - - -class AsyncTestCase(unittest.TestCase): - """Base class for async tests""" - - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() - - def async_test(self, coro): - return self.loop.run_until_complete(coro) - - -class TestToolHandlers(AsyncTestCase): - """Test cases for MCP tool handlers""" - - def setUp(self): - super().setUp() - - # Create temporary directory - self.temp_dir = tempfile.mkdtemp() - self.temp_path = Path(self.temp_dir) - - # Mock config - self.config = MagicMock(spec=Config) - self.config.base_path = self.temp_path - self.config.input_path = self.temp_path / 'input_images' - self.config.generated_images_path = self.temp_path / 'generated_images' - self.config.max_image_size_mb = 20 - self.config.default_aspect_ratio = '1:1' - self.config.safety_tolerance = 2 - self.config.OUTPUT_FORMAT = 'png' - self.config.prompt_upsampling = False - self.config.MODEL_NAME = 'flux-kontext-pro' - self.config.save_parameters = True - - # Setup directories - self.config.input_path.mkdir(parents=True, exist_ok=True) - self.config.generated_images_path.mkdir(parents=True, exist_ok=True) - - # Mock config methods - self.config.ensure_output_directory.return_value = None - self.config.generate_base_name.return_value = 'fluxedit_12345_20250826_143022' - self.config.generate_filename.side_effect = lambda base, num, ext: f'{base}_{num:03d}.{ext}' - self.config.get_output_path.side_effect = lambda base, num, ext: self.config.generated_images_path / f'{base}_{num:03d}.{ext}' - - # Create test image - self.test_image = Image.new('RGB', (100, 100), color='blue') - buffer = io.BytesIO() - self.test_image.save(buffer, format='PNG') - self.test_image_data = buffer.getvalue() - self.test_image_b64 = self._encode_image_b64(self.test_image_data) - - # Create test image file - self.test_image_file = self.config.input_path / 'test.png' - self.test_image.save(self.test_image_file) - - # Initialize handlers - self.handlers = ToolHandlers(self.config) - - def tearDown(self): - """Clean up test fixtures""" - import shutil - super().tearDown() - if self.temp_path.exists(): - shutil.rmtree(self.temp_path) - - def _encode_image_b64(self, image_data: bytes) -> str: - """Helper to encode image as base64""" - import base64 - return base64.b64encode(image_data).decode('utf-8') - - def _create_mock_flux_response(self, success: bool = True) -> FluxEditResponse: - """Helper to create mock FLUX response""" - if success: - return FluxEditResponse( - success=True, - edited_image_data=self.test_image_data, - image_size=(100, 100), - execution_time=5.5, - request_id='test_request_123', - result_url='https://example.com/result.png', - metadata={'seed': 12345} - ) - else: - return FluxEditResponse( - success=False, - error_message='FLUX edit failed', - execution_time=2.0 - ) - - def test_flux_edit_image_parameter_validation_failure(self): - """Test flux_edit_image with invalid parameters""" - async def run_test(): - # Missing required parameter - arguments = { - 'prompt': 'Make it blue', - 'seed': 12345 - # Missing input_image_b64 - } - - result = await self.handlers.handle_flux_edit_image(arguments) - - self.assertEqual(len(result), 1) - self.assertIsInstance(result[0], TextContent) - self.assertIn('Parameter validation failed', result[0].text) - self.assertIn('input_image_b64 is required', result[0].text) - - self.async_test(run_test()) - - @patch('src.server.handlers.FluxEditClient') - def test_flux_edit_image_success(self, mock_client_class): - """Test successful flux_edit_image""" - async def run_test(): - # Setup mock client - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - mock_client.edit_image.return_value = self._create_mock_flux_response(success=True) - - # Valid arguments - arguments = { - 'input_image_b64': self.test_image_b64, - 'prompt': 'Make the sky blue', - 'seed': 12345, - 'aspect_ratio': '16:9', - 'save_to_file': True - } - - result = await self.handlers.handle_flux_edit_image(arguments) - - # Verify result structure - self.assertGreater(len(result), 0) - self.assertIsInstance(result[0], TextContent) - self.assertIn('βœ… Image edited successfully', result[0].text) - self.assertIn('Seed: 12345', result[0].text) - - # Should have image preview - if len(result) > 1: - self.assertIsInstance(result[1], ImageContent) - - self.async_test(run_test()) - - @patch('src.server.handlers.FluxEditClient') - def test_flux_edit_image_flux_failure(self, mock_client_class): - """Test flux_edit_image with FLUX API failure""" - async def run_test(): - # Setup mock client to return failure - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - mock_client.edit_image.return_value = self._create_mock_flux_response(success=False) - - arguments = { - 'input_image_b64': self.test_image_b64, - 'prompt': 'Make it blue', - 'seed': 12345 - } - - result = await self.handlers.handle_flux_edit_image(arguments) - - self.assertEqual(len(result), 1) - self.assertIsInstance(result[0], TextContent) - self.assertIn('❌ FLUX edit failed', result[0].text) - - self.async_test(run_test()) - - def test_flux_edit_image_from_file_not_found(self): - """Test flux_edit_image_from_file with non-existent file""" - async def run_test(): - arguments = { - 'input_image_name': 'nonexistent.png', - 'prompt': 'Edit this', - 'seed': 12345 - } - - result = await self.handlers.handle_flux_edit_image_from_file(arguments) - - self.assertEqual(len(result), 1) - self.assertIsInstance(result[0], TextContent) - self.assertIn('❌ File not found', result[0].text) - self.assertIn('nonexistent.png', result[0].text) - - self.async_test(run_test()) - - @patch('src.server.handlers.FluxEditClient') - def test_flux_edit_image_from_file_success(self, mock_client_class): - """Test successful flux_edit_image_from_file""" - async def run_test(): - # Setup mock client - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - mock_client.edit_image.return_value = self._create_mock_flux_response(success=True) - - arguments = { - 'input_image_name': 'test.png', - 'prompt': 'Make it awesome', - 'seed': 54321, - 'save_to_file': True - } - - result = await self.handlers.handle_flux_edit_image_from_file(arguments) - - # Verify result - self.assertGreater(len(result), 0) - self.assertIsInstance(result[0], TextContent) - self.assertIn('βœ… Image edited successfully from file', result[0].text) - self.assertIn('test.png', result[0].text) - self.assertIn('Seed: 54321', result[0].text) - - self.async_test(run_test()) - - def test_validate_image_success(self): - """Test successful image validation""" - async def run_test(): - arguments = {'image_path': str(self.test_image_file)} - - result = await self.handlers.handle_validate_image(arguments) - - self.assertEqual(len(result), 1) - self.assertIsInstance(result[0], TextContent) - self.assertIn('βœ… Image validation passed', result[0].text) - self.assertIn('100x100', result[0].text) - - self.async_test(run_test()) - - def test_validate_image_not_found(self): - """Test image validation with non-existent file""" - async def run_test(): - arguments = {'image_path': '/nonexistent/path.png'} - - result = await self.handlers.handle_validate_image(arguments) - - self.assertEqual(len(result), 1) - self.assertIsInstance(result[0], TextContent) - self.assertIn('❌ Image validation failed', result[0].text) - self.assertIn('File not found', result[0].text) - - self.async_test(run_test()) - - def test_move_temp_to_output_success(self): - """Test successful file move operation""" - async def run_test(): - # Create temp directory and file - temp_dir = self.config.base_path / 'temp' - temp_dir.mkdir(exist_ok=True) - temp_file = temp_dir / 'temp_test.png' - self.test_image.save(temp_file) - - arguments = { - 'temp_file_name': 'temp_test.png', - 'output_file_name': 'moved_test.png', - 'copy_only': False - } - - result = await self.handlers.handle_move_temp_to_output(arguments) - - self.assertEqual(len(result), 1) - self.assertIsInstance(result[0], TextContent) - self.assertIn('βœ… File moved successfully', result[0].text) - self.assertIn('temp_test.png', result[0].text) - self.assertIn('moved_test.png', result[0].text) - - self.async_test(run_test()) - - def test_move_temp_to_output_not_found(self): - """Test file move with non-existent temp file""" - async def run_test(): - arguments = { - 'temp_file_name': 'nonexistent.png', - 'copy_only': False - } - - result = await self.handlers.handle_move_temp_to_output(arguments) - - self.assertEqual(len(result), 1) - self.assertIsInstance(result[0], TextContent) - self.assertIn('❌ Temp file not found', result[0].text) - - self.async_test(run_test()) - - def test_move_temp_to_output_copy_only(self): - """Test copy-only file operation""" - async def run_test(): - # Create temp directory and file - temp_dir = self.config.base_path / 'temp' - temp_dir.mkdir(exist_ok=True) - temp_file = temp_dir / 'temp_copy.png' - self.test_image.save(temp_file) - - arguments = { - 'temp_file_name': 'temp_copy.png', - 'copy_only': True - } - - result = await self.handlers.handle_move_temp_to_output(arguments) - - self.assertEqual(len(result), 1) - self.assertIsInstance(result[0], TextContent) - self.assertIn('βœ… File copied successfully', result[0].text) - - # Original file should still exist - self.assertTrue(temp_file.exists()) - - self.async_test(run_test()) - - def test_seed_management(self): - """Test seed creation and reset functionality""" - # Test seed generation - seed1 = self.handlers._get_or_create_seed() - seed2 = self.handlers._get_or_create_seed() - - # Should return same seed for session - self.assertEqual(seed1, seed2) - - # Test seed reset - self.handlers._reset_seed() - seed3 = self.handlers._get_or_create_seed() - - # Should be different after reset - self.assertNotEqual(seed1, seed3) - - def test_temp_file_operations(self): - """Test temporary file save and move operations""" - # Test saving b64 to temp - filename = 'test_temp.png' - temp_path = self.handlers._save_b64_to_temp_file(self.test_image_b64, filename) - - self.assertTrue(Path(temp_path).exists()) - self.assertIn(filename, temp_path) - - # Test moving to generated images - base_name = 'test_base' - moved_path = self.handlers._move_temp_to_generated(temp_path, base_name, 1) - - self.assertTrue(Path(moved_path).exists()) - self.assertIn('test_base_001.png', moved_path) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py deleted file mode 100644 index da33b7e..0000000 --- a/tests/test_image_utils.py +++ /dev/null @@ -1,197 +0,0 @@ -"""Unit tests for image utilities""" - -import unittest -import tempfile -import base64 -from pathlib import Path -from PIL import Image -import io - -# Add src to path -import sys -sys.path.insert(0, str(Path(__file__).parent.parent / 'src')) - -from src.utils.image_utils import ( - get_file_size_mb, - validate_image_file, - get_image_dimensions, - get_image_dimensions_from_bytes, - encode_image_base64, - decode_image_base64, - save_image, - get_optimal_aspect_ratio, - validate_aspect_ratio, - convert_image_to_base64 -) - - -class TestImageUtils(unittest.TestCase): - """Test cases for image utility functions""" - - def setUp(self): - """Set up test fixtures""" - self.temp_dir = tempfile.mkdtemp() - self.temp_path = Path(self.temp_dir) - - # Create a test image - self.test_image = Image.new('RGB', (100, 100), color='red') - self.test_image_path = self.temp_path / 'test_image.png' - self.test_image.save(self.test_image_path) - - # Create test image data - buffer = io.BytesIO() - self.test_image.save(buffer, format='PNG') - self.test_image_data = buffer.getvalue() - - def tearDown(self): - """Clean up test fixtures""" - import shutil - if self.temp_path.exists(): - shutil.rmtree(self.temp_path) - - def test_get_file_size_mb(self): - """Test file size calculation""" - size_mb = get_file_size_mb(self.test_image_path) - self.assertGreater(size_mb, 0) - self.assertLess(size_mb, 1) # Small test image should be < 1MB - - def test_validate_image_file_success(self): - """Test successful image validation""" - is_valid, size_mb, error = validate_image_file(str(self.test_image_path), 20) - - self.assertTrue(is_valid) - self.assertGreater(size_mb, 0) - self.assertIsNone(error) - - def test_validate_image_file_not_found(self): - """Test validation of non-existent file""" - is_valid, size_mb, error = validate_image_file('nonexistent.png', 20) - - self.assertFalse(is_valid) - self.assertEqual(size_mb, 0) - self.assertIn('File not found', error) - - def test_validate_image_file_too_large(self): - """Test validation of file too large""" - # Test with very small limit - is_valid, size_mb, error = validate_image_file(str(self.test_image_path), 0.001) - - self.assertFalse(is_valid) - self.assertIn('exceeds', error) - - def test_get_image_dimensions(self): - """Test getting image dimensions""" - width, height = get_image_dimensions(str(self.test_image_path)) - self.assertEqual(width, 100) - self.assertEqual(height, 100) - - def test_get_image_dimensions_from_bytes(self): - """Test getting dimensions from image bytes""" - width, height = get_image_dimensions_from_bytes(self.test_image_data) - self.assertEqual(width, 100) - self.assertEqual(height, 100) - - def test_encode_decode_base64(self): - """Test base64 encoding and decoding""" - # Encode - b64_string = encode_image_base64(self.test_image_data) - self.assertIsInstance(b64_string, str) - - # Decode - decoded_data = decode_image_base64(b64_string) - self.assertEqual(decoded_data, self.test_image_data) - - def test_decode_base64_with_data_url(self): - """Test decoding base64 with data URL prefix""" - b64_string = encode_image_base64(self.test_image_data) - data_url = f"data:image/png;base64,{b64_string}" - - decoded_data = decode_image_base64(data_url) - self.assertEqual(decoded_data, self.test_image_data) - - def test_save_image(self): - """Test saving image data to file""" - output_path = self.temp_path / 'output.png' - - success = save_image(self.test_image_data, str(output_path)) - self.assertTrue(success) - self.assertTrue(output_path.exists()) - - # Verify file content - with open(output_path, 'rb') as f: - saved_data = f.read() - self.assertEqual(saved_data, self.test_image_data) - - def test_get_optimal_aspect_ratio(self): - """Test optimal aspect ratio calculation""" - # Test square image - ratio = get_optimal_aspect_ratio(100, 100) - self.assertEqual(ratio, "1:1") - - # Test wide image - ratio = get_optimal_aspect_ratio(160, 90) - self.assertEqual(ratio, "16:9") - - # Test tall image - ratio = get_optimal_aspect_ratio(90, 160) - self.assertEqual(ratio, "9:16") - - def test_validate_aspect_ratio(self): - """Test aspect ratio validation""" - # Test matching ratio - self.assertTrue(validate_aspect_ratio(100, 100, "1:1")) - self.assertTrue(validate_aspect_ratio(160, 90, "16:9")) - - # Test non-matching ratio (within tolerance) - self.assertTrue(validate_aspect_ratio(161, 90, "16:9")) # Small difference - - # Test non-matching ratio (outside tolerance) - self.assertFalse(validate_aspect_ratio(200, 100, "1:1")) - - def test_convert_image_to_base64(self): - """Test converting image file to base64""" - b64_string = convert_image_to_base64(str(self.test_image_path)) - - self.assertIsInstance(b64_string, str) - - # Verify we can decode it back - decoded_data = decode_image_base64(b64_string) - - # Images should have same dimensions - width, height = get_image_dimensions_from_bytes(decoded_data) - self.assertEqual(width, 100) - self.assertEqual(height, 100) - - def create_large_image_file(self, size_mb: float) -> Path: - """Helper to create a large image file for testing""" - # Calculate dimensions for target size (rough estimate) - # PNG compression varies, so this is approximate - pixels = int((size_mb * 1024 * 1024) / 4) # 4 bytes per pixel (RGBA) - dimension = int(pixels ** 0.5) - - large_image = Image.new('RGBA', (dimension, dimension), color='red') - large_image_path = self.temp_path / 'large_image.png' - large_image.save(large_image_path) - - return large_image_path - - def test_large_image_handling(self): - """Test handling of large images""" - # This test might be slow, so we'll use a smaller "large" image - try: - large_path = self.create_large_image_file(0.1) # 0.1 MB - - # Test validation - is_valid, size_mb, error = validate_image_file(str(large_path), 20) - self.assertTrue(is_valid) - - # Test conversion to base64 - b64_string = convert_image_to_base64(str(large_path)) - self.assertIsInstance(b64_string, str) - - except Exception as e: - self.skipTest(f"Large image test skipped due to: {e}") - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_validation.py b/tests/test_validation.py deleted file mode 100644 index 7e30b85..0000000 --- a/tests/test_validation.py +++ /dev/null @@ -1,218 +0,0 @@ -"""Unit tests for validation utilities""" - -import unittest -from pathlib import Path - -# Add src to path -import sys -sys.path.insert(0, str(Path(__file__).parent.parent / 'src')) - -from src.utils.validation import ( - validate_edit_parameters, - validate_file_parameters, - validate_move_file_parameters, - validate_image_path_parameter, - sanitize_prompt, - validate_aspect_ratio_format, - validate_seed_range, - validate_filename_safety -) - - -class TestValidation(unittest.TestCase): - """Test cases for validation utilities""" - - def test_validate_edit_parameters_success(self): - """Test successful validation of edit parameters""" - valid_args = { - 'input_image_b64': 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==', - 'prompt': 'Make the sky blue', - 'seed': 12345, - 'aspect_ratio': '16:9', - 'save_to_file': True - } - - is_valid, error = validate_edit_parameters(valid_args) - self.assertTrue(is_valid) - self.assertIsNone(error) - - def test_validate_edit_parameters_missing_required(self): - """Test validation fails with missing required parameters""" - # Missing input_image_b64 - args = {'prompt': 'test', 'seed': 12345} - is_valid, error = validate_edit_parameters(args) - self.assertFalse(is_valid) - self.assertIn('input_image_b64 is required', error) - - # Missing prompt - args = {'input_image_b64': 'test_b64', 'seed': 12345} - is_valid, error = validate_edit_parameters(args) - self.assertFalse(is_valid) - self.assertIn('prompt is required', error) - - # Missing seed - args = {'input_image_b64': 'test_b64', 'prompt': 'test'} - is_valid, error = validate_edit_parameters(args) - self.assertFalse(is_valid) - self.assertIn('seed is required', error) - - def test_validate_edit_parameters_invalid_types(self): - """Test validation fails with invalid parameter types""" - # Invalid seed type - args = { - 'input_image_b64': 'valid_b64', - 'prompt': 'test prompt', - 'seed': 'not_a_number' - } - is_valid, error = validate_edit_parameters(args) - self.assertFalse(is_valid) - self.assertIn('seed must be an integer', error) - - # Invalid aspect ratio - args = { - 'input_image_b64': 'valid_b64', - 'prompt': 'test prompt', - 'seed': 12345, - 'aspect_ratio': 'invalid_ratio' - } - is_valid, error = validate_edit_parameters(args) - self.assertFalse(is_valid) - self.assertIn('aspect_ratio must be one of', error) - - def test_validate_edit_parameters_invalid_ranges(self): - """Test validation fails with invalid parameter ranges""" - # Seed out of range - args = { - 'input_image_b64': 'valid_b64', - 'prompt': 'test prompt', - 'seed': -1 - } - is_valid, error = validate_edit_parameters(args) - self.assertFalse(is_valid) - self.assertIn('seed must be between', error) - - # Prompt too long - args = { - 'input_image_b64': 'valid_b64', - 'prompt': 'x' * 10001, # Too long - 'seed': 12345 - } - is_valid, error = validate_edit_parameters(args) - self.assertFalse(is_valid) - self.assertIn('prompt is too long', error) - - def test_validate_file_parameters_success(self): - """Test successful validation of file parameters""" - valid_args = { - 'input_image_name': 'test.png', - 'prompt': 'Edit this image', - 'seed': 54321, - 'aspect_ratio': '1:1', - 'save_to_file': False - } - - is_valid, error = validate_file_parameters(valid_args) - self.assertTrue(is_valid) - self.assertIsNone(error) - - def test_validate_file_parameters_invalid_filename(self): - """Test validation fails with invalid filename""" - # Path traversal attempt - args = { - 'input_image_name': '../../../etc/passwd', - 'prompt': 'test', - 'seed': 12345 - } - is_valid, error = validate_file_parameters(args) - self.assertFalse(is_valid) - self.assertIn('cannot contain path separators', error) - - # Invalid extension - args = { - 'input_image_name': 'test.exe', - 'prompt': 'test', - 'seed': 12345 - } - is_valid, error = validate_file_parameters(args) - self.assertFalse(is_valid) - self.assertIn('must have a valid image extension', error) - - def test_validate_move_file_parameters_success(self): - """Test successful validation of move file parameters""" - valid_args = { - 'temp_file_name': 'temp_image.png', - 'output_file_name': 'output.png', - 'copy_only': True - } - - is_valid, error = validate_move_file_parameters(valid_args) - self.assertTrue(is_valid) - self.assertIsNone(error) - - def test_validate_image_path_parameter_success(self): - """Test successful validation of image path parameter""" - valid_args = {'image_path': '/path/to/image.png'} - - is_valid, error = validate_image_path_parameter(valid_args) - self.assertTrue(is_valid) - self.assertIsNone(error) - - def test_sanitize_prompt(self): - """Test prompt sanitization""" - # Test whitespace normalization - prompt = " This has extra whitespace " - sanitized = sanitize_prompt(prompt) - self.assertEqual(sanitized, "This has extra whitespace") - - # Test null byte removal - prompt = "Test\x00with\x00nulls" - sanitized = sanitize_prompt(prompt) - self.assertEqual(sanitized, "Testwithulls") - - # Test length limiting - long_prompt = "x" * 10001 - sanitized = sanitize_prompt(long_prompt) - self.assertEqual(len(sanitized), 10000) - - def test_validate_aspect_ratio_format(self): - """Test aspect ratio format validation""" - # Valid formats - self.assertTrue(validate_aspect_ratio_format("16:9")) - self.assertTrue(validate_aspect_ratio_format("1:1")) - self.assertTrue(validate_aspect_ratio_format("4:3")) - - # Invalid formats - self.assertFalse(validate_aspect_ratio_format("16-9")) - self.assertFalse(validate_aspect_ratio_format("16:9:1")) - self.assertFalse(validate_aspect_ratio_format("a:b")) - self.assertFalse(validate_aspect_ratio_format("0:1")) - - def test_validate_seed_range(self): - """Test seed range validation""" - # Valid seeds - self.assertTrue(validate_seed_range(0)) - self.assertTrue(validate_seed_range(12345)) - self.assertTrue(validate_seed_range(2**32 - 1)) - - # Invalid seeds - self.assertFalse(validate_seed_range(-1)) - self.assertFalse(validate_seed_range(2**32)) - self.assertFalse(validate_seed_range("not_a_number")) - - def test_validate_filename_safety(self): - """Test filename safety validation""" - # Safe filenames - self.assertTrue(validate_filename_safety("image.png")) - self.assertTrue(validate_filename_safety("my_image_123.jpg")) - self.assertTrue(validate_filename_safety("test-file.png")) - - # Unsafe filenames - self.assertFalse(validate_filename_safety("../image.png")) - self.assertFalse(validate_filename_safety("path/to/image.png")) - self.assertFalse(validate_filename_safety("image<>.png")) - self.assertFalse(validate_filename_safety("CON.png")) # Windows reserved - self.assertFalse(validate_filename_safety("x" * 256)) # Too long - - -if __name__ == '__main__': - unittest.main() diff --git a/troubleshoot.bat b/troubleshoot.bat deleted file mode 100644 index df1def2..0000000 --- a/troubleshoot.bat +++ /dev/null @@ -1,150 +0,0 @@ -@echo off -echo FLUX.1 Edit MCP Server - Troubleshooting -echo ======================================= - -echo 1. Checking Python installation... -python --version -if errorlevel 1 ( - echo [ERROR] Python is not installed or not in PATH - goto :end -) else ( - echo [OK] Python is available -) - -echo. -echo 2. Checking pip installation... -pip --version -if errorlevel 1 ( - echo [ERROR] pip is not available - goto :end -) else ( - echo [OK] pip is available -) - -echo. -echo 3. Checking virtual environment... -if exist "venv" ( - echo [OK] Virtual environment directory exists - call venv\Scripts\activate.bat - if errorlevel 1 ( - echo [ERROR] Cannot activate virtual environment - goto :end - ) else ( - echo [OK] Virtual environment activated - ) -) else ( - echo [WARNING] Virtual environment not found - echo Creating virtual environment... - python -m venv venv - if errorlevel 1 ( - echo [ERROR] Failed to create virtual environment - goto :end - ) - call venv\Scripts\activate.bat - echo [OK] Virtual environment created and activated -) - -echo. -echo 4. Checking Python in virtual environment... -where python -python --version - -echo. -echo 5. Checking critical dependencies... -python -c "import aiohttp; print(f'aiohttp: {aiohttp.__version__}')" 2>nul -if errorlevel 1 ( - echo [ERROR] aiohttp not found - installing... - pip install aiohttp==3.11.7 -) else ( - echo [OK] aiohttp is available -) - -python -c "import httpx; print(f'httpx: {httpx.__version__}')" 2>nul -if errorlevel 1 ( - echo [ERROR] httpx not found - installing... - pip install httpx==0.28.1 -) else ( - echo [OK] httpx is available -) - -python -c "import mcp" 2>nul -if errorlevel 1 ( - echo [ERROR] mcp not found - installing... - pip install mcp==1.1.0 -) else ( - echo [OK] mcp is available -) - -python -c "from PIL import Image; print('Pillow: available')" 2>nul -if errorlevel 1 ( - echo [ERROR] Pillow not found - installing... - pip install Pillow==11.0.0 -) else ( - echo [OK] Pillow is available -) - -echo. -echo 6. Checking configuration files... -if exist ".env" ( - echo [OK] .env file exists -) else ( - echo [WARNING] .env file not found - if exist ".env.example" ( - echo Creating .env from example... - copy .env.example .env - echo [OK] .env file created from example - ) else ( - echo [ERROR] .env.example file not found - ) -) - -echo. -echo 7. Checking required directories... -if not exist "input_images" mkdir input_images & echo [OK] Created input_images directory -if not exist "generated_images" mkdir generated_images & echo [OK] Created generated_images directory -if not exist "temp" mkdir temp & echo [OK] Created temp directory - -echo. -echo 8. Testing basic imports... -python -c " -import sys -print(f'Python executable: {sys.executable}') -print(f'Python version: {sys.version}') -print('Testing imports...') -try: - import aiohttp - print('βœ“ aiohttp imported successfully') -except ImportError as e: - print(f'βœ— aiohttp import failed: {e}') - -try: - import mcp - print('βœ“ mcp imported successfully') -except ImportError as e: - print(f'βœ— mcp import failed: {e}') - -try: - from src.connector import Config - print('βœ“ Local Config imported successfully') -except ImportError as e: - print(f'βœ— Local Config import failed: {e}') - -try: - from src.server import main - print('βœ“ Local server main imported successfully') -except ImportError as e: - print(f'βœ— Local server main import failed: {e}') -" - -echo. -echo Troubleshooting complete! -echo. -echo If you still have issues: -echo 1. Delete venv folder and run install_dependencies.bat -echo 2. Make sure you have a stable internet connection -echo 3. Check if your antivirus is blocking Python/pip -echo 4. Try running as administrator -echo. - -:end -pause