309 lines
11 KiB
Python
309 lines
11 KiB
Python
"""Unit tests for FLUX API client"""
|
|
|
|
import unittest
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from pathlib import Path
|
|
|
|
# Add src to path
|
|
import sys
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
|
|
|
from src.connector.flux_client import FluxEditClient, FluxEditRequest, FluxEditResponse
|
|
from src.connector.config import Config
|
|
|
|
|
|
class TestFluxEditClient(unittest.TestCase):
|
|
"""Test cases for FLUX API client"""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures"""
|
|
# Mock config
|
|
self.config = MagicMock(spec=Config)
|
|
self.config.api_key = 'test_api_key'
|
|
self.config.default_timeout = 30
|
|
self.config.polling_interval = 1
|
|
self.config.max_polling_attempts = 5
|
|
self.config.get_api_url.side_effect = lambda endpoint: f"https://api.test.com{endpoint}"
|
|
|
|
self.client = FluxEditClient(self.config)
|
|
|
|
# Sample request
|
|
self.sample_request = FluxEditRequest(
|
|
input_image_b64='test_base64_data',
|
|
prompt='Make the sky blue',
|
|
seed=12345,
|
|
aspect_ratio='16:9'
|
|
)
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures"""
|
|
# Ensure client session is closed
|
|
if hasattr(self.client, 'session') and self.client.session:
|
|
asyncio.create_task(self.client.close())
|
|
|
|
@patch('aiohttp.ClientSession')
|
|
async def test_create_edit_request_success(self, mock_session_class):
|
|
"""Test successful edit request creation"""
|
|
# Mock session and response
|
|
mock_session = AsyncMock()
|
|
mock_response = AsyncMock()
|
|
mock_response.status = 200
|
|
mock_response.json.return_value = {'id': 'request_12345'}
|
|
|
|
mock_session.post.return_value.__aenter__.return_value = mock_response
|
|
mock_session_class.return_value = mock_session
|
|
|
|
# Test
|
|
request_id = await self.client._create_edit_request(self.sample_request)
|
|
|
|
# Verify
|
|
self.assertEqual(request_id, 'request_12345')
|
|
mock_session.post.assert_called_once()
|
|
|
|
# Check payload structure
|
|
call_args = mock_session.post.call_args
|
|
payload = call_args[1]['json']
|
|
self.assertEqual(payload['prompt'], 'Make the sky blue')
|
|
self.assertEqual(payload['seed'], 12345)
|
|
self.assertEqual(payload['input_image'], 'test_base64_data')
|
|
|
|
@patch('aiohttp.ClientSession')
|
|
async def test_create_edit_request_failure(self, mock_session_class):
|
|
"""Test edit request creation failure"""
|
|
# Mock session and response
|
|
mock_session = AsyncMock()
|
|
mock_response = AsyncMock()
|
|
mock_response.status = 400
|
|
mock_response.text.return_value = 'Bad Request'
|
|
|
|
mock_session.post.return_value.__aenter__.return_value = mock_response
|
|
mock_session_class.return_value = mock_session
|
|
|
|
# Test
|
|
request_id = await self.client._create_edit_request(self.sample_request)
|
|
|
|
# Verify
|
|
self.assertIsNone(request_id)
|
|
|
|
@patch('aiohttp.ClientSession')
|
|
async def test_poll_result_success(self, mock_session_class):
|
|
"""Test successful result polling"""
|
|
# Mock session and responses
|
|
mock_session = AsyncMock()
|
|
|
|
# First response: processing
|
|
mock_response_processing = AsyncMock()
|
|
mock_response_processing.status = 200
|
|
mock_response_processing.json.return_value = {'status': 'processing'}
|
|
|
|
# Second response: ready
|
|
mock_response_ready = AsyncMock()
|
|
mock_response_ready.status = 200
|
|
mock_response_ready.json.return_value = {
|
|
'status': 'ready',
|
|
'result': {'sample': 'https://example.com/image.png'}
|
|
}
|
|
|
|
# Mock to return processing first, then ready
|
|
mock_session.get.return_value.__aenter__.side_effect = [
|
|
mock_response_processing,
|
|
mock_response_ready
|
|
]
|
|
mock_session_class.return_value = mock_session
|
|
|
|
# Test
|
|
result = await self.client._poll_result('test_request_id')
|
|
|
|
# Verify
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result['status'], 'ready')
|
|
self.assertIn('result', result)
|
|
self.assertEqual(mock_session.get.call_count, 2)
|
|
|
|
@patch('aiohttp.ClientSession')
|
|
async def test_poll_result_timeout(self, mock_session_class):
|
|
"""Test polling timeout"""
|
|
# Mock session to always return processing
|
|
mock_session = AsyncMock()
|
|
mock_response = AsyncMock()
|
|
mock_response.status = 200
|
|
mock_response.json.return_value = {'status': 'processing'}
|
|
|
|
mock_session.get.return_value.__aenter__.return_value = mock_response
|
|
mock_session_class.return_value = mock_session
|
|
|
|
# Test
|
|
result = await self.client._poll_result('test_request_id')
|
|
|
|
# Verify - should timeout after max attempts
|
|
self.assertIsNone(result)
|
|
self.assertEqual(mock_session.get.call_count, self.config.max_polling_attempts)
|
|
|
|
@patch('aiohttp.ClientSession')
|
|
async def test_download_result_image_success(self, mock_session_class):
|
|
"""Test successful image download"""
|
|
# Mock session and response
|
|
mock_session = AsyncMock()
|
|
mock_response = AsyncMock()
|
|
mock_response.status = 200
|
|
mock_response.read.return_value = b'fake_image_data'
|
|
|
|
mock_session.get.return_value.__aenter__.return_value = mock_response
|
|
mock_session_class.return_value = mock_session
|
|
|
|
# Test
|
|
image_data = await self.client._download_result_image('https://example.com/image.png')
|
|
|
|
# Verify
|
|
self.assertEqual(image_data, b'fake_image_data')
|
|
mock_session.get.assert_called_once_with('https://example.com/image.png')
|
|
|
|
@patch('aiohttp.ClientSession')
|
|
async def test_download_result_image_failure(self, mock_session_class):
|
|
"""Test image download failure"""
|
|
# Mock session and response
|
|
mock_session = AsyncMock()
|
|
mock_response = AsyncMock()
|
|
mock_response.status = 404
|
|
|
|
mock_session.get.return_value.__aenter__.return_value = mock_response
|
|
mock_session_class.return_value = mock_session
|
|
|
|
# Test
|
|
image_data = await self.client._download_result_image('https://example.com/image.png')
|
|
|
|
# Verify
|
|
self.assertIsNone(image_data)
|
|
|
|
def test_get_image_size(self):
|
|
"""Test image size detection from bytes"""
|
|
# Create a small test image in memory
|
|
from PIL import Image
|
|
import io
|
|
|
|
# Create 10x20 test image
|
|
img = Image.new('RGB', (10, 20), color='red')
|
|
buffer = io.BytesIO()
|
|
img.save(buffer, format='PNG')
|
|
image_data = buffer.getvalue()
|
|
|
|
# Test
|
|
size = self.client._get_image_size(image_data)
|
|
|
|
# Verify
|
|
self.assertEqual(size, (10, 20))
|
|
|
|
def test_get_image_size_invalid_data(self):
|
|
"""Test image size detection with invalid data"""
|
|
size = self.client._get_image_size(b'invalid_image_data')
|
|
self.assertIsNone(size)
|
|
|
|
@patch.object(FluxEditClient, '_create_edit_request')
|
|
@patch.object(FluxEditClient, '_poll_result')
|
|
@patch.object(FluxEditClient, '_download_result_image')
|
|
async def test_edit_image_success(self, mock_download, mock_poll, mock_create):
|
|
"""Test complete successful edit flow"""
|
|
# Setup mocks
|
|
mock_create.return_value = 'request_123'
|
|
mock_poll.return_value = {
|
|
'status': 'ready',
|
|
'result': {'sample': 'https://example.com/result.png'}
|
|
}
|
|
mock_download.return_value = b'edited_image_data'
|
|
|
|
# Test
|
|
response = await self.client.edit_image(self.sample_request)
|
|
|
|
# Verify
|
|
self.assertTrue(response.success)
|
|
self.assertEqual(response.edited_image_data, b'edited_image_data')
|
|
self.assertEqual(response.request_id, 'request_123')
|
|
self.assertEqual(response.result_url, 'https://example.com/result.png')
|
|
self.assertGreater(response.execution_time, 0)
|
|
|
|
@patch.object(FluxEditClient, '_create_edit_request')
|
|
async def test_edit_image_create_failure(self, mock_create):
|
|
"""Test edit flow with creation failure"""
|
|
# Setup mock
|
|
mock_create.return_value = None
|
|
|
|
# Test
|
|
response = await self.client.edit_image(self.sample_request)
|
|
|
|
# Verify
|
|
self.assertFalse(response.success)
|
|
self.assertIn('Failed to create edit request', response.error_message)
|
|
|
|
@patch.object(FluxEditClient, '_create_edit_request')
|
|
@patch.object(FluxEditClient, '_poll_result')
|
|
async def test_edit_image_poll_failure(self, mock_poll, mock_create):
|
|
"""Test edit flow with polling failure"""
|
|
# Setup mocks
|
|
mock_create.return_value = 'request_123'
|
|
mock_poll.return_value = None
|
|
|
|
# Test
|
|
response = await self.client.edit_image(self.sample_request)
|
|
|
|
# Verify
|
|
self.assertFalse(response.success)
|
|
self.assertIn('Failed to get edit result', response.error_message)
|
|
|
|
async def test_context_manager(self):
|
|
"""Test async context manager functionality"""
|
|
async with FluxEditClient(self.config) as client:
|
|
self.assertIsInstance(client, FluxEditClient)
|
|
|
|
# Session should be closed after context
|
|
if hasattr(client, 'session'):
|
|
self.assertTrue(client.session is None or client.session.closed)
|
|
|
|
|
|
# Test helper to run async tests
|
|
class AsyncTestCase(unittest.TestCase):
|
|
def setUp(self):
|
|
self.loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self.loop)
|
|
|
|
def tearDown(self):
|
|
self.loop.close()
|
|
|
|
def async_test(self, coro):
|
|
return self.loop.run_until_complete(coro)
|
|
|
|
|
|
class TestFluxEditClientAsync(AsyncTestCase):
|
|
"""Async test runner for FLUX client tests"""
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.config = MagicMock(spec=Config)
|
|
self.config.api_key = 'test_api_key'
|
|
self.config.default_timeout = 30
|
|
self.config.polling_interval = 0.1 # Faster for tests
|
|
self.config.max_polling_attempts = 3
|
|
self.config.get_api_url.side_effect = lambda endpoint: f"https://api.test.com{endpoint}"
|
|
|
|
self.client = FluxEditClient(self.config)
|
|
self.sample_request = FluxEditRequest(
|
|
input_image_b64='test_base64_data',
|
|
prompt='Make the sky blue',
|
|
seed=12345,
|
|
aspect_ratio='16:9'
|
|
)
|
|
|
|
def test_async_context_manager(self):
|
|
"""Test async context manager"""
|
|
async def run_test():
|
|
async with FluxEditClient(self.config) as client:
|
|
self.assertIsInstance(client, FluxEditClient)
|
|
return True
|
|
|
|
result = self.async_test(run_test())
|
|
self.assertTrue(result)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|