clean up code
This commit is contained in:
9
.claude/settings.local.json
Normal file
9
.claude/settings.local.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(py debug_test.py)"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
}
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FLUX.1 Edit MCP Server - Dependency Check Script
|
||||
This script checks if all required dependencies are properly installed
|
||||
"""
|
||||
|
||||
import sys
|
||||
import importlib.util
|
||||
from typing import List, Tuple
|
||||
|
||||
def check_dependency(module_name: str, package_name: str = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
Check if a dependency is installed and importable
|
||||
|
||||
Args:
|
||||
module_name: Name of the module to import
|
||||
package_name: Name of the package to install (if different from module)
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message)
|
||||
"""
|
||||
if package_name is None:
|
||||
package_name = module_name
|
||||
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec is None:
|
||||
return False, f"[MISSING] {module_name} not found - install with: pip install {package_name}"
|
||||
|
||||
# Try to actually import it
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Get version if available
|
||||
version = getattr(module, '__version__', 'unknown')
|
||||
return True, f"[OK] {module_name} {version}"
|
||||
|
||||
except ImportError as e:
|
||||
return False, f"[ERROR] {module_name} import failed: {e}"
|
||||
except Exception as e:
|
||||
return False, f"[ERROR] {module_name} error: {e}"
|
||||
|
||||
def check_local_modules() -> List[Tuple[bool, str]]:
|
||||
"""Check local project modules"""
|
||||
results = []
|
||||
|
||||
try:
|
||||
# Add src to path temporarily
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
src_path = Path(__file__).parent / 'src'
|
||||
if src_path.exists():
|
||||
sys.path.insert(0, str(src_path))
|
||||
|
||||
# Test local imports
|
||||
try:
|
||||
from connector.config import Config
|
||||
results.append((True, "[OK] Local Config module"))
|
||||
except Exception as e:
|
||||
results.append((False, f"[ERROR] Local Config module: {e}"))
|
||||
|
||||
try:
|
||||
from connector.flux_client import FluxEditClient
|
||||
results.append((True, "[OK] Local FluxEditClient module"))
|
||||
except Exception as e:
|
||||
results.append((False, f"[ERROR] Local FluxEditClient module: {e}"))
|
||||
|
||||
try:
|
||||
from server.mcp_server import FluxEditMCPServer
|
||||
results.append((True, "[OK] Local MCP Server module"))
|
||||
except Exception as e:
|
||||
results.append((False, f"[ERROR] Local MCP Server module: {e}"))
|
||||
|
||||
except Exception as e:
|
||||
results.append((False, f"[ERROR] Local module check failed: {e}"))
|
||||
|
||||
return results
|
||||
|
||||
def main():
|
||||
"""Main dependency check function"""
|
||||
print("FLUX.1 Edit MCP Server - Dependency Check")
|
||||
print("=========================================")
|
||||
print(f"Python version: {sys.version}")
|
||||
print(f"Python executable: {sys.executable}")
|
||||
print()
|
||||
|
||||
# Required dependencies with their install names
|
||||
dependencies = [
|
||||
("aiohttp", "aiohttp==3.11.7"),
|
||||
("httpx", "httpx==0.28.1"),
|
||||
("mcp", "mcp==1.1.0"),
|
||||
("PIL", "Pillow==11.0.0"),
|
||||
("dotenv", "python-dotenv==1.0.1"),
|
||||
("pydantic", "pydantic==2.10.3"),
|
||||
("structlog", "structlog==24.4.0"),
|
||||
]
|
||||
|
||||
# Optional dependencies for development
|
||||
optional_dependencies = [
|
||||
("pytest", "pytest==8.3.4"),
|
||||
("black", "black==24.10.0"),
|
||||
]
|
||||
|
||||
all_good = True
|
||||
|
||||
print("Checking required dependencies...")
|
||||
print("-" * 50)
|
||||
|
||||
for module_name, package_name in dependencies:
|
||||
success, message = check_dependency(module_name, package_name)
|
||||
print(message)
|
||||
if not success:
|
||||
all_good = False
|
||||
|
||||
print()
|
||||
print("Checking optional dependencies...")
|
||||
print("-" * 50)
|
||||
|
||||
for module_name, package_name in optional_dependencies:
|
||||
success, message = check_dependency(module_name, package_name)
|
||||
print(message)
|
||||
|
||||
print()
|
||||
print("Checking local modules...")
|
||||
print("-" * 50)
|
||||
|
||||
local_results = check_local_modules()
|
||||
for success, message in local_results:
|
||||
print(message)
|
||||
if not success:
|
||||
all_good = False
|
||||
|
||||
print()
|
||||
print("=" * 50)
|
||||
|
||||
if all_good:
|
||||
print("[SUCCESS] All required dependencies are installed and working!")
|
||||
print("You can now run: python main.py")
|
||||
return 0
|
||||
else:
|
||||
print("[FAILED] Some dependencies are missing or broken.")
|
||||
print()
|
||||
print("To fix this, try:")
|
||||
print("1. Run: install_dependencies.bat (Windows)")
|
||||
print("2. Or run: pip install -r requirements.txt")
|
||||
print("3. Or run individual pip install commands shown above")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
189
debug_test.py
Normal file
189
debug_test.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
FLUX.1 Edit MCP Server Debug Test
|
||||
|
||||
This script helps diagnose issues with the MCP server
|
||||
"""
|
||||
|
||||
# Force UTF-8 encoding setup
|
||||
import os
|
||||
import sys
|
||||
import locale
|
||||
|
||||
# Force UTF-8 environment variables
|
||||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||||
os.environ['PYTHONUTF8'] = '1'
|
||||
os.environ['LC_ALL'] = 'C.UTF-8'
|
||||
|
||||
# Windows-specific UTF-8 setup
|
||||
if sys.platform.startswith('win'):
|
||||
try:
|
||||
os.system('chcp 65001 >nul 2>&1')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
locale.setlocale(locale.LC_ALL, 'C.UTF-8')
|
||||
except locale.Error:
|
||||
try:
|
||||
locale.setlocale(locale.LC_ALL, '')
|
||||
except locale.Error:
|
||||
pass
|
||||
|
||||
import traceback
|
||||
|
||||
# Add paths
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, 'src'))
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
def test_imports():
|
||||
"""Test all required imports"""
|
||||
print("=== Import Tests ===")
|
||||
|
||||
# Test MCP imports
|
||||
try:
|
||||
import mcp.types as types
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
print("[SUCCESS] MCP imports successful")
|
||||
except ImportError as e:
|
||||
print(f"[ERROR] MCP import failed: {e}")
|
||||
return False
|
||||
|
||||
# Test local imports
|
||||
try:
|
||||
from src.connector.config import Config
|
||||
print("[SUCCESS] Config import successful")
|
||||
except ImportError as e:
|
||||
print(f"[ERROR] Config import failed: {e}")
|
||||
print(f"Error details: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from src.server.models import TOOL_DEFINITIONS, ToolName
|
||||
print("[SUCCESS] Models import successful")
|
||||
except ImportError as e:
|
||||
print(f"[ERROR] Models import failed: {e}")
|
||||
print(f"Error details: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from src.server.handlers import ToolHandlers
|
||||
print("[SUCCESS] Handlers import successful")
|
||||
except ImportError as e:
|
||||
print(f"[ERROR] Handlers import failed: {e}")
|
||||
print(f"Error details: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
# Test aiohttp
|
||||
try:
|
||||
import aiohttp
|
||||
print(f"[SUCCESS] aiohttp version: {aiohttp.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"[ERROR] aiohttp import failed: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_config():
|
||||
"""Test configuration loading"""
|
||||
print("\n=== Configuration Tests ===")
|
||||
|
||||
try:
|
||||
from src.connector.config import Config
|
||||
config = Config()
|
||||
print("[SUCCESS] Config created successfully")
|
||||
|
||||
# Check API key
|
||||
if config.api_key:
|
||||
print(f"[SUCCESS] API key configured: ***{config.api_key[-4:]}")
|
||||
else:
|
||||
print("[WARNING] API key not configured")
|
||||
|
||||
# Check paths
|
||||
print(f"[SUCCESS] Input path: {config.input_path}")
|
||||
print(f"[SUCCESS] Output path: {config.generated_images_path}")
|
||||
|
||||
# Test validation
|
||||
if config.validate():
|
||||
print("[SUCCESS] Configuration validation passed")
|
||||
else:
|
||||
print("[ERROR] Configuration validation failed")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Configuration test failed: {e}")
|
||||
print(f"Error details: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
def test_handlers():
|
||||
"""Test handler creation"""
|
||||
print("\n=== Handler Tests ===")
|
||||
|
||||
try:
|
||||
from src.connector.config import Config
|
||||
from src.server.handlers import ToolHandlers
|
||||
|
||||
config = Config()
|
||||
handlers = ToolHandlers(config)
|
||||
print("[SUCCESS] Handlers created successfully")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Handler test failed: {e}")
|
||||
print(f"Error details: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
def test_server_creation():
|
||||
"""Test MCP server creation"""
|
||||
print("\n=== Server Creation Tests ===")
|
||||
|
||||
try:
|
||||
from mcp.server import Server
|
||||
|
||||
server = Server("flux1-edit-test")
|
||||
print("[SUCCESS] MCP server created successfully")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Server creation failed: {e}")
|
||||
print(f"Error details: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("FLUX.1 Edit MCP Server Debug Test")
|
||||
print("=" * 50)
|
||||
|
||||
tests = [
|
||||
test_imports,
|
||||
test_config,
|
||||
test_handlers,
|
||||
test_server_creation
|
||||
]
|
||||
|
||||
passed = 0
|
||||
for test in tests:
|
||||
try:
|
||||
if test():
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Test crashed: {e}")
|
||||
|
||||
print(f"\n=== Results ===")
|
||||
print(f"Passed: {passed}/{len(tests)}")
|
||||
|
||||
if passed == len(tests):
|
||||
print("[SUCCESS] All tests passed! The server should work.")
|
||||
return 0
|
||||
else:
|
||||
print("[ERROR] Some tests failed. Check the errors above.")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
249
diagnostic.py
249
diagnostic.py
@@ -1,249 +0,0 @@
|
||||
"""
|
||||
Test script to diagnose FLUX.1 Edit MCP Server issues
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, str(Path(__file__).parent / 'src'))
|
||||
|
||||
def test_imports():
|
||||
"""Test all required imports"""
|
||||
print("🔍 Testing imports...")
|
||||
|
||||
# Test standard library imports
|
||||
try:
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
print("✅ Standard library imports OK")
|
||||
except ImportError as e:
|
||||
print(f"❌ Standard library import failed: {e}")
|
||||
return False
|
||||
|
||||
# Test third-party dependencies
|
||||
deps = {
|
||||
'aiohttp': '3.11.7',
|
||||
'httpx': '0.28.1',
|
||||
'mcp': '1.1.0+',
|
||||
'PIL': 'Pillow 11.0.0',
|
||||
'dotenv': 'python-dotenv 1.0.1',
|
||||
'pydantic': '2.10.3'
|
||||
}
|
||||
|
||||
for module, expected in deps.items():
|
||||
try:
|
||||
__import__(module)
|
||||
print(f"✅ {module} ({expected}) OK")
|
||||
except ImportError as e:
|
||||
print(f"❌ {module} missing: {e}")
|
||||
return False
|
||||
|
||||
# Test MCP specific imports
|
||||
try:
|
||||
import mcp.types as types
|
||||
import mcp.server
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
print("✅ MCP imports OK")
|
||||
except ImportError as e:
|
||||
print(f"❌ MCP import failed: {e}")
|
||||
return False
|
||||
|
||||
# Test local module imports
|
||||
try:
|
||||
from src.connector import Config
|
||||
print("✅ Config import OK")
|
||||
except ImportError as e:
|
||||
print(f"❌ Config import failed: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from src.server.models import TOOL_DEFINITIONS, ToolName
|
||||
print("✅ Models import OK")
|
||||
except ImportError as e:
|
||||
print(f"❌ Models import failed: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from src.server.handlers import ToolHandlers
|
||||
print("✅ Handlers import OK")
|
||||
except ImportError as e:
|
||||
print(f"❌ Handlers import failed: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from src.utils import validate_edit_parameters
|
||||
print("✅ Utils import OK")
|
||||
except ImportError as e:
|
||||
print(f"❌ Utils import failed: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_config():
|
||||
"""Test configuration loading"""
|
||||
print("\n🔍 Testing configuration...")
|
||||
|
||||
try:
|
||||
from src.connector import Config
|
||||
config = Config()
|
||||
|
||||
print(f"✅ Config loaded")
|
||||
print(f" - API key: {'***' + config.api_key[-4:] if config.api_key else 'Not Set'}")
|
||||
print(f" - Input path: {config.input_path}")
|
||||
print(f" - Output path: {config.generated_images_path}")
|
||||
|
||||
# Test config validation
|
||||
if config.validate():
|
||||
print("✅ Config validation passed")
|
||||
else:
|
||||
print("❌ Config validation failed")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Config test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_mcp_server_creation():
|
||||
"""Test MCP server creation"""
|
||||
print("\n🔍 Testing MCP server creation...")
|
||||
|
||||
try:
|
||||
from src.server.mcp_server import FluxEditMCPServer
|
||||
|
||||
server = FluxEditMCPServer()
|
||||
print("✅ MCP server created successfully")
|
||||
|
||||
# Test tool definitions
|
||||
from src.server.models import TOOL_DEFINITIONS
|
||||
print(f"✅ Tool definitions loaded: {len(TOOL_DEFINITIONS)} tools")
|
||||
for tool_name in TOOL_DEFINITIONS:
|
||||
print(f" - {tool_name}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ MCP server creation failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def test_mcp_server_init():
|
||||
"""Test MCP server initialization"""
|
||||
print("\n🔍 Testing MCP server initialization...")
|
||||
|
||||
try:
|
||||
from src.server.mcp_server import create_server
|
||||
|
||||
server = create_server()
|
||||
print("✅ MCP server initialized")
|
||||
|
||||
# Test configuration validation
|
||||
if server.validate_config():
|
||||
print("✅ Server config validation passed")
|
||||
else:
|
||||
print("❌ Server config validation failed")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ MCP server initialization failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_directories():
|
||||
"""Test directory structure"""
|
||||
print("\n🔍 Testing directory structure...")
|
||||
|
||||
base_path = Path(__file__).parent
|
||||
|
||||
required_dirs = [
|
||||
'src',
|
||||
'src/connector',
|
||||
'src/server',
|
||||
'src/utils',
|
||||
'input_images',
|
||||
'generated_images'
|
||||
]
|
||||
|
||||
for dir_path in required_dirs:
|
||||
full_path = base_path / dir_path
|
||||
if full_path.exists():
|
||||
print(f"✅ {dir_path}/")
|
||||
else:
|
||||
print(f"❌ {dir_path}/ (missing)")
|
||||
try:
|
||||
full_path.mkdir(parents=True, exist_ok=True)
|
||||
print(f"✅ Created {dir_path}/")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to create {dir_path}/: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
print("🚀 FLUX.1 Edit MCP Server Diagnostic Tool")
|
||||
print("=" * 50)
|
||||
|
||||
tests = [
|
||||
("Directory Structure", test_directories),
|
||||
("Import Dependencies", test_imports),
|
||||
("Configuration", test_config),
|
||||
("MCP Server Creation", test_mcp_server_creation),
|
||||
("MCP Server Initialization", lambda: asyncio.run(test_mcp_server_init())),
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_name, test_func in tests:
|
||||
print(f"\n📋 Running {test_name}...")
|
||||
try:
|
||||
result = test_func()
|
||||
results.append((test_name, result))
|
||||
if result:
|
||||
print(f"✅ {test_name} PASSED")
|
||||
else:
|
||||
print(f"❌ {test_name} FAILED")
|
||||
except Exception as e:
|
||||
print(f"❌ {test_name} FAILED with exception: {e}")
|
||||
results.append((test_name, False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 50)
|
||||
print("📊 DIAGNOSTIC SUMMARY")
|
||||
print("=" * 50)
|
||||
|
||||
passed = sum(1 for _, result in results if result)
|
||||
total = len(results)
|
||||
|
||||
for test_name, result in results:
|
||||
status = "✅ PASS" if result else "❌ FAIL"
|
||||
print(f"{status} {test_name}")
|
||||
|
||||
print(f"\n🎯 Results: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("🎉 All tests passed! Server should work correctly.")
|
||||
else:
|
||||
print("🔧 Some tests failed. Please fix the issues above.")
|
||||
|
||||
return passed == total
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
370
main.py
370
main.py
@@ -1,30 +1,362 @@
|
||||
"""Main entry point for FLUX.1 Edit MCP Server"""
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
FLUX.1 Edit MCP Server - Fixed Version
|
||||
|
||||
import asyncio
|
||||
FLUX.1 Kontext를 사용한 AI 이미지 편집 MCP 서버
|
||||
- Enhanced error handling and UTF-8 support
|
||||
- MCP protocol compliance
|
||||
- Based on imagen4 server structure
|
||||
"""
|
||||
|
||||
# ==================== CRITICAL: UTF-8 SETUP MUST BE FIRST ====================
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import locale
|
||||
|
||||
# Force UTF-8 environment variables - set immediately
|
||||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||||
os.environ['PYTHONUTF8'] = '1'
|
||||
os.environ['LC_ALL'] = 'C.UTF-8'
|
||||
|
||||
# Windows-specific UTF-8 setup
|
||||
if sys.platform.startswith('win'):
|
||||
try:
|
||||
os.system('chcp 65001 >nul 2>&1')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
locale.setlocale(locale.LC_ALL, 'C.UTF-8')
|
||||
except locale.Error:
|
||||
try:
|
||||
locale.setlocale(locale.LC_ALL, '')
|
||||
except locale.Error:
|
||||
pass
|
||||
|
||||
# Add src to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent / 'src'))
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, 'src'))
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
# ==================== Imports ====================
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from pathlib import Path
|
||||
|
||||
# Safe logging setup
|
||||
class SafeUnicodeHandler(logging.StreamHandler):
|
||||
"""Ultra-safe Unicode stream handler that prevents all encoding issues"""
|
||||
|
||||
def __init__(self, stream=None):
|
||||
super().__init__(stream)
|
||||
self.encoding = 'utf-8'
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
msg = self.format(record)
|
||||
|
||||
# Windows safety: replace problematic Unicode characters
|
||||
if sys.platform.startswith('win'):
|
||||
emoji_replacements = {
|
||||
'✅': '[SUCCESS]', '❌': '[ERROR]', '⚠️': '[WARNING]',
|
||||
'🔄': '[RETRY]', '⏳': '[WAIT]', '🖼️': '[IMAGE]',
|
||||
'📁': '[FILES]', '⚙️': '[PARAMS]', '🎨': '[GENERATE]'
|
||||
}
|
||||
|
||||
for emoji, replacement in emoji_replacements.items():
|
||||
msg = msg.replace(emoji, replacement)
|
||||
|
||||
# Ensure safe encoding
|
||||
msg = msg.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
|
||||
|
||||
# Write safely
|
||||
stream = self.stream
|
||||
terminator = getattr(self, 'terminator', '\n')
|
||||
|
||||
try:
|
||||
if hasattr(stream, 'buffer'):
|
||||
stream.buffer.write((msg + terminator).encode('utf-8', errors='replace'))
|
||||
stream.buffer.flush()
|
||||
else:
|
||||
stream.write(msg + terminator)
|
||||
if hasattr(stream, 'flush'):
|
||||
stream.flush()
|
||||
except (UnicodeEncodeError, UnicodeDecodeError):
|
||||
try:
|
||||
safe_msg = msg.encode('ascii', errors='replace').decode('ascii')
|
||||
if hasattr(stream, 'buffer'):
|
||||
stream.buffer.write((safe_msg + terminator).encode('ascii'))
|
||||
stream.buffer.flush()
|
||||
else:
|
||||
stream.write(safe_msg + terminator)
|
||||
if hasattr(stream, 'flush'):
|
||||
stream.flush()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
# Setup safe logging
|
||||
safe_handler = SafeUnicodeHandler(sys.stderr)
|
||||
safe_handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s [%(name)s] [%(levelname)s] %(message)s'
|
||||
))
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.INFO)
|
||||
root_logger.handlers.clear()
|
||||
root_logger.addHandler(safe_handler)
|
||||
|
||||
logger = logging.getLogger("flux1-edit-mcp")
|
||||
|
||||
# ==================== MCP Imports with Error Handling ====================
|
||||
try:
|
||||
import mcp.types as types
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.server import NotificationOptions
|
||||
logger.info("MCP imports successful")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import MCP: {e}")
|
||||
logger.error("Please install MCP: pip install mcp")
|
||||
sys.exit(1)
|
||||
|
||||
# ==================== Local Imports with Error Handling ====================
|
||||
try:
|
||||
from src.connector.config import Config
|
||||
from src.server.models import TOOL_DEFINITIONS, ToolName
|
||||
from src.server.handlers import ToolHandlers
|
||||
logger.info("Local imports successful")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import local modules: {e}")
|
||||
logger.error(f"Current directory: {current_dir}")
|
||||
logger.error(f"Python path: {sys.path}")
|
||||
|
||||
# List available modules for debugging
|
||||
try:
|
||||
src_path = os.path.join(current_dir, 'src')
|
||||
if os.path.exists(src_path):
|
||||
logger.error(f"Available in src: {os.listdir(src_path)}")
|
||||
for subdir in ['connector', 'server', 'utils']:
|
||||
subdir_path = os.path.join(src_path, subdir)
|
||||
if os.path.exists(subdir_path):
|
||||
logger.error(f"Available in src/{subdir}: {os.listdir(subdir_path)}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
# ==================== Utility Functions ====================
|
||||
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Remove or truncate sensitive data from arguments for safe logging"""
|
||||
safe_args = {}
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, str):
|
||||
if (key.endswith('_b64') or key.endswith('_data') or
|
||||
(len(value) > 100 and any(value.startswith(prefix) for prefix in ['iVBORw0KGgo', '/9j/', 'R0lGOD']))):
|
||||
safe_args[key] = f"<image_data:{len(value)} chars>"
|
||||
elif len(value) > 1000:
|
||||
safe_args[key] = f"{value[:100]}...<truncated:{len(value)} chars>"
|
||||
else:
|
||||
safe_args[key] = value
|
||||
else:
|
||||
safe_args[key] = value
|
||||
return safe_args
|
||||
|
||||
# ==================== MCP Server Class ====================
|
||||
class FluxEditMCPServer:
|
||||
"""FLUX.1 Edit MCP server with enhanced error handling"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize server with comprehensive error handling"""
|
||||
logger.info("Initializing FLUX.1 Edit MCP Server...")
|
||||
|
||||
try:
|
||||
# Create configuration
|
||||
self.config = Config()
|
||||
logger.info("Configuration created successfully")
|
||||
|
||||
# Validate configuration
|
||||
if not self.config.validate():
|
||||
raise RuntimeError("Configuration validation failed")
|
||||
logger.info("Configuration validated successfully")
|
||||
|
||||
# Create MCP server
|
||||
self.server = Server("flux1-edit")
|
||||
logger.info("MCP server instance created")
|
||||
|
||||
# Create tool handlers
|
||||
self.handlers = ToolHandlers(self.config)
|
||||
logger.info("Tool handlers created successfully")
|
||||
|
||||
# Register handlers
|
||||
self._register_handlers()
|
||||
logger.info("Handlers registered successfully")
|
||||
|
||||
logger.info("FLUX.1 Edit MCP Server initialization complete")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize server: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _register_handlers(self):
|
||||
"""Register MCP handlers with comprehensive error handling"""
|
||||
|
||||
@self.server.list_tools()
|
||||
async def handle_list_tools() -> List[types.Tool]:
|
||||
"""List available tools"""
|
||||
try:
|
||||
logger.info("Listing available tools")
|
||||
tools = []
|
||||
|
||||
for tool_name, tool_def in TOOL_DEFINITIONS.items():
|
||||
# Build properties for parameters
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param in tool_def.parameters:
|
||||
prop_def = {
|
||||
"type": param.type,
|
||||
"description": param.description
|
||||
}
|
||||
|
||||
# Add enum if specified
|
||||
if param.enum:
|
||||
prop_def["enum"] = param.enum
|
||||
|
||||
# Add default if specified
|
||||
if param.default is not None:
|
||||
prop_def["default"] = param.default
|
||||
|
||||
properties[param.name] = prop_def
|
||||
|
||||
if param.required:
|
||||
required.append(param.name)
|
||||
|
||||
# Build tool schema
|
||||
tool = types.Tool(
|
||||
name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
)
|
||||
tools.append(tool)
|
||||
|
||||
logger.info(f"Listed {len(tools)} tools successfully")
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing tools: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
@self.server.call_tool()
|
||||
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[types.TextContent | types.ImageContent]:
|
||||
"""Handle tool calls with comprehensive error handling"""
|
||||
try:
|
||||
# Log tool call safely
|
||||
safe_args = sanitize_args_for_logging(arguments)
|
||||
logger.info(f"Tool call: {name} with args: {safe_args}")
|
||||
|
||||
# Route to appropriate handler
|
||||
if name == ToolName.FLUX_EDIT_IMAGE:
|
||||
return await self.handlers.handle_flux_edit_image(arguments)
|
||||
elif name == ToolName.FLUX_EDIT_IMAGE_FROM_FILE:
|
||||
return await self.handlers.handle_flux_edit_image_from_file(arguments)
|
||||
elif name == ToolName.VALIDATE_IMAGE:
|
||||
return await self.handlers.handle_validate_image(arguments)
|
||||
elif name == ToolName.MOVE_TEMP_TO_OUTPUT:
|
||||
return await self.handlers.handle_move_temp_to_output(arguments)
|
||||
else:
|
||||
error_msg = f"Unknown tool: {name}"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(
|
||||
type="text",
|
||||
text=f"[ERROR] {error_msg}"
|
||||
)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling tool {name}: {e}", exc_info=True)
|
||||
return [types.TextContent(
|
||||
type="text",
|
||||
text=f"[ERROR] Tool execution error: {str(e)}"
|
||||
)]
|
||||
|
||||
async def run(self):
|
||||
"""Run the MCP server"""
|
||||
try:
|
||||
logger.info("Starting FLUX.1 Edit MCP Server...")
|
||||
logger.info(f"API key configured: {'Yes' if self.config.api_key else 'No'}")
|
||||
logger.info(f"Input directory: {self.config.input_path}")
|
||||
logger.info(f"Output directory: {self.config.generated_images_path}")
|
||||
|
||||
# Run server using stdio with proper initialization
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
logger.info("MCP server started with stdio transport")
|
||||
|
||||
await self.server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name="flux1-edit",
|
||||
server_version="1.0.0",
|
||||
capabilities=self.server.get_capabilities(
|
||||
notification_options=NotificationOptions(),
|
||||
experimental_capabilities={}
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Server run error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# ==================== Main Function ====================
|
||||
async def main():
|
||||
"""Main entry point with comprehensive error handling"""
|
||||
try:
|
||||
logger.info("Starting FLUX.1 Edit MCP Server main function")
|
||||
|
||||
# Create and run server
|
||||
server = FluxEditMCPServer()
|
||||
await server.run()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server stopped by user (Ctrl+C)")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal server error: {e}", exc_info=True)
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# Import and run the main server function
|
||||
from src.server.mcp_server import main
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
# Set up signal handling
|
||||
import signal
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}, shutting down gracefully...")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Run the server
|
||||
logger.info("Starting FLUX.1 Edit MCP Server from main")
|
||||
exit_code = asyncio.run(main())
|
||||
sys.exit(exit_code or 0)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server stopped by user")
|
||||
sys.exit(0)
|
||||
except SystemExit as e:
|
||||
logger.info(f"Server exiting with code {e.code}")
|
||||
sys.exit(e.code)
|
||||
except Exception as e:
|
||||
# Log to stderr for debugging, but avoid stdout pollution for MCP
|
||||
import logging
|
||||
logging.basicConfig(
|
||||
level=logging.ERROR,
|
||||
format='%(asctime)s [%(name)s] %(levelname)s: %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('flux1-edit.log', mode='a', encoding='utf-8')
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Fatal error: {e}", exc_info=True)
|
||||
logger.error(f"Fatal error in main: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# FLUX.1 Edit MCP Server Dependencies
|
||||
|
||||
# Core MCP Server - Updated version
|
||||
# Core MCP Server
|
||||
mcp==1.2.0
|
||||
|
||||
# HTTP Client for FLUX API
|
||||
httpx==0.28.1
|
||||
aiohttp==3.11.7
|
||||
# HTTP Client for FLUX API - using stable version
|
||||
aiohttp==3.9.5
|
||||
|
||||
# Image Processing
|
||||
Pillow==11.0.0
|
||||
@@ -16,13 +15,7 @@ python-dotenv==1.0.1
|
||||
# Data Validation
|
||||
pydantic==2.10.3
|
||||
|
||||
# Async utilities - asyncio is built into Python 3.7+
|
||||
# asyncio-compat package not needed for modern Python versions
|
||||
|
||||
# Logging
|
||||
structlog==24.4.0
|
||||
|
||||
# Development and Testing (optional)
|
||||
# Development and Testing
|
||||
pytest==8.3.4
|
||||
pytest-asyncio==0.25.0
|
||||
pytest-mock==3.14.0
|
||||
|
||||
126
simple_test.py
126
simple_test.py
@@ -1,126 +0,0 @@
|
||||
"""
|
||||
Simple test script for FLUX.1 Edit MCP Server
|
||||
"""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, str(Path(__file__).parent / 'src'))
|
||||
|
||||
async def test_server_import():
|
||||
"""Test if the server can be imported and initialized"""
|
||||
print("🔍 Testing server import and initialization...")
|
||||
|
||||
try:
|
||||
# Test imports
|
||||
from src.server.mcp_server import create_server
|
||||
from src.connector import Config
|
||||
|
||||
print("✅ Successfully imported server components")
|
||||
|
||||
# Test configuration
|
||||
config = Config()
|
||||
if config.validate():
|
||||
print("✅ Configuration validation passed")
|
||||
else:
|
||||
print("❌ Configuration validation failed")
|
||||
return False
|
||||
|
||||
# Test server creation
|
||||
server = create_server()
|
||||
print("✅ Server created successfully")
|
||||
|
||||
# Test server validation
|
||||
if server.validate_config():
|
||||
print("✅ Server configuration validated")
|
||||
else:
|
||||
print("❌ Server configuration validation failed")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def test_mcp_protocol():
|
||||
"""Test MCP protocol basics"""
|
||||
print("\n🔍 Testing MCP protocol basics...")
|
||||
|
||||
try:
|
||||
import mcp.types as types
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
|
||||
print("✅ MCP imports successful")
|
||||
|
||||
# Create a minimal server for testing
|
||||
server = Server("test-server")
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools():
|
||||
return [
|
||||
types.Tool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"test_param": {
|
||||
"type": "string",
|
||||
"description": "Test parameter"
|
||||
}
|
||||
},
|
||||
"required": ["test_param"]
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
print("✅ MCP server handlers configured")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ MCP protocol test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
print("🚀 FLUX.1 Edit MCP Server - Simple Test")
|
||||
print("=" * 50)
|
||||
|
||||
# Run async tests
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
test1_result = loop.run_until_complete(test_server_import())
|
||||
test2_result = loop.run_until_complete(test_mcp_protocol())
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("📊 TEST RESULTS")
|
||||
print("=" * 50)
|
||||
|
||||
print(f"Server Import & Init: {'✅ PASS' if test1_result else '❌ FAIL'}")
|
||||
print(f"MCP Protocol Basic: {'✅ PASS' if test2_result else '❌ FAIL'}")
|
||||
|
||||
if test1_result and test2_result:
|
||||
print("\n🎉 All tests passed! Server should be ready.")
|
||||
print("\n💡 Try running: python main.py")
|
||||
return True
|
||||
else:
|
||||
print("\n🔧 Some tests failed. Check the errors above.")
|
||||
return False
|
||||
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -21,12 +21,12 @@ except ImportError as e:
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
logger.error("aiohttp is not installed. Please run: pip install aiohttp==3.11.7")
|
||||
print("\n❌ Missing dependency: aiohttp")
|
||||
logger.error("aiohttp is not installed. Please run: pip install aiohttp==3.9.5")
|
||||
print("\n[ERROR] Missing dependency: aiohttp")
|
||||
print("Please run one of the following commands:")
|
||||
print(" - install_dependencies.bat (on Windows)")
|
||||
print(" - pip install -r requirements.txt")
|
||||
print(" - pip install aiohttp==3.11.7")
|
||||
print(" - pip install aiohttp==3.9.5")
|
||||
sys.exit(1)
|
||||
raise
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Server package for FLUX.1 Edit"""
|
||||
|
||||
from .mcp_server import FluxEditMCPServer, create_server, main
|
||||
from .handlers import ToolHandlers
|
||||
from .models import TOOL_DEFINITIONS, ToolName
|
||||
|
||||
__all__ = ['FluxEditMCPServer', 'create_server', 'main', 'ToolHandlers', 'TOOL_DEFINITIONS', 'ToolName']
|
||||
__all__ = ['ToolHandlers', 'TOOL_DEFINITIONS', 'ToolName']
|
||||
|
||||
@@ -1,605 +0,0 @@
|
||||
"""MCP Tool Handlers for FLUX.1 Edit MCP Server"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from mcp.types import TextContent, ImageContent
|
||||
|
||||
from ..connector import Config, FluxEditClient, FluxEditRequest
|
||||
from ..utils import (
|
||||
validate_edit_parameters,
|
||||
validate_file_parameters,
|
||||
validate_image_path_parameter,
|
||||
validate_move_file_parameters,
|
||||
validate_image_file,
|
||||
save_image,
|
||||
encode_image_base64,
|
||||
decode_image_base64,
|
||||
sanitize_prompt,
|
||||
get_image_dimensions,
|
||||
convert_image_to_base64,
|
||||
get_file_size_mb
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolHandlers:
|
||||
"""Handler class for FLUX.1 Edit MCP tools"""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
"""Initialize handlers with configuration"""
|
||||
self.config = config
|
||||
self.current_seed = None # Track current seed for session
|
||||
|
||||
def _get_or_create_seed(self) -> int:
|
||||
"""Get current seed or create new one"""
|
||||
if self.current_seed is None:
|
||||
self.current_seed = random.randint(0, 999999)
|
||||
return self.current_seed
|
||||
|
||||
def _reset_seed(self):
|
||||
"""Reset seed for new session"""
|
||||
self.current_seed = None
|
||||
|
||||
def _save_b64_to_temp_file(self, b64_data: str, filename: str) -> str:
|
||||
"""Save base64 data to a temporary file with specified filename
|
||||
|
||||
Args:
|
||||
b64_data: Base64 encoded image data
|
||||
filename: Desired filename for the file
|
||||
|
||||
Returns:
|
||||
str: Path to saved file
|
||||
"""
|
||||
try:
|
||||
# Decode base64 data
|
||||
image_data = decode_image_base64(b64_data)
|
||||
|
||||
# Save to local temp directory for processing
|
||||
temp_dir = self.config.base_path / 'temp'
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
file_path = temp_dir / filename
|
||||
|
||||
if not save_image(image_data, str(file_path)):
|
||||
raise RuntimeError(f"Failed to save image to temp file: {filename}")
|
||||
|
||||
logger.info(f"Saved temp file: {filename} ({len(image_data) / 1024:.1f} KB)")
|
||||
|
||||
return str(file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving b64 to temp file: {e}")
|
||||
raise
|
||||
|
||||
def _move_temp_to_generated(self, temp_file_path: str, base_name: str, index: int, extension: str = None) -> str:
|
||||
"""
|
||||
Move file from temp directory to generated_images directory
|
||||
|
||||
Args:
|
||||
temp_file_path: Path to temporary file
|
||||
base_name: Base name for the destination file
|
||||
index: Index for the file (0 for input, 1+ for output)
|
||||
extension: File extension (will detect from temp file if not provided)
|
||||
|
||||
Returns:
|
||||
str: Path to moved file in generated_images directory
|
||||
"""
|
||||
try:
|
||||
# Ensure output directory exists
|
||||
self.config.ensure_output_directory()
|
||||
|
||||
temp_path = Path(temp_file_path)
|
||||
|
||||
# Verify source file exists
|
||||
if not temp_path.exists():
|
||||
raise FileNotFoundError(f"Temp file not found: {temp_file_path}")
|
||||
|
||||
# Detect extension from temp file if not provided
|
||||
if extension is None:
|
||||
extension = temp_path.suffix[1:] if temp_path.suffix else 'png'
|
||||
|
||||
# Generate destination filename
|
||||
dest_filename = self.config.generate_filename(base_name, index, extension)
|
||||
dest_path = self.config.generated_images_path / dest_filename
|
||||
|
||||
# Copy file (preserve original in temp for potential reuse)
|
||||
import shutil
|
||||
try:
|
||||
shutil.copy2(temp_file_path, dest_path)
|
||||
|
||||
# Verify copy was successful
|
||||
if not dest_path.exists():
|
||||
raise RuntimeError(f"File copy verification failed: {dest_path}")
|
||||
|
||||
# Check file sizes match
|
||||
if temp_path.stat().st_size != dest_path.stat().st_size:
|
||||
raise RuntimeError(f"File copy size mismatch: {temp_path.stat().st_size} != {dest_path.stat().st_size}")
|
||||
|
||||
except PermissionError as e:
|
||||
raise RuntimeError(f"Permission denied copying file to {dest_path}: {e}")
|
||||
except shutil.Error as e:
|
||||
raise RuntimeError(f"Copy operation failed: {e}")
|
||||
|
||||
logger.info(f"Moved temp file to generated_images: {temp_path.name} → {dest_filename}")
|
||||
|
||||
return str(dest_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving temp file to generated_images: {e}")
|
||||
raise
|
||||
|
||||
async def handle_flux_edit_image(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
|
||||
"""
|
||||
Handle flux_edit_image tool call
|
||||
|
||||
Args:
|
||||
arguments: Tool arguments
|
||||
|
||||
Returns:
|
||||
List of content items
|
||||
"""
|
||||
try:
|
||||
# Validate parameters
|
||||
is_valid, error_msg = validate_edit_parameters(arguments)
|
||||
if not is_valid:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ Parameter validation failed: {error_msg}"
|
||||
)]
|
||||
|
||||
# Extract parameters
|
||||
input_image_b64 = arguments['input_image_b64']
|
||||
prompt = sanitize_prompt(arguments['prompt'])
|
||||
seed = arguments['seed']
|
||||
aspect_ratio = arguments.get('aspect_ratio', self.config.default_aspect_ratio)
|
||||
save_to_file = arguments.get('save_to_file', True)
|
||||
|
||||
logger.info(f"Starting FLUX edit with seed {seed}")
|
||||
|
||||
# Generate base name
|
||||
base_name = self.config.generate_base_name(seed)
|
||||
|
||||
# Save input image to temp and then to generated_images as 000
|
||||
temp_image_name = f'temp_input_{random.randint(1000, 9999)}.png'
|
||||
temp_image_path = self._save_b64_to_temp_file(input_image_b64, temp_image_name)
|
||||
|
||||
# Copy to generated_images as input (000)
|
||||
input_generated_path = self._move_temp_to_generated(temp_image_path, base_name, 0)
|
||||
logger.info(f"Input file saved: {Path(input_generated_path).name}")
|
||||
|
||||
# Create FLUX edit request
|
||||
request = FluxEditRequest(
|
||||
input_image_b64=input_image_b64,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio,
|
||||
safety_tolerance=self.config.safety_tolerance,
|
||||
output_format=self.config.OUTPUT_FORMAT,
|
||||
prompt_upsampling=self.config.prompt_upsampling
|
||||
)
|
||||
|
||||
# Process edit using FLUX API
|
||||
async with FluxEditClient(self.config) as client:
|
||||
response = await client.edit_image(request)
|
||||
|
||||
if not response.success:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ FLUX edit failed: {response.error_message}"
|
||||
)]
|
||||
|
||||
# Save output image and metadata
|
||||
saved_path = None
|
||||
json_path = None
|
||||
|
||||
if save_to_file:
|
||||
output_path = self.config.get_output_path(base_name, 1, 'png')
|
||||
|
||||
if save_image(response.edited_image_data, str(output_path)):
|
||||
saved_path = str(output_path)
|
||||
|
||||
# Save parameters as JSON
|
||||
if self.config.save_parameters:
|
||||
params_dict = {
|
||||
"base_name": base_name,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": self.config.MODEL_NAME,
|
||||
"prompt": prompt,
|
||||
"seed": seed,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"safety_tolerance": self.config.safety_tolerance,
|
||||
"output_format": self.config.OUTPUT_FORMAT,
|
||||
"prompt_upsampling": self.config.prompt_upsampling,
|
||||
"input_image_temp": temp_image_name,
|
||||
"input_generated_path": input_generated_path,
|
||||
"output_size": response.image_size,
|
||||
"execution_time": response.execution_time,
|
||||
"request_id": response.request_id,
|
||||
"metadata": response.metadata
|
||||
}
|
||||
|
||||
json_path = self.config.get_output_path(base_name, 1, 'json')
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(params_dict, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Parameters saved to: {json_path}")
|
||||
|
||||
# Prepare response
|
||||
contents = []
|
||||
|
||||
# Add text description
|
||||
text = f"✅ Image edited successfully with FLUX.1 Kontext!\n"
|
||||
text += f"🎲 Seed: {seed}\n"
|
||||
text += f"📁 Base name: {base_name}\n"
|
||||
if response.image_size:
|
||||
text += f"📐 Size: {response.image_size[0]}x{response.image_size[1]}\n"
|
||||
text += f"📏 Aspect ratio: {aspect_ratio}\n"
|
||||
text += f"⏱️ Processing time: {response.execution_time:.1f}s\n"
|
||||
|
||||
if saved_path:
|
||||
text += f"\n💾 Output: {Path(saved_path).name}"
|
||||
text += f"\n📝 Input: {Path(input_generated_path).name}"
|
||||
if json_path:
|
||||
text += f"\n📋 Parameters: {Path(json_path).name}"
|
||||
|
||||
contents.append(TextContent(type="text", text=text))
|
||||
|
||||
# Add image preview
|
||||
if response.edited_image_data:
|
||||
image_b64 = encode_image_base64(response.edited_image_data)
|
||||
contents.append(ImageContent(
|
||||
type="image",
|
||||
data=image_b64,
|
||||
mimeType="image/png"
|
||||
))
|
||||
|
||||
# Reset seed for next session
|
||||
self._reset_seed()
|
||||
|
||||
return contents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in handle_flux_edit_image: {e}", exc_info=True)
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ Unexpected error: {str(e)}"
|
||||
)]
|
||||
|
||||
async def handle_flux_edit_image_from_file(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
|
||||
"""
|
||||
Handle flux_edit_image_from_file tool call
|
||||
|
||||
Args:
|
||||
arguments: Tool arguments
|
||||
|
||||
Returns:
|
||||
List of content items
|
||||
"""
|
||||
try:
|
||||
# Validate parameters
|
||||
is_valid, error_msg = validate_file_parameters(arguments)
|
||||
if not is_valid:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ Parameter validation failed: {error_msg}"
|
||||
)]
|
||||
|
||||
# Extract parameters
|
||||
input_image_name = arguments['input_image_name']
|
||||
prompt = sanitize_prompt(arguments['prompt'])
|
||||
seed = arguments['seed']
|
||||
aspect_ratio = arguments.get('aspect_ratio', self.config.default_aspect_ratio)
|
||||
save_to_file = arguments.get('save_to_file', True)
|
||||
|
||||
# Check if file exists in input directory
|
||||
input_file_path = self.config.input_path / input_image_name
|
||||
|
||||
if not input_file_path.exists():
|
||||
# Enhanced error message with debug info
|
||||
error_text = f"❌ File not found in input directory: {input_image_name}\n"
|
||||
error_text += f"📁 Looking in: {self.config.input_path}\n"
|
||||
error_text += f"🔍 Full path: {input_file_path}\n"
|
||||
error_text += f"📂 Input directory exists: {self.config.input_path.exists()}\n"
|
||||
|
||||
# List available files in input directory
|
||||
if self.config.input_path.exists():
|
||||
files = [f.name for f in self.config.input_path.iterdir() if f.is_file()]
|
||||
if files:
|
||||
error_text += f"📋 Available files: {', '.join(files[:10])}"
|
||||
if len(files) > 10:
|
||||
error_text += f" and {len(files) - 10} more..."
|
||||
else:
|
||||
error_text += "📋 No files found in input directory"
|
||||
else:
|
||||
error_text += "⚠️ Input directory does not exist"
|
||||
|
||||
return [TextContent(type="text", text=error_text)]
|
||||
|
||||
# Validate the image file
|
||||
is_valid, size_mb, validation_error = validate_image_file(
|
||||
str(input_file_path),
|
||||
self.config.max_image_size_mb
|
||||
)
|
||||
if not is_valid:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ Image validation failed: {validation_error}"
|
||||
)]
|
||||
|
||||
logger.info(f"Starting FLUX edit from file: {input_image_name} ({size_mb:.2f}MB)")
|
||||
|
||||
# Convert image to base64
|
||||
try:
|
||||
input_image_b64 = convert_image_to_base64(str(input_file_path))
|
||||
except Exception as e:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ Failed to convert image to base64: {str(e)}"
|
||||
)]
|
||||
|
||||
# Generate base name
|
||||
base_name = self.config.generate_base_name(seed)
|
||||
|
||||
# Copy original file to generated_images as input (000)
|
||||
try:
|
||||
with open(input_file_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
|
||||
input_generated_path = self.config.get_output_path(base_name, 0, 'png')
|
||||
if not save_image(image_data, str(input_generated_path)):
|
||||
raise RuntimeError("Failed to save input to generated_images")
|
||||
|
||||
logger.info(f"Input file copied: {Path(input_generated_path).name}")
|
||||
|
||||
except Exception as e:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ Failed to copy input file: {str(e)}"
|
||||
)]
|
||||
|
||||
# Create FLUX edit request
|
||||
request = FluxEditRequest(
|
||||
input_image_b64=input_image_b64,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio,
|
||||
safety_tolerance=self.config.safety_tolerance,
|
||||
output_format=self.config.OUTPUT_FORMAT,
|
||||
prompt_upsampling=self.config.prompt_upsampling
|
||||
)
|
||||
|
||||
# Process edit using FLUX API
|
||||
async with FluxEditClient(self.config) as client:
|
||||
response = await client.edit_image(request)
|
||||
|
||||
if not response.success:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ FLUX edit failed: {response.error_message}"
|
||||
)]
|
||||
|
||||
# Save output image and metadata
|
||||
saved_path = None
|
||||
json_path = None
|
||||
|
||||
if save_to_file:
|
||||
output_path = self.config.get_output_path(base_name, 1, 'png')
|
||||
|
||||
if save_image(response.edited_image_data, str(output_path)):
|
||||
saved_path = str(output_path)
|
||||
|
||||
# Save parameters as JSON
|
||||
if self.config.save_parameters:
|
||||
params_dict = {
|
||||
"base_name": base_name,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": self.config.MODEL_NAME,
|
||||
"prompt": prompt,
|
||||
"seed": seed,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"safety_tolerance": self.config.safety_tolerance,
|
||||
"output_format": self.config.OUTPUT_FORMAT,
|
||||
"prompt_upsampling": self.config.prompt_upsampling,
|
||||
"input_image_name": input_image_name,
|
||||
"input_file_path": str(input_file_path),
|
||||
"input_size": get_image_dimensions(str(input_file_path)),
|
||||
"input_size_mb": size_mb,
|
||||
"output_size": response.image_size,
|
||||
"execution_time": response.execution_time,
|
||||
"request_id": response.request_id,
|
||||
"metadata": response.metadata
|
||||
}
|
||||
|
||||
json_path = self.config.get_output_path(base_name, 1, 'json')
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(params_dict, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Parameters saved to: {json_path}")
|
||||
|
||||
# Prepare response
|
||||
contents = []
|
||||
|
||||
# Add text description
|
||||
text = f"✅ Image edited successfully from file with FLUX.1 Kontext!\n"
|
||||
text += f"📝 Input: {input_image_name} ({size_mb:.2f}MB)\n"
|
||||
text += f"🎲 Seed: {seed}\n"
|
||||
text += f"📁 Base name: {base_name}\n"
|
||||
if response.image_size:
|
||||
text += f"📐 Size: {response.image_size[0]}x{response.image_size[1]}\n"
|
||||
text += f"📏 Aspect ratio: {aspect_ratio}\n"
|
||||
text += f"⏱️ Processing time: {response.execution_time:.1f}s\n"
|
||||
|
||||
if saved_path:
|
||||
text += f"\n💾 Output: {Path(saved_path).name}"
|
||||
text += f"\n📝 Input copy: {Path(input_generated_path).name}"
|
||||
if json_path:
|
||||
text += f"\n📋 Parameters: {Path(json_path).name}"
|
||||
|
||||
contents.append(TextContent(type="text", text=text))
|
||||
|
||||
# Add image preview
|
||||
if response.edited_image_data:
|
||||
image_b64 = encode_image_base64(response.edited_image_data)
|
||||
contents.append(ImageContent(
|
||||
type="image",
|
||||
data=image_b64,
|
||||
mimeType="image/png"
|
||||
))
|
||||
|
||||
# Reset seed for next session
|
||||
self._reset_seed()
|
||||
|
||||
return contents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in handle_flux_edit_image_from_file: {e}", exc_info=True)
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ File-based edit error: {str(e)}"
|
||||
)]
|
||||
|
||||
async def handle_validate_image(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||
"""
|
||||
Handle validate_image tool call
|
||||
|
||||
Args:
|
||||
arguments: Tool arguments
|
||||
|
||||
Returns:
|
||||
List of content items
|
||||
"""
|
||||
try:
|
||||
# Validate parameters
|
||||
is_valid, error_msg = validate_image_path_parameter(arguments)
|
||||
if not is_valid:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ Parameter validation failed: {error_msg}"
|
||||
)]
|
||||
|
||||
image_path = arguments['image_path']
|
||||
|
||||
# Validate image
|
||||
is_valid, size_mb, error_msg = validate_image_file(
|
||||
image_path,
|
||||
self.config.max_image_size_mb
|
||||
)
|
||||
|
||||
# Get additional info if valid
|
||||
if is_valid:
|
||||
width, height = get_image_dimensions(image_path)
|
||||
|
||||
text = f"✅ Image validation passed!\n"
|
||||
text += f"📁 File: {Path(image_path).name}\n"
|
||||
text += f"📐 Dimensions: {width}x{height}\n"
|
||||
text += f"💾 Size: {size_mb:.2f}MB\n"
|
||||
text += f"🎯 Max allowed: {self.config.max_image_size_mb}MB\n"
|
||||
|
||||
# Check aspect ratio compatibility
|
||||
from ..utils import get_optimal_aspect_ratio
|
||||
optimal_ratio = get_optimal_aspect_ratio(width, height)
|
||||
text += f"📏 Optimal aspect ratio: {optimal_ratio}"
|
||||
else:
|
||||
text = f"❌ Image validation failed!\n"
|
||||
text += f"📁 File: {Path(image_path).name}\n"
|
||||
text += f"⚠️ Issue: {error_msg}"
|
||||
|
||||
return [TextContent(type="text", text=text)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in handle_validate_image: {e}", exc_info=True)
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ Validation error: {str(e)}"
|
||||
)]
|
||||
|
||||
async def handle_move_temp_to_output(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||
"""
|
||||
Handle move_temp_to_output tool call
|
||||
|
||||
Args:
|
||||
arguments: Tool arguments
|
||||
|
||||
Returns:
|
||||
List of content items
|
||||
"""
|
||||
try:
|
||||
# Validate parameters
|
||||
is_valid, error_msg = validate_move_file_parameters(arguments)
|
||||
if not is_valid:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ Parameter validation failed: {error_msg}"
|
||||
)]
|
||||
|
||||
temp_file_name = arguments['temp_file_name']
|
||||
output_file_name = arguments.get('output_file_name')
|
||||
copy_only = arguments.get('copy_only', False)
|
||||
|
||||
# Get temp file path
|
||||
temp_file_path = self.config.base_path / 'temp' / temp_file_name
|
||||
|
||||
# Check if temp file exists
|
||||
if not temp_file_path.exists():
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ Temp file not found: {temp_file_name}"
|
||||
)]
|
||||
|
||||
# Generate output file name if not provided
|
||||
if not output_file_name:
|
||||
base_name = self.config.generate_base_name_simple()
|
||||
file_ext = Path(temp_file_name).suffix[1:] or 'png'
|
||||
output_file_name = f"{base_name}_001.{file_ext}"
|
||||
|
||||
# Ensure output directory exists
|
||||
self.config.ensure_output_directory()
|
||||
|
||||
# Get output path
|
||||
output_path = self.config.generated_images_path / output_file_name
|
||||
|
||||
# Move or copy file
|
||||
try:
|
||||
import shutil
|
||||
if copy_only:
|
||||
shutil.copy2(temp_file_path, output_path)
|
||||
operation = "copied"
|
||||
else:
|
||||
shutil.move(str(temp_file_path), str(output_path))
|
||||
operation = "moved"
|
||||
|
||||
# Verify operation was successful
|
||||
if not output_path.exists():
|
||||
raise RuntimeError(f"File {operation} verification failed")
|
||||
|
||||
logger.info(f"📁 File {operation}: {temp_file_name} -> {output_file_name}")
|
||||
|
||||
# Get file size for reporting
|
||||
file_size_mb = output_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
text = f"✅ File {operation} successfully!\n"
|
||||
text += f"📁 From temp: {temp_file_name}\n"
|
||||
text += f"📁 To output: {output_file_name}\n"
|
||||
text += f"💾 Size: {file_size_mb:.2f}MB"
|
||||
|
||||
return [TextContent(type="text", text=text)]
|
||||
|
||||
except PermissionError as e:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ Permission denied: {str(e)}"
|
||||
)]
|
||||
except Exception as e:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ File operation failed: {str(e)}"
|
||||
)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in handle_move_temp_to_output: {e}", exc_info=True)
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"❌ File move error: {str(e)}"
|
||||
)]
|
||||
@@ -1,172 +0,0 @@
|
||||
"""MCP Server for FLUX.1 Edit"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, Any, List
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
|
||||
from ..connector import Config
|
||||
from .models import TOOL_DEFINITIONS, ToolName
|
||||
from .handlers import ToolHandlers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for the MCP server"""
|
||||
|
||||
# Setup logging with minimal output for MCP compatibility
|
||||
logging.basicConfig(
|
||||
level=logging.WARNING, # Only warnings and errors
|
||||
format='%(asctime)s [%(name)s] %(levelname)s: %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('flux1-edit.log', mode='a', encoding='utf-8')
|
||||
]
|
||||
)
|
||||
|
||||
# Silence noisy loggers completely
|
||||
logging.getLogger('aiohttp').setLevel(logging.ERROR)
|
||||
logging.getLogger('PIL').setLevel(logging.ERROR)
|
||||
logging.getLogger('httpx').setLevel(logging.ERROR)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
# Create configuration
|
||||
config = Config()
|
||||
|
||||
# Validate configuration
|
||||
if not config.validate():
|
||||
logger.error("Configuration validation failed")
|
||||
raise RuntimeError("Configuration validation failed")
|
||||
|
||||
# Create MCP server
|
||||
server = Server("flux1-edit")
|
||||
|
||||
# Create tool handlers
|
||||
handlers = ToolHandlers(config)
|
||||
|
||||
logger.info("Setting up MCP server handlers...")
|
||||
|
||||
# Set up list_tools handler
|
||||
@server.list_tools()
|
||||
async def list_tools() -> List[types.Tool]:
|
||||
"""List available tools"""
|
||||
tools = []
|
||||
|
||||
for tool_name, tool_def in TOOL_DEFINITIONS.items():
|
||||
# Build properties for parameters
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param in tool_def.parameters:
|
||||
prop_def = {
|
||||
"type": param.type,
|
||||
"description": param.description
|
||||
}
|
||||
|
||||
# Add enum if specified
|
||||
if param.enum:
|
||||
prop_def["enum"] = param.enum
|
||||
|
||||
# Add default if specified
|
||||
if param.default is not None:
|
||||
prop_def["default"] = param.default
|
||||
|
||||
properties[param.name] = prop_def
|
||||
|
||||
if param.required:
|
||||
required.append(param.name)
|
||||
|
||||
# Build tool schema
|
||||
tool = types.Tool(
|
||||
name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
)
|
||||
tools.append(tool)
|
||||
|
||||
logger.debug(f"Listed {len(tools)} tools")
|
||||
return tools
|
||||
|
||||
# Set up call_tool handler
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: Dict[str, Any]) -> List[types.TextContent | types.ImageContent]:
|
||||
"""Handle tool calls"""
|
||||
try:
|
||||
logger.info(f"Tool call: {name}")
|
||||
|
||||
# Sanitize arguments for logging
|
||||
safe_args = arguments.copy()
|
||||
if 'input_image_b64' in safe_args:
|
||||
b64_data = safe_args['input_image_b64']
|
||||
safe_args['input_image_b64'] = f"<base64 image data: {len(b64_data)} chars>"
|
||||
if 'prompt' in safe_args and len(safe_args['prompt']) > 100:
|
||||
safe_args['prompt'] = safe_args['prompt'][:100] + '...'
|
||||
|
||||
logger.debug(f"Arguments: {safe_args}")
|
||||
|
||||
# Route to appropriate handler
|
||||
if name == ToolName.FLUX_EDIT_IMAGE:
|
||||
return await handlers.handle_flux_edit_image(arguments)
|
||||
elif name == ToolName.FLUX_EDIT_IMAGE_FROM_FILE:
|
||||
return await handlers.handle_flux_edit_image_from_file(arguments)
|
||||
elif name == ToolName.VALIDATE_IMAGE:
|
||||
return await handlers.handle_validate_image(arguments)
|
||||
elif name == ToolName.MOVE_TEMP_TO_OUTPUT:
|
||||
return await handlers.handle_move_temp_to_output(arguments)
|
||||
else:
|
||||
return [types.TextContent(
|
||||
type="text",
|
||||
text=f"[ERROR] Unknown tool: {name}"
|
||||
)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling tool {name}: {e}", exc_info=True)
|
||||
return [types.TextContent(
|
||||
type="text",
|
||||
text=f"[ERROR] Tool execution error: {str(e)}"
|
||||
)]
|
||||
|
||||
logger.info("Starting FLUX.1 Edit MCP Server...")
|
||||
logger.info(f"API key configured: {'Yes' if config.api_key else 'No'}")
|
||||
logger.info(f"Input directory: {config.input_path}")
|
||||
logger.info(f"Output directory: {config.generated_images_path}")
|
||||
|
||||
# Run the server using stdio
|
||||
await stdio_server(server)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Server startup failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
class FluxEditMCPServer:
|
||||
"""Legacy wrapper class (kept for compatibility)"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the MCP server"""
|
||||
self.config = Config()
|
||||
self.server = Server("flux1-edit")
|
||||
self.handlers = ToolHandlers(self.config)
|
||||
logger.info("FLUX.1 Edit MCP Server initialized")
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
"""Validate server configuration"""
|
||||
return self.config.validate()
|
||||
|
||||
|
||||
def create_server() -> FluxEditMCPServer:
|
||||
"""Create and return a FLUX.1 Edit MCP Server instance"""
|
||||
return FluxEditMCPServer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -9,7 +9,6 @@ from .image_utils import (
|
||||
save_image,
|
||||
encode_image_base64,
|
||||
decode_image_base64,
|
||||
optimize_image_for_flux,
|
||||
convert_image_to_base64,
|
||||
validate_aspect_ratio,
|
||||
get_optimal_aspect_ratio
|
||||
@@ -37,7 +36,6 @@ __all__ = [
|
||||
'save_image',
|
||||
'encode_image_base64',
|
||||
'decode_image_base64',
|
||||
'optimize_image_for_flux',
|
||||
'convert_image_to_base64',
|
||||
'validate_aspect_ratio',
|
||||
'get_optimal_aspect_ratio',
|
||||
|
||||
@@ -215,81 +215,6 @@ def decode_image_base64(base64_str: str) -> bytes:
|
||||
raise ValueError(f"Failed to decode base64 data: {e}")
|
||||
|
||||
|
||||
def optimize_image_for_flux(image_path: str, max_size_mb: float = 20.0) -> bytes:
|
||||
"""
|
||||
Optimize image for FLUX.1 Kontext API (20MB limit)
|
||||
|
||||
Args:
|
||||
image_path: Path to input image
|
||||
max_size_mb: Maximum size in MB (default: 20 for FLUX)
|
||||
|
||||
Returns:
|
||||
bytes: Optimized image data
|
||||
"""
|
||||
max_size_bytes = max_size_mb * 1024 * 1024
|
||||
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
# For FLUX, we want to preserve quality as much as possible
|
||||
# since 20MB is quite generous
|
||||
|
||||
# Convert to RGB if needed (FLUX typically prefers RGB)
|
||||
if img.mode != 'RGB':
|
||||
if img.mode == 'RGBA':
|
||||
# Create white background for transparent images
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
|
||||
img = background
|
||||
else:
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Try PNG first (lossless)
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG', optimize=True)
|
||||
png_data = buffer.getvalue()
|
||||
|
||||
if len(png_data) <= max_size_bytes:
|
||||
logger.info(f"Image optimized as PNG: {len(png_data) / (1024*1024):.2f}MB")
|
||||
return png_data
|
||||
|
||||
# PNG too large, try JPEG with high quality
|
||||
for quality in [95, 90, 85, 80]:
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='JPEG', quality=quality, optimize=True)
|
||||
jpeg_data = buffer.getvalue()
|
||||
|
||||
if len(jpeg_data) <= max_size_bytes:
|
||||
size_mb = len(jpeg_data) / (1024 * 1024)
|
||||
logger.info(f"Image optimized as JPEG (quality {quality}): {size_mb:.2f}MB")
|
||||
return jpeg_data
|
||||
|
||||
# Still too large, try resizing (preserve aspect ratio)
|
||||
logger.warning("Image still too large, attempting resize...")
|
||||
|
||||
scale = 0.95
|
||||
while scale > 0.5:
|
||||
new_width = int(img.width * scale)
|
||||
new_height = int(img.height * scale)
|
||||
|
||||
resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
resized.save(buffer, format='JPEG', quality=85, optimize=True)
|
||||
data = buffer.getvalue()
|
||||
|
||||
if len(data) <= max_size_bytes:
|
||||
size_mb = len(data) / (1024 * 1024)
|
||||
logger.warning(f"Image resized to {new_width}x{new_height} ({scale*100:.0f}%): {size_mb:.2f}MB")
|
||||
return data
|
||||
|
||||
scale -= 0.05
|
||||
|
||||
raise ValueError(f"Cannot optimize image to under {max_size_mb}MB")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error optimizing image: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def convert_image_to_base64(image_path: str) -> str:
|
||||
"""
|
||||
@@ -302,16 +227,9 @@ def convert_image_to_base64(image_path: str) -> str:
|
||||
str: Base64 encoded image data
|
||||
"""
|
||||
try:
|
||||
# Check if optimization is needed
|
||||
current_size_mb = get_file_size_mb(image_path)
|
||||
|
||||
if current_size_mb <= 20.0:
|
||||
# Read directly if under limit
|
||||
# Read image file directly (validation should handle size limits)
|
||||
with open(image_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
else:
|
||||
# Optimize if over limit
|
||||
image_data = optimize_image_for_flux(image_path)
|
||||
|
||||
return encode_image_base64(image_data)
|
||||
|
||||
|
||||
35
start.bat
35
start.bat
@@ -1,35 +0,0 @@
|
||||
@echo off
|
||||
echo FLUX.1 Edit MCP Server - Quick Start
|
||||
echo ====================================
|
||||
|
||||
REM Set UTF-8 encoding to prevent Unicode errors
|
||||
chcp 65001 >nul 2>&1
|
||||
set PYTHONIOENCODING=utf-8
|
||||
set PYTHONUTF8=1
|
||||
|
||||
echo Running simple test first...
|
||||
python simple_test.py
|
||||
if errorlevel 1 (
|
||||
echo.
|
||||
echo Simple test failed. Please check the errors above.
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo.
|
||||
echo Simple test passed! Starting MCP server...
|
||||
echo.
|
||||
|
||||
REM Start the main MCP server
|
||||
python main.py
|
||||
|
||||
if errorlevel 1 (
|
||||
echo.
|
||||
echo Server failed to start. Check flux1-edit.log for details.
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo.
|
||||
echo Server stopped.
|
||||
pause
|
||||
119
test_fixes.py
119
test_fixes.py
@@ -1,119 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify FLUX.1 Edit MCP Server fixes
|
||||
This script tests if the Unicode and JSON parsing issues are resolved
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add src to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent / 'src'))
|
||||
|
||||
def test_console_encoding():
|
||||
"""Test console encoding with safe characters"""
|
||||
print("Testing console output with safe characters...")
|
||||
print("[OK] ASCII characters work fine")
|
||||
print("[ERROR] Error messages use brackets instead of Unicode")
|
||||
print("[SUCCESS] Success messages use brackets instead of Unicode")
|
||||
print("[INFO] Info messages work properly")
|
||||
return True
|
||||
|
||||
def test_dependency_imports():
|
||||
"""Test importing dependencies silently"""
|
||||
print("Testing dependency imports...")
|
||||
|
||||
missing_deps = []
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
# Silent check
|
||||
except ImportError:
|
||||
missing_deps.append("aiohttp")
|
||||
|
||||
try:
|
||||
import mcp
|
||||
# Silent check
|
||||
except ImportError:
|
||||
missing_deps.append("mcp")
|
||||
|
||||
try:
|
||||
from src.connector import Config
|
||||
# Silent check
|
||||
except ImportError:
|
||||
missing_deps.append("src.connector.Config")
|
||||
|
||||
try:
|
||||
from src.server import main
|
||||
# Silent check
|
||||
except ImportError:
|
||||
missing_deps.append("src.server.main")
|
||||
|
||||
if missing_deps:
|
||||
print(f"[ERROR] Missing dependencies: {', '.join(missing_deps)}")
|
||||
return False
|
||||
else:
|
||||
print("[SUCCESS] All imports successful")
|
||||
return True
|
||||
|
||||
def test_server_creation():
|
||||
"""Test creating server instance without starting it"""
|
||||
print("Testing server creation...")
|
||||
|
||||
try:
|
||||
from src.server import create_server
|
||||
server = create_server()
|
||||
print("[SUCCESS] Server instance created successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Server creation failed: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
print("FLUX.1 Edit MCP Server - Fix Verification")
|
||||
print("=" * 50)
|
||||
|
||||
# Set UTF-8 encoding environment variables
|
||||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||||
os.environ['PYTHONUTF8'] = '1'
|
||||
|
||||
tests = [
|
||||
("Console encoding", test_console_encoding),
|
||||
("Dependency imports", test_dependency_imports),
|
||||
("Server creation", test_server_creation),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
total = len(tests)
|
||||
|
||||
for test_name, test_func in tests:
|
||||
print(f"\nRunning test: {test_name}")
|
||||
print("-" * 30)
|
||||
|
||||
try:
|
||||
if test_func():
|
||||
passed += 1
|
||||
print(f"[PASSED] {test_name}")
|
||||
else:
|
||||
print(f"[FAILED] {test_name}")
|
||||
except Exception as e:
|
||||
print(f"[FAILED] {test_name}: {e}")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print(f"Test Results: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("[SUCCESS] All tests passed! The fixes should work.")
|
||||
print("\nTo run the server:")
|
||||
print("1. Make sure your .env file is configured")
|
||||
print("2. Run: start.bat or run.bat")
|
||||
print("3. The server should start without Unicode errors")
|
||||
return 0
|
||||
else:
|
||||
print(f"[FAILED] {total - passed} tests failed. Check the errors above.")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -1,106 +0,0 @@
|
||||
"""Test runner for all FLUX.1 Edit MCP Server tests"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
||||
|
||||
# Import all test modules
|
||||
from test_config import TestConfig
|
||||
from test_image_utils import TestImageUtils
|
||||
from test_validation import TestValidation
|
||||
from test_flux_client import TestFluxEditClient, TestFluxEditClientAsync
|
||||
from test_handlers import TestToolHandlers
|
||||
|
||||
|
||||
def create_test_suite():
|
||||
"""Create and return the complete test suite"""
|
||||
suite = unittest.TestSuite()
|
||||
|
||||
# Add all test cases
|
||||
suite.addTest(unittest.makeSuite(TestConfig))
|
||||
suite.addTest(unittest.makeSuite(TestImageUtils))
|
||||
suite.addTest(unittest.makeSuite(TestValidation))
|
||||
suite.addTest(unittest.makeSuite(TestFluxEditClient))
|
||||
suite.addTest(unittest.makeSuite(TestFluxEditClientAsync))
|
||||
suite.addTest(unittest.makeSuite(TestToolHandlers))
|
||||
|
||||
return suite
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""Run all tests and return results"""
|
||||
# Setup test runner
|
||||
runner = unittest.TextTestRunner(
|
||||
verbosity=2,
|
||||
stream=sys.stdout,
|
||||
descriptions=True,
|
||||
failfast=False
|
||||
)
|
||||
|
||||
# Create and run test suite
|
||||
suite = create_test_suite()
|
||||
result = runner.run(suite)
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*70}")
|
||||
print("TEST SUMMARY")
|
||||
print(f"{'='*70}")
|
||||
print(f"Tests run: {result.testsRun}")
|
||||
print(f"Failures: {len(result.failures)}")
|
||||
print(f"Errors: {len(result.errors)}")
|
||||
print(f"Skipped: {len(result.skipped) if hasattr(result, 'skipped') else 0}")
|
||||
|
||||
if result.failures:
|
||||
print(f"\nFAILURES:")
|
||||
for test, traceback in result.failures:
|
||||
print(f" - {test}: {traceback.split('AssertionError: ')[-1].split()[0] if 'AssertionError:' in traceback else 'Unknown failure'}")
|
||||
|
||||
if result.errors:
|
||||
print(f"\nERRORS:")
|
||||
for test, traceback in result.errors:
|
||||
error_msg = traceback.split('\n')[-2] if traceback.split('\n')[-2] else 'Unknown error'
|
||||
print(f" - {test}: {error_msg}")
|
||||
|
||||
success = len(result.failures) == 0 and len(result.errors) == 0
|
||||
print(f"\nOVERALL: {'✅ PASSED' if success else '❌ FAILED'}")
|
||||
|
||||
return success
|
||||
|
||||
|
||||
def run_specific_test(test_name: str):
|
||||
"""Run a specific test module"""
|
||||
test_modules = {
|
||||
'config': TestConfig,
|
||||
'image_utils': TestImageUtils,
|
||||
'validation': TestValidation,
|
||||
'flux_client': TestFluxEditClient,
|
||||
'flux_client_async': TestFluxEditClientAsync,
|
||||
'handlers': TestToolHandlers
|
||||
}
|
||||
|
||||
if test_name not in test_modules:
|
||||
print(f"❌ Unknown test module: {test_name}")
|
||||
print(f"Available modules: {', '.join(test_modules.keys())}")
|
||||
return False
|
||||
|
||||
suite = unittest.makeSuite(test_modules[test_name])
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(suite)
|
||||
|
||||
return len(result.failures) == 0 and len(result.errors) == 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Check for specific test argument
|
||||
if len(sys.argv) > 1:
|
||||
test_name = sys.argv[1]
|
||||
success = run_specific_test(test_name)
|
||||
else:
|
||||
# Run all tests
|
||||
success = run_tests()
|
||||
|
||||
# Exit with appropriate code
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -1,159 +0,0 @@
|
||||
"""Unit tests for Config class"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# Add src to path
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
||||
|
||||
from src.connector.config import Config
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
"""Test cases for Config class"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.temp_path = Path(self.temp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures"""
|
||||
import shutil
|
||||
if self.temp_path.exists():
|
||||
shutil.rmtree(self.temp_path)
|
||||
|
||||
@patch.dict(os.environ, {
|
||||
'FLUX_API_KEY': 'test_api_key_12345',
|
||||
'LOG_LEVEL': 'DEBUG',
|
||||
'MAX_IMAGE_SIZE_MB': '25',
|
||||
'DEFAULT_TIMEOUT': '600'
|
||||
})
|
||||
def test_config_initialization(self):
|
||||
"""Test config initialization with environment variables"""
|
||||
config = Config()
|
||||
|
||||
self.assertEqual(config.api_key, 'test_api_key_12345')
|
||||
self.assertEqual(config.log_level, 'DEBUG')
|
||||
self.assertEqual(config.max_image_size_mb, 25)
|
||||
self.assertEqual(config.default_timeout, 600)
|
||||
|
||||
def test_config_defaults(self):
|
||||
"""Test config defaults when env vars not set"""
|
||||
# Clear environment
|
||||
env_vars_to_clear = [
|
||||
'FLUX_API_KEY', 'LOG_LEVEL', 'MAX_IMAGE_SIZE_MB',
|
||||
'DEFAULT_TIMEOUT', 'POLLING_INTERVAL', 'MAX_POLLING_ATTEMPTS'
|
||||
]
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
config = Config()
|
||||
|
||||
self.assertEqual(config.api_key, '')
|
||||
self.assertEqual(config.log_level, 'INFO')
|
||||
self.assertEqual(config.max_image_size_mb, 20)
|
||||
self.assertEqual(config.default_timeout, 300)
|
||||
self.assertEqual(config.polling_interval, 2)
|
||||
self.assertEqual(config.max_polling_attempts, 150)
|
||||
|
||||
def test_api_url_generation(self):
|
||||
"""Test API URL generation"""
|
||||
config = Config()
|
||||
|
||||
edit_url = config.get_api_url(config.EDIT_ENDPOINT)
|
||||
result_url = config.get_api_url(config.RESULT_ENDPOINT)
|
||||
|
||||
expected_edit = f"{config.api_base_url}/flux-kontext-pro"
|
||||
expected_result = f"{config.api_base_url}/v1/get_result"
|
||||
|
||||
self.assertEqual(edit_url, expected_edit)
|
||||
self.assertEqual(result_url, expected_result)
|
||||
|
||||
def test_filename_generation(self):
|
||||
"""Test filename generation"""
|
||||
config = Config()
|
||||
|
||||
base_name = "fluxedit_123456_20250826_143022"
|
||||
|
||||
# Test different file numbers and extensions
|
||||
filename_000 = config.generate_filename(base_name, 0, 'png')
|
||||
filename_001 = config.generate_filename(base_name, 1, 'png')
|
||||
filename_json = config.generate_filename(base_name, 1, 'json')
|
||||
|
||||
self.assertEqual(filename_000, "fluxedit_123456_20250826_143022_000.png")
|
||||
self.assertEqual(filename_001, "fluxedit_123456_20250826_143022_001.png")
|
||||
self.assertEqual(filename_json, "fluxedit_123456_20250826_143022_001.json")
|
||||
|
||||
def test_base_name_generation(self):
|
||||
"""Test base name generation"""
|
||||
config = Config()
|
||||
|
||||
# Test with seed
|
||||
base_name_with_seed = config.generate_base_name(12345)
|
||||
self.assertIn("fluxedit_12345_", base_name_with_seed)
|
||||
|
||||
# Test simple generation
|
||||
base_name_simple = config.generate_base_name_simple()
|
||||
self.assertIn("fluxedit_", base_name_simple)
|
||||
self.assertNotIn("_12345_", base_name_simple) # No seed in simple
|
||||
|
||||
@patch.dict(os.environ, {
|
||||
'FLUX_API_KEY': 'valid_key',
|
||||
'MAX_IMAGE_SIZE_MB': '20',
|
||||
'DEFAULT_TIMEOUT': '300'
|
||||
})
|
||||
def test_validation_success(self):
|
||||
"""Test successful validation"""
|
||||
config = Config()
|
||||
self.assertTrue(config.validate())
|
||||
|
||||
def test_validation_failures(self):
|
||||
"""Test validation failures"""
|
||||
# Test missing API key
|
||||
with patch.dict(os.environ, {'FLUX_API_KEY': ''}, clear=True):
|
||||
config = Config()
|
||||
self.assertFalse(config.validate())
|
||||
|
||||
# Test invalid image size
|
||||
with patch.dict(os.environ, {
|
||||
'FLUX_API_KEY': 'valid_key',
|
||||
'MAX_IMAGE_SIZE_MB': '0'
|
||||
}, clear=True):
|
||||
config = Config()
|
||||
self.assertFalse(config.validate())
|
||||
|
||||
# Test invalid timeout
|
||||
with patch.dict(os.environ, {
|
||||
'FLUX_API_KEY': 'valid_key',
|
||||
'DEFAULT_TIMEOUT': '-1'
|
||||
}, clear=True):
|
||||
config = Config()
|
||||
self.assertFalse(config.validate())
|
||||
|
||||
def test_max_image_size_bytes(self):
|
||||
"""Test max image size in bytes calculation"""
|
||||
config = Config()
|
||||
config.max_image_size_mb = 20
|
||||
|
||||
expected_bytes = 20 * 1024 * 1024
|
||||
self.assertEqual(config.get_max_image_size_bytes(), expected_bytes)
|
||||
|
||||
@patch('pathlib.Path.mkdir')
|
||||
@patch('pathlib.Path.exists')
|
||||
def test_directory_creation(self, mock_exists, mock_mkdir):
|
||||
"""Test directory creation logic"""
|
||||
mock_exists.return_value = True
|
||||
|
||||
# This should not raise an exception
|
||||
config = Config()
|
||||
|
||||
# Verify mkdir was called
|
||||
mock_mkdir.assert_called()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,308 +0,0 @@
|
||||
"""Unit tests for FLUX API client"""
|
||||
|
||||
import unittest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from pathlib import Path
|
||||
|
||||
# Add src to path
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
||||
|
||||
from src.connector.flux_client import FluxEditClient, FluxEditRequest, FluxEditResponse
|
||||
from src.connector.config import Config
|
||||
|
||||
|
||||
class TestFluxEditClient(unittest.TestCase):
|
||||
"""Test cases for FLUX API client"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
# Mock config
|
||||
self.config = MagicMock(spec=Config)
|
||||
self.config.api_key = 'test_api_key'
|
||||
self.config.default_timeout = 30
|
||||
self.config.polling_interval = 1
|
||||
self.config.max_polling_attempts = 5
|
||||
self.config.get_api_url.side_effect = lambda endpoint: f"https://api.test.com{endpoint}"
|
||||
|
||||
self.client = FluxEditClient(self.config)
|
||||
|
||||
# Sample request
|
||||
self.sample_request = FluxEditRequest(
|
||||
input_image_b64='test_base64_data',
|
||||
prompt='Make the sky blue',
|
||||
seed=12345,
|
||||
aspect_ratio='16:9'
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures"""
|
||||
# Ensure client session is closed
|
||||
if hasattr(self.client, 'session') and self.client.session:
|
||||
asyncio.create_task(self.client.close())
|
||||
|
||||
@patch('aiohttp.ClientSession')
|
||||
async def test_create_edit_request_success(self, mock_session_class):
|
||||
"""Test successful edit request creation"""
|
||||
# Mock session and response
|
||||
mock_session = AsyncMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json.return_value = {'id': 'request_12345'}
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Test
|
||||
request_id = await self.client._create_edit_request(self.sample_request)
|
||||
|
||||
# Verify
|
||||
self.assertEqual(request_id, 'request_12345')
|
||||
mock_session.post.assert_called_once()
|
||||
|
||||
# Check payload structure
|
||||
call_args = mock_session.post.call_args
|
||||
payload = call_args[1]['json']
|
||||
self.assertEqual(payload['prompt'], 'Make the sky blue')
|
||||
self.assertEqual(payload['seed'], 12345)
|
||||
self.assertEqual(payload['input_image'], 'test_base64_data')
|
||||
|
||||
@patch('aiohttp.ClientSession')
|
||||
async def test_create_edit_request_failure(self, mock_session_class):
|
||||
"""Test edit request creation failure"""
|
||||
# Mock session and response
|
||||
mock_session = AsyncMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 400
|
||||
mock_response.text.return_value = 'Bad Request'
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Test
|
||||
request_id = await self.client._create_edit_request(self.sample_request)
|
||||
|
||||
# Verify
|
||||
self.assertIsNone(request_id)
|
||||
|
||||
@patch('aiohttp.ClientSession')
|
||||
async def test_poll_result_success(self, mock_session_class):
|
||||
"""Test successful result polling"""
|
||||
# Mock session and responses
|
||||
mock_session = AsyncMock()
|
||||
|
||||
# First response: processing
|
||||
mock_response_processing = AsyncMock()
|
||||
mock_response_processing.status = 200
|
||||
mock_response_processing.json.return_value = {'status': 'processing'}
|
||||
|
||||
# Second response: ready
|
||||
mock_response_ready = AsyncMock()
|
||||
mock_response_ready.status = 200
|
||||
mock_response_ready.json.return_value = {
|
||||
'status': 'ready',
|
||||
'result': {'sample': 'https://example.com/image.png'}
|
||||
}
|
||||
|
||||
# Mock to return processing first, then ready
|
||||
mock_session.get.return_value.__aenter__.side_effect = [
|
||||
mock_response_processing,
|
||||
mock_response_ready
|
||||
]
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Test
|
||||
result = await self.client._poll_result('test_request_id')
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result['status'], 'ready')
|
||||
self.assertIn('result', result)
|
||||
self.assertEqual(mock_session.get.call_count, 2)
|
||||
|
||||
@patch('aiohttp.ClientSession')
|
||||
async def test_poll_result_timeout(self, mock_session_class):
|
||||
"""Test polling timeout"""
|
||||
# Mock session to always return processing
|
||||
mock_session = AsyncMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json.return_value = {'status': 'processing'}
|
||||
|
||||
mock_session.get.return_value.__aenter__.return_value = mock_response
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Test
|
||||
result = await self.client._poll_result('test_request_id')
|
||||
|
||||
# Verify - should timeout after max attempts
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(mock_session.get.call_count, self.config.max_polling_attempts)
|
||||
|
||||
@patch('aiohttp.ClientSession')
|
||||
async def test_download_result_image_success(self, mock_session_class):
|
||||
"""Test successful image download"""
|
||||
# Mock session and response
|
||||
mock_session = AsyncMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read.return_value = b'fake_image_data'
|
||||
|
||||
mock_session.get.return_value.__aenter__.return_value = mock_response
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Test
|
||||
image_data = await self.client._download_result_image('https://example.com/image.png')
|
||||
|
||||
# Verify
|
||||
self.assertEqual(image_data, b'fake_image_data')
|
||||
mock_session.get.assert_called_once_with('https://example.com/image.png')
|
||||
|
||||
@patch('aiohttp.ClientSession')
|
||||
async def test_download_result_image_failure(self, mock_session_class):
|
||||
"""Test image download failure"""
|
||||
# Mock session and response
|
||||
mock_session = AsyncMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 404
|
||||
|
||||
mock_session.get.return_value.__aenter__.return_value = mock_response
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Test
|
||||
image_data = await self.client._download_result_image('https://example.com/image.png')
|
||||
|
||||
# Verify
|
||||
self.assertIsNone(image_data)
|
||||
|
||||
def test_get_image_size(self):
|
||||
"""Test image size detection from bytes"""
|
||||
# Create a small test image in memory
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
# Create 10x20 test image
|
||||
img = Image.new('RGB', (10, 20), color='red')
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
image_data = buffer.getvalue()
|
||||
|
||||
# Test
|
||||
size = self.client._get_image_size(image_data)
|
||||
|
||||
# Verify
|
||||
self.assertEqual(size, (10, 20))
|
||||
|
||||
def test_get_image_size_invalid_data(self):
|
||||
"""Test image size detection with invalid data"""
|
||||
size = self.client._get_image_size(b'invalid_image_data')
|
||||
self.assertIsNone(size)
|
||||
|
||||
@patch.object(FluxEditClient, '_create_edit_request')
|
||||
@patch.object(FluxEditClient, '_poll_result')
|
||||
@patch.object(FluxEditClient, '_download_result_image')
|
||||
async def test_edit_image_success(self, mock_download, mock_poll, mock_create):
|
||||
"""Test complete successful edit flow"""
|
||||
# Setup mocks
|
||||
mock_create.return_value = 'request_123'
|
||||
mock_poll.return_value = {
|
||||
'status': 'ready',
|
||||
'result': {'sample': 'https://example.com/result.png'}
|
||||
}
|
||||
mock_download.return_value = b'edited_image_data'
|
||||
|
||||
# Test
|
||||
response = await self.client.edit_image(self.sample_request)
|
||||
|
||||
# Verify
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.edited_image_data, b'edited_image_data')
|
||||
self.assertEqual(response.request_id, 'request_123')
|
||||
self.assertEqual(response.result_url, 'https://example.com/result.png')
|
||||
self.assertGreater(response.execution_time, 0)
|
||||
|
||||
@patch.object(FluxEditClient, '_create_edit_request')
|
||||
async def test_edit_image_create_failure(self, mock_create):
|
||||
"""Test edit flow with creation failure"""
|
||||
# Setup mock
|
||||
mock_create.return_value = None
|
||||
|
||||
# Test
|
||||
response = await self.client.edit_image(self.sample_request)
|
||||
|
||||
# Verify
|
||||
self.assertFalse(response.success)
|
||||
self.assertIn('Failed to create edit request', response.error_message)
|
||||
|
||||
@patch.object(FluxEditClient, '_create_edit_request')
|
||||
@patch.object(FluxEditClient, '_poll_result')
|
||||
async def test_edit_image_poll_failure(self, mock_poll, mock_create):
|
||||
"""Test edit flow with polling failure"""
|
||||
# Setup mocks
|
||||
mock_create.return_value = 'request_123'
|
||||
mock_poll.return_value = None
|
||||
|
||||
# Test
|
||||
response = await self.client.edit_image(self.sample_request)
|
||||
|
||||
# Verify
|
||||
self.assertFalse(response.success)
|
||||
self.assertIn('Failed to get edit result', response.error_message)
|
||||
|
||||
async def test_context_manager(self):
|
||||
"""Test async context manager functionality"""
|
||||
async with FluxEditClient(self.config) as client:
|
||||
self.assertIsInstance(client, FluxEditClient)
|
||||
|
||||
# Session should be closed after context
|
||||
if hasattr(client, 'session'):
|
||||
self.assertTrue(client.session is None or client.session.closed)
|
||||
|
||||
|
||||
# Test helper to run async tests
|
||||
class AsyncTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.close()
|
||||
|
||||
def async_test(self, coro):
|
||||
return self.loop.run_until_complete(coro)
|
||||
|
||||
|
||||
class TestFluxEditClientAsync(AsyncTestCase):
|
||||
"""Async test runner for FLUX client tests"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.config = MagicMock(spec=Config)
|
||||
self.config.api_key = 'test_api_key'
|
||||
self.config.default_timeout = 30
|
||||
self.config.polling_interval = 0.1 # Faster for tests
|
||||
self.config.max_polling_attempts = 3
|
||||
self.config.get_api_url.side_effect = lambda endpoint: f"https://api.test.com{endpoint}"
|
||||
|
||||
self.client = FluxEditClient(self.config)
|
||||
self.sample_request = FluxEditRequest(
|
||||
input_image_b64='test_base64_data',
|
||||
prompt='Make the sky blue',
|
||||
seed=12345,
|
||||
aspect_ratio='16:9'
|
||||
)
|
||||
|
||||
def test_async_context_manager(self):
|
||||
"""Test async context manager"""
|
||||
async def run_test():
|
||||
async with FluxEditClient(self.config) as client:
|
||||
self.assertIsInstance(client, FluxEditClient)
|
||||
return True
|
||||
|
||||
result = self.async_test(run_test())
|
||||
self.assertTrue(result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,360 +0,0 @@
|
||||
"""Unit tests for MCP tool handlers"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
# Add src to path
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
||||
|
||||
from src.server.handlers import ToolHandlers
|
||||
from src.connector.config import Config
|
||||
from src.connector.flux_client import FluxEditResponse
|
||||
from mcp.types import TextContent, ImageContent
|
||||
|
||||
|
||||
class AsyncTestCase(unittest.TestCase):
|
||||
"""Base class for async tests"""
|
||||
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.close()
|
||||
|
||||
def async_test(self, coro):
|
||||
return self.loop.run_until_complete(coro)
|
||||
|
||||
|
||||
class TestToolHandlers(AsyncTestCase):
|
||||
"""Test cases for MCP tool handlers"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# Create temporary directory
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.temp_path = Path(self.temp_dir)
|
||||
|
||||
# Mock config
|
||||
self.config = MagicMock(spec=Config)
|
||||
self.config.base_path = self.temp_path
|
||||
self.config.input_path = self.temp_path / 'input_images'
|
||||
self.config.generated_images_path = self.temp_path / 'generated_images'
|
||||
self.config.max_image_size_mb = 20
|
||||
self.config.default_aspect_ratio = '1:1'
|
||||
self.config.safety_tolerance = 2
|
||||
self.config.OUTPUT_FORMAT = 'png'
|
||||
self.config.prompt_upsampling = False
|
||||
self.config.MODEL_NAME = 'flux-kontext-pro'
|
||||
self.config.save_parameters = True
|
||||
|
||||
# Setup directories
|
||||
self.config.input_path.mkdir(parents=True, exist_ok=True)
|
||||
self.config.generated_images_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Mock config methods
|
||||
self.config.ensure_output_directory.return_value = None
|
||||
self.config.generate_base_name.return_value = 'fluxedit_12345_20250826_143022'
|
||||
self.config.generate_filename.side_effect = lambda base, num, ext: f'{base}_{num:03d}.{ext}'
|
||||
self.config.get_output_path.side_effect = lambda base, num, ext: self.config.generated_images_path / f'{base}_{num:03d}.{ext}'
|
||||
|
||||
# Create test image
|
||||
self.test_image = Image.new('RGB', (100, 100), color='blue')
|
||||
buffer = io.BytesIO()
|
||||
self.test_image.save(buffer, format='PNG')
|
||||
self.test_image_data = buffer.getvalue()
|
||||
self.test_image_b64 = self._encode_image_b64(self.test_image_data)
|
||||
|
||||
# Create test image file
|
||||
self.test_image_file = self.config.input_path / 'test.png'
|
||||
self.test_image.save(self.test_image_file)
|
||||
|
||||
# Initialize handlers
|
||||
self.handlers = ToolHandlers(self.config)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures"""
|
||||
import shutil
|
||||
super().tearDown()
|
||||
if self.temp_path.exists():
|
||||
shutil.rmtree(self.temp_path)
|
||||
|
||||
def _encode_image_b64(self, image_data: bytes) -> str:
|
||||
"""Helper to encode image as base64"""
|
||||
import base64
|
||||
return base64.b64encode(image_data).decode('utf-8')
|
||||
|
||||
def _create_mock_flux_response(self, success: bool = True) -> FluxEditResponse:
|
||||
"""Helper to create mock FLUX response"""
|
||||
if success:
|
||||
return FluxEditResponse(
|
||||
success=True,
|
||||
edited_image_data=self.test_image_data,
|
||||
image_size=(100, 100),
|
||||
execution_time=5.5,
|
||||
request_id='test_request_123',
|
||||
result_url='https://example.com/result.png',
|
||||
metadata={'seed': 12345}
|
||||
)
|
||||
else:
|
||||
return FluxEditResponse(
|
||||
success=False,
|
||||
error_message='FLUX edit failed',
|
||||
execution_time=2.0
|
||||
)
|
||||
|
||||
def test_flux_edit_image_parameter_validation_failure(self):
|
||||
"""Test flux_edit_image with invalid parameters"""
|
||||
async def run_test():
|
||||
# Missing required parameter
|
||||
arguments = {
|
||||
'prompt': 'Make it blue',
|
||||
'seed': 12345
|
||||
# Missing input_image_b64
|
||||
}
|
||||
|
||||
result = await self.handlers.handle_flux_edit_image(arguments)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TextContent)
|
||||
self.assertIn('Parameter validation failed', result[0].text)
|
||||
self.assertIn('input_image_b64 is required', result[0].text)
|
||||
|
||||
self.async_test(run_test())
|
||||
|
||||
@patch('src.server.handlers.FluxEditClient')
|
||||
def test_flux_edit_image_success(self, mock_client_class):
|
||||
"""Test successful flux_edit_image"""
|
||||
async def run_test():
|
||||
# Setup mock client
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
mock_client.edit_image.return_value = self._create_mock_flux_response(success=True)
|
||||
|
||||
# Valid arguments
|
||||
arguments = {
|
||||
'input_image_b64': self.test_image_b64,
|
||||
'prompt': 'Make the sky blue',
|
||||
'seed': 12345,
|
||||
'aspect_ratio': '16:9',
|
||||
'save_to_file': True
|
||||
}
|
||||
|
||||
result = await self.handlers.handle_flux_edit_image(arguments)
|
||||
|
||||
# Verify result structure
|
||||
self.assertGreater(len(result), 0)
|
||||
self.assertIsInstance(result[0], TextContent)
|
||||
self.assertIn('✅ Image edited successfully', result[0].text)
|
||||
self.assertIn('Seed: 12345', result[0].text)
|
||||
|
||||
# Should have image preview
|
||||
if len(result) > 1:
|
||||
self.assertIsInstance(result[1], ImageContent)
|
||||
|
||||
self.async_test(run_test())
|
||||
|
||||
@patch('src.server.handlers.FluxEditClient')
|
||||
def test_flux_edit_image_flux_failure(self, mock_client_class):
|
||||
"""Test flux_edit_image with FLUX API failure"""
|
||||
async def run_test():
|
||||
# Setup mock client to return failure
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
mock_client.edit_image.return_value = self._create_mock_flux_response(success=False)
|
||||
|
||||
arguments = {
|
||||
'input_image_b64': self.test_image_b64,
|
||||
'prompt': 'Make it blue',
|
||||
'seed': 12345
|
||||
}
|
||||
|
||||
result = await self.handlers.handle_flux_edit_image(arguments)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TextContent)
|
||||
self.assertIn('❌ FLUX edit failed', result[0].text)
|
||||
|
||||
self.async_test(run_test())
|
||||
|
||||
def test_flux_edit_image_from_file_not_found(self):
|
||||
"""Test flux_edit_image_from_file with non-existent file"""
|
||||
async def run_test():
|
||||
arguments = {
|
||||
'input_image_name': 'nonexistent.png',
|
||||
'prompt': 'Edit this',
|
||||
'seed': 12345
|
||||
}
|
||||
|
||||
result = await self.handlers.handle_flux_edit_image_from_file(arguments)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TextContent)
|
||||
self.assertIn('❌ File not found', result[0].text)
|
||||
self.assertIn('nonexistent.png', result[0].text)
|
||||
|
||||
self.async_test(run_test())
|
||||
|
||||
@patch('src.server.handlers.FluxEditClient')
|
||||
def test_flux_edit_image_from_file_success(self, mock_client_class):
|
||||
"""Test successful flux_edit_image_from_file"""
|
||||
async def run_test():
|
||||
# Setup mock client
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
mock_client.edit_image.return_value = self._create_mock_flux_response(success=True)
|
||||
|
||||
arguments = {
|
||||
'input_image_name': 'test.png',
|
||||
'prompt': 'Make it awesome',
|
||||
'seed': 54321,
|
||||
'save_to_file': True
|
||||
}
|
||||
|
||||
result = await self.handlers.handle_flux_edit_image_from_file(arguments)
|
||||
|
||||
# Verify result
|
||||
self.assertGreater(len(result), 0)
|
||||
self.assertIsInstance(result[0], TextContent)
|
||||
self.assertIn('✅ Image edited successfully from file', result[0].text)
|
||||
self.assertIn('test.png', result[0].text)
|
||||
self.assertIn('Seed: 54321', result[0].text)
|
||||
|
||||
self.async_test(run_test())
|
||||
|
||||
def test_validate_image_success(self):
|
||||
"""Test successful image validation"""
|
||||
async def run_test():
|
||||
arguments = {'image_path': str(self.test_image_file)}
|
||||
|
||||
result = await self.handlers.handle_validate_image(arguments)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TextContent)
|
||||
self.assertIn('✅ Image validation passed', result[0].text)
|
||||
self.assertIn('100x100', result[0].text)
|
||||
|
||||
self.async_test(run_test())
|
||||
|
||||
def test_validate_image_not_found(self):
|
||||
"""Test image validation with non-existent file"""
|
||||
async def run_test():
|
||||
arguments = {'image_path': '/nonexistent/path.png'}
|
||||
|
||||
result = await self.handlers.handle_validate_image(arguments)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TextContent)
|
||||
self.assertIn('❌ Image validation failed', result[0].text)
|
||||
self.assertIn('File not found', result[0].text)
|
||||
|
||||
self.async_test(run_test())
|
||||
|
||||
def test_move_temp_to_output_success(self):
|
||||
"""Test successful file move operation"""
|
||||
async def run_test():
|
||||
# Create temp directory and file
|
||||
temp_dir = self.config.base_path / 'temp'
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
temp_file = temp_dir / 'temp_test.png'
|
||||
self.test_image.save(temp_file)
|
||||
|
||||
arguments = {
|
||||
'temp_file_name': 'temp_test.png',
|
||||
'output_file_name': 'moved_test.png',
|
||||
'copy_only': False
|
||||
}
|
||||
|
||||
result = await self.handlers.handle_move_temp_to_output(arguments)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TextContent)
|
||||
self.assertIn('✅ File moved successfully', result[0].text)
|
||||
self.assertIn('temp_test.png', result[0].text)
|
||||
self.assertIn('moved_test.png', result[0].text)
|
||||
|
||||
self.async_test(run_test())
|
||||
|
||||
def test_move_temp_to_output_not_found(self):
|
||||
"""Test file move with non-existent temp file"""
|
||||
async def run_test():
|
||||
arguments = {
|
||||
'temp_file_name': 'nonexistent.png',
|
||||
'copy_only': False
|
||||
}
|
||||
|
||||
result = await self.handlers.handle_move_temp_to_output(arguments)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TextContent)
|
||||
self.assertIn('❌ Temp file not found', result[0].text)
|
||||
|
||||
self.async_test(run_test())
|
||||
|
||||
def test_move_temp_to_output_copy_only(self):
|
||||
"""Test copy-only file operation"""
|
||||
async def run_test():
|
||||
# Create temp directory and file
|
||||
temp_dir = self.config.base_path / 'temp'
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
temp_file = temp_dir / 'temp_copy.png'
|
||||
self.test_image.save(temp_file)
|
||||
|
||||
arguments = {
|
||||
'temp_file_name': 'temp_copy.png',
|
||||
'copy_only': True
|
||||
}
|
||||
|
||||
result = await self.handlers.handle_move_temp_to_output(arguments)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TextContent)
|
||||
self.assertIn('✅ File copied successfully', result[0].text)
|
||||
|
||||
# Original file should still exist
|
||||
self.assertTrue(temp_file.exists())
|
||||
|
||||
self.async_test(run_test())
|
||||
|
||||
def test_seed_management(self):
|
||||
"""Test seed creation and reset functionality"""
|
||||
# Test seed generation
|
||||
seed1 = self.handlers._get_or_create_seed()
|
||||
seed2 = self.handlers._get_or_create_seed()
|
||||
|
||||
# Should return same seed for session
|
||||
self.assertEqual(seed1, seed2)
|
||||
|
||||
# Test seed reset
|
||||
self.handlers._reset_seed()
|
||||
seed3 = self.handlers._get_or_create_seed()
|
||||
|
||||
# Should be different after reset
|
||||
self.assertNotEqual(seed1, seed3)
|
||||
|
||||
def test_temp_file_operations(self):
|
||||
"""Test temporary file save and move operations"""
|
||||
# Test saving b64 to temp
|
||||
filename = 'test_temp.png'
|
||||
temp_path = self.handlers._save_b64_to_temp_file(self.test_image_b64, filename)
|
||||
|
||||
self.assertTrue(Path(temp_path).exists())
|
||||
self.assertIn(filename, temp_path)
|
||||
|
||||
# Test moving to generated images
|
||||
base_name = 'test_base'
|
||||
moved_path = self.handlers._move_temp_to_generated(temp_path, base_name, 1)
|
||||
|
||||
self.assertTrue(Path(moved_path).exists())
|
||||
self.assertIn('test_base_001.png', moved_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,197 +0,0 @@
|
||||
"""Unit tests for image utilities"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
# Add src to path
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
||||
|
||||
from src.utils.image_utils import (
|
||||
get_file_size_mb,
|
||||
validate_image_file,
|
||||
get_image_dimensions,
|
||||
get_image_dimensions_from_bytes,
|
||||
encode_image_base64,
|
||||
decode_image_base64,
|
||||
save_image,
|
||||
get_optimal_aspect_ratio,
|
||||
validate_aspect_ratio,
|
||||
convert_image_to_base64
|
||||
)
|
||||
|
||||
|
||||
class TestImageUtils(unittest.TestCase):
|
||||
"""Test cases for image utility functions"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.temp_path = Path(self.temp_dir)
|
||||
|
||||
# Create a test image
|
||||
self.test_image = Image.new('RGB', (100, 100), color='red')
|
||||
self.test_image_path = self.temp_path / 'test_image.png'
|
||||
self.test_image.save(self.test_image_path)
|
||||
|
||||
# Create test image data
|
||||
buffer = io.BytesIO()
|
||||
self.test_image.save(buffer, format='PNG')
|
||||
self.test_image_data = buffer.getvalue()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures"""
|
||||
import shutil
|
||||
if self.temp_path.exists():
|
||||
shutil.rmtree(self.temp_path)
|
||||
|
||||
def test_get_file_size_mb(self):
|
||||
"""Test file size calculation"""
|
||||
size_mb = get_file_size_mb(self.test_image_path)
|
||||
self.assertGreater(size_mb, 0)
|
||||
self.assertLess(size_mb, 1) # Small test image should be < 1MB
|
||||
|
||||
def test_validate_image_file_success(self):
|
||||
"""Test successful image validation"""
|
||||
is_valid, size_mb, error = validate_image_file(str(self.test_image_path), 20)
|
||||
|
||||
self.assertTrue(is_valid)
|
||||
self.assertGreater(size_mb, 0)
|
||||
self.assertIsNone(error)
|
||||
|
||||
def test_validate_image_file_not_found(self):
|
||||
"""Test validation of non-existent file"""
|
||||
is_valid, size_mb, error = validate_image_file('nonexistent.png', 20)
|
||||
|
||||
self.assertFalse(is_valid)
|
||||
self.assertEqual(size_mb, 0)
|
||||
self.assertIn('File not found', error)
|
||||
|
||||
def test_validate_image_file_too_large(self):
|
||||
"""Test validation of file too large"""
|
||||
# Test with very small limit
|
||||
is_valid, size_mb, error = validate_image_file(str(self.test_image_path), 0.001)
|
||||
|
||||
self.assertFalse(is_valid)
|
||||
self.assertIn('exceeds', error)
|
||||
|
||||
def test_get_image_dimensions(self):
|
||||
"""Test getting image dimensions"""
|
||||
width, height = get_image_dimensions(str(self.test_image_path))
|
||||
self.assertEqual(width, 100)
|
||||
self.assertEqual(height, 100)
|
||||
|
||||
def test_get_image_dimensions_from_bytes(self):
|
||||
"""Test getting dimensions from image bytes"""
|
||||
width, height = get_image_dimensions_from_bytes(self.test_image_data)
|
||||
self.assertEqual(width, 100)
|
||||
self.assertEqual(height, 100)
|
||||
|
||||
def test_encode_decode_base64(self):
|
||||
"""Test base64 encoding and decoding"""
|
||||
# Encode
|
||||
b64_string = encode_image_base64(self.test_image_data)
|
||||
self.assertIsInstance(b64_string, str)
|
||||
|
||||
# Decode
|
||||
decoded_data = decode_image_base64(b64_string)
|
||||
self.assertEqual(decoded_data, self.test_image_data)
|
||||
|
||||
def test_decode_base64_with_data_url(self):
|
||||
"""Test decoding base64 with data URL prefix"""
|
||||
b64_string = encode_image_base64(self.test_image_data)
|
||||
data_url = f"data:image/png;base64,{b64_string}"
|
||||
|
||||
decoded_data = decode_image_base64(data_url)
|
||||
self.assertEqual(decoded_data, self.test_image_data)
|
||||
|
||||
def test_save_image(self):
|
||||
"""Test saving image data to file"""
|
||||
output_path = self.temp_path / 'output.png'
|
||||
|
||||
success = save_image(self.test_image_data, str(output_path))
|
||||
self.assertTrue(success)
|
||||
self.assertTrue(output_path.exists())
|
||||
|
||||
# Verify file content
|
||||
with open(output_path, 'rb') as f:
|
||||
saved_data = f.read()
|
||||
self.assertEqual(saved_data, self.test_image_data)
|
||||
|
||||
def test_get_optimal_aspect_ratio(self):
|
||||
"""Test optimal aspect ratio calculation"""
|
||||
# Test square image
|
||||
ratio = get_optimal_aspect_ratio(100, 100)
|
||||
self.assertEqual(ratio, "1:1")
|
||||
|
||||
# Test wide image
|
||||
ratio = get_optimal_aspect_ratio(160, 90)
|
||||
self.assertEqual(ratio, "16:9")
|
||||
|
||||
# Test tall image
|
||||
ratio = get_optimal_aspect_ratio(90, 160)
|
||||
self.assertEqual(ratio, "9:16")
|
||||
|
||||
def test_validate_aspect_ratio(self):
|
||||
"""Test aspect ratio validation"""
|
||||
# Test matching ratio
|
||||
self.assertTrue(validate_aspect_ratio(100, 100, "1:1"))
|
||||
self.assertTrue(validate_aspect_ratio(160, 90, "16:9"))
|
||||
|
||||
# Test non-matching ratio (within tolerance)
|
||||
self.assertTrue(validate_aspect_ratio(161, 90, "16:9")) # Small difference
|
||||
|
||||
# Test non-matching ratio (outside tolerance)
|
||||
self.assertFalse(validate_aspect_ratio(200, 100, "1:1"))
|
||||
|
||||
def test_convert_image_to_base64(self):
|
||||
"""Test converting image file to base64"""
|
||||
b64_string = convert_image_to_base64(str(self.test_image_path))
|
||||
|
||||
self.assertIsInstance(b64_string, str)
|
||||
|
||||
# Verify we can decode it back
|
||||
decoded_data = decode_image_base64(b64_string)
|
||||
|
||||
# Images should have same dimensions
|
||||
width, height = get_image_dimensions_from_bytes(decoded_data)
|
||||
self.assertEqual(width, 100)
|
||||
self.assertEqual(height, 100)
|
||||
|
||||
def create_large_image_file(self, size_mb: float) -> Path:
|
||||
"""Helper to create a large image file for testing"""
|
||||
# Calculate dimensions for target size (rough estimate)
|
||||
# PNG compression varies, so this is approximate
|
||||
pixels = int((size_mb * 1024 * 1024) / 4) # 4 bytes per pixel (RGBA)
|
||||
dimension = int(pixels ** 0.5)
|
||||
|
||||
large_image = Image.new('RGBA', (dimension, dimension), color='red')
|
||||
large_image_path = self.temp_path / 'large_image.png'
|
||||
large_image.save(large_image_path)
|
||||
|
||||
return large_image_path
|
||||
|
||||
def test_large_image_handling(self):
|
||||
"""Test handling of large images"""
|
||||
# This test might be slow, so we'll use a smaller "large" image
|
||||
try:
|
||||
large_path = self.create_large_image_file(0.1) # 0.1 MB
|
||||
|
||||
# Test validation
|
||||
is_valid, size_mb, error = validate_image_file(str(large_path), 20)
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
# Test conversion to base64
|
||||
b64_string = convert_image_to_base64(str(large_path))
|
||||
self.assertIsInstance(b64_string, str)
|
||||
|
||||
except Exception as e:
|
||||
self.skipTest(f"Large image test skipped due to: {e}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,218 +0,0 @@
|
||||
"""Unit tests for validation utilities"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
# Add src to path
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
||||
|
||||
from src.utils.validation import (
|
||||
validate_edit_parameters,
|
||||
validate_file_parameters,
|
||||
validate_move_file_parameters,
|
||||
validate_image_path_parameter,
|
||||
sanitize_prompt,
|
||||
validate_aspect_ratio_format,
|
||||
validate_seed_range,
|
||||
validate_filename_safety
|
||||
)
|
||||
|
||||
|
||||
class TestValidation(unittest.TestCase):
|
||||
"""Test cases for validation utilities"""
|
||||
|
||||
def test_validate_edit_parameters_success(self):
|
||||
"""Test successful validation of edit parameters"""
|
||||
valid_args = {
|
||||
'input_image_b64': 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==',
|
||||
'prompt': 'Make the sky blue',
|
||||
'seed': 12345,
|
||||
'aspect_ratio': '16:9',
|
||||
'save_to_file': True
|
||||
}
|
||||
|
||||
is_valid, error = validate_edit_parameters(valid_args)
|
||||
self.assertTrue(is_valid)
|
||||
self.assertIsNone(error)
|
||||
|
||||
def test_validate_edit_parameters_missing_required(self):
|
||||
"""Test validation fails with missing required parameters"""
|
||||
# Missing input_image_b64
|
||||
args = {'prompt': 'test', 'seed': 12345}
|
||||
is_valid, error = validate_edit_parameters(args)
|
||||
self.assertFalse(is_valid)
|
||||
self.assertIn('input_image_b64 is required', error)
|
||||
|
||||
# Missing prompt
|
||||
args = {'input_image_b64': 'test_b64', 'seed': 12345}
|
||||
is_valid, error = validate_edit_parameters(args)
|
||||
self.assertFalse(is_valid)
|
||||
self.assertIn('prompt is required', error)
|
||||
|
||||
# Missing seed
|
||||
args = {'input_image_b64': 'test_b64', 'prompt': 'test'}
|
||||
is_valid, error = validate_edit_parameters(args)
|
||||
self.assertFalse(is_valid)
|
||||
self.assertIn('seed is required', error)
|
||||
|
||||
def test_validate_edit_parameters_invalid_types(self):
|
||||
"""Test validation fails with invalid parameter types"""
|
||||
# Invalid seed type
|
||||
args = {
|
||||
'input_image_b64': 'valid_b64',
|
||||
'prompt': 'test prompt',
|
||||
'seed': 'not_a_number'
|
||||
}
|
||||
is_valid, error = validate_edit_parameters(args)
|
||||
self.assertFalse(is_valid)
|
||||
self.assertIn('seed must be an integer', error)
|
||||
|
||||
# Invalid aspect ratio
|
||||
args = {
|
||||
'input_image_b64': 'valid_b64',
|
||||
'prompt': 'test prompt',
|
||||
'seed': 12345,
|
||||
'aspect_ratio': 'invalid_ratio'
|
||||
}
|
||||
is_valid, error = validate_edit_parameters(args)
|
||||
self.assertFalse(is_valid)
|
||||
self.assertIn('aspect_ratio must be one of', error)
|
||||
|
||||
def test_validate_edit_parameters_invalid_ranges(self):
|
||||
"""Test validation fails with invalid parameter ranges"""
|
||||
# Seed out of range
|
||||
args = {
|
||||
'input_image_b64': 'valid_b64',
|
||||
'prompt': 'test prompt',
|
||||
'seed': -1
|
||||
}
|
||||
is_valid, error = validate_edit_parameters(args)
|
||||
self.assertFalse(is_valid)
|
||||
self.assertIn('seed must be between', error)
|
||||
|
||||
# Prompt too long
|
||||
args = {
|
||||
'input_image_b64': 'valid_b64',
|
||||
'prompt': 'x' * 10001, # Too long
|
||||
'seed': 12345
|
||||
}
|
||||
is_valid, error = validate_edit_parameters(args)
|
||||
self.assertFalse(is_valid)
|
||||
self.assertIn('prompt is too long', error)
|
||||
|
||||
def test_validate_file_parameters_success(self):
|
||||
"""Test successful validation of file parameters"""
|
||||
valid_args = {
|
||||
'input_image_name': 'test.png',
|
||||
'prompt': 'Edit this image',
|
||||
'seed': 54321,
|
||||
'aspect_ratio': '1:1',
|
||||
'save_to_file': False
|
||||
}
|
||||
|
||||
is_valid, error = validate_file_parameters(valid_args)
|
||||
self.assertTrue(is_valid)
|
||||
self.assertIsNone(error)
|
||||
|
||||
def test_validate_file_parameters_invalid_filename(self):
|
||||
"""Test validation fails with invalid filename"""
|
||||
# Path traversal attempt
|
||||
args = {
|
||||
'input_image_name': '../../../etc/passwd',
|
||||
'prompt': 'test',
|
||||
'seed': 12345
|
||||
}
|
||||
is_valid, error = validate_file_parameters(args)
|
||||
self.assertFalse(is_valid)
|
||||
self.assertIn('cannot contain path separators', error)
|
||||
|
||||
# Invalid extension
|
||||
args = {
|
||||
'input_image_name': 'test.exe',
|
||||
'prompt': 'test',
|
||||
'seed': 12345
|
||||
}
|
||||
is_valid, error = validate_file_parameters(args)
|
||||
self.assertFalse(is_valid)
|
||||
self.assertIn('must have a valid image extension', error)
|
||||
|
||||
def test_validate_move_file_parameters_success(self):
|
||||
"""Test successful validation of move file parameters"""
|
||||
valid_args = {
|
||||
'temp_file_name': 'temp_image.png',
|
||||
'output_file_name': 'output.png',
|
||||
'copy_only': True
|
||||
}
|
||||
|
||||
is_valid, error = validate_move_file_parameters(valid_args)
|
||||
self.assertTrue(is_valid)
|
||||
self.assertIsNone(error)
|
||||
|
||||
def test_validate_image_path_parameter_success(self):
|
||||
"""Test successful validation of image path parameter"""
|
||||
valid_args = {'image_path': '/path/to/image.png'}
|
||||
|
||||
is_valid, error = validate_image_path_parameter(valid_args)
|
||||
self.assertTrue(is_valid)
|
||||
self.assertIsNone(error)
|
||||
|
||||
def test_sanitize_prompt(self):
|
||||
"""Test prompt sanitization"""
|
||||
# Test whitespace normalization
|
||||
prompt = " This has extra whitespace "
|
||||
sanitized = sanitize_prompt(prompt)
|
||||
self.assertEqual(sanitized, "This has extra whitespace")
|
||||
|
||||
# Test null byte removal
|
||||
prompt = "Test\x00with\x00nulls"
|
||||
sanitized = sanitize_prompt(prompt)
|
||||
self.assertEqual(sanitized, "Testwithulls")
|
||||
|
||||
# Test length limiting
|
||||
long_prompt = "x" * 10001
|
||||
sanitized = sanitize_prompt(long_prompt)
|
||||
self.assertEqual(len(sanitized), 10000)
|
||||
|
||||
def test_validate_aspect_ratio_format(self):
|
||||
"""Test aspect ratio format validation"""
|
||||
# Valid formats
|
||||
self.assertTrue(validate_aspect_ratio_format("16:9"))
|
||||
self.assertTrue(validate_aspect_ratio_format("1:1"))
|
||||
self.assertTrue(validate_aspect_ratio_format("4:3"))
|
||||
|
||||
# Invalid formats
|
||||
self.assertFalse(validate_aspect_ratio_format("16-9"))
|
||||
self.assertFalse(validate_aspect_ratio_format("16:9:1"))
|
||||
self.assertFalse(validate_aspect_ratio_format("a:b"))
|
||||
self.assertFalse(validate_aspect_ratio_format("0:1"))
|
||||
|
||||
def test_validate_seed_range(self):
|
||||
"""Test seed range validation"""
|
||||
# Valid seeds
|
||||
self.assertTrue(validate_seed_range(0))
|
||||
self.assertTrue(validate_seed_range(12345))
|
||||
self.assertTrue(validate_seed_range(2**32 - 1))
|
||||
|
||||
# Invalid seeds
|
||||
self.assertFalse(validate_seed_range(-1))
|
||||
self.assertFalse(validate_seed_range(2**32))
|
||||
self.assertFalse(validate_seed_range("not_a_number"))
|
||||
|
||||
def test_validate_filename_safety(self):
|
||||
"""Test filename safety validation"""
|
||||
# Safe filenames
|
||||
self.assertTrue(validate_filename_safety("image.png"))
|
||||
self.assertTrue(validate_filename_safety("my_image_123.jpg"))
|
||||
self.assertTrue(validate_filename_safety("test-file.png"))
|
||||
|
||||
# Unsafe filenames
|
||||
self.assertFalse(validate_filename_safety("../image.png"))
|
||||
self.assertFalse(validate_filename_safety("path/to/image.png"))
|
||||
self.assertFalse(validate_filename_safety("image<>.png"))
|
||||
self.assertFalse(validate_filename_safety("CON.png")) # Windows reserved
|
||||
self.assertFalse(validate_filename_safety("x" * 256)) # Too long
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
150
troubleshoot.bat
150
troubleshoot.bat
@@ -1,150 +0,0 @@
|
||||
@echo off
|
||||
echo FLUX.1 Edit MCP Server - Troubleshooting
|
||||
echo =======================================
|
||||
|
||||
echo 1. Checking Python installation...
|
||||
python --version
|
||||
if errorlevel 1 (
|
||||
echo [ERROR] Python is not installed or not in PATH
|
||||
goto :end
|
||||
) else (
|
||||
echo [OK] Python is available
|
||||
)
|
||||
|
||||
echo.
|
||||
echo 2. Checking pip installation...
|
||||
pip --version
|
||||
if errorlevel 1 (
|
||||
echo [ERROR] pip is not available
|
||||
goto :end
|
||||
) else (
|
||||
echo [OK] pip is available
|
||||
)
|
||||
|
||||
echo.
|
||||
echo 3. Checking virtual environment...
|
||||
if exist "venv" (
|
||||
echo [OK] Virtual environment directory exists
|
||||
call venv\Scripts\activate.bat
|
||||
if errorlevel 1 (
|
||||
echo [ERROR] Cannot activate virtual environment
|
||||
goto :end
|
||||
) else (
|
||||
echo [OK] Virtual environment activated
|
||||
)
|
||||
) else (
|
||||
echo [WARNING] Virtual environment not found
|
||||
echo Creating virtual environment...
|
||||
python -m venv venv
|
||||
if errorlevel 1 (
|
||||
echo [ERROR] Failed to create virtual environment
|
||||
goto :end
|
||||
)
|
||||
call venv\Scripts\activate.bat
|
||||
echo [OK] Virtual environment created and activated
|
||||
)
|
||||
|
||||
echo.
|
||||
echo 4. Checking Python in virtual environment...
|
||||
where python
|
||||
python --version
|
||||
|
||||
echo.
|
||||
echo 5. Checking critical dependencies...
|
||||
python -c "import aiohttp; print(f'aiohttp: {aiohttp.__version__}')" 2>nul
|
||||
if errorlevel 1 (
|
||||
echo [ERROR] aiohttp not found - installing...
|
||||
pip install aiohttp==3.11.7
|
||||
) else (
|
||||
echo [OK] aiohttp is available
|
||||
)
|
||||
|
||||
python -c "import httpx; print(f'httpx: {httpx.__version__}')" 2>nul
|
||||
if errorlevel 1 (
|
||||
echo [ERROR] httpx not found - installing...
|
||||
pip install httpx==0.28.1
|
||||
) else (
|
||||
echo [OK] httpx is available
|
||||
)
|
||||
|
||||
python -c "import mcp" 2>nul
|
||||
if errorlevel 1 (
|
||||
echo [ERROR] mcp not found - installing...
|
||||
pip install mcp==1.1.0
|
||||
) else (
|
||||
echo [OK] mcp is available
|
||||
)
|
||||
|
||||
python -c "from PIL import Image; print('Pillow: available')" 2>nul
|
||||
if errorlevel 1 (
|
||||
echo [ERROR] Pillow not found - installing...
|
||||
pip install Pillow==11.0.0
|
||||
) else (
|
||||
echo [OK] Pillow is available
|
||||
)
|
||||
|
||||
echo.
|
||||
echo 6. Checking configuration files...
|
||||
if exist ".env" (
|
||||
echo [OK] .env file exists
|
||||
) else (
|
||||
echo [WARNING] .env file not found
|
||||
if exist ".env.example" (
|
||||
echo Creating .env from example...
|
||||
copy .env.example .env
|
||||
echo [OK] .env file created from example
|
||||
) else (
|
||||
echo [ERROR] .env.example file not found
|
||||
)
|
||||
)
|
||||
|
||||
echo.
|
||||
echo 7. Checking required directories...
|
||||
if not exist "input_images" mkdir input_images & echo [OK] Created input_images directory
|
||||
if not exist "generated_images" mkdir generated_images & echo [OK] Created generated_images directory
|
||||
if not exist "temp" mkdir temp & echo [OK] Created temp directory
|
||||
|
||||
echo.
|
||||
echo 8. Testing basic imports...
|
||||
python -c "
|
||||
import sys
|
||||
print(f'Python executable: {sys.executable}')
|
||||
print(f'Python version: {sys.version}')
|
||||
print('Testing imports...')
|
||||
try:
|
||||
import aiohttp
|
||||
print('✓ aiohttp imported successfully')
|
||||
except ImportError as e:
|
||||
print(f'✗ aiohttp import failed: {e}')
|
||||
|
||||
try:
|
||||
import mcp
|
||||
print('✓ mcp imported successfully')
|
||||
except ImportError as e:
|
||||
print(f'✗ mcp import failed: {e}')
|
||||
|
||||
try:
|
||||
from src.connector import Config
|
||||
print('✓ Local Config imported successfully')
|
||||
except ImportError as e:
|
||||
print(f'✗ Local Config import failed: {e}')
|
||||
|
||||
try:
|
||||
from src.server import main
|
||||
print('✓ Local server main imported successfully')
|
||||
except ImportError as e:
|
||||
print(f'✗ Local server main import failed: {e}')
|
||||
"
|
||||
|
||||
echo.
|
||||
echo Troubleshooting complete!
|
||||
echo.
|
||||
echo If you still have issues:
|
||||
echo 1. Delete venv folder and run install_dependencies.bat
|
||||
echo 2. Make sure you have a stable internet connection
|
||||
echo 3. Check if your antivirus is blocking Python/pip
|
||||
echo 4. Try running as administrator
|
||||
echo.
|
||||
|
||||
:end
|
||||
pause
|
||||
Reference in New Issue
Block a user