137 lines
4.4 KiB
Python
137 lines
4.4 KiB
Python
"""
|
|
Tests for Imagen4 Server
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import Mock, AsyncMock, patch
|
|
|
|
from src.connector import Config
|
|
from src.server import Imagen4MCPServer, ToolHandlers, MCPToolDefinitions
|
|
from src.server.models import MCPToolDefinitions
|
|
|
|
|
|
class TestMCPToolDefinitions:
|
|
"""MCPToolDefinitions tests"""
|
|
|
|
def test_get_generate_image_tool(self):
|
|
"""Image generation tool definition test"""
|
|
tool = MCPToolDefinitions.get_generate_image_tool()
|
|
|
|
assert tool.name == "generate_image"
|
|
assert "Google Imagen 4" in tool.description
|
|
assert "prompt" in tool.inputSchema["properties"]
|
|
assert "seed" in tool.inputSchema["required"]
|
|
|
|
def test_get_regenerate_from_json_tool(self):
|
|
"""JSON regeneration tool definition test"""
|
|
tool = MCPToolDefinitions.get_regenerate_from_json_tool()
|
|
|
|
assert tool.name == "regenerate_from_json"
|
|
assert "JSON file" in tool.description
|
|
assert "json_file_path" in tool.inputSchema["properties"]
|
|
|
|
def test_get_generate_random_seed_tool(self):
|
|
"""Random seed generation tool definition test"""
|
|
tool = MCPToolDefinitions.get_generate_random_seed_tool()
|
|
|
|
assert tool.name == "generate_random_seed"
|
|
assert "random seed" in tool.description
|
|
|
|
def test_get_all_tools(self):
|
|
"""All tools return test"""
|
|
tools = MCPToolDefinitions.get_all_tools()
|
|
|
|
assert len(tools) == 3
|
|
tool_names = [tool.name for tool in tools]
|
|
assert "generate_image" in tool_names
|
|
assert "regenerate_from_json" in tool_names
|
|
assert "generate_random_seed" in tool_names
|
|
|
|
|
|
class TestToolHandlers:
|
|
"""ToolHandlers tests"""
|
|
|
|
@pytest.fixture
|
|
def config(self):
|
|
"""Test configuration"""
|
|
return Config(project_id="test-project")
|
|
|
|
@pytest.fixture
|
|
def handlers(self, config):
|
|
"""Test handlers"""
|
|
with patch('src.server.handlers.Imagen4Client'):
|
|
return ToolHandlers(config)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_generate_random_seed(self, handlers):
|
|
"""Random seed generation handler test"""
|
|
result = await handlers.handle_generate_random_seed({})
|
|
|
|
assert len(result) == 1
|
|
assert result[0].type == "text"
|
|
assert "Generated random seed:" in result[0].text
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_generate_image_missing_prompt(self, handlers):
|
|
"""Missing prompt test"""
|
|
arguments = {"seed": 12345}
|
|
|
|
result = await handlers.handle_generate_image(arguments)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].type == "text"
|
|
assert "Prompt is required" in result[0].text
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_generate_image_missing_seed(self, handlers):
|
|
"""Missing seed value test"""
|
|
arguments = {"prompt": "Test prompt"}
|
|
|
|
result = await handlers.handle_generate_image(arguments)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].type == "text"
|
|
assert "Seed value is required" in result[0].text
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_regenerate_from_json_missing_path(self, handlers):
|
|
"""Missing JSON file path test"""
|
|
arguments = {}
|
|
|
|
result = await handlers.handle_regenerate_from_json(arguments)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].type == "text"
|
|
assert "JSON file path is required" in result[0].text
|
|
|
|
|
|
class TestImagen4MCPServer:
|
|
"""Imagen4MCPServer tests"""
|
|
|
|
@pytest.fixture
|
|
def config(self):
|
|
"""Test configuration"""
|
|
return Config(project_id="test-project")
|
|
|
|
def test_server_initialization(self, config):
|
|
"""Server initialization test"""
|
|
with patch('src.server.mcp_server.ToolHandlers'):
|
|
server = Imagen4MCPServer(config)
|
|
|
|
assert server.config == config
|
|
assert server.server is not None
|
|
assert server.handlers is not None
|
|
|
|
def test_get_server(self, config):
|
|
"""Server instance return test"""
|
|
with patch('src.server.mcp_server.ToolHandlers'):
|
|
mcp_server = Imagen4MCPServer(config)
|
|
server = mcp_server.get_server()
|
|
|
|
assert server is not None
|
|
assert server.name == "imagen4-mcp-server"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|