361 lines
13 KiB
Python
361 lines
13 KiB
Python
"""Unit tests for MCP tool handlers"""
|
|
|
|
import unittest
|
|
import tempfile
|
|
import asyncio
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from PIL import Image
|
|
import io
|
|
|
|
# Add src to path
|
|
import sys
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
|
|
|
from src.server.handlers import ToolHandlers
|
|
from src.connector.config import Config
|
|
from src.connector.flux_client import FluxEditResponse
|
|
from mcp.types import TextContent, ImageContent
|
|
|
|
|
|
class AsyncTestCase(unittest.TestCase):
|
|
"""Base class for async tests"""
|
|
|
|
def setUp(self):
|
|
self.loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self.loop)
|
|
|
|
def tearDown(self):
|
|
self.loop.close()
|
|
|
|
def async_test(self, coro):
|
|
return self.loop.run_until_complete(coro)
|
|
|
|
|
|
class TestToolHandlers(AsyncTestCase):
|
|
"""Test cases for MCP tool handlers"""
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
|
|
# Create temporary directory
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
self.temp_path = Path(self.temp_dir)
|
|
|
|
# Mock config
|
|
self.config = MagicMock(spec=Config)
|
|
self.config.base_path = self.temp_path
|
|
self.config.input_path = self.temp_path / 'input_images'
|
|
self.config.generated_images_path = self.temp_path / 'generated_images'
|
|
self.config.max_image_size_mb = 20
|
|
self.config.default_aspect_ratio = '1:1'
|
|
self.config.safety_tolerance = 2
|
|
self.config.OUTPUT_FORMAT = 'png'
|
|
self.config.prompt_upsampling = False
|
|
self.config.MODEL_NAME = 'flux-kontext-pro'
|
|
self.config.save_parameters = True
|
|
|
|
# Setup directories
|
|
self.config.input_path.mkdir(parents=True, exist_ok=True)
|
|
self.config.generated_images_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Mock config methods
|
|
self.config.ensure_output_directory.return_value = None
|
|
self.config.generate_base_name.return_value = 'fluxedit_12345_20250826_143022'
|
|
self.config.generate_filename.side_effect = lambda base, num, ext: f'{base}_{num:03d}.{ext}'
|
|
self.config.get_output_path.side_effect = lambda base, num, ext: self.config.generated_images_path / f'{base}_{num:03d}.{ext}'
|
|
|
|
# Create test image
|
|
self.test_image = Image.new('RGB', (100, 100), color='blue')
|
|
buffer = io.BytesIO()
|
|
self.test_image.save(buffer, format='PNG')
|
|
self.test_image_data = buffer.getvalue()
|
|
self.test_image_b64 = self._encode_image_b64(self.test_image_data)
|
|
|
|
# Create test image file
|
|
self.test_image_file = self.config.input_path / 'test.png'
|
|
self.test_image.save(self.test_image_file)
|
|
|
|
# Initialize handlers
|
|
self.handlers = ToolHandlers(self.config)
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures"""
|
|
import shutil
|
|
super().tearDown()
|
|
if self.temp_path.exists():
|
|
shutil.rmtree(self.temp_path)
|
|
|
|
def _encode_image_b64(self, image_data: bytes) -> str:
|
|
"""Helper to encode image as base64"""
|
|
import base64
|
|
return base64.b64encode(image_data).decode('utf-8')
|
|
|
|
def _create_mock_flux_response(self, success: bool = True) -> FluxEditResponse:
|
|
"""Helper to create mock FLUX response"""
|
|
if success:
|
|
return FluxEditResponse(
|
|
success=True,
|
|
edited_image_data=self.test_image_data,
|
|
image_size=(100, 100),
|
|
execution_time=5.5,
|
|
request_id='test_request_123',
|
|
result_url='https://example.com/result.png',
|
|
metadata={'seed': 12345}
|
|
)
|
|
else:
|
|
return FluxEditResponse(
|
|
success=False,
|
|
error_message='FLUX edit failed',
|
|
execution_time=2.0
|
|
)
|
|
|
|
def test_flux_edit_image_parameter_validation_failure(self):
|
|
"""Test flux_edit_image with invalid parameters"""
|
|
async def run_test():
|
|
# Missing required parameter
|
|
arguments = {
|
|
'prompt': 'Make it blue',
|
|
'seed': 12345
|
|
# Missing input_image_b64
|
|
}
|
|
|
|
result = await self.handlers.handle_flux_edit_image(arguments)
|
|
|
|
self.assertEqual(len(result), 1)
|
|
self.assertIsInstance(result[0], TextContent)
|
|
self.assertIn('Parameter validation failed', result[0].text)
|
|
self.assertIn('input_image_b64 is required', result[0].text)
|
|
|
|
self.async_test(run_test())
|
|
|
|
@patch('src.server.handlers.FluxEditClient')
|
|
def test_flux_edit_image_success(self, mock_client_class):
|
|
"""Test successful flux_edit_image"""
|
|
async def run_test():
|
|
# Setup mock client
|
|
mock_client = AsyncMock()
|
|
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
mock_client.edit_image.return_value = self._create_mock_flux_response(success=True)
|
|
|
|
# Valid arguments
|
|
arguments = {
|
|
'input_image_b64': self.test_image_b64,
|
|
'prompt': 'Make the sky blue',
|
|
'seed': 12345,
|
|
'aspect_ratio': '16:9',
|
|
'save_to_file': True
|
|
}
|
|
|
|
result = await self.handlers.handle_flux_edit_image(arguments)
|
|
|
|
# Verify result structure
|
|
self.assertGreater(len(result), 0)
|
|
self.assertIsInstance(result[0], TextContent)
|
|
self.assertIn('✅ Image edited successfully', result[0].text)
|
|
self.assertIn('Seed: 12345', result[0].text)
|
|
|
|
# Should have image preview
|
|
if len(result) > 1:
|
|
self.assertIsInstance(result[1], ImageContent)
|
|
|
|
self.async_test(run_test())
|
|
|
|
@patch('src.server.handlers.FluxEditClient')
|
|
def test_flux_edit_image_flux_failure(self, mock_client_class):
|
|
"""Test flux_edit_image with FLUX API failure"""
|
|
async def run_test():
|
|
# Setup mock client to return failure
|
|
mock_client = AsyncMock()
|
|
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
mock_client.edit_image.return_value = self._create_mock_flux_response(success=False)
|
|
|
|
arguments = {
|
|
'input_image_b64': self.test_image_b64,
|
|
'prompt': 'Make it blue',
|
|
'seed': 12345
|
|
}
|
|
|
|
result = await self.handlers.handle_flux_edit_image(arguments)
|
|
|
|
self.assertEqual(len(result), 1)
|
|
self.assertIsInstance(result[0], TextContent)
|
|
self.assertIn('❌ FLUX edit failed', result[0].text)
|
|
|
|
self.async_test(run_test())
|
|
|
|
def test_flux_edit_image_from_file_not_found(self):
|
|
"""Test flux_edit_image_from_file with non-existent file"""
|
|
async def run_test():
|
|
arguments = {
|
|
'input_image_name': 'nonexistent.png',
|
|
'prompt': 'Edit this',
|
|
'seed': 12345
|
|
}
|
|
|
|
result = await self.handlers.handle_flux_edit_image_from_file(arguments)
|
|
|
|
self.assertEqual(len(result), 1)
|
|
self.assertIsInstance(result[0], TextContent)
|
|
self.assertIn('❌ File not found', result[0].text)
|
|
self.assertIn('nonexistent.png', result[0].text)
|
|
|
|
self.async_test(run_test())
|
|
|
|
@patch('src.server.handlers.FluxEditClient')
|
|
def test_flux_edit_image_from_file_success(self, mock_client_class):
|
|
"""Test successful flux_edit_image_from_file"""
|
|
async def run_test():
|
|
# Setup mock client
|
|
mock_client = AsyncMock()
|
|
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
mock_client.edit_image.return_value = self._create_mock_flux_response(success=True)
|
|
|
|
arguments = {
|
|
'input_image_name': 'test.png',
|
|
'prompt': 'Make it awesome',
|
|
'seed': 54321,
|
|
'save_to_file': True
|
|
}
|
|
|
|
result = await self.handlers.handle_flux_edit_image_from_file(arguments)
|
|
|
|
# Verify result
|
|
self.assertGreater(len(result), 0)
|
|
self.assertIsInstance(result[0], TextContent)
|
|
self.assertIn('✅ Image edited successfully from file', result[0].text)
|
|
self.assertIn('test.png', result[0].text)
|
|
self.assertIn('Seed: 54321', result[0].text)
|
|
|
|
self.async_test(run_test())
|
|
|
|
def test_validate_image_success(self):
|
|
"""Test successful image validation"""
|
|
async def run_test():
|
|
arguments = {'image_path': str(self.test_image_file)}
|
|
|
|
result = await self.handlers.handle_validate_image(arguments)
|
|
|
|
self.assertEqual(len(result), 1)
|
|
self.assertIsInstance(result[0], TextContent)
|
|
self.assertIn('✅ Image validation passed', result[0].text)
|
|
self.assertIn('100x100', result[0].text)
|
|
|
|
self.async_test(run_test())
|
|
|
|
def test_validate_image_not_found(self):
|
|
"""Test image validation with non-existent file"""
|
|
async def run_test():
|
|
arguments = {'image_path': '/nonexistent/path.png'}
|
|
|
|
result = await self.handlers.handle_validate_image(arguments)
|
|
|
|
self.assertEqual(len(result), 1)
|
|
self.assertIsInstance(result[0], TextContent)
|
|
self.assertIn('❌ Image validation failed', result[0].text)
|
|
self.assertIn('File not found', result[0].text)
|
|
|
|
self.async_test(run_test())
|
|
|
|
def test_move_temp_to_output_success(self):
|
|
"""Test successful file move operation"""
|
|
async def run_test():
|
|
# Create temp directory and file
|
|
temp_dir = self.config.base_path / 'temp'
|
|
temp_dir.mkdir(exist_ok=True)
|
|
temp_file = temp_dir / 'temp_test.png'
|
|
self.test_image.save(temp_file)
|
|
|
|
arguments = {
|
|
'temp_file_name': 'temp_test.png',
|
|
'output_file_name': 'moved_test.png',
|
|
'copy_only': False
|
|
}
|
|
|
|
result = await self.handlers.handle_move_temp_to_output(arguments)
|
|
|
|
self.assertEqual(len(result), 1)
|
|
self.assertIsInstance(result[0], TextContent)
|
|
self.assertIn('✅ File moved successfully', result[0].text)
|
|
self.assertIn('temp_test.png', result[0].text)
|
|
self.assertIn('moved_test.png', result[0].text)
|
|
|
|
self.async_test(run_test())
|
|
|
|
def test_move_temp_to_output_not_found(self):
|
|
"""Test file move with non-existent temp file"""
|
|
async def run_test():
|
|
arguments = {
|
|
'temp_file_name': 'nonexistent.png',
|
|
'copy_only': False
|
|
}
|
|
|
|
result = await self.handlers.handle_move_temp_to_output(arguments)
|
|
|
|
self.assertEqual(len(result), 1)
|
|
self.assertIsInstance(result[0], TextContent)
|
|
self.assertIn('❌ Temp file not found', result[0].text)
|
|
|
|
self.async_test(run_test())
|
|
|
|
def test_move_temp_to_output_copy_only(self):
|
|
"""Test copy-only file operation"""
|
|
async def run_test():
|
|
# Create temp directory and file
|
|
temp_dir = self.config.base_path / 'temp'
|
|
temp_dir.mkdir(exist_ok=True)
|
|
temp_file = temp_dir / 'temp_copy.png'
|
|
self.test_image.save(temp_file)
|
|
|
|
arguments = {
|
|
'temp_file_name': 'temp_copy.png',
|
|
'copy_only': True
|
|
}
|
|
|
|
result = await self.handlers.handle_move_temp_to_output(arguments)
|
|
|
|
self.assertEqual(len(result), 1)
|
|
self.assertIsInstance(result[0], TextContent)
|
|
self.assertIn('✅ File copied successfully', result[0].text)
|
|
|
|
# Original file should still exist
|
|
self.assertTrue(temp_file.exists())
|
|
|
|
self.async_test(run_test())
|
|
|
|
def test_seed_management(self):
|
|
"""Test seed creation and reset functionality"""
|
|
# Test seed generation
|
|
seed1 = self.handlers._get_or_create_seed()
|
|
seed2 = self.handlers._get_or_create_seed()
|
|
|
|
# Should return same seed for session
|
|
self.assertEqual(seed1, seed2)
|
|
|
|
# Test seed reset
|
|
self.handlers._reset_seed()
|
|
seed3 = self.handlers._get_or_create_seed()
|
|
|
|
# Should be different after reset
|
|
self.assertNotEqual(seed1, seed3)
|
|
|
|
def test_temp_file_operations(self):
|
|
"""Test temporary file save and move operations"""
|
|
# Test saving b64 to temp
|
|
filename = 'test_temp.png'
|
|
temp_path = self.handlers._save_b64_to_temp_file(self.test_image_b64, filename)
|
|
|
|
self.assertTrue(Path(temp_path).exists())
|
|
self.assertIn(filename, temp_path)
|
|
|
|
# Test moving to generated images
|
|
base_name = 'test_base'
|
|
moved_path = self.handlers._move_temp_to_generated(temp_path, base_name, 1)
|
|
|
|
self.assertTrue(Path(moved_path).exists())
|
|
self.assertIn('test_base_001.png', moved_path)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|