Files
imagen4/tests/test_connector.py

179 lines
5.9 KiB
Python

"""
Tests for Imagen4 Connector
"""
import pytest
import asyncio
from unittest.mock import Mock, patch
from src.connector import Config, Imagen4Client
from src.connector.imagen4_client import ImageGenerationRequest
class TestConfig:
"""Config 클래스 테스트"""
def test_config_creation(self):
"""설정 생성 테스트"""
config = Config(
project_id="test-project",
location="us-central1",
output_path="./test_images"
)
assert config.project_id == "test-project"
assert config.location == "us-central1"
assert config.output_path == "./test_images"
def test_config_validation(self):
"""설정 유효성 검사 테스트"""
# 유효한 설정
valid_config = Config(project_id="test-project")
assert valid_config.validate() is True
# 무효한 설정
invalid_config = Config(project_id="")
assert invalid_config.validate() is False
@patch.dict('os.environ', {
'GOOGLE_CLOUD_PROJECT_ID': 'test-project',
'GOOGLE_CLOUD_LOCATION': 'us-west1',
'GENERATED_IMAGES_PATH': './custom_path'
})
def test_config_from_env(self):
"""환경 변수에서 설정 로드 테스트"""
config = Config.from_env()
assert config.project_id == "test-project"
assert config.location == "us-west1"
assert config.output_path == "./custom_path"
@patch.dict('os.environ', {}, clear=True)
def test_config_from_env_missing_project_id(self):
"""필수 환경 변수 누락 테스트"""
with pytest.raises(ValueError, match="GOOGLE_CLOUD_PROJECT_ID environment variable is required"):
Config.from_env()
class TestImageGenerationRequest:
"""ImageGenerationRequest 테스트"""
def test_valid_request(self):
"""유효한 요청 테스트"""
request = ImageGenerationRequest(
prompt="A beautiful landscape",
seed=12345
)
# 유효성 검사가 예외 없이 완료되어야 함
request.validate()
def test_invalid_prompt(self):
"""무효한 프롬프트 테스트"""
request = ImageGenerationRequest(
prompt="",
seed=12345
)
with pytest.raises(ValueError, match="Prompt is required"):
request.validate()
def test_invalid_seed(self):
"""무효한 시드값 테스트"""
request = ImageGenerationRequest(
prompt="Test prompt",
seed=-1
)
with pytest.raises(ValueError, match="Seed value must be an integer between 0 and 4294967295"):
request.validate()
def test_invalid_number_of_images(self):
"""무효한 이미지 개수 테스트"""
request = ImageGenerationRequest(
prompt="Test prompt",
seed=12345,
number_of_images=3
)
with pytest.raises(ValueError, match="Number of images must be 1 or 2 only"):
request.validate()
class TestImagen4Client:
"""Imagen4Client 테스트"""
def test_client_initialization(self):
"""클라이언트 초기화 테스트"""
config = Config(project_id="test-project")
with patch('src.connector.imagen4_client.genai.Client'):
client = Imagen4Client(config)
assert client.config == config
def test_client_invalid_config(self):
"""무효한 설정으로 클라이언트 초기화 테스트"""
invalid_config = Config(project_id="")
with pytest.raises(ValueError, match="Invalid configuration"):
Imagen4Client(invalid_config)
@pytest.mark.asyncio
async def test_generate_image_success(self):
"""이미지 생성 성공 테스트"""
config = Config(project_id="test-project")
# Create mock response
mock_image_data = b"fake_image_data"
mock_response = Mock()
mock_response.generated_images = [Mock()]
mock_response.generated_images[0].image = Mock()
mock_response.generated_images[0].image.image_bytes = mock_image_data
with patch('src.connector.imagen4_client.genai.Client') as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
mock_client.models.generate_images = Mock(return_value=mock_response)
client = Imagen4Client(config)
request = ImageGenerationRequest(
prompt="Test prompt",
seed=12345
)
with patch('asyncio.to_thread', return_value=mock_response):
response = await client.generate_image(request)
assert response.success is True
assert len(response.images_data) == 1
assert response.images_data[0] == mock_image_data
@pytest.mark.asyncio
async def test_generate_image_failure(self):
"""Image generation failure test"""
config = Config(project_id="test-project")
with patch('src.connector.imagen4_client.genai.Client') as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
client = Imagen4Client(config)
request = ImageGenerationRequest(
prompt="Test prompt",
seed=12345
)
# Simulate API call failure
with patch('asyncio.to_thread', side_effect=Exception("API Error")):
response = await client.generate_image(request)
assert response.success is False
assert "API Error" in response.error_message
assert len(response.images_data) == 0
if __name__ == "__main__":
pytest.main([__file__])