Files
flux1-edit/tests/test_flux_client.py
2025-08-26 02:35:44 +09:00

309 lines
11 KiB
Python

"""Unit tests for FLUX API client"""
import unittest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from pathlib import Path
# Add src to path
import sys
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
from src.connector.flux_client import FluxEditClient, FluxEditRequest, FluxEditResponse
from src.connector.config import Config
class TestFluxEditClient(unittest.TestCase):
"""Test cases for FLUX API client"""
def setUp(self):
"""Set up test fixtures"""
# Mock config
self.config = MagicMock(spec=Config)
self.config.api_key = 'test_api_key'
self.config.default_timeout = 30
self.config.polling_interval = 1
self.config.max_polling_attempts = 5
self.config.get_api_url.side_effect = lambda endpoint: f"https://api.test.com{endpoint}"
self.client = FluxEditClient(self.config)
# Sample request
self.sample_request = FluxEditRequest(
input_image_b64='test_base64_data',
prompt='Make the sky blue',
seed=12345,
aspect_ratio='16:9'
)
def tearDown(self):
"""Clean up test fixtures"""
# Ensure client session is closed
if hasattr(self.client, 'session') and self.client.session:
asyncio.create_task(self.client.close())
@patch('aiohttp.ClientSession')
async def test_create_edit_request_success(self, mock_session_class):
"""Test successful edit request creation"""
# Mock session and response
mock_session = AsyncMock()
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json.return_value = {'id': 'request_12345'}
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session_class.return_value = mock_session
# Test
request_id = await self.client._create_edit_request(self.sample_request)
# Verify
self.assertEqual(request_id, 'request_12345')
mock_session.post.assert_called_once()
# Check payload structure
call_args = mock_session.post.call_args
payload = call_args[1]['json']
self.assertEqual(payload['prompt'], 'Make the sky blue')
self.assertEqual(payload['seed'], 12345)
self.assertEqual(payload['input_image'], 'test_base64_data')
@patch('aiohttp.ClientSession')
async def test_create_edit_request_failure(self, mock_session_class):
"""Test edit request creation failure"""
# Mock session and response
mock_session = AsyncMock()
mock_response = AsyncMock()
mock_response.status = 400
mock_response.text.return_value = 'Bad Request'
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session_class.return_value = mock_session
# Test
request_id = await self.client._create_edit_request(self.sample_request)
# Verify
self.assertIsNone(request_id)
@patch('aiohttp.ClientSession')
async def test_poll_result_success(self, mock_session_class):
"""Test successful result polling"""
# Mock session and responses
mock_session = AsyncMock()
# First response: processing
mock_response_processing = AsyncMock()
mock_response_processing.status = 200
mock_response_processing.json.return_value = {'status': 'processing'}
# Second response: ready
mock_response_ready = AsyncMock()
mock_response_ready.status = 200
mock_response_ready.json.return_value = {
'status': 'ready',
'result': {'sample': 'https://example.com/image.png'}
}
# Mock to return processing first, then ready
mock_session.get.return_value.__aenter__.side_effect = [
mock_response_processing,
mock_response_ready
]
mock_session_class.return_value = mock_session
# Test
result = await self.client._poll_result('test_request_id')
# Verify
self.assertIsNotNone(result)
self.assertEqual(result['status'], 'ready')
self.assertIn('result', result)
self.assertEqual(mock_session.get.call_count, 2)
@patch('aiohttp.ClientSession')
async def test_poll_result_timeout(self, mock_session_class):
"""Test polling timeout"""
# Mock session to always return processing
mock_session = AsyncMock()
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json.return_value = {'status': 'processing'}
mock_session.get.return_value.__aenter__.return_value = mock_response
mock_session_class.return_value = mock_session
# Test
result = await self.client._poll_result('test_request_id')
# Verify - should timeout after max attempts
self.assertIsNone(result)
self.assertEqual(mock_session.get.call_count, self.config.max_polling_attempts)
@patch('aiohttp.ClientSession')
async def test_download_result_image_success(self, mock_session_class):
"""Test successful image download"""
# Mock session and response
mock_session = AsyncMock()
mock_response = AsyncMock()
mock_response.status = 200
mock_response.read.return_value = b'fake_image_data'
mock_session.get.return_value.__aenter__.return_value = mock_response
mock_session_class.return_value = mock_session
# Test
image_data = await self.client._download_result_image('https://example.com/image.png')
# Verify
self.assertEqual(image_data, b'fake_image_data')
mock_session.get.assert_called_once_with('https://example.com/image.png')
@patch('aiohttp.ClientSession')
async def test_download_result_image_failure(self, mock_session_class):
"""Test image download failure"""
# Mock session and response
mock_session = AsyncMock()
mock_response = AsyncMock()
mock_response.status = 404
mock_session.get.return_value.__aenter__.return_value = mock_response
mock_session_class.return_value = mock_session
# Test
image_data = await self.client._download_result_image('https://example.com/image.png')
# Verify
self.assertIsNone(image_data)
def test_get_image_size(self):
"""Test image size detection from bytes"""
# Create a small test image in memory
from PIL import Image
import io
# Create 10x20 test image
img = Image.new('RGB', (10, 20), color='red')
buffer = io.BytesIO()
img.save(buffer, format='PNG')
image_data = buffer.getvalue()
# Test
size = self.client._get_image_size(image_data)
# Verify
self.assertEqual(size, (10, 20))
def test_get_image_size_invalid_data(self):
"""Test image size detection with invalid data"""
size = self.client._get_image_size(b'invalid_image_data')
self.assertIsNone(size)
@patch.object(FluxEditClient, '_create_edit_request')
@patch.object(FluxEditClient, '_poll_result')
@patch.object(FluxEditClient, '_download_result_image')
async def test_edit_image_success(self, mock_download, mock_poll, mock_create):
"""Test complete successful edit flow"""
# Setup mocks
mock_create.return_value = 'request_123'
mock_poll.return_value = {
'status': 'ready',
'result': {'sample': 'https://example.com/result.png'}
}
mock_download.return_value = b'edited_image_data'
# Test
response = await self.client.edit_image(self.sample_request)
# Verify
self.assertTrue(response.success)
self.assertEqual(response.edited_image_data, b'edited_image_data')
self.assertEqual(response.request_id, 'request_123')
self.assertEqual(response.result_url, 'https://example.com/result.png')
self.assertGreater(response.execution_time, 0)
@patch.object(FluxEditClient, '_create_edit_request')
async def test_edit_image_create_failure(self, mock_create):
"""Test edit flow with creation failure"""
# Setup mock
mock_create.return_value = None
# Test
response = await self.client.edit_image(self.sample_request)
# Verify
self.assertFalse(response.success)
self.assertIn('Failed to create edit request', response.error_message)
@patch.object(FluxEditClient, '_create_edit_request')
@patch.object(FluxEditClient, '_poll_result')
async def test_edit_image_poll_failure(self, mock_poll, mock_create):
"""Test edit flow with polling failure"""
# Setup mocks
mock_create.return_value = 'request_123'
mock_poll.return_value = None
# Test
response = await self.client.edit_image(self.sample_request)
# Verify
self.assertFalse(response.success)
self.assertIn('Failed to get edit result', response.error_message)
async def test_context_manager(self):
"""Test async context manager functionality"""
async with FluxEditClient(self.config) as client:
self.assertIsInstance(client, FluxEditClient)
# Session should be closed after context
if hasattr(client, 'session'):
self.assertTrue(client.session is None or client.session.closed)
# Test helper to run async tests
class AsyncTestCase(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
def async_test(self, coro):
return self.loop.run_until_complete(coro)
class TestFluxEditClientAsync(AsyncTestCase):
"""Async test runner for FLUX client tests"""
def setUp(self):
super().setUp()
self.config = MagicMock(spec=Config)
self.config.api_key = 'test_api_key'
self.config.default_timeout = 30
self.config.polling_interval = 0.1 # Faster for tests
self.config.max_polling_attempts = 3
self.config.get_api_url.side_effect = lambda endpoint: f"https://api.test.com{endpoint}"
self.client = FluxEditClient(self.config)
self.sample_request = FluxEditRequest(
input_image_b64='test_base64_data',
prompt='Make the sky blue',
seed=12345,
aspect_ratio='16:9'
)
def test_async_context_manager(self):
"""Test async context manager"""
async def run_test():
async with FluxEditClient(self.config) as client:
self.assertIsInstance(client, FluxEditClient)
return True
result = self.async_test(run_test())
self.assertTrue(result)
if __name__ == '__main__':
unittest.main()