179 lines
5.9 KiB
Python
179 lines
5.9 KiB
Python
"""
|
|
Tests for Imagen4 Connector
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import Mock, patch
|
|
|
|
from src.connector import Config, Imagen4Client
|
|
from src.connector.imagen4_client import ImageGenerationRequest
|
|
|
|
|
|
class TestConfig:
|
|
"""Config 클래스 테스트"""
|
|
|
|
def test_config_creation(self):
|
|
"""설정 생성 테스트"""
|
|
config = Config(
|
|
project_id="test-project",
|
|
location="us-central1",
|
|
output_path="./test_images"
|
|
)
|
|
|
|
assert config.project_id == "test-project"
|
|
assert config.location == "us-central1"
|
|
assert config.output_path == "./test_images"
|
|
|
|
def test_config_validation(self):
|
|
"""설정 유효성 검사 테스트"""
|
|
# 유효한 설정
|
|
valid_config = Config(project_id="test-project")
|
|
assert valid_config.validate() is True
|
|
|
|
# 무효한 설정
|
|
invalid_config = Config(project_id="")
|
|
assert invalid_config.validate() is False
|
|
|
|
@patch.dict('os.environ', {
|
|
'GOOGLE_CLOUD_PROJECT_ID': 'test-project',
|
|
'GOOGLE_CLOUD_LOCATION': 'us-west1',
|
|
'GENERATED_IMAGES_PATH': './custom_path'
|
|
})
|
|
def test_config_from_env(self):
|
|
"""환경 변수에서 설정 로드 테스트"""
|
|
config = Config.from_env()
|
|
|
|
assert config.project_id == "test-project"
|
|
assert config.location == "us-west1"
|
|
assert config.output_path == "./custom_path"
|
|
|
|
@patch.dict('os.environ', {}, clear=True)
|
|
def test_config_from_env_missing_project_id(self):
|
|
"""필수 환경 변수 누락 테스트"""
|
|
with pytest.raises(ValueError, match="GOOGLE_CLOUD_PROJECT_ID environment variable is required"):
|
|
Config.from_env()
|
|
|
|
|
|
class TestImageGenerationRequest:
|
|
"""ImageGenerationRequest 테스트"""
|
|
|
|
def test_valid_request(self):
|
|
"""유효한 요청 테스트"""
|
|
request = ImageGenerationRequest(
|
|
prompt="A beautiful landscape",
|
|
seed=12345
|
|
)
|
|
|
|
# 유효성 검사가 예외 없이 완료되어야 함
|
|
request.validate()
|
|
|
|
def test_invalid_prompt(self):
|
|
"""무효한 프롬프트 테스트"""
|
|
request = ImageGenerationRequest(
|
|
prompt="",
|
|
seed=12345
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Prompt is required"):
|
|
request.validate()
|
|
|
|
def test_invalid_seed(self):
|
|
"""무효한 시드값 테스트"""
|
|
request = ImageGenerationRequest(
|
|
prompt="Test prompt",
|
|
seed=-1
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Seed value must be an integer between 0 and 4294967295"):
|
|
request.validate()
|
|
|
|
def test_invalid_number_of_images(self):
|
|
"""무효한 이미지 개수 테스트"""
|
|
request = ImageGenerationRequest(
|
|
prompt="Test prompt",
|
|
seed=12345,
|
|
number_of_images=3
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Number of images must be 1 or 2 only"):
|
|
request.validate()
|
|
|
|
|
|
class TestImagen4Client:
|
|
"""Imagen4Client 테스트"""
|
|
|
|
def test_client_initialization(self):
|
|
"""클라이언트 초기화 테스트"""
|
|
config = Config(project_id="test-project")
|
|
|
|
with patch('src.connector.imagen4_client.genai.Client'):
|
|
client = Imagen4Client(config)
|
|
assert client.config == config
|
|
|
|
def test_client_invalid_config(self):
|
|
"""무효한 설정으로 클라이언트 초기화 테스트"""
|
|
invalid_config = Config(project_id="")
|
|
|
|
with pytest.raises(ValueError, match="Invalid configuration"):
|
|
Imagen4Client(invalid_config)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_image_success(self):
|
|
"""이미지 생성 성공 테스트"""
|
|
config = Config(project_id="test-project")
|
|
|
|
# Create mock response
|
|
mock_image_data = b"fake_image_data"
|
|
mock_response = Mock()
|
|
mock_response.generated_images = [Mock()]
|
|
mock_response.generated_images[0].image = Mock()
|
|
mock_response.generated_images[0].image.image_bytes = mock_image_data
|
|
|
|
with patch('src.connector.imagen4_client.genai.Client') as mock_client_class:
|
|
mock_client = Mock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.models.generate_images = Mock(return_value=mock_response)
|
|
|
|
client = Imagen4Client(config)
|
|
|
|
request = ImageGenerationRequest(
|
|
prompt="Test prompt",
|
|
seed=12345
|
|
)
|
|
|
|
with patch('asyncio.to_thread', return_value=mock_response):
|
|
response = await client.generate_image(request)
|
|
|
|
assert response.success is True
|
|
assert len(response.images_data) == 1
|
|
assert response.images_data[0] == mock_image_data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_image_failure(self):
|
|
"""Image generation failure test"""
|
|
config = Config(project_id="test-project")
|
|
|
|
with patch('src.connector.imagen4_client.genai.Client') as mock_client_class:
|
|
mock_client = Mock()
|
|
mock_client_class.return_value = mock_client
|
|
|
|
client = Imagen4Client(config)
|
|
|
|
request = ImageGenerationRequest(
|
|
prompt="Test prompt",
|
|
seed=12345
|
|
)
|
|
|
|
# Simulate API call failure
|
|
with patch('asyncio.to_thread', side_effect=Exception("API Error")):
|
|
response = await client.generate_image(request)
|
|
|
|
assert response.success is False
|
|
assert "API Error" in response.error_message
|
|
assert len(response.images_data) == 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|