Files
flux1-edit/tests/test_handlers.py
2025-08-26 02:35:44 +09:00

361 lines
13 KiB
Python

"""Unit tests for MCP tool handlers"""
import unittest
import tempfile
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from PIL import Image
import io
# Add src to path
import sys
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
from src.server.handlers import ToolHandlers
from src.connector.config import Config
from src.connector.flux_client import FluxEditResponse
from mcp.types import TextContent, ImageContent
class AsyncTestCase(unittest.TestCase):
"""Base class for async tests"""
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
def async_test(self, coro):
return self.loop.run_until_complete(coro)
class TestToolHandlers(AsyncTestCase):
"""Test cases for MCP tool handlers"""
def setUp(self):
super().setUp()
# Create temporary directory
self.temp_dir = tempfile.mkdtemp()
self.temp_path = Path(self.temp_dir)
# Mock config
self.config = MagicMock(spec=Config)
self.config.base_path = self.temp_path
self.config.input_path = self.temp_path / 'input_images'
self.config.generated_images_path = self.temp_path / 'generated_images'
self.config.max_image_size_mb = 20
self.config.default_aspect_ratio = '1:1'
self.config.safety_tolerance = 2
self.config.OUTPUT_FORMAT = 'png'
self.config.prompt_upsampling = False
self.config.MODEL_NAME = 'flux-kontext-pro'
self.config.save_parameters = True
# Setup directories
self.config.input_path.mkdir(parents=True, exist_ok=True)
self.config.generated_images_path.mkdir(parents=True, exist_ok=True)
# Mock config methods
self.config.ensure_output_directory.return_value = None
self.config.generate_base_name.return_value = 'fluxedit_12345_20250826_143022'
self.config.generate_filename.side_effect = lambda base, num, ext: f'{base}_{num:03d}.{ext}'
self.config.get_output_path.side_effect = lambda base, num, ext: self.config.generated_images_path / f'{base}_{num:03d}.{ext}'
# Create test image
self.test_image = Image.new('RGB', (100, 100), color='blue')
buffer = io.BytesIO()
self.test_image.save(buffer, format='PNG')
self.test_image_data = buffer.getvalue()
self.test_image_b64 = self._encode_image_b64(self.test_image_data)
# Create test image file
self.test_image_file = self.config.input_path / 'test.png'
self.test_image.save(self.test_image_file)
# Initialize handlers
self.handlers = ToolHandlers(self.config)
def tearDown(self):
"""Clean up test fixtures"""
import shutil
super().tearDown()
if self.temp_path.exists():
shutil.rmtree(self.temp_path)
def _encode_image_b64(self, image_data: bytes) -> str:
"""Helper to encode image as base64"""
import base64
return base64.b64encode(image_data).decode('utf-8')
def _create_mock_flux_response(self, success: bool = True) -> FluxEditResponse:
"""Helper to create mock FLUX response"""
if success:
return FluxEditResponse(
success=True,
edited_image_data=self.test_image_data,
image_size=(100, 100),
execution_time=5.5,
request_id='test_request_123',
result_url='https://example.com/result.png',
metadata={'seed': 12345}
)
else:
return FluxEditResponse(
success=False,
error_message='FLUX edit failed',
execution_time=2.0
)
def test_flux_edit_image_parameter_validation_failure(self):
"""Test flux_edit_image with invalid parameters"""
async def run_test():
# Missing required parameter
arguments = {
'prompt': 'Make it blue',
'seed': 12345
# Missing input_image_b64
}
result = await self.handlers.handle_flux_edit_image(arguments)
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], TextContent)
self.assertIn('Parameter validation failed', result[0].text)
self.assertIn('input_image_b64 is required', result[0].text)
self.async_test(run_test())
@patch('src.server.handlers.FluxEditClient')
def test_flux_edit_image_success(self, mock_client_class):
"""Test successful flux_edit_image"""
async def run_test():
# Setup mock client
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
mock_client.edit_image.return_value = self._create_mock_flux_response(success=True)
# Valid arguments
arguments = {
'input_image_b64': self.test_image_b64,
'prompt': 'Make the sky blue',
'seed': 12345,
'aspect_ratio': '16:9',
'save_to_file': True
}
result = await self.handlers.handle_flux_edit_image(arguments)
# Verify result structure
self.assertGreater(len(result), 0)
self.assertIsInstance(result[0], TextContent)
self.assertIn('✅ Image edited successfully', result[0].text)
self.assertIn('Seed: 12345', result[0].text)
# Should have image preview
if len(result) > 1:
self.assertIsInstance(result[1], ImageContent)
self.async_test(run_test())
@patch('src.server.handlers.FluxEditClient')
def test_flux_edit_image_flux_failure(self, mock_client_class):
"""Test flux_edit_image with FLUX API failure"""
async def run_test():
# Setup mock client to return failure
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
mock_client.edit_image.return_value = self._create_mock_flux_response(success=False)
arguments = {
'input_image_b64': self.test_image_b64,
'prompt': 'Make it blue',
'seed': 12345
}
result = await self.handlers.handle_flux_edit_image(arguments)
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], TextContent)
self.assertIn('❌ FLUX edit failed', result[0].text)
self.async_test(run_test())
def test_flux_edit_image_from_file_not_found(self):
"""Test flux_edit_image_from_file with non-existent file"""
async def run_test():
arguments = {
'input_image_name': 'nonexistent.png',
'prompt': 'Edit this',
'seed': 12345
}
result = await self.handlers.handle_flux_edit_image_from_file(arguments)
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], TextContent)
self.assertIn('❌ File not found', result[0].text)
self.assertIn('nonexistent.png', result[0].text)
self.async_test(run_test())
@patch('src.server.handlers.FluxEditClient')
def test_flux_edit_image_from_file_success(self, mock_client_class):
"""Test successful flux_edit_image_from_file"""
async def run_test():
# Setup mock client
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
mock_client.edit_image.return_value = self._create_mock_flux_response(success=True)
arguments = {
'input_image_name': 'test.png',
'prompt': 'Make it awesome',
'seed': 54321,
'save_to_file': True
}
result = await self.handlers.handle_flux_edit_image_from_file(arguments)
# Verify result
self.assertGreater(len(result), 0)
self.assertIsInstance(result[0], TextContent)
self.assertIn('✅ Image edited successfully from file', result[0].text)
self.assertIn('test.png', result[0].text)
self.assertIn('Seed: 54321', result[0].text)
self.async_test(run_test())
def test_validate_image_success(self):
"""Test successful image validation"""
async def run_test():
arguments = {'image_path': str(self.test_image_file)}
result = await self.handlers.handle_validate_image(arguments)
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], TextContent)
self.assertIn('✅ Image validation passed', result[0].text)
self.assertIn('100x100', result[0].text)
self.async_test(run_test())
def test_validate_image_not_found(self):
"""Test image validation with non-existent file"""
async def run_test():
arguments = {'image_path': '/nonexistent/path.png'}
result = await self.handlers.handle_validate_image(arguments)
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], TextContent)
self.assertIn('❌ Image validation failed', result[0].text)
self.assertIn('File not found', result[0].text)
self.async_test(run_test())
def test_move_temp_to_output_success(self):
"""Test successful file move operation"""
async def run_test():
# Create temp directory and file
temp_dir = self.config.base_path / 'temp'
temp_dir.mkdir(exist_ok=True)
temp_file = temp_dir / 'temp_test.png'
self.test_image.save(temp_file)
arguments = {
'temp_file_name': 'temp_test.png',
'output_file_name': 'moved_test.png',
'copy_only': False
}
result = await self.handlers.handle_move_temp_to_output(arguments)
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], TextContent)
self.assertIn('✅ File moved successfully', result[0].text)
self.assertIn('temp_test.png', result[0].text)
self.assertIn('moved_test.png', result[0].text)
self.async_test(run_test())
def test_move_temp_to_output_not_found(self):
"""Test file move with non-existent temp file"""
async def run_test():
arguments = {
'temp_file_name': 'nonexistent.png',
'copy_only': False
}
result = await self.handlers.handle_move_temp_to_output(arguments)
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], TextContent)
self.assertIn('❌ Temp file not found', result[0].text)
self.async_test(run_test())
def test_move_temp_to_output_copy_only(self):
"""Test copy-only file operation"""
async def run_test():
# Create temp directory and file
temp_dir = self.config.base_path / 'temp'
temp_dir.mkdir(exist_ok=True)
temp_file = temp_dir / 'temp_copy.png'
self.test_image.save(temp_file)
arguments = {
'temp_file_name': 'temp_copy.png',
'copy_only': True
}
result = await self.handlers.handle_move_temp_to_output(arguments)
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], TextContent)
self.assertIn('✅ File copied successfully', result[0].text)
# Original file should still exist
self.assertTrue(temp_file.exists())
self.async_test(run_test())
def test_seed_management(self):
"""Test seed creation and reset functionality"""
# Test seed generation
seed1 = self.handlers._get_or_create_seed()
seed2 = self.handlers._get_or_create_seed()
# Should return same seed for session
self.assertEqual(seed1, seed2)
# Test seed reset
self.handlers._reset_seed()
seed3 = self.handlers._get_or_create_seed()
# Should be different after reset
self.assertNotEqual(seed1, seed3)
def test_temp_file_operations(self):
"""Test temporary file save and move operations"""
# Test saving b64 to temp
filename = 'test_temp.png'
temp_path = self.handlers._save_b64_to_temp_file(self.test_image_b64, filename)
self.assertTrue(Path(temp_path).exists())
self.assertIn(filename, temp_path)
# Test moving to generated images
base_name = 'test_base'
moved_path = self.handlers._move_temp_to_generated(temp_path, base_name, 1)
self.assertTrue(Path(moved_path).exists())
self.assertIn('test_base_001.png', moved_path)
if __name__ == '__main__':
unittest.main()