imagen4 mcp server, mcp connector implementation
This commit is contained in:
13
.env.example
Normal file
13
.env.example
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Google Cloud 설정
|
||||||
|
GOOGLE_CLOUD_PROJECT_ID=your-project-id
|
||||||
|
GOOGLE_CLOUD_LOCATION=us-central1
|
||||||
|
|
||||||
|
# 이미지 저장 설정
|
||||||
|
GENERATED_IMAGES_PATH=./generated_images
|
||||||
|
|
||||||
|
# Imagen4 모델 설정
|
||||||
|
# IMAGEN4_DEFAULT_MODEL=imagen-4.0-generate-001 # 기본 모델
|
||||||
|
# IMAGEN4_DEFAULT_MODEL=imagen-4.0-ultra-generate-001 # Ultra 모델 (더 높은 품질)
|
||||||
|
|
||||||
|
# 선택적 설정
|
||||||
|
# LOG_LEVEL=INFO
|
||||||
119
.gitignore
vendored
Normal file
119
.gitignore
vendored
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
# 바이트 컴파일된 파일 / 최적화된 파일 / DLL 파일
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C 확장
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# 배포
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# 단위 테스트 / 커버리지 보고서
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# 환경
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder IDE
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope 프로젝트
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# Google Cloud 인증 키
|
||||||
|
*.json
|
||||||
|
!claude_desktop_config.json
|
||||||
|
|
||||||
|
# 로그 파일
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# 생성된 이미지
|
||||||
|
generated_images/
|
||||||
|
*.png
|
||||||
|
*.jpg
|
||||||
|
*.jpeg
|
||||||
|
|
||||||
|
# VS Code
|
||||||
|
.vscode/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# 임시 파일들
|
||||||
|
temp_*
|
||||||
|
cleanup*
|
||||||
204
README.md
Normal file
204
README.md
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
# Imagen4 MCP Server with Preview Images
|
||||||
|
|
||||||
|
Google Imagen 4를 사용한 AI 이미지 생성 MCP 서버입니다. 하나의 `main.py` 파일로 모든 기능이 통합되어 있습니다.
|
||||||
|
|
||||||
|
## 🚀 주요 기능
|
||||||
|
|
||||||
|
- **고품질 이미지 생성**: 2048x2048 PNG 이미지 생성
|
||||||
|
- **미리보기 이미지**: 512x512 JPEG 미리보기를 base64로 제공
|
||||||
|
- **파일 저장**: 원본 PNG + 메타데이터 JSON 저장
|
||||||
|
- **재생성**: JSON 파일로부터 동일한 설정으로 재생성
|
||||||
|
- **랜덤 시드**: 재현 가능한 결과를 위한 시드 생성
|
||||||
|
|
||||||
|
## 📦 설치
|
||||||
|
|
||||||
|
### 1. 의존성 설치
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
필요한 패키지:
|
||||||
|
- `google-genai>=0.2.0` - Google Imagen 4 API
|
||||||
|
- `mcp>=1.0.0` - MCP 프로토콜
|
||||||
|
- `python-dotenv>=1.0.0` - 환경변수 관리
|
||||||
|
- `Pillow>=10.0.0` - 이미지 처리
|
||||||
|
|
||||||
|
### 2. 환경 설정
|
||||||
|
|
||||||
|
`.env` 파일을 생성하고 Google Cloud 자격 증명을 설정:
|
||||||
|
|
||||||
|
```env
|
||||||
|
GOOGLE_APPLICATION_CREDENTIALS=path/to/your/service-account-key.json
|
||||||
|
PROJECT_ID=your-gcp-project-id
|
||||||
|
LOCATION=us-central1
|
||||||
|
OUTPUT_PATH=./generated_images
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🎮 사용법
|
||||||
|
|
||||||
|
### 서버 실행
|
||||||
|
|
||||||
|
#### Windows
|
||||||
|
```bash
|
||||||
|
run.bat
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 직접 실행
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Claude Desktop 설정
|
||||||
|
|
||||||
|
`claude_desktop_config.json` 내용을 Claude Desktop 설정에 추가:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"imagen4": {
|
||||||
|
"command": "python",
|
||||||
|
"args": ["main.py"],
|
||||||
|
"cwd": "D:\\Project\\little-fairy\\imagen4"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🛠️ 사용 가능한 도구
|
||||||
|
|
||||||
|
### 1. generate_image
|
||||||
|
고품질 이미지를 생성하고 미리보기를 제공합니다.
|
||||||
|
|
||||||
|
**파라미터:**
|
||||||
|
- `prompt` (필수): 이미지 생성 프롬프트
|
||||||
|
- `seed` (필수): 재현 가능한 결과를 위한 시드 값
|
||||||
|
- `negative_prompt` (선택): 제외할 요소 지정
|
||||||
|
- `number_of_images` (선택): 생성할 이미지 수 (1-2)
|
||||||
|
- `aspect_ratio` (선택): 종횡비 ("1:1", "9:16", "16:9", "3:4", "4:3")
|
||||||
|
- `save_to_file` (선택): 파일 저장 여부
|
||||||
|
|
||||||
|
**응답 예시:**
|
||||||
|
```
|
||||||
|
✅ Images have been successfully generated!
|
||||||
|
|
||||||
|
🖼️ Preview Images Generated: 1 images (512x512 JPEG)
|
||||||
|
Preview 1 (base64 JPEG): /9j/4AAQSkZJRgABAQAAAQABAAD/2wBD...(67234 chars)
|
||||||
|
|
||||||
|
📁 Files saved:
|
||||||
|
- ./generated_images/imagen4_20250821_143052_seed_1234567890.png
|
||||||
|
- ./generated_images/imagen4_20250821_143052_seed_1234567890.json
|
||||||
|
|
||||||
|
⚙️ Generation Parameters:
|
||||||
|
- prompt: A beautiful sunset over mountains
|
||||||
|
- seed: 1234567890
|
||||||
|
- aspect_ratio: 1:1
|
||||||
|
- number_of_images: 1
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. generate_random_seed
|
||||||
|
이미지 생성용 랜덤 시드를 생성합니다.
|
||||||
|
|
||||||
|
### 3. regenerate_from_json
|
||||||
|
저장된 JSON 파라미터 파일로부터 이미지를 재생성합니다.
|
||||||
|
|
||||||
|
## 🖼️ 미리보기 이미지 특징
|
||||||
|
|
||||||
|
### 변환 과정
|
||||||
|
1. **원본**: 2048x2048 PNG (약 8-15MB)
|
||||||
|
2. **변환**: 512x512 JPEG (약 50-200KB)
|
||||||
|
3. **최적화**: 품질 85, 최적화 활성화
|
||||||
|
4. **인코딩**: Base64 문자열
|
||||||
|
|
||||||
|
### 장점
|
||||||
|
- **95% 크기 감소**: 빠른 전송 및 표시
|
||||||
|
- **즉시 사용**: 웹 환경에서 바로 표시 가능
|
||||||
|
- **투명도 처리**: PNG 투명도를 흰색 배경으로 변환
|
||||||
|
- **종횡비 유지**: 원본 비율을 유지하면서 크기 조정
|
||||||
|
|
||||||
|
## 🧪 테스트
|
||||||
|
|
||||||
|
기능을 테스트하려면:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python test_main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
이 스크립트는 다음을 확인합니다:
|
||||||
|
- 모든 클래스 및 함수 import
|
||||||
|
- 이미지 처리 기능
|
||||||
|
- 미리보기 생성 기능
|
||||||
|
|
||||||
|
## 📁 프로젝트 구조
|
||||||
|
|
||||||
|
```
|
||||||
|
imagen4/
|
||||||
|
├── main.py # 통합된 메인 서버 파일
|
||||||
|
├── run.bat # Windows 실행 파일
|
||||||
|
├── test_main.py # 테스트 스크립트
|
||||||
|
├── claude_desktop_config.json # Claude Desktop 설정
|
||||||
|
├── requirements.txt # 의존성 목록
|
||||||
|
├── .env # 환경 변수 (사용자가 생성)
|
||||||
|
├── generated_images/ # 생성된 이미지 저장소
|
||||||
|
└── src/
|
||||||
|
└── connector/ # Google API 연결 모듈
|
||||||
|
├── __init__.py
|
||||||
|
├── config.py
|
||||||
|
├── imagen4_client.py
|
||||||
|
└── utils.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔧 기술적 세부사항
|
||||||
|
|
||||||
|
### 이미지 처리
|
||||||
|
- **PIL (Pillow)** 사용으로 안정적인 이미지 변환
|
||||||
|
- **LANCZOS 리샘플링**으로 고품질 크기 조정
|
||||||
|
- **투명도 자동 처리**: RGBA → RGB 변환 시 흰색 배경 적용
|
||||||
|
- **중앙 정렬**: 목표 크기보다 작은 이미지는 중앙에 배치
|
||||||
|
|
||||||
|
### MCP 프로토콜
|
||||||
|
- **버전 3.0.0**: 미리보기 이미지 지원
|
||||||
|
- **비동기 처리**: asyncio 기반 안정적인 비동기 작업
|
||||||
|
- **오류 처리**: 상세한 로깅 및 사용자 친화적 오류 메시지
|
||||||
|
- **타임아웃**: 6분 타임아웃으로 안전한 작업 보장
|
||||||
|
|
||||||
|
## 🐛 문제 해결
|
||||||
|
|
||||||
|
### 일반적인 문제
|
||||||
|
|
||||||
|
1. **Pillow 설치 오류**
|
||||||
|
```bash
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install Pillow
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Google Cloud 인증 오류**
|
||||||
|
- 서비스 계정 키 파일 경로 확인
|
||||||
|
- PROJECT_ID가 올바른지 확인
|
||||||
|
- Imagen API가 활성화되어 있는지 확인
|
||||||
|
|
||||||
|
3. **메모리 부족**
|
||||||
|
- 이미지 수를 1개로 제한
|
||||||
|
- 시스템 메모리 확인 (최소 8GB 권장)
|
||||||
|
|
||||||
|
### 로그 확인
|
||||||
|
|
||||||
|
서버 실행 시 자세한 로그가 stderr로 출력됩니다:
|
||||||
|
- 이미지 생성 진행 상황
|
||||||
|
- 미리보기 생성 과정
|
||||||
|
- 파일 저장 상태
|
||||||
|
- 오류 및 경고 메시지
|
||||||
|
|
||||||
|
## 📄 라이센스
|
||||||
|
|
||||||
|
MIT License - 자세한 내용은 LICENSE 파일을 참조하세요.
|
||||||
|
|
||||||
|
## 🎯 요약
|
||||||
|
|
||||||
|
이 통합된 `main.py`는 다음과 같은 이점을 제공합니다:
|
||||||
|
|
||||||
|
- **단순성**: 하나의 파일로 모든 기능 제공
|
||||||
|
- **효율성**: 512x512 JPEG 미리보기로 빠른 응답
|
||||||
|
- **호환성**: 기존 MCP 클라이언트와 완벽 호환
|
||||||
|
- **확장성**: 필요에 따라 쉽게 기능 추가 가능
|
||||||
|
|
||||||
|
Google Imagen 4의 강력한 이미지 생성 기능을 MCP 프로토콜을 통해 효율적으로 활용할 수 있습니다!
|
||||||
12
claude_desktop_config.json
Normal file
12
claude_desktop_config.json
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"imagen4": {
|
||||||
|
"command": "python",
|
||||||
|
"args": ["main.py"],
|
||||||
|
"cwd": "D:\\Project\\imagen4",
|
||||||
|
"env": {
|
||||||
|
"PYTHONPATH": "D:\\Project\\imagen4"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
706
main.py
Normal file
706
main.py
Normal file
@@ -0,0 +1,706 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Imagen4 MCP Server with Preview Image Support
|
||||||
|
|
||||||
|
Google Imagen 4를 사용한 AI 이미지 생성 MCP 서버
|
||||||
|
- 2048x2048 PNG 원본 이미지 생성
|
||||||
|
- 512x512 JPEG 미리보기 이미지 base64 제공
|
||||||
|
- 향상된 응답 형식
|
||||||
|
|
||||||
|
Run from imagen4 root directory
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
except ImportError:
|
||||||
|
print("Warning: python-dotenv not installed", file=sys.stderr)
|
||||||
|
|
||||||
|
# Add current directory to PYTHONPATH
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.insert(0, current_dir)
|
||||||
|
|
||||||
|
# MCP imports
|
||||||
|
try:
|
||||||
|
from mcp.server.models import InitializationOptions
|
||||||
|
from mcp.server import NotificationOptions
|
||||||
|
from mcp.server.stdio import stdio_server
|
||||||
|
from mcp.server import Server
|
||||||
|
from mcp.types import Tool, TextContent
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"Error importing MCP: {e}", file=sys.stderr)
|
||||||
|
print("Please install required packages: pip install mcp", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Connector imports
|
||||||
|
from src.connector import Config, Imagen4Client
|
||||||
|
from src.connector.imagen4_client import ImageGenerationRequest
|
||||||
|
from src.connector.utils import save_generated_images
|
||||||
|
|
||||||
|
# Image processing import
|
||||||
|
try:
|
||||||
|
from PIL import Image
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"Error importing Pillow: {e}", file=sys.stderr)
|
||||||
|
print("Please install Pillow: pip install Pillow", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[logging.StreamHandler(sys.stderr)]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("imagen4-mcp-server")
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Image Processing Utilities ====================
|
||||||
|
|
||||||
|
def create_preview_image_b64(image_data: bytes, target_size: int = 512, quality: int = 85) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Convert PNG image data to JPEG preview with specified size and return as base64
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_data: Original PNG image data in bytes
|
||||||
|
target_size: Target size for the preview (default: 512x512)
|
||||||
|
quality: JPEG quality (1-100, default: 85)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base64 encoded JPEG image string, or None if conversion fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Open image from bytes
|
||||||
|
with Image.open(io.BytesIO(image_data)) as img:
|
||||||
|
# Convert to RGB if necessary (PNG might have alpha channel)
|
||||||
|
if img.mode in ('RGBA', 'LA', 'P'):
|
||||||
|
# Create white background for transparent images
|
||||||
|
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||||
|
if img.mode == 'P':
|
||||||
|
img = img.convert('RGBA')
|
||||||
|
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
|
||||||
|
img = background
|
||||||
|
elif img.mode != 'RGB':
|
||||||
|
img = img.convert('RGB')
|
||||||
|
|
||||||
|
# Resize to target size maintaining aspect ratio
|
||||||
|
img.thumbnail((target_size, target_size), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
# If image is smaller than target size, pad it to exact size
|
||||||
|
if img.size != (target_size, target_size):
|
||||||
|
# Create new image with target size and white background
|
||||||
|
new_img = Image.new('RGB', (target_size, target_size), (255, 255, 255))
|
||||||
|
# Center the resized image
|
||||||
|
x = (target_size - img.size[0]) // 2
|
||||||
|
y = (target_size - img.size[1]) // 2
|
||||||
|
new_img.paste(img, (x, y))
|
||||||
|
img = new_img
|
||||||
|
|
||||||
|
# Convert to JPEG and encode as base64
|
||||||
|
output_buffer = io.BytesIO()
|
||||||
|
img.save(output_buffer, format='JPEG', quality=quality, optimize=True)
|
||||||
|
jpeg_data = output_buffer.getvalue()
|
||||||
|
|
||||||
|
# Encode to base64
|
||||||
|
base64_data = base64.b64encode(jpeg_data).decode('utf-8')
|
||||||
|
|
||||||
|
logger.info(f"Preview image created: {target_size}x{target_size} JPEG, {len(jpeg_data)} bytes")
|
||||||
|
return base64_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create preview image: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_info(image_data: bytes) -> Optional[dict]:
|
||||||
|
"""Get image information"""
|
||||||
|
try:
|
||||||
|
with Image.open(io.BytesIO(image_data)) as img:
|
||||||
|
return {
|
||||||
|
'format': img.format,
|
||||||
|
'size': img.size,
|
||||||
|
'mode': img.mode,
|
||||||
|
'bytes': len(image_data)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get image info: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Response Models ====================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageGenerationResult:
|
||||||
|
"""Enhanced image generation result with preview support"""
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
original_images_count: int
|
||||||
|
preview_images_b64: Optional[list[str]] = None # List of base64 JPEG previews (512x512)
|
||||||
|
saved_files: Optional[list[str]] = None
|
||||||
|
generation_params: Optional[Dict[str, Any]] = None
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
def to_text_content(self) -> str:
|
||||||
|
"""Convert result to text format for MCP response"""
|
||||||
|
lines = [self.message]
|
||||||
|
|
||||||
|
if self.success and self.preview_images_b64:
|
||||||
|
lines.append(f"\n🖼️ Preview Images Generated: {len(self.preview_images_b64)} images (512x512 JPEG)")
|
||||||
|
for i, preview_b64 in enumerate(self.preview_images_b64):
|
||||||
|
lines.append(f"Preview {i+1} (base64 JPEG): {preview_b64[:50]}...({len(preview_b64)} chars)")
|
||||||
|
|
||||||
|
if self.saved_files:
|
||||||
|
lines.append(f"\n📁 Files saved:")
|
||||||
|
for filepath in self.saved_files:
|
||||||
|
lines.append(f" - {filepath}")
|
||||||
|
|
||||||
|
if self.generation_params:
|
||||||
|
lines.append(f"\n⚙️ Generation Parameters:")
|
||||||
|
for key, value in self.generation_params.items():
|
||||||
|
if key == 'prompt' and len(str(value)) > 100:
|
||||||
|
lines.append(f" - {key}: {str(value)[:100]}...")
|
||||||
|
else:
|
||||||
|
lines.append(f" - {key}: {value}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== MCP Tool Definitions ====================
|
||||||
|
|
||||||
|
def get_tools() -> List[Tool]:
|
||||||
|
"""Return all tool definitions"""
|
||||||
|
return [
|
||||||
|
Tool(
|
||||||
|
name="generate_image",
|
||||||
|
description="Generate 2048x2048 PNG images with 512x512 JPEG previews using Google Imagen 4 AI.",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text prompt for image generation (Korean or English)"
|
||||||
|
},
|
||||||
|
"negative_prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Negative prompt specifying elements not to generate (optional)",
|
||||||
|
"default": ""
|
||||||
|
},
|
||||||
|
"number_of_images": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of images to generate (1 or 2 only)",
|
||||||
|
"enum": [1, 2],
|
||||||
|
"default": 1
|
||||||
|
},
|
||||||
|
"seed": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Seed value for reproducible results (required, 0 ~ 4294967295 range)",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 4294967295
|
||||||
|
},
|
||||||
|
"aspect_ratio": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Image aspect ratio",
|
||||||
|
"enum": ["1:1", "9:16", "16:9", "3:4", "4:3"],
|
||||||
|
"default": "1:1"
|
||||||
|
},
|
||||||
|
"save_to_file": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to save generated images to files (default: true)",
|
||||||
|
"default": True
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["prompt", "seed"]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
name="regenerate_from_json",
|
||||||
|
description="Read parameters from JSON file and regenerate images with previews.",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"json_file_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to JSON file containing saved parameters"
|
||||||
|
},
|
||||||
|
"save_to_file": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to save regenerated images to files (default: true)",
|
||||||
|
"default": True
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["json_file_path"]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
name="generate_random_seed",
|
||||||
|
description="Generate random seed value for image generation.",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Utility Functions ====================
|
||||||
|
|
||||||
|
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Remove or truncate sensitive data from arguments for safe logging"""
|
||||||
|
safe_args = {}
|
||||||
|
for key, value in arguments.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
# Check if it's likely base64 image data
|
||||||
|
if (key in ['data', 'image_data', 'base64', 'preview_image_b64'] or
|
||||||
|
(len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))):
|
||||||
|
# Truncate long image data
|
||||||
|
safe_args[key] = f"<image_data:{len(value)} chars>"
|
||||||
|
elif len(value) > 1000:
|
||||||
|
# Truncate any very long strings
|
||||||
|
safe_args[key] = f"{value[:100]}...<truncated:{len(value)} total chars>"
|
||||||
|
else:
|
||||||
|
safe_args[key] = value
|
||||||
|
else:
|
||||||
|
safe_args[key] = value
|
||||||
|
return safe_args
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Tool Handlers ====================
|
||||||
|
|
||||||
|
class Imagen4ToolHandlers:
|
||||||
|
"""MCP tool handler class with preview image support"""
|
||||||
|
|
||||||
|
def __init__(self, config: Config):
|
||||||
|
"""Initialize handler"""
|
||||||
|
self.config = config
|
||||||
|
self.client = Imagen4Client(config)
|
||||||
|
|
||||||
|
async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||||
|
"""Random seed generation handler"""
|
||||||
|
try:
|
||||||
|
random_seed = random.randint(0, 2**32 - 1)
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Generated random seed: {random_seed}"
|
||||||
|
)]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Random seed generation error: {str(e)}")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during random seed generation: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||||
|
"""Image regeneration from JSON file handler with preview support"""
|
||||||
|
try:
|
||||||
|
json_file_path = arguments.get("json_file_path")
|
||||||
|
save_to_file = arguments.get("save_to_file", True)
|
||||||
|
|
||||||
|
if not json_file_path:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text="Error: JSON file path is required."
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Load parameters from JSON file
|
||||||
|
try:
|
||||||
|
with open(json_file_path, 'r', encoding='utf-8') as f:
|
||||||
|
params = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error: Cannot find JSON file: {json_file_path}"
|
||||||
|
)]
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error: JSON parsing error: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Check required parameters
|
||||||
|
required_params = ['prompt', 'seed']
|
||||||
|
missing_params = [p for p in required_params if p not in params]
|
||||||
|
|
||||||
|
if missing_params:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Create image generation request object
|
||||||
|
request = ImageGenerationRequest(
|
||||||
|
prompt=params.get('prompt'),
|
||||||
|
negative_prompt=params.get('negative_prompt', ''),
|
||||||
|
number_of_images=params.get('number_of_images', 1),
|
||||||
|
seed=params.get('seed'),
|
||||||
|
aspect_ratio=params.get('aspect_ratio', '1:1')
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Loaded parameters from JSON: {json_file_path}")
|
||||||
|
|
||||||
|
# Generate image
|
||||||
|
response = await self.client.generate_image(request)
|
||||||
|
|
||||||
|
if not response.success:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during image regeneration: {response.error_message}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Create preview images
|
||||||
|
preview_images_b64 = []
|
||||||
|
for image_data in response.images_data:
|
||||||
|
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
|
||||||
|
if preview_b64:
|
||||||
|
preview_images_b64.append(preview_b64)
|
||||||
|
logger.info(f"Created preview image: {len(preview_b64)} chars base64 JPEG")
|
||||||
|
|
||||||
|
# Save files (optional)
|
||||||
|
saved_files = []
|
||||||
|
if save_to_file:
|
||||||
|
regeneration_params = {
|
||||||
|
"prompt": request.prompt,
|
||||||
|
"negative_prompt": request.negative_prompt,
|
||||||
|
"number_of_images": request.number_of_images,
|
||||||
|
"seed": request.seed,
|
||||||
|
"aspect_ratio": request.aspect_ratio,
|
||||||
|
"regenerated_from": json_file_path,
|
||||||
|
"original_generated_at": params.get('generated_at', 'unknown')
|
||||||
|
}
|
||||||
|
|
||||||
|
saved_files = save_generated_images(
|
||||||
|
images_data=response.images_data,
|
||||||
|
save_directory=self.config.output_path,
|
||||||
|
seed=request.seed,
|
||||||
|
generation_params=regeneration_params,
|
||||||
|
filename_prefix="imagen4_regen"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create result with preview images
|
||||||
|
result = ImageGenerationResult(
|
||||||
|
success=True,
|
||||||
|
message=f"✅ Images have been successfully regenerated from {json_file_path}",
|
||||||
|
original_images_count=len(response.images_data),
|
||||||
|
preview_images_b64=preview_images_b64,
|
||||||
|
saved_files=saved_files,
|
||||||
|
generation_params={
|
||||||
|
"prompt": request.prompt,
|
||||||
|
"seed": request.seed,
|
||||||
|
"aspect_ratio": request.aspect_ratio,
|
||||||
|
"number_of_images": request.number_of_images
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=result.to_text_content()
|
||||||
|
)]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error occurred during image regeneration: {str(e)}")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during image regeneration: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||||
|
"""Image generation handler with preview image support"""
|
||||||
|
try:
|
||||||
|
# Log arguments safely without exposing image data
|
||||||
|
safe_args = sanitize_args_for_logging(arguments)
|
||||||
|
logger.info(f"handle_generate_image called with arguments: {safe_args}")
|
||||||
|
|
||||||
|
# Extract and validate arguments
|
||||||
|
prompt = arguments.get("prompt")
|
||||||
|
if not prompt:
|
||||||
|
logger.error("No prompt provided")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text="Error: Prompt is required."
|
||||||
|
)]
|
||||||
|
|
||||||
|
seed = arguments.get("seed")
|
||||||
|
if seed is None:
|
||||||
|
logger.error("No seed provided")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed."
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Create image generation request object
|
||||||
|
request = ImageGenerationRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=arguments.get("negative_prompt", ""),
|
||||||
|
number_of_images=arguments.get("number_of_images", 1),
|
||||||
|
seed=seed,
|
||||||
|
aspect_ratio=arguments.get("aspect_ratio", "1:1")
|
||||||
|
)
|
||||||
|
|
||||||
|
save_to_file = arguments.get("save_to_file", True)
|
||||||
|
|
||||||
|
logger.info(f"Starting image generation: '{prompt[:50]}...', Seed: {seed}")
|
||||||
|
|
||||||
|
# Generate image with timeout
|
||||||
|
try:
|
||||||
|
logger.info("Calling client.generate_image()...")
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
self.client.generate_image(request),
|
||||||
|
timeout=360.0 # 6 minute timeout
|
||||||
|
)
|
||||||
|
logger.info(f"Image generation completed. Success: {response.success}")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("Image generation timed out after 6 minutes")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text="Error: Image generation timed out after 6 minutes. Please try again."
|
||||||
|
)]
|
||||||
|
|
||||||
|
if not response.success:
|
||||||
|
logger.error(f"Image generation failed: {response.error_message}")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during image generation: {response.error_message}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
logger.info(f"Generated {len(response.images_data)} images successfully")
|
||||||
|
|
||||||
|
# Create preview images (512x512 JPEG base64)
|
||||||
|
preview_images_b64 = []
|
||||||
|
for i, image_data in enumerate(response.images_data):
|
||||||
|
# Log original image info
|
||||||
|
image_info = get_image_info(image_data)
|
||||||
|
if image_info:
|
||||||
|
logger.info(f"Original image {i+1}: {image_info}")
|
||||||
|
|
||||||
|
# Create preview
|
||||||
|
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
|
||||||
|
if preview_b64:
|
||||||
|
preview_images_b64.append(preview_b64)
|
||||||
|
logger.info(f"Created preview image {i+1}: {len(preview_b64)} chars base64 JPEG")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to create preview for image {i+1}")
|
||||||
|
|
||||||
|
# Save files if requested
|
||||||
|
saved_files = []
|
||||||
|
if save_to_file:
|
||||||
|
logger.info("Saving files to disk...")
|
||||||
|
generation_params = {
|
||||||
|
"prompt": request.prompt,
|
||||||
|
"negative_prompt": request.negative_prompt,
|
||||||
|
"number_of_images": request.number_of_images,
|
||||||
|
"seed": request.seed,
|
||||||
|
"aspect_ratio": request.aspect_ratio,
|
||||||
|
"guidance_scale": 7.5,
|
||||||
|
"safety_filter_level": "block_only_high",
|
||||||
|
"person_generation": "allow_all",
|
||||||
|
"add_watermark": False
|
||||||
|
}
|
||||||
|
|
||||||
|
saved_files = save_generated_images(
|
||||||
|
images_data=response.images_data,
|
||||||
|
save_directory=self.config.output_path,
|
||||||
|
seed=request.seed,
|
||||||
|
generation_params=generation_params
|
||||||
|
)
|
||||||
|
logger.info(f"Files saved: {saved_files}")
|
||||||
|
|
||||||
|
# Verify files were created
|
||||||
|
for file_path in saved_files:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
size = os.path.getsize(file_path)
|
||||||
|
logger.info(f" ✓ Verified: {file_path} ({size} bytes)")
|
||||||
|
else:
|
||||||
|
logger.error(f" ❌ Missing: {file_path}")
|
||||||
|
|
||||||
|
# Create enhanced result with preview images
|
||||||
|
result = ImageGenerationResult(
|
||||||
|
success=True,
|
||||||
|
message=f"✅ Images have been successfully generated!",
|
||||||
|
original_images_count=len(response.images_data),
|
||||||
|
preview_images_b64=preview_images_b64,
|
||||||
|
saved_files=saved_files,
|
||||||
|
generation_params={
|
||||||
|
"prompt": request.prompt,
|
||||||
|
"seed": request.seed,
|
||||||
|
"aspect_ratio": request.aspect_ratio,
|
||||||
|
"number_of_images": request.number_of_images,
|
||||||
|
"negative_prompt": request.negative_prompt
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Returning response with {len(preview_images_b64)} preview images")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=result.to_text_content()
|
||||||
|
)]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True)
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during image generation: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== MCP Server ====================
|
||||||
|
|
||||||
|
class Imagen4MCPServer:
|
||||||
|
"""Imagen4 MCP server class with preview image support"""
|
||||||
|
|
||||||
|
def __init__(self, config: Config):
|
||||||
|
"""Initialize server"""
|
||||||
|
self.config = config
|
||||||
|
self.server = Server("imagen4-mcp-server")
|
||||||
|
self.handlers = Imagen4ToolHandlers(config)
|
||||||
|
|
||||||
|
# Register handlers
|
||||||
|
self._register_handlers()
|
||||||
|
|
||||||
|
def _register_handlers(self) -> None:
|
||||||
|
"""Register MCP handlers"""
|
||||||
|
|
||||||
|
@self.server.list_tools()
|
||||||
|
async def handle_list_tools() -> List[Tool]:
|
||||||
|
"""Return list of available tools"""
|
||||||
|
try:
|
||||||
|
logger.info("Listing available tools with preview image support")
|
||||||
|
return get_tools()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing tools: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@self.server.call_tool()
|
||||||
|
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||||
|
"""Handle tool calls with preview image support"""
|
||||||
|
try:
|
||||||
|
# Log tool call safely without exposing sensitive data
|
||||||
|
safe_args = sanitize_args_for_logging(arguments)
|
||||||
|
logger.info(f"Tool called: {name} with arguments: {safe_args}")
|
||||||
|
|
||||||
|
if name == "generate_random_seed":
|
||||||
|
return await self.handlers.handle_generate_random_seed(arguments)
|
||||||
|
elif name == "regenerate_from_json":
|
||||||
|
return await self.handlers.handle_regenerate_from_json(arguments)
|
||||||
|
elif name == "generate_image":
|
||||||
|
return await self.handlers.handle_generate_image(arguments)
|
||||||
|
else:
|
||||||
|
error_msg = f"Unknown tool: {name}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error: {error_msg}"
|
||||||
|
)]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling tool call '{name}': {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Tool call traceback: {traceback.format_exc()}")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred while processing tool '{name}': {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
def get_server(self) -> Server:
|
||||||
|
"""Return MCP server instance"""
|
||||||
|
return self.server
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Main Function ====================
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main function"""
|
||||||
|
logger.info("Starting Imagen 4 MCP Server with Preview Image Support...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load configuration
|
||||||
|
config = Config.from_env()
|
||||||
|
logger.info(f"Configuration loaded - Project: {config.project_id}, Location: {config.location}")
|
||||||
|
|
||||||
|
# Create MCP server
|
||||||
|
mcp_server = Imagen4MCPServer(config)
|
||||||
|
server = mcp_server.get_server()
|
||||||
|
|
||||||
|
logger.info("Imagen 4 MCP Server initialized successfully")
|
||||||
|
logger.info("Features: 512x512 JPEG preview images, base64 encoding, enhanced responses")
|
||||||
|
|
||||||
|
# Run MCP server with better error handling
|
||||||
|
try:
|
||||||
|
async with stdio_server() as (read_stream, write_stream):
|
||||||
|
logger.info("Starting MCP server with stdio transport")
|
||||||
|
await server.run(
|
||||||
|
read_stream,
|
||||||
|
write_stream,
|
||||||
|
InitializationOptions(
|
||||||
|
server_name="imagen4-mcp-server",
|
||||||
|
server_version="3.0.0",
|
||||||
|
capabilities=server.get_capabilities(
|
||||||
|
notification_options=NotificationOptions(),
|
||||||
|
experimental_capabilities={
|
||||||
|
"preview_images": {},
|
||||||
|
"base64_jpeg_previews": {}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as stdio_error:
|
||||||
|
logger.error(f"STDIO server error: {stdio_error}")
|
||||||
|
# Try to handle specific MCP protocol errors
|
||||||
|
if "TaskGroup" in str(stdio_error):
|
||||||
|
logger.error("TaskGroup error detected - this may be due to client disconnection")
|
||||||
|
raise
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Configuration error: {e}")
|
||||||
|
logger.error("Please check your .env file or environment variables.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Server stopped by user (Ctrl+C)")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Server error: {e}")
|
||||||
|
logger.error(f"Error type: {type(e).__name__}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
# Set up signal handling for graceful shutdown
|
||||||
|
import signal
|
||||||
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
logger.info(f"Received signal {signum}, shutting down gracefully...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
|
# Run the server
|
||||||
|
exit_code = asyncio.run(main())
|
||||||
|
sys.exit(exit_code or 0)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Server stopped by user")
|
||||||
|
sys.exit(0)
|
||||||
|
except SystemExit as e:
|
||||||
|
logger.info(f"Server exiting with code {e.code}")
|
||||||
|
sys.exit(e.code)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Fatal error: {e}")
|
||||||
|
logger.error(f"Error type: {type(e).__name__}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||||
|
sys.exit(1)
|
||||||
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# Requirements for separated Imagen4 architecture
|
||||||
|
google-genai>=0.2.0
|
||||||
|
mcp>=1.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
pytest>=7.0.0
|
||||||
|
pytest-asyncio>=0.21.0
|
||||||
|
Pillow>=10.0.0
|
||||||
50
run.bat
Normal file
50
run.bat
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
@echo off
|
||||||
|
echo Starting Imagen4 MCP Server with Preview Image Support...
|
||||||
|
echo Features: 512x512 JPEG preview images, base64 encoding
|
||||||
|
echo.
|
||||||
|
|
||||||
|
REM Check if virtual environment exists
|
||||||
|
if not exist "venv\Scripts\activate.bat" (
|
||||||
|
echo Creating virtual environment...
|
||||||
|
python -m venv venv
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo Failed to create virtual environment
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
REM Activate virtual environment
|
||||||
|
echo Activating virtual environment...
|
||||||
|
call venv\Scripts\activate.bat
|
||||||
|
|
||||||
|
REM Check if requirements are installed
|
||||||
|
python -c "import PIL" 2>nul
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo Installing requirements...
|
||||||
|
pip install -r requirements.txt
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo Failed to install requirements
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
REM Check if .env file exists
|
||||||
|
if not exist ".env" (
|
||||||
|
echo Warning: .env file not found
|
||||||
|
echo Please copy .env.example to .env and configure your API credentials
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
REM Run server
|
||||||
|
echo Starting MCP server...
|
||||||
|
python main.py
|
||||||
|
|
||||||
|
REM Keep window open if there's an error
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo.
|
||||||
|
echo Server exited with error
|
||||||
|
pause
|
||||||
|
)
|
||||||
10
src/__init__.py
Normal file
10
src/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
Imagen4 Package
|
||||||
|
|
||||||
|
분리된 아키텍처를 가진 Google Imagen 4 MCP 서버
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 기본 컴포넌트만 export
|
||||||
|
# 자세한 import는 각 모듈에서 직접 하도록 함
|
||||||
|
|
||||||
|
__version__ = "2.1.0"
|
||||||
9
src/async_task/__init__.py
Normal file
9
src/async_task/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""
|
||||||
|
Async Task Management Module for MCP Server
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .models import TaskStatus, TaskResult, generate_task_id
|
||||||
|
from .task_manager import TaskManager
|
||||||
|
from .worker_pool import WorkerPool
|
||||||
|
|
||||||
|
__all__ = ['TaskManager', 'TaskStatus', 'TaskResult', 'WorkerPool']
|
||||||
48
src/async_task/models.py
Normal file
48
src/async_task/models.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""
|
||||||
|
Task Status and Result Models
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
"""Task execution status"""
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskResult:
|
||||||
|
"""Task execution result"""
|
||||||
|
task_id: str
|
||||||
|
status: TaskStatus
|
||||||
|
result: Optional[Any] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
|
started_at: Optional[datetime] = None
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration(self) -> Optional[float]:
|
||||||
|
"""Calculate task execution duration in seconds"""
|
||||||
|
if self.started_at and self.completed_at:
|
||||||
|
return (self.completed_at - self.started_at).total_seconds()
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_finished(self) -> bool:
|
||||||
|
"""Check if task is finished (completed, failed, or cancelled)"""
|
||||||
|
return self.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_task_id() -> str:
|
||||||
|
"""Generate unique task ID"""
|
||||||
|
return str(uuid.uuid4())
|
||||||
324
src/async_task/task_manager.py
Normal file
324
src/async_task/task_manager.py
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
"""
|
||||||
|
Task Manager for MCP Server
|
||||||
|
Manages background task execution with status tracking and result retrieval
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Optional, Dict, List
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from .models import TaskResult, TaskStatus, generate_task_id
|
||||||
|
from .worker_pool import WorkerPool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskManager:
|
||||||
|
"""Main task manager for MCP server"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
use_process_pool: bool = False,
|
||||||
|
default_timeout: float = 600.0, # 10 minutes
|
||||||
|
result_retention_hours: int = 24, # Keep results for 24 hours
|
||||||
|
max_retained_results: int = 200
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize task manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_workers: Maximum number of worker threads/processes
|
||||||
|
use_process_pool: Use process pool instead of thread pool
|
||||||
|
default_timeout: Default timeout for tasks in seconds
|
||||||
|
result_retention_hours: How long to keep finished task results
|
||||||
|
max_retained_results: Maximum number of results to retain
|
||||||
|
"""
|
||||||
|
self.worker_pool = WorkerPool(
|
||||||
|
max_workers=max_workers,
|
||||||
|
use_process_pool=use_process_pool,
|
||||||
|
task_timeout=default_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
self.result_retention_hours = result_retention_hours
|
||||||
|
self.max_retained_results = max_retained_results
|
||||||
|
|
||||||
|
# Start cleanup task
|
||||||
|
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||||
|
|
||||||
|
logger.info(f"TaskManager initialized with {self.worker_pool.max_workers} workers")
|
||||||
|
|
||||||
|
async def submit_task(
|
||||||
|
self,
|
||||||
|
func: Callable,
|
||||||
|
*args,
|
||||||
|
task_id: Optional[str] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Submit a task for background execution
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to execute
|
||||||
|
*args: Function arguments
|
||||||
|
task_id: Optional custom task ID
|
||||||
|
timeout: Task timeout in seconds
|
||||||
|
metadata: Additional metadata for the task
|
||||||
|
**kwargs: Function keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Task ID for tracking progress
|
||||||
|
"""
|
||||||
|
task_id = await self.worker_pool.submit_task(
|
||||||
|
func, *args, task_id=task_id, timeout=timeout, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add additional metadata if provided
|
||||||
|
if metadata:
|
||||||
|
result = self.worker_pool.get_task_result(task_id)
|
||||||
|
if result:
|
||||||
|
result.metadata.update(metadata)
|
||||||
|
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
def get_task_status(self, task_id: str) -> Optional[TaskStatus]:
|
||||||
|
"""Get task status by ID"""
|
||||||
|
result = self.worker_pool.get_task_result(task_id)
|
||||||
|
return result.status if result else None
|
||||||
|
|
||||||
|
def get_task_result(self, task_id: str) -> Optional[TaskResult]:
|
||||||
|
"""Get complete task result by ID with circular reference protection"""
|
||||||
|
try:
|
||||||
|
raw_result = self.worker_pool.get_task_result(task_id)
|
||||||
|
if not raw_result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create a safe copy to avoid circular references
|
||||||
|
safe_result = TaskResult(
|
||||||
|
task_id=raw_result.task_id,
|
||||||
|
status=raw_result.status,
|
||||||
|
result=None, # We'll set this safely
|
||||||
|
error=raw_result.error,
|
||||||
|
created_at=raw_result.created_at,
|
||||||
|
started_at=raw_result.started_at,
|
||||||
|
completed_at=raw_result.completed_at,
|
||||||
|
metadata=raw_result.metadata.copy() if raw_result.metadata else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Safely copy the result data
|
||||||
|
if raw_result.result is not None:
|
||||||
|
try:
|
||||||
|
# If it's a dict, create a clean copy
|
||||||
|
if isinstance(raw_result.result, dict):
|
||||||
|
safe_result.result = {
|
||||||
|
'success': raw_result.result.get('success', False),
|
||||||
|
'error': raw_result.result.get('error'),
|
||||||
|
'images_b64': raw_result.result.get('images_b64', []),
|
||||||
|
'saved_files': raw_result.result.get('saved_files', []),
|
||||||
|
'request': raw_result.result.get('request', {}),
|
||||||
|
'image_count': raw_result.result.get('image_count', 0)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# For non-dict results, just copy directly
|
||||||
|
safe_result.result = raw_result.result
|
||||||
|
except Exception as copy_error:
|
||||||
|
logger.error(f"Error creating safe copy of result: {str(copy_error)}")
|
||||||
|
safe_result.result = {
|
||||||
|
'success': False,
|
||||||
|
'error': f'Result data corrupted: {str(copy_error)}'
|
||||||
|
}
|
||||||
|
|
||||||
|
return safe_result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_task_result: {str(e)}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_task_progress(self, task_id: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get task progress information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with task status, timing info, and metadata
|
||||||
|
"""
|
||||||
|
result = self.get_task_result(task_id)
|
||||||
|
if not result:
|
||||||
|
return {'error': 'Task not found'}
|
||||||
|
|
||||||
|
progress = {
|
||||||
|
'task_id': result.task_id,
|
||||||
|
'status': result.status.value,
|
||||||
|
'created_at': result.created_at.isoformat(),
|
||||||
|
'metadata': result.metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.started_at:
|
||||||
|
progress['started_at'] = result.started_at.isoformat()
|
||||||
|
|
||||||
|
if result.status == TaskStatus.RUNNING:
|
||||||
|
elapsed = (datetime.now() - result.started_at).total_seconds()
|
||||||
|
progress['elapsed_seconds'] = elapsed
|
||||||
|
|
||||||
|
if result.completed_at:
|
||||||
|
progress['completed_at'] = result.completed_at.isoformat()
|
||||||
|
progress['duration_seconds'] = result.duration
|
||||||
|
|
||||||
|
if result.error:
|
||||||
|
progress['error'] = result.error
|
||||||
|
|
||||||
|
return progress
|
||||||
|
|
||||||
|
def list_tasks(
|
||||||
|
self,
|
||||||
|
status_filter: Optional[TaskStatus] = None,
|
||||||
|
limit: int = 50
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List tasks with optional filtering
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status_filter: Filter by task status
|
||||||
|
limit: Maximum number of tasks to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of task progress dictionaries
|
||||||
|
"""
|
||||||
|
all_results = self.worker_pool.get_all_results()
|
||||||
|
|
||||||
|
# Filter by status if requested
|
||||||
|
if status_filter:
|
||||||
|
filtered_results = {
|
||||||
|
k: v for k, v in all_results.items()
|
||||||
|
if v.status == status_filter
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
filtered_results = all_results
|
||||||
|
|
||||||
|
# Sort by creation time (most recent first)
|
||||||
|
sorted_items = sorted(
|
||||||
|
filtered_results.items(),
|
||||||
|
key=lambda x: x[1].created_at,
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply limit and return progress info
|
||||||
|
return [
|
||||||
|
self.get_task_progress(task_id)
|
||||||
|
for task_id, _ in sorted_items[:limit]
|
||||||
|
]
|
||||||
|
|
||||||
|
def cancel_task(self, task_id: str) -> bool:
|
||||||
|
"""Cancel a running task"""
|
||||||
|
return self.worker_pool.cancel_task(task_id)
|
||||||
|
|
||||||
|
async def wait_for_task(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
poll_interval: float = 1.0
|
||||||
|
) -> TaskResult:
|
||||||
|
"""
|
||||||
|
Wait for a task to complete
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to wait for
|
||||||
|
timeout: Maximum time to wait in seconds
|
||||||
|
poll_interval: How often to check task status
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskResult: Final task result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
asyncio.TimeoutError: If timeout is reached
|
||||||
|
ValueError: If task doesn't exist
|
||||||
|
"""
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
result = self.get_task_result(task_id)
|
||||||
|
if not result:
|
||||||
|
raise ValueError(f"Task {task_id} not found")
|
||||||
|
|
||||||
|
if result.is_finished:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Check timeout
|
||||||
|
if timeout:
|
||||||
|
elapsed = (datetime.now() - start_time).total_seconds()
|
||||||
|
if elapsed >= timeout:
|
||||||
|
raise asyncio.TimeoutError(f"Timeout waiting for task {task_id}")
|
||||||
|
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
async def _periodic_cleanup(self) -> None:
|
||||||
|
"""Periodic cleanup of old task results"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Wait 1 hour between cleanups
|
||||||
|
await asyncio.sleep(3600)
|
||||||
|
|
||||||
|
# Clean up old results
|
||||||
|
cutoff_time = datetime.now() - timedelta(hours=self.result_retention_hours)
|
||||||
|
all_results = self.worker_pool.get_all_results()
|
||||||
|
|
||||||
|
to_remove = []
|
||||||
|
for task_id, result in all_results.items():
|
||||||
|
if (result.is_finished and
|
||||||
|
result.completed_at and
|
||||||
|
result.completed_at < cutoff_time):
|
||||||
|
to_remove.append(task_id)
|
||||||
|
|
||||||
|
# Remove old results
|
||||||
|
for task_id in to_remove:
|
||||||
|
if task_id in self.worker_pool._results:
|
||||||
|
del self.worker_pool._results[task_id]
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
logger.info(f"Cleaned up {len(to_remove)} old task results")
|
||||||
|
|
||||||
|
# Also clean up excess results
|
||||||
|
self.worker_pool.cleanup_finished_tasks(self.max_retained_results)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in periodic cleanup: {e}")
|
||||||
|
|
||||||
|
async def shutdown(self, wait: bool = True) -> None:
|
||||||
|
"""Shutdown task manager"""
|
||||||
|
logger.info("Shutting down task manager...")
|
||||||
|
|
||||||
|
# Cancel cleanup task
|
||||||
|
if not self._cleanup_task.done():
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Shutdown worker pool
|
||||||
|
await self.worker_pool.shutdown(wait=wait)
|
||||||
|
|
||||||
|
logger.info("Task manager shutdown complete")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get task manager statistics"""
|
||||||
|
all_results = self.worker_pool.get_all_results()
|
||||||
|
|
||||||
|
status_counts = {}
|
||||||
|
for status in TaskStatus:
|
||||||
|
status_counts[status.value] = sum(
|
||||||
|
1 for r in all_results.values() if r.status == status
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_tasks': len(all_results),
|
||||||
|
'active_tasks': self.worker_pool.active_task_count,
|
||||||
|
'max_workers': self.worker_pool.max_workers,
|
||||||
|
'worker_type': 'process' if self.worker_pool.use_process_pool else 'thread',
|
||||||
|
'status_counts': status_counts
|
||||||
|
}
|
||||||
258
src/async_task/worker_pool.py
Normal file
258
src/async_task/worker_pool.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
"""
|
||||||
|
Worker Pool for Background Task Execution
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import multiprocessing
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from .models import TaskResult, TaskStatus, generate_task_id
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerPool:
|
||||||
|
"""Worker pool for executing background tasks"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
use_process_pool: bool = False,
|
||||||
|
task_timeout: float = 600.0 # 10 minutes default
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize worker pool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_workers: Maximum number of worker threads/processes
|
||||||
|
use_process_pool: Use ProcessPoolExecutor instead of ThreadPoolExecutor
|
||||||
|
task_timeout: Default timeout for tasks in seconds
|
||||||
|
"""
|
||||||
|
self.max_workers = max_workers or min(32, (multiprocessing.cpu_count() or 1) + 4)
|
||||||
|
self.use_process_pool = use_process_pool
|
||||||
|
self.task_timeout = task_timeout
|
||||||
|
|
||||||
|
# Initialize executor
|
||||||
|
if use_process_pool:
|
||||||
|
self.executor = ProcessPoolExecutor(max_workers=self.max_workers)
|
||||||
|
logger.info(f"Initialized ProcessPoolExecutor with {self.max_workers} workers")
|
||||||
|
else:
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
||||||
|
logger.info(f"Initialized ThreadPoolExecutor with {self.max_workers} workers")
|
||||||
|
|
||||||
|
# Track running tasks
|
||||||
|
self._running_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
self._results: dict[str, TaskResult] = {}
|
||||||
|
|
||||||
|
async def submit_task(
|
||||||
|
self,
|
||||||
|
func: Callable,
|
||||||
|
*args,
|
||||||
|
task_id: Optional[str] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Submit a task for background execution
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to execute
|
||||||
|
*args: Function arguments
|
||||||
|
task_id: Optional custom task ID
|
||||||
|
timeout: Task timeout in seconds
|
||||||
|
**kwargs: Function keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Task ID for tracking
|
||||||
|
"""
|
||||||
|
if task_id is None:
|
||||||
|
task_id = generate_task_id()
|
||||||
|
|
||||||
|
timeout = timeout or self.task_timeout
|
||||||
|
|
||||||
|
# Create task result
|
||||||
|
task_result = TaskResult(
|
||||||
|
task_id=task_id,
|
||||||
|
status=TaskStatus.PENDING,
|
||||||
|
metadata={
|
||||||
|
'function_name': func.__name__,
|
||||||
|
'timeout': timeout,
|
||||||
|
'worker_type': 'process' if self.use_process_pool else 'thread'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._results[task_id] = task_result
|
||||||
|
|
||||||
|
# Create and start background task
|
||||||
|
background_task = asyncio.create_task(
|
||||||
|
self._execute_task(func, args, kwargs, task_result, timeout)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._running_tasks[task_id] = background_task
|
||||||
|
|
||||||
|
logger.info(f"Task {task_id} submitted for execution (timeout: {timeout}s)")
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
async def _execute_task(
|
||||||
|
self,
|
||||||
|
func: Callable,
|
||||||
|
args: tuple,
|
||||||
|
kwargs: dict,
|
||||||
|
task_result: TaskResult,
|
||||||
|
timeout: float
|
||||||
|
) -> None:
|
||||||
|
"""Execute task in background"""
|
||||||
|
try:
|
||||||
|
task_result.status = TaskStatus.RUNNING
|
||||||
|
task_result.started_at = datetime.now()
|
||||||
|
|
||||||
|
logger.info(f"Starting task {task_result.task_id}: {func.__name__}")
|
||||||
|
|
||||||
|
# Execute function with timeout
|
||||||
|
if self.use_process_pool:
|
||||||
|
# For process pool, function must be pickleable
|
||||||
|
future = self.executor.submit(func, *args, **kwargs)
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
asyncio.wrap_future(future),
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For thread pool, can use lambda/closure
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(func, *args, **kwargs),
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
task_result.result = result
|
||||||
|
task_result.status = TaskStatus.COMPLETED
|
||||||
|
task_result.completed_at = datetime.now()
|
||||||
|
|
||||||
|
logger.info(f"Task {task_result.task_id} completed successfully in {task_result.duration:.2f}s")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
task_result.status = TaskStatus.FAILED
|
||||||
|
task_result.error = f"Task timed out after {timeout} seconds"
|
||||||
|
task_result.completed_at = datetime.now()
|
||||||
|
logger.error(f"Task {task_result.task_id} timed out after {timeout}s")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
task_result.status = TaskStatus.CANCELLED
|
||||||
|
task_result.error = "Task was cancelled"
|
||||||
|
task_result.completed_at = datetime.now()
|
||||||
|
logger.info(f"Task {task_result.task_id} was cancelled")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
task_result.status = TaskStatus.FAILED
|
||||||
|
task_result.error = str(e)
|
||||||
|
task_result.completed_at = datetime.now()
|
||||||
|
logger.error(f"Task {task_result.task_id} failed: {str(e)}", exc_info=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up running task reference
|
||||||
|
if task_result.task_id in self._running_tasks:
|
||||||
|
del self._running_tasks[task_result.task_id]
|
||||||
|
|
||||||
|
def get_task_result(self, task_id: str) -> Optional[TaskResult]:
|
||||||
|
"""Get task result by ID"""
|
||||||
|
return self._results.get(task_id)
|
||||||
|
|
||||||
|
def get_all_results(self) -> dict[str, TaskResult]:
|
||||||
|
"""Get all task results"""
|
||||||
|
return self._results.copy()
|
||||||
|
|
||||||
|
def cancel_task(self, task_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Cancel a running task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to cancel
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if task was cancelled, False if not found or already finished
|
||||||
|
"""
|
||||||
|
if task_id in self._running_tasks:
|
||||||
|
task = self._running_tasks[task_id]
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
logger.info(f"Task {task_id} cancellation requested")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def cleanup_finished_tasks(self, max_results: int = 100) -> int:
|
||||||
|
"""
|
||||||
|
Clean up finished task results to prevent memory buildup
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_results: Maximum number of results to keep
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Number of results cleaned up
|
||||||
|
"""
|
||||||
|
if len(self._results) <= max_results:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Sort by completion time, keep most recent
|
||||||
|
finished_results = {
|
||||||
|
k: v for k, v in self._results.items()
|
||||||
|
if v.is_finished and v.completed_at
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(finished_results) <= max_results:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
sorted_results = sorted(
|
||||||
|
finished_results.items(),
|
||||||
|
key=lambda x: x[1].completed_at,
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep most recent max_results
|
||||||
|
to_keep = set(k for k, _ in sorted_results[:max_results])
|
||||||
|
|
||||||
|
# Remove old results
|
||||||
|
to_remove = [k for k in finished_results.keys() if k not in to_keep]
|
||||||
|
for task_id in to_remove:
|
||||||
|
del self._results[task_id]
|
||||||
|
|
||||||
|
cleaned_count = len(to_remove)
|
||||||
|
if cleaned_count > 0:
|
||||||
|
logger.info(f"Cleaned up {cleaned_count} old task results")
|
||||||
|
|
||||||
|
return cleaned_count
|
||||||
|
|
||||||
|
async def shutdown(self, wait: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Shutdown worker pool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wait: Whether to wait for running tasks to complete
|
||||||
|
"""
|
||||||
|
logger.info("Shutting down worker pool...")
|
||||||
|
|
||||||
|
if wait:
|
||||||
|
# Cancel all running tasks
|
||||||
|
for task_id, task in self._running_tasks.items():
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
logger.info(f"Cancelled task {task_id}")
|
||||||
|
|
||||||
|
# Wait for tasks to complete cancellation
|
||||||
|
if self._running_tasks:
|
||||||
|
await asyncio.gather(*self._running_tasks.values(), return_exceptions=True)
|
||||||
|
|
||||||
|
# Shutdown executor
|
||||||
|
self.executor.shutdown(wait=wait)
|
||||||
|
logger.info("Worker pool shutdown complete")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_task_count(self) -> int:
|
||||||
|
"""Get number of currently running tasks"""
|
||||||
|
return len([t for t in self._running_tasks.values() if not t.done()])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_task_count(self) -> int:
|
||||||
|
"""Get total number of tracked tasks"""
|
||||||
|
return len(self._results)
|
||||||
10
src/connector/__init__.py
Normal file
10
src/connector/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
Imagen4 Connector Package
|
||||||
|
|
||||||
|
Google Imagen 4 API와의 통신을 담당하는 커넥터 모듈
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .imagen4_client import Imagen4Client
|
||||||
|
from .config import Config
|
||||||
|
|
||||||
|
__all__ = ['Imagen4Client', 'Config']
|
||||||
101
src/connector/config.py
Normal file
101
src/connector/config.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""
|
||||||
|
Configuration management for Imagen4 connector
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
"""Imagen4 configuration class"""
|
||||||
|
project_id: str
|
||||||
|
location: str = "us-central1"
|
||||||
|
output_path: str = "./generated_images"
|
||||||
|
default_model: str = "imagen-4.0-generate-001" # Default model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls, env_file: Optional[str] = None) -> 'Config':
|
||||||
|
"""
|
||||||
|
Load configuration from environment variables
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_file: Optional path to .env file to load first
|
||||||
|
"""
|
||||||
|
# Try to load .env file if specified or if python-dotenv is available
|
||||||
|
if env_file or not os.getenv("GOOGLE_CLOUD_PROJECT_ID"):
|
||||||
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
if env_file:
|
||||||
|
load_dotenv(env_file)
|
||||||
|
else:
|
||||||
|
# Try to find .env file in current directory or parent directories
|
||||||
|
current_dir = os.getcwd()
|
||||||
|
env_paths = [
|
||||||
|
os.path.join(current_dir, '.env'),
|
||||||
|
os.path.join(os.path.dirname(current_dir), '.env'),
|
||||||
|
os.path.join(os.path.dirname(__file__), '..', '..', '.env')
|
||||||
|
]
|
||||||
|
|
||||||
|
for env_path in env_paths:
|
||||||
|
if os.path.exists(env_path):
|
||||||
|
load_dotenv(env_path)
|
||||||
|
break
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
pass # python-dotenv not available, continue with system env vars
|
||||||
|
|
||||||
|
# Try multiple environment variable names for better compatibility
|
||||||
|
project_id = (
|
||||||
|
os.getenv("GOOGLE_CLOUD_PROJECT_ID") or
|
||||||
|
os.getenv("GOOGLE_CLOUD_PROJECT") or
|
||||||
|
os.getenv("GCP_PROJECT")
|
||||||
|
)
|
||||||
|
|
||||||
|
if not project_id:
|
||||||
|
raise ValueError(
|
||||||
|
"Google Cloud Project ID is required. Please set one of these environment variables: "
|
||||||
|
"GOOGLE_CLOUD_PROJECT_ID, GOOGLE_CLOUD_PROJECT, or GCP_PROJECT"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also set GOOGLE_CLOUD_PROJECT for Google Cloud libraries if not already set
|
||||||
|
if not os.getenv("GOOGLE_CLOUD_PROJECT"):
|
||||||
|
os.environ["GOOGLE_CLOUD_PROJECT"] = project_id
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
project_id=project_id,
|
||||||
|
location=os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1"),
|
||||||
|
output_path=os.getenv("GENERATED_IMAGES_PATH", "./generated_images"),
|
||||||
|
default_model=os.getenv("IMAGEN4_DEFAULT_MODEL", "imagen-4.0-generate-001")
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_testing(cls, project_id: str = "test-project-id") -> 'Config':
|
||||||
|
"""Create a config instance for testing purposes"""
|
||||||
|
return cls(
|
||||||
|
project_id=project_id,
|
||||||
|
location="us-central1",
|
||||||
|
output_path="./test_generated_images",
|
||||||
|
default_model="imagen-4.0-generate-001"
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
"""Validate configuration"""
|
||||||
|
if not self.project_id:
|
||||||
|
return False
|
||||||
|
if not self.location:
|
||||||
|
return False
|
||||||
|
if not self.output_path:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation of config"""
|
||||||
|
return (
|
||||||
|
f"Config(project_id='{self.project_id}', "
|
||||||
|
f"location='{self.location}', "
|
||||||
|
f"output_path='{self.output_path}', "
|
||||||
|
f"default_model='{self.default_model}')"
|
||||||
|
)
|
||||||
215
src/connector/imagen4_client.py
Normal file
215
src/connector/imagen4_client.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
"""
|
||||||
|
Google Imagen 4 API Client
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from google import genai
|
||||||
|
from google.genai.types import GenerateImagesConfig
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"Google GenAI library not found: {e}. Install with: pip install google-genai")
|
||||||
|
|
||||||
|
from .config import Config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageGenerationRequest:
|
||||||
|
"""Image generation request data class"""
|
||||||
|
prompt: str
|
||||||
|
negative_prompt: str = ""
|
||||||
|
number_of_images: int = 1
|
||||||
|
seed: int = 0
|
||||||
|
aspect_ratio: str = "1:1"
|
||||||
|
model: str = "imagen-4.0-generate-001" # Model selection
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""Validate request data"""
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from ..utils.token_utils import validate_prompt_length
|
||||||
|
|
||||||
|
if not self.prompt:
|
||||||
|
raise ValueError("Prompt is required")
|
||||||
|
|
||||||
|
# 프롬프트 토큰 수 검증
|
||||||
|
is_valid, token_count, error_msg = validate_prompt_length(self.prompt)
|
||||||
|
if not is_valid:
|
||||||
|
raise ValueError(f"Prompt validation failed: {error_msg}")
|
||||||
|
|
||||||
|
# negative_prompt도 검증 (더 관대한 제한)
|
||||||
|
if self.negative_prompt:
|
||||||
|
neg_is_valid, neg_token_count, neg_error_msg = validate_prompt_length(
|
||||||
|
self.negative_prompt, max_tokens=240 # negative prompt는 절반 제한
|
||||||
|
)
|
||||||
|
if not neg_is_valid:
|
||||||
|
raise ValueError(f"Negative prompt validation failed: {neg_error_msg}")
|
||||||
|
|
||||||
|
if not isinstance(self.seed, int) or self.seed < 0 or self.seed > 4294967295:
|
||||||
|
raise ValueError("Seed value must be an integer between 0 and 4294967295")
|
||||||
|
|
||||||
|
if self.number_of_images not in [1, 2]:
|
||||||
|
raise ValueError("Number of images must be 1 or 2 only")
|
||||||
|
|
||||||
|
# Validate model name
|
||||||
|
valid_models = ["imagen-4.0-generate-001", "imagen-4.0-ultra-generate-001"]
|
||||||
|
if self.model not in valid_models:
|
||||||
|
raise ValueError(f"Model must be one of: {valid_models}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageGenerationResponse:
|
||||||
|
"""Image generation response data class"""
|
||||||
|
images_data: List[bytes]
|
||||||
|
request: ImageGenerationRequest
|
||||||
|
success: bool = True
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Imagen4Client:
|
||||||
|
"""Google Imagen 4 API client"""
|
||||||
|
|
||||||
|
def __init__(self, config: Config):
|
||||||
|
"""
|
||||||
|
Initialize client
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Imagen4 configuration object
|
||||||
|
"""
|
||||||
|
if not config.validate():
|
||||||
|
raise ValueError("Invalid configuration")
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self._client = None
|
||||||
|
self._initialize_client()
|
||||||
|
|
||||||
|
def _initialize_client(self) -> None:
|
||||||
|
"""Initialize Google GenAI client"""
|
||||||
|
try:
|
||||||
|
self._client = genai.Client(
|
||||||
|
vertexai=True,
|
||||||
|
project=self.config.project_id,
|
||||||
|
location=self.config.location
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log client information
|
||||||
|
if not self._client.vertexai:
|
||||||
|
logger.info("Using Gemini Developer API.")
|
||||||
|
elif self._client._api_client.project:
|
||||||
|
logger.info(f"Using Vertex AI with project: {self._client._api_client.project} in location: {self._client._api_client.location}")
|
||||||
|
elif self._client._api_client.api_key:
|
||||||
|
logger.info(f"Using Vertex AI in express mode with API key: {self._client._api_client.api_key[:5]}...{self._client._api_client.api_key[-5:]}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize Imagen4Client: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse:
|
||||||
|
"""
|
||||||
|
Process image generation request with MCP-safe execution
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Image generation request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageGenerationResponse: Generation result
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate request
|
||||||
|
request.validate()
|
||||||
|
|
||||||
|
# 토큰 수 정보 로깅
|
||||||
|
from ..utils.token_utils import get_prompt_stats
|
||||||
|
prompt_stats = get_prompt_stats(request.prompt)
|
||||||
|
logger.info(f"Prompt token analysis: {prompt_stats['estimated_tokens']}/{prompt_stats['max_tokens']} tokens")
|
||||||
|
|
||||||
|
if request.negative_prompt:
|
||||||
|
neg_stats = get_prompt_stats(request.negative_prompt)
|
||||||
|
logger.info(f"Negative prompt token analysis: {neg_stats['estimated_tokens']} tokens")
|
||||||
|
|
||||||
|
logger.info(f"Starting image generation - Prompt: '{request.prompt[:50]}...', Seed: {request.seed}, Model: {request.model}")
|
||||||
|
print(f"Aspect Ratio: {request.aspect_ratio}, Number of Images: {request.number_of_images}, Model: {request.model}")
|
||||||
|
print(f"Prompt Tokens: {prompt_stats['estimated_tokens']}/{prompt_stats['max_tokens']}")
|
||||||
|
|
||||||
|
# Create a new event loop if needed (for MCP compatibility)
|
||||||
|
def _sync_generate_images():
|
||||||
|
"""Synchronous wrapper for API call"""
|
||||||
|
return self._client.models.generate_images(
|
||||||
|
model=request.model,
|
||||||
|
prompt=request.prompt,
|
||||||
|
config=GenerateImagesConfig(
|
||||||
|
add_watermark=False,
|
||||||
|
aspect_ratio=request.aspect_ratio,
|
||||||
|
image_size="2K", # 2048x2048
|
||||||
|
negative_prompt=request.negative_prompt,
|
||||||
|
number_of_images=request.number_of_images,
|
||||||
|
output_mime_type="image/png",
|
||||||
|
person_generation="allow_all",
|
||||||
|
safety_filter_level="block_only_high",
|
||||||
|
seed=request.seed,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run in thread with explicit timeout
|
||||||
|
logger.info("Making API call with thread executor...")
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(_sync_generate_images),
|
||||||
|
timeout=300.0 # 5 minute timeout (increased from 90s)
|
||||||
|
)
|
||||||
|
logger.info("API call completed successfully")
|
||||||
|
|
||||||
|
print(f"Get response: {response}")
|
||||||
|
|
||||||
|
# Extract image data from response
|
||||||
|
images_data = []
|
||||||
|
if response.generated_images:
|
||||||
|
for i, gen_image in enumerate(response.generated_images):
|
||||||
|
if hasattr(gen_image, 'image') and hasattr(gen_image.image, 'image_bytes'):
|
||||||
|
image_bytes = gen_image.image.image_bytes
|
||||||
|
if image_bytes:
|
||||||
|
images_data.append(image_bytes)
|
||||||
|
logger.info(f"Image {i+1} extraction complete (size: {len(image_bytes)} bytes)")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Image {i+1}'s image_bytes is empty.")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Cannot find image_bytes in image {i+1}.")
|
||||||
|
|
||||||
|
if not images_data:
|
||||||
|
raise Exception("Cannot find generated image data.")
|
||||||
|
|
||||||
|
logger.info(f"Total {len(images_data)} images generated successfully")
|
||||||
|
|
||||||
|
return ImageGenerationResponse(
|
||||||
|
images_data=images_data,
|
||||||
|
request=request,
|
||||||
|
success=True
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
error_msg = "Image generation timed out after 5 minutes"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return ImageGenerationResponse(
|
||||||
|
images_data=[],
|
||||||
|
request=request,
|
||||||
|
success=False,
|
||||||
|
error_message=error_msg
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error occurred during image generation: {str(e)}")
|
||||||
|
return ImageGenerationResponse(
|
||||||
|
images_data=[],
|
||||||
|
request=request,
|
||||||
|
success=False,
|
||||||
|
error_message=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
"""Check client status"""
|
||||||
|
try:
|
||||||
|
return self._client is not None and self.config.validate()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
70
src/connector/utils.py
Normal file
70
src/connector/utils.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""
|
||||||
|
Utility functions for Imagen4 connector
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def save_generated_images(
|
||||||
|
images_data: List[bytes],
|
||||||
|
save_directory: str = "./generated_images",
|
||||||
|
filename_prefix: str = "imagen4",
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
generation_params: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""Save generated images to files"""
|
||||||
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
saved_files = []
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
for i, image_data in enumerate(images_data):
|
||||||
|
try:
|
||||||
|
if seed is not None:
|
||||||
|
base_filename = f"{filename_prefix}_{seed}_{timestamp}_{i+1:03d}"
|
||||||
|
else:
|
||||||
|
base_filename = f"{filename_prefix}_{timestamp}_{i+1:03d}"
|
||||||
|
|
||||||
|
# Save image file
|
||||||
|
image_filename = f"{base_filename}.png"
|
||||||
|
image_filepath = os.path.join(save_directory, image_filename)
|
||||||
|
|
||||||
|
with open(image_filepath, 'wb') as f:
|
||||||
|
f.write(image_data)
|
||||||
|
|
||||||
|
saved_files.append(image_filepath)
|
||||||
|
logger.info(f"Image saved successfully: {image_filepath}")
|
||||||
|
|
||||||
|
# Save JSON parameter file
|
||||||
|
if generation_params:
|
||||||
|
json_filename = f"{base_filename}.json"
|
||||||
|
json_filepath = os.path.join(save_directory, json_filename)
|
||||||
|
|
||||||
|
params_with_metadata = {
|
||||||
|
**generation_params,
|
||||||
|
"generated_at": datetime.now().isoformat(),
|
||||||
|
"image_filename": image_filename,
|
||||||
|
"image_index": i + 1,
|
||||||
|
"total_images": len(images_data),
|
||||||
|
"file_size_bytes": len(image_data),
|
||||||
|
"model_version": "imagen-4.0-generate-001",
|
||||||
|
"image_format": "PNG",
|
||||||
|
"image_size": "2048x2048"
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(json_filepath, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(params_with_metadata, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
saved_files.append(json_filepath)
|
||||||
|
logger.info(f"Parameters saved successfully: {json_filepath}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save image {i+1}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return saved_files
|
||||||
26
src/server/__init__.py
Normal file
26
src/server/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
Imagen4 Server Package
|
||||||
|
|
||||||
|
MCP 서버 구현을 담당하는 모듈
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Import basic server components
|
||||||
|
from src.server.mcp_server import Imagen4MCPServer
|
||||||
|
from src.server.handlers import ToolHandlers
|
||||||
|
from src.server.models import MCPToolDefinitions
|
||||||
|
|
||||||
|
# Import enhanced components
|
||||||
|
try:
|
||||||
|
from src.server.enhanced_mcp_server import EnhancedImagen4MCPServer
|
||||||
|
from src.server.enhanced_handlers import EnhancedToolHandlers
|
||||||
|
from src.server.enhanced_models import ImageGenerationResult, PreviewImageResponse
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Imagen4MCPServer', 'ToolHandlers', 'MCPToolDefinitions',
|
||||||
|
'EnhancedImagen4MCPServer', 'EnhancedToolHandlers',
|
||||||
|
'ImageGenerationResult', 'PreviewImageResponse'
|
||||||
|
]
|
||||||
|
except ImportError as e:
|
||||||
|
# Fallback if enhanced modules are not available
|
||||||
|
print(f"Warning: Enhanced features not available: {e}")
|
||||||
|
__all__ = ['Imagen4MCPServer', 'ToolHandlers', 'MCPToolDefinitions']
|
||||||
322
src/server/enhanced_handlers.py
Normal file
322
src/server/enhanced_handlers.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
"""
|
||||||
|
Enhanced Tool Handlers for MCP Server with Preview Image Support
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from mcp.types import TextContent, ImageContent
|
||||||
|
|
||||||
|
from src.connector import Imagen4Client, Config
|
||||||
|
from src.connector.imagen4_client import ImageGenerationRequest
|
||||||
|
from src.connector.utils import save_generated_images
|
||||||
|
from src.utils.image_utils import create_preview_image_b64, get_image_info
|
||||||
|
from src.server.enhanced_models import ImageGenerationResult, PreviewImageResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Remove or truncate sensitive data from arguments for safe logging"""
|
||||||
|
safe_args = {}
|
||||||
|
for key, value in arguments.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
# Check if it's likely base64 image data
|
||||||
|
if (key in ['data', 'image_data', 'base64', 'preview_image_b64'] or
|
||||||
|
(len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))):
|
||||||
|
# Truncate long image data
|
||||||
|
safe_args[key] = f"<image_data:{len(value)} chars>"
|
||||||
|
elif len(value) > 1000:
|
||||||
|
# Truncate any very long strings
|
||||||
|
safe_args[key] = f"{value[:100]}...<truncated:{len(value)} total chars>"
|
||||||
|
else:
|
||||||
|
safe_args[key] = value
|
||||||
|
else:
|
||||||
|
safe_args[key] = value
|
||||||
|
return safe_args
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedToolHandlers:
|
||||||
|
"""Enhanced MCP tool handler class with preview image support"""
|
||||||
|
|
||||||
|
def __init__(self, config: Config):
|
||||||
|
"""Initialize handler"""
|
||||||
|
self.config = config
|
||||||
|
self.client = Imagen4Client(config)
|
||||||
|
|
||||||
|
async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||||
|
"""Random seed generation handler"""
|
||||||
|
try:
|
||||||
|
random_seed = random.randint(0, 2**32 - 1)
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Generated random seed: {random_seed}"
|
||||||
|
)]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Random seed generation error: {str(e)}")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during random seed generation: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||||
|
"""Image regeneration from JSON file handler with preview support"""
|
||||||
|
try:
|
||||||
|
json_file_path = arguments.get("json_file_path")
|
||||||
|
save_to_file = arguments.get("save_to_file", True)
|
||||||
|
|
||||||
|
if not json_file_path:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text="Error: JSON file path is required."
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Load parameters from JSON file
|
||||||
|
try:
|
||||||
|
with open(json_file_path, 'r', encoding='utf-8') as f:
|
||||||
|
params = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error: Cannot find JSON file: {json_file_path}"
|
||||||
|
)]
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error: JSON parsing error: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Check required parameters
|
||||||
|
required_params = ['prompt', 'seed']
|
||||||
|
missing_params = [p for p in required_params if p not in params]
|
||||||
|
|
||||||
|
if missing_params:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Create image generation request object
|
||||||
|
request = ImageGenerationRequest(
|
||||||
|
prompt=params.get('prompt'),
|
||||||
|
negative_prompt=params.get('negative_prompt', ''),
|
||||||
|
number_of_images=params.get('number_of_images', 1),
|
||||||
|
seed=params.get('seed'),
|
||||||
|
aspect_ratio=params.get('aspect_ratio', '1:1'),
|
||||||
|
model=params.get('model', self.config.default_model)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Loaded parameters from JSON: {json_file_path}")
|
||||||
|
|
||||||
|
# Generate image
|
||||||
|
response = await self.client.generate_image(request)
|
||||||
|
|
||||||
|
if not response.success:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during image regeneration: {response.error_message}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Create preview images
|
||||||
|
preview_images_b64 = []
|
||||||
|
for image_data in response.images_data:
|
||||||
|
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
|
||||||
|
if preview_b64:
|
||||||
|
preview_images_b64.append(preview_b64)
|
||||||
|
logger.info(f"Created preview image: {len(preview_b64)} chars base64 JPEG")
|
||||||
|
|
||||||
|
# Save files (optional)
|
||||||
|
saved_files = []
|
||||||
|
if save_to_file:
|
||||||
|
regeneration_params = {
|
||||||
|
"prompt": request.prompt,
|
||||||
|
"negative_prompt": request.negative_prompt,
|
||||||
|
"number_of_images": request.number_of_images,
|
||||||
|
"seed": request.seed,
|
||||||
|
"aspect_ratio": request.aspect_ratio,
|
||||||
|
"model": request.model,
|
||||||
|
"regenerated_from": json_file_path,
|
||||||
|
"original_generated_at": params.get('generated_at', 'unknown')
|
||||||
|
}
|
||||||
|
|
||||||
|
saved_files = save_generated_images(
|
||||||
|
images_data=response.images_data,
|
||||||
|
save_directory=self.config.output_path,
|
||||||
|
seed=request.seed,
|
||||||
|
generation_params=regeneration_params,
|
||||||
|
filename_prefix="imagen4_regen"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create result with preview images
|
||||||
|
result = ImageGenerationResult(
|
||||||
|
success=True,
|
||||||
|
message=f"✅ Images have been successfully regenerated from {json_file_path}",
|
||||||
|
original_images_count=len(response.images_data),
|
||||||
|
preview_images_b64=preview_images_b64,
|
||||||
|
saved_files=saved_files,
|
||||||
|
generation_params={
|
||||||
|
"prompt": request.prompt,
|
||||||
|
"seed": request.seed,
|
||||||
|
"aspect_ratio": request.aspect_ratio,
|
||||||
|
"number_of_images": request.number_of_images,
|
||||||
|
"model": request.model
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=result.to_text_content()
|
||||||
|
)]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error occurred during image regeneration: {str(e)}")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during image regeneration: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||||
|
"""Enhanced image generation handler with preview image support"""
|
||||||
|
try:
|
||||||
|
# Log arguments safely without exposing image data
|
||||||
|
safe_args = sanitize_args_for_logging(arguments)
|
||||||
|
logger.info(f"handle_generate_image called with arguments: {safe_args}")
|
||||||
|
|
||||||
|
# Extract and validate arguments
|
||||||
|
prompt = arguments.get("prompt")
|
||||||
|
if not prompt:
|
||||||
|
logger.error("No prompt provided")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text="Error: Prompt is required."
|
||||||
|
)]
|
||||||
|
|
||||||
|
seed = arguments.get("seed")
|
||||||
|
if seed is None:
|
||||||
|
logger.error("No seed provided")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed."
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Create image generation request object
|
||||||
|
request = ImageGenerationRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=arguments.get("negative_prompt", ""),
|
||||||
|
number_of_images=arguments.get("number_of_images", 1),
|
||||||
|
seed=seed,
|
||||||
|
aspect_ratio=arguments.get("aspect_ratio", "1:1"),
|
||||||
|
model=arguments.get("model", self.config.default_model)
|
||||||
|
)
|
||||||
|
|
||||||
|
save_to_file = arguments.get("save_to_file", True)
|
||||||
|
|
||||||
|
logger.info(f"Starting image generation: '{prompt[:50]}...', Seed: {seed}")
|
||||||
|
|
||||||
|
# Generate image with timeout
|
||||||
|
try:
|
||||||
|
logger.info("Calling client.generate_image()...")
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
self.client.generate_image(request),
|
||||||
|
timeout=360.0 # 6 minute timeout
|
||||||
|
)
|
||||||
|
logger.info(f"Image generation completed. Success: {response.success}")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("Image generation timed out after 6 minutes")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text="Error: Image generation timed out after 6 minutes. Please try again."
|
||||||
|
)]
|
||||||
|
|
||||||
|
if not response.success:
|
||||||
|
logger.error(f"Image generation failed: {response.error_message}")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during image generation: {response.error_message}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
logger.info(f"Generated {len(response.images_data)} images successfully")
|
||||||
|
|
||||||
|
# Create preview images (512x512 JPEG base64)
|
||||||
|
preview_images_b64 = []
|
||||||
|
for i, image_data in enumerate(response.images_data):
|
||||||
|
# Log original image info
|
||||||
|
image_info = get_image_info(image_data)
|
||||||
|
if image_info:
|
||||||
|
logger.info(f"Original image {i+1}: {image_info}")
|
||||||
|
|
||||||
|
# Create preview
|
||||||
|
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
|
||||||
|
if preview_b64:
|
||||||
|
preview_images_b64.append(preview_b64)
|
||||||
|
logger.info(f"Created preview image {i+1}: {len(preview_b64)} chars base64 JPEG")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to create preview for image {i+1}")
|
||||||
|
|
||||||
|
# Save files if requested
|
||||||
|
saved_files = []
|
||||||
|
if save_to_file:
|
||||||
|
logger.info("Saving files to disk...")
|
||||||
|
generation_params = {
|
||||||
|
"prompt": request.prompt,
|
||||||
|
"negative_prompt": request.negative_prompt,
|
||||||
|
"number_of_images": request.number_of_images,
|
||||||
|
"seed": request.seed,
|
||||||
|
"aspect_ratio": request.aspect_ratio,
|
||||||
|
"model": request.model,
|
||||||
|
"guidance_scale": 7.5,
|
||||||
|
"safety_filter_level": "block_only_high",
|
||||||
|
"person_generation": "allow_all",
|
||||||
|
"add_watermark": False
|
||||||
|
}
|
||||||
|
|
||||||
|
saved_files = save_generated_images(
|
||||||
|
images_data=response.images_data,
|
||||||
|
save_directory=self.config.output_path,
|
||||||
|
seed=request.seed,
|
||||||
|
generation_params=generation_params
|
||||||
|
)
|
||||||
|
logger.info(f"Files saved: {saved_files}")
|
||||||
|
|
||||||
|
# Verify files were created
|
||||||
|
import os
|
||||||
|
for file_path in saved_files:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
size = os.path.getsize(file_path)
|
||||||
|
logger.info(f" ✓ Verified: {file_path} ({size} bytes)")
|
||||||
|
else:
|
||||||
|
logger.error(f" ❌ Missing: {file_path}")
|
||||||
|
|
||||||
|
# Create enhanced result with preview images
|
||||||
|
result = ImageGenerationResult(
|
||||||
|
success=True,
|
||||||
|
message=f"✅ Images have been successfully generated!",
|
||||||
|
original_images_count=len(response.images_data),
|
||||||
|
preview_images_b64=preview_images_b64,
|
||||||
|
saved_files=saved_files,
|
||||||
|
generation_params={
|
||||||
|
"prompt": request.prompt,
|
||||||
|
"seed": request.seed,
|
||||||
|
"aspect_ratio": request.aspect_ratio,
|
||||||
|
"number_of_images": request.number_of_images,
|
||||||
|
"negative_prompt": request.negative_prompt,
|
||||||
|
"model": request.model
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Returning response with {len(preview_images_b64)} preview images")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=result.to_text_content()
|
||||||
|
)]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True)
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during image generation: {str(e)}"
|
||||||
|
)]
|
||||||
63
src/server/enhanced_mcp_server.py
Normal file
63
src/server/enhanced_mcp_server.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
Enhanced MCP Server implementation for Imagen4 with Preview Image Support
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
from mcp.server import Server
|
||||||
|
from mcp.types import Tool, TextContent
|
||||||
|
|
||||||
|
from src.connector import Config
|
||||||
|
from src.server.models import MCPToolDefinitions
|
||||||
|
from src.server.enhanced_handlers import EnhancedToolHandlers, sanitize_args_for_logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedImagen4MCPServer:
|
||||||
|
"""Enhanced Imagen4 MCP server class with preview image support"""
|
||||||
|
|
||||||
|
def __init__(self, config: Config):
|
||||||
|
"""Initialize server"""
|
||||||
|
self.config = config
|
||||||
|
self.server = Server("imagen4-enhanced-mcp-server")
|
||||||
|
self.handlers = EnhancedToolHandlers(config)
|
||||||
|
|
||||||
|
# Register handlers
|
||||||
|
self._register_handlers()
|
||||||
|
|
||||||
|
def _register_handlers(self) -> None:
|
||||||
|
"""Register MCP handlers"""
|
||||||
|
|
||||||
|
@self.server.list_tools()
|
||||||
|
async def handle_list_tools() -> List[Tool]:
|
||||||
|
"""Return list of available tools"""
|
||||||
|
logger.info("Listing available tools (enhanced version)")
|
||||||
|
return MCPToolDefinitions.get_all_tools()
|
||||||
|
|
||||||
|
@self.server.call_tool()
|
||||||
|
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||||
|
"""Handle tool calls with enhanced preview image support"""
|
||||||
|
# Log tool call safely without exposing sensitive data
|
||||||
|
safe_args = sanitize_args_for_logging(arguments)
|
||||||
|
logger.info(f"Enhanced tool called: {name} with arguments: {safe_args}")
|
||||||
|
|
||||||
|
if name == "generate_random_seed":
|
||||||
|
return await self.handlers.handle_generate_random_seed(arguments)
|
||||||
|
elif name == "regenerate_from_json":
|
||||||
|
return await self.handlers.handle_regenerate_from_json(arguments)
|
||||||
|
elif name == "generate_image":
|
||||||
|
return await self.handlers.handle_generate_image(arguments)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown tool: {name}")
|
||||||
|
|
||||||
|
def get_server(self) -> Server:
|
||||||
|
"""Return MCP server instance"""
|
||||||
|
return self.server
|
||||||
|
|
||||||
|
|
||||||
|
# Create a factory function for easier import
|
||||||
|
def create_enhanced_server(config: Config) -> EnhancedImagen4MCPServer:
|
||||||
|
"""Factory function to create enhanced server"""
|
||||||
|
return EnhancedImagen4MCPServer(config)
|
||||||
71
src/server/enhanced_models.py
Normal file
71
src/server/enhanced_models.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""
|
||||||
|
Enhanced MCP response models with preview image support
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageGenerationResult:
|
||||||
|
"""Enhanced image generation result with preview support"""
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
original_images_count: int
|
||||||
|
preview_images_b64: Optional[list[str]] = None # List of base64 JPEG previews (512x512)
|
||||||
|
saved_files: Optional[list[str]] = None
|
||||||
|
generation_params: Optional[Dict[str, Any]] = None
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
def to_text_content(self) -> str:
|
||||||
|
"""Convert result to text format for MCP response"""
|
||||||
|
lines = [self.message]
|
||||||
|
|
||||||
|
if self.success and self.preview_images_b64:
|
||||||
|
lines.append(f"\n🖼️ Preview Images Generated: {len(self.preview_images_b64)} images (512x512 JPEG)")
|
||||||
|
for i, preview_b64 in enumerate(self.preview_images_b64):
|
||||||
|
lines.append(f"Preview {i+1} (base64 JPEG): {preview_b64[:50]}...({len(preview_b64)} chars)")
|
||||||
|
|
||||||
|
if self.saved_files:
|
||||||
|
lines.append(f"\n📁 Files saved:")
|
||||||
|
for filepath in self.saved_files:
|
||||||
|
lines.append(f" - {filepath}")
|
||||||
|
|
||||||
|
if self.generation_params:
|
||||||
|
lines.append(f"\n⚙️ Generation Parameters:")
|
||||||
|
for key, value in self.generation_params.items():
|
||||||
|
if key == 'prompt' and len(str(value)) > 100:
|
||||||
|
lines.append(f" - {key}: {str(value)[:100]}...")
|
||||||
|
else:
|
||||||
|
lines.append(f" - {key}: {value}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PreviewImageResponse:
|
||||||
|
"""Response containing preview images in base64 format"""
|
||||||
|
preview_images_b64: list[str] # Base64 encoded JPEG images (512x512)
|
||||||
|
original_count: int
|
||||||
|
message: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_image_data(cls, images_data: list[bytes], message: str = "") -> 'PreviewImageResponse':
|
||||||
|
"""Create response from original PNG image data"""
|
||||||
|
from src.utils.image_utils import create_preview_image_b64
|
||||||
|
|
||||||
|
preview_images = []
|
||||||
|
for i, image_data in enumerate(images_data):
|
||||||
|
preview_b64 = create_preview_image_b64(image_data, target_size=512, quality=85)
|
||||||
|
if preview_b64:
|
||||||
|
preview_images.append(preview_b64)
|
||||||
|
else:
|
||||||
|
# Fallback - use original if preview creation fails
|
||||||
|
import base64
|
||||||
|
preview_images.append(base64.b64encode(image_data).decode('utf-8'))
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
preview_images_b64=preview_images,
|
||||||
|
original_count=len(images_data),
|
||||||
|
message=message or f"Generated {len(preview_images)} preview images"
|
||||||
|
)
|
||||||
308
src/server/handlers.py
Normal file
308
src/server/handlers.py
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
"""
|
||||||
|
Tool Handlers for MCP Server
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from mcp.types import TextContent, ImageContent
|
||||||
|
|
||||||
|
from src.connector import Imagen4Client, Config
|
||||||
|
from src.connector.imagen4_client import ImageGenerationRequest
|
||||||
|
from src.connector.utils import save_generated_images
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_args_for_logging(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Remove or truncate sensitive data from arguments for safe logging"""
|
||||||
|
safe_args = {}
|
||||||
|
for key, value in arguments.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
# Check if it's likely base64 image data
|
||||||
|
if (key in ['data', 'image_data', 'base64'] or
|
||||||
|
(len(value) > 100 and value.startswith(('iVBORw0KGgo', '/9j/', 'R0lGOD')))):
|
||||||
|
# Truncate long image data
|
||||||
|
safe_args[key] = f"<image_data:{len(value)} chars>"
|
||||||
|
elif len(value) > 1000:
|
||||||
|
# Truncate any very long strings
|
||||||
|
safe_args[key] = f"{value[:100]}...<truncated:{len(value)} total chars>"
|
||||||
|
else:
|
||||||
|
safe_args[key] = value
|
||||||
|
else:
|
||||||
|
safe_args[key] = value
|
||||||
|
return safe_args
|
||||||
|
|
||||||
|
|
||||||
|
class ToolHandlers:
|
||||||
|
"""MCP tool handler class"""
|
||||||
|
|
||||||
|
def __init__(self, config: Config):
|
||||||
|
"""Initialize handler"""
|
||||||
|
self.config = config
|
||||||
|
self.client = Imagen4Client(config)
|
||||||
|
|
||||||
|
async def handle_generate_random_seed(self, arguments: Dict[str, Any]) -> List[TextContent]:
|
||||||
|
"""Random seed generation handler"""
|
||||||
|
try:
|
||||||
|
random_seed = random.randint(0, 2**32 - 1)
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Generated random seed: {random_seed}"
|
||||||
|
)]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Random seed generation error: {str(e)}")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during random seed generation: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
async def handle_regenerate_from_json(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
|
||||||
|
"""Image regeneration from JSON file handler"""
|
||||||
|
try:
|
||||||
|
json_file_path = arguments.get("json_file_path")
|
||||||
|
save_to_file = arguments.get("save_to_file", True)
|
||||||
|
|
||||||
|
if not json_file_path:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text="Error: JSON file path is required."
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Load parameters from JSON file
|
||||||
|
try:
|
||||||
|
with open(json_file_path, 'r', encoding='utf-8') as f:
|
||||||
|
params = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error: Cannot find JSON file: {json_file_path}"
|
||||||
|
)]
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error: JSON parsing error: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Check required parameters
|
||||||
|
required_params = ['prompt', 'seed']
|
||||||
|
missing_params = [p for p in required_params if p not in params]
|
||||||
|
|
||||||
|
if missing_params:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error: Required parameters are missing from JSON file: {', '.join(missing_params)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Create image generation request object
|
||||||
|
request = ImageGenerationRequest(
|
||||||
|
prompt=params.get('prompt'),
|
||||||
|
negative_prompt=params.get('negative_prompt', ''),
|
||||||
|
number_of_images=params.get('number_of_images', 1),
|
||||||
|
seed=params.get('seed'),
|
||||||
|
aspect_ratio=params.get('aspect_ratio', '1:1')
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Loaded parameters from JSON: {json_file_path}")
|
||||||
|
|
||||||
|
# Generate image
|
||||||
|
response = await self.client.generate_image(request)
|
||||||
|
|
||||||
|
if not response.success:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during image regeneration: {response.error_message}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Save files (optional)
|
||||||
|
saved_files = []
|
||||||
|
if save_to_file:
|
||||||
|
regeneration_params = {
|
||||||
|
"prompt": request.prompt,
|
||||||
|
"negative_prompt": request.negative_prompt,
|
||||||
|
"number_of_images": request.number_of_images,
|
||||||
|
"seed": request.seed,
|
||||||
|
"aspect_ratio": request.aspect_ratio,
|
||||||
|
"regenerated_from": json_file_path,
|
||||||
|
"original_generated_at": params.get('generated_at', 'unknown')
|
||||||
|
}
|
||||||
|
|
||||||
|
saved_files = save_generated_images(
|
||||||
|
images_data=response.images_data,
|
||||||
|
save_directory=self.config.output_path,
|
||||||
|
seed=request.seed,
|
||||||
|
generation_params=regeneration_params,
|
||||||
|
filename_prefix="imagen4_regen"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate result message
|
||||||
|
message_parts = [
|
||||||
|
f"Images have been successfully regenerated.",
|
||||||
|
f"JSON file: {json_file_path}",
|
||||||
|
f"Size: 2048x2048 PNG",
|
||||||
|
f"Count: {len(response.images_data)} images",
|
||||||
|
f"Seed: {request.seed}",
|
||||||
|
f"Prompt: {request.prompt}"
|
||||||
|
]
|
||||||
|
|
||||||
|
if saved_files:
|
||||||
|
message_parts.append(f"\nRegenerated files saved successfully:")
|
||||||
|
for filepath in saved_files:
|
||||||
|
message_parts.append(f"- {filepath}")
|
||||||
|
|
||||||
|
# Generate result
|
||||||
|
result = [
|
||||||
|
TextContent(
|
||||||
|
type="text",
|
||||||
|
text="\n".join(message_parts)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add Base64 encoded images
|
||||||
|
for i, image_data in enumerate(response.images_data):
|
||||||
|
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||||
|
result.append(
|
||||||
|
ImageContent(
|
||||||
|
type="image",
|
||||||
|
data=image_base64,
|
||||||
|
mimeType="image/png"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(f"Added image {i+1} to response: {len(image_base64)} chars base64 data")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error occurred during image regeneration: {str(e)}")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during image regeneration: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
async def handle_generate_image(self, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
|
||||||
|
"""Image generation handler with synchronous processing"""
|
||||||
|
try:
|
||||||
|
# Log arguments safely without exposing image data
|
||||||
|
safe_args = sanitize_args_for_logging(arguments)
|
||||||
|
logger.info(f"handle_generate_image called with arguments: {safe_args}")
|
||||||
|
|
||||||
|
# Extract and validate arguments
|
||||||
|
prompt = arguments.get("prompt")
|
||||||
|
if not prompt:
|
||||||
|
logger.error("No prompt provided")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text="Error: Prompt is required."
|
||||||
|
)]
|
||||||
|
|
||||||
|
seed = arguments.get("seed")
|
||||||
|
if seed is None:
|
||||||
|
logger.error("No seed provided")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text="Error: Seed value is required. You can use the generate_random_seed tool to generate a random seed."
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Create image generation request object
|
||||||
|
request = ImageGenerationRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=arguments.get("negative_prompt", ""),
|
||||||
|
number_of_images=arguments.get("number_of_images", 1),
|
||||||
|
seed=seed,
|
||||||
|
aspect_ratio=arguments.get("aspect_ratio", "1:1")
|
||||||
|
)
|
||||||
|
|
||||||
|
save_to_file = arguments.get("save_to_file", True)
|
||||||
|
|
||||||
|
logger.info(f"Starting SYNCHRONOUS image generation: '{prompt[:50]}...', Seed: {seed}")
|
||||||
|
|
||||||
|
# Generate image synchronously with longer timeout
|
||||||
|
try:
|
||||||
|
logger.info("Calling client.generate_image() synchronously...")
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
self.client.generate_image(request),
|
||||||
|
timeout=360.0 # 6 minute timeout
|
||||||
|
)
|
||||||
|
logger.info(f"Image generation completed. Success: {response.success}")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("Image generation timed out after 6 minutes")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text="Error: Image generation timed out after 6 minutes. Please try again."
|
||||||
|
)]
|
||||||
|
|
||||||
|
if not response.success:
|
||||||
|
logger.error(f"Image generation failed: {response.error_message}")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during image generation: {response.error_message}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
logger.info(f"Generated {len(response.images_data)} images successfully")
|
||||||
|
|
||||||
|
# Save files if requested
|
||||||
|
saved_files = []
|
||||||
|
if save_to_file:
|
||||||
|
logger.info("Saving files to disk...")
|
||||||
|
generation_params = {
|
||||||
|
"prompt": request.prompt,
|
||||||
|
"negative_prompt": request.negative_prompt,
|
||||||
|
"number_of_images": request.number_of_images,
|
||||||
|
"seed": request.seed,
|
||||||
|
"aspect_ratio": request.aspect_ratio,
|
||||||
|
"guidance_scale": 7.5,
|
||||||
|
"safety_filter_level": "block_only_high",
|
||||||
|
"person_generation": "allow_all",
|
||||||
|
"add_watermark": False
|
||||||
|
}
|
||||||
|
|
||||||
|
saved_files = save_generated_images(
|
||||||
|
images_data=response.images_data,
|
||||||
|
save_directory=self.config.output_path,
|
||||||
|
seed=request.seed,
|
||||||
|
generation_params=generation_params
|
||||||
|
)
|
||||||
|
logger.info(f"Files saved: {saved_files}")
|
||||||
|
|
||||||
|
# Verify files were created
|
||||||
|
import os
|
||||||
|
for file_path in saved_files:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
size = os.path.getsize(file_path)
|
||||||
|
logger.info(f" ✓ Verified: {file_path} ({size} bytes)")
|
||||||
|
else:
|
||||||
|
logger.error(f" ❌ Missing: {file_path}")
|
||||||
|
|
||||||
|
# Generate result message
|
||||||
|
message_parts = [
|
||||||
|
f"✅ Images have been successfully generated!",
|
||||||
|
f"Prompt: {request.prompt}",
|
||||||
|
f"Seed: {request.seed}",
|
||||||
|
f"Size: 2048x2048 PNG",
|
||||||
|
f"Count: {len(response.images_data)} images"
|
||||||
|
]
|
||||||
|
|
||||||
|
if saved_files:
|
||||||
|
message_parts.append(f"\n📁 Files saved successfully:")
|
||||||
|
for filepath in saved_files:
|
||||||
|
message_parts.append(f" - {filepath}")
|
||||||
|
else:
|
||||||
|
message_parts.append(f"\nℹ️ Images generated but not saved to file. Use save_to_file=true to save.")
|
||||||
|
|
||||||
|
logger.info("Returning synchronous response")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text="\n".join(message_parts)
|
||||||
|
)]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error occurred during image generation: {str(e)}", exc_info=True)
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error occurred during image generation: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
57
src/server/mcp_server.py
Normal file
57
src/server/mcp_server.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""
|
||||||
|
MCP Server implementation for Imagen4
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
from mcp.server import Server
|
||||||
|
from mcp.types import Tool, TextContent, ImageContent
|
||||||
|
|
||||||
|
from src.connector import Config
|
||||||
|
from src.server.models import MCPToolDefinitions
|
||||||
|
from src.server.handlers import ToolHandlers, sanitize_args_for_logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Imagen4MCPServer:
|
||||||
|
"""Imagen4 MCP server class"""
|
||||||
|
|
||||||
|
def __init__(self, config: Config):
|
||||||
|
"""Initialize server"""
|
||||||
|
self.config = config
|
||||||
|
self.server = Server("imagen4-mcp-server")
|
||||||
|
self.handlers = ToolHandlers(config)
|
||||||
|
|
||||||
|
# Register handlers
|
||||||
|
self._register_handlers()
|
||||||
|
|
||||||
|
def _register_handlers(self) -> None:
|
||||||
|
"""Register MCP handlers"""
|
||||||
|
|
||||||
|
@self.server.list_tools()
|
||||||
|
async def handle_list_tools() -> List[Tool]:
|
||||||
|
"""Return list of available tools"""
|
||||||
|
logger.info("Listing available tools")
|
||||||
|
return MCPToolDefinitions.get_all_tools()
|
||||||
|
|
||||||
|
@self.server.call_tool()
|
||||||
|
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
|
||||||
|
"""Handle tool calls"""
|
||||||
|
# Log tool call safely without exposing sensitive data
|
||||||
|
safe_args = sanitize_args_for_logging(arguments)
|
||||||
|
logger.info(f"Tool called: {name} with arguments: {safe_args}")
|
||||||
|
|
||||||
|
if name == "generate_random_seed":
|
||||||
|
return await self.handlers.handle_generate_random_seed(arguments)
|
||||||
|
elif name == "regenerate_from_json":
|
||||||
|
return await self.handlers.handle_regenerate_from_json(arguments)
|
||||||
|
elif name == "generate_image":
|
||||||
|
return await self.handlers.handle_generate_image(arguments)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown tool: {name}")
|
||||||
|
|
||||||
|
def get_server(self) -> Server:
|
||||||
|
"""Return MCP server instance"""
|
||||||
|
return self.server
|
||||||
39
src/server/minimal_handler.py
Normal file
39
src/server/minimal_handler.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""
|
||||||
|
Minimal Task Result Handler for Emergency Use
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from mcp.types import TextContent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def minimal_get_task_result(task_manager, task_id: str) -> List[TextContent]:
|
||||||
|
"""Absolutely minimal task result handler"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Minimal handler for task: {task_id}")
|
||||||
|
|
||||||
|
# Just return the task status for now
|
||||||
|
status = task_manager.get_task_status(task_id)
|
||||||
|
|
||||||
|
if status is None:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"❌ Task '{task_id}' not found."
|
||||||
|
)]
|
||||||
|
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"📋 Task '{task_id}' Status: {status.value}\n"
|
||||||
|
f"Note: This is a minimal emergency handler.\n"
|
||||||
|
f"If you see this message, the original handler has a serious issue."
|
||||||
|
)]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Even minimal handler failed: {str(e)}", exc_info=True)
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Complete failure: {str(e)}"
|
||||||
|
)]
|
||||||
132
src/server/models.py
Normal file
132
src/server/models.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""
|
||||||
|
MCP Tool Models and Definitions
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mcp.types import Tool
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolDefinitions:
|
||||||
|
"""MCP tool definition class"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_generate_image_tool() -> Tool:
|
||||||
|
"""Image generation tool definition"""
|
||||||
|
return Tool(
|
||||||
|
name="generate_image",
|
||||||
|
description="Generate 2048x2048 PNG images from text prompts using Google Imagen 4 AI.",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text prompt for image generation (Korean or English). Maximum 480 tokens allowed.",
|
||||||
|
"maxLength": 2000 # Approximate character limit for safety
|
||||||
|
},
|
||||||
|
"negative_prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Negative prompt specifying elements not to generate (optional). Maximum 240 tokens allowed.",
|
||||||
|
"default": "",
|
||||||
|
"maxLength": 1000 # Approximate character limit for safety
|
||||||
|
},
|
||||||
|
"number_of_images": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of images to generate (1 or 2 only)",
|
||||||
|
"enum": [1, 2],
|
||||||
|
"default": 1
|
||||||
|
},
|
||||||
|
"seed": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Seed value for reproducible results (required, 0 ~ 4294967295 range)",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 4294967295
|
||||||
|
},
|
||||||
|
"aspect_ratio": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Image aspect ratio",
|
||||||
|
"enum": ["1:1", "9:16", "16:9", "3:4", "4:3"],
|
||||||
|
"default": "1:1"
|
||||||
|
},
|
||||||
|
"save_to_file": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to save generated images to files (default: true)",
|
||||||
|
"default": True
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Imagen 4 model to use for generation",
|
||||||
|
"enum": ["imagen-4.0-generate-001", "imagen-4.0-ultra-generate-001"],
|
||||||
|
"default": "imagen-4.0-generate-001"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["prompt", "seed"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_regenerate_from_json_tool() -> Tool:
|
||||||
|
"""Regenerate from JSON tool definition"""
|
||||||
|
return Tool(
|
||||||
|
name="regenerate_from_json",
|
||||||
|
description="Read parameters from JSON file and regenerate images with the same settings.",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"json_file_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to JSON file containing saved parameters"
|
||||||
|
},
|
||||||
|
"save_to_file": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to save regenerated images to files (default: true)",
|
||||||
|
"default": True
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["json_file_path"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_generate_random_seed_tool() -> Tool:
|
||||||
|
"""Random seed generation tool definition"""
|
||||||
|
return Tool(
|
||||||
|
name="generate_random_seed",
|
||||||
|
description="Generate random seed value for image generation.",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_validate_prompt_tool() -> Tool:
|
||||||
|
"""Prompt validation tool definition"""
|
||||||
|
return Tool(
|
||||||
|
name="validate_prompt",
|
||||||
|
description="Validate prompt length and estimate token count before image generation.",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text prompt to validate"
|
||||||
|
},
|
||||||
|
"negative_prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Negative prompt to validate (optional)",
|
||||||
|
"default": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["prompt"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_all_tools(cls) -> list[Tool]:
|
||||||
|
"""Return all tool definitions"""
|
||||||
|
return [
|
||||||
|
cls.get_generate_image_tool(),
|
||||||
|
cls.get_regenerate_from_json_tool(),
|
||||||
|
cls.get_generate_random_seed_tool(),
|
||||||
|
cls.get_validate_prompt_tool()
|
||||||
|
]
|
||||||
80
src/server/safe_result_handler.py
Normal file
80
src/server/safe_result_handler.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""
|
||||||
|
Safe Get Task Result Handler - Debug Version
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from mcp.types import TextContent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def safe_get_task_result(task_manager, task_id: str) -> List[TextContent]:
|
||||||
|
"""Minimal safe version of get_task_result without image processing"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Safe get_task_result called for: {task_id}")
|
||||||
|
|
||||||
|
# Step 1: Try to get raw result
|
||||||
|
try:
|
||||||
|
raw_result = task_manager.worker_pool.get_task_result(task_id)
|
||||||
|
logger.info(f"Raw result type: {type(raw_result)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get raw result: {str(e)}")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error accessing task data: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
if not raw_result:
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"❌ Task '{task_id}' not found."
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Step 2: Check status safely
|
||||||
|
try:
|
||||||
|
status = raw_result.status
|
||||||
|
logger.info(f"Task status: {status}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get status: {str(e)}")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Error reading task status: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Step 3: Return basic info without touching result.result
|
||||||
|
try:
|
||||||
|
duration = raw_result.duration if hasattr(raw_result, 'duration') else None
|
||||||
|
created = raw_result.created_at.isoformat() if hasattr(raw_result, 'created_at') and raw_result.created_at else "unknown"
|
||||||
|
|
||||||
|
message = f"✅ Task '{task_id}' Status Report:\n"
|
||||||
|
message += f"Status: {status.value}\n"
|
||||||
|
message += f"Created: {created}\n"
|
||||||
|
|
||||||
|
if duration:
|
||||||
|
message += f"Duration: {duration:.2f}s\n"
|
||||||
|
|
||||||
|
if hasattr(raw_result, 'error') and raw_result.error:
|
||||||
|
message += f"Error: {raw_result.error}\n"
|
||||||
|
|
||||||
|
message += "\nNote: This is a safe diagnostic version."
|
||||||
|
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=message
|
||||||
|
)]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to format response: {str(e)}")
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Task exists but data formatting failed: {str(e)}"
|
||||||
|
)]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Critical error in safe_get_task_result: {str(e)}", exc_info=True)
|
||||||
|
return [TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"Critical error: {str(e)}"
|
||||||
|
)]
|
||||||
1
src/utils/__init__.py
Normal file
1
src/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Utils package
|
||||||
107
src/utils/image_utils.py
Normal file
107
src/utils/image_utils.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
"""
|
||||||
|
Image processing utilities for preview generation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_preview_image_b64(image_data: bytes, target_size: int = 512, quality: int = 85) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Convert PNG image data to JPEG preview with specified size and return as base64
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_data: Original PNG image data in bytes
|
||||||
|
target_size: Target size for the preview (default: 512x512)
|
||||||
|
quality: JPEG quality (1-100, default: 85)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base64 encoded JPEG image string, or None if conversion fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Open image from bytes
|
||||||
|
with Image.open(io.BytesIO(image_data)) as img:
|
||||||
|
# Convert to RGB if necessary (PNG might have alpha channel)
|
||||||
|
if img.mode in ('RGBA', 'LA', 'P'):
|
||||||
|
# Create white background for transparent images
|
||||||
|
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||||
|
if img.mode == 'P':
|
||||||
|
img = img.convert('RGBA')
|
||||||
|
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
|
||||||
|
img = background
|
||||||
|
elif img.mode != 'RGB':
|
||||||
|
img = img.convert('RGB')
|
||||||
|
|
||||||
|
# Resize to target size maintaining aspect ratio
|
||||||
|
img.thumbnail((target_size, target_size), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
# If image is smaller than target size, pad it to exact size
|
||||||
|
if img.size != (target_size, target_size):
|
||||||
|
# Create new image with target size and white background
|
||||||
|
new_img = Image.new('RGB', (target_size, target_size), (255, 255, 255))
|
||||||
|
# Center the resized image
|
||||||
|
x = (target_size - img.size[0]) // 2
|
||||||
|
y = (target_size - img.size[1]) // 2
|
||||||
|
new_img.paste(img, (x, y))
|
||||||
|
img = new_img
|
||||||
|
|
||||||
|
# Convert to JPEG and encode as base64
|
||||||
|
output_buffer = io.BytesIO()
|
||||||
|
img.save(output_buffer, format='JPEG', quality=quality, optimize=True)
|
||||||
|
jpeg_data = output_buffer.getvalue()
|
||||||
|
|
||||||
|
# Encode to base64
|
||||||
|
base64_data = base64.b64encode(jpeg_data).decode('utf-8')
|
||||||
|
|
||||||
|
logger.info(f"Preview image created: {target_size}x{target_size} JPEG, {len(jpeg_data)} bytes, base64 length: {len(base64_data)}")
|
||||||
|
return base64_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create preview image: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_image_data(image_data: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
Validate if image data is a valid image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_data: Image data in bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if valid image, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Image.open(io.BytesIO(image_data)) as img:
|
||||||
|
img.verify() # Verify image integrity
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_info(image_data: bytes) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get image information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_data: Image data in bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with image info (format, size, mode) or None if invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Image.open(io.BytesIO(image_data)) as img:
|
||||||
|
return {
|
||||||
|
'format': img.format,
|
||||||
|
'size': img.size,
|
||||||
|
'mode': img.mode,
|
||||||
|
'bytes': len(image_data)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get image info: {str(e)}")
|
||||||
|
return None
|
||||||
200
src/utils/token_utils.py
Normal file
200
src/utils/token_utils.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
"""
|
||||||
|
Token counting utilities for Imagen4 prompts
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 기본 토큰 추정 상수
|
||||||
|
AVERAGE_CHARS_PER_TOKEN = 4 # 영어 기준 평균값
|
||||||
|
KOREAN_CHARS_PER_TOKEN = 2 # 한글 기준 평균값
|
||||||
|
MAX_PROMPT_TOKENS = 480 # 최대 프롬프트 토큰 수
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_token_count(text: str) -> int:
|
||||||
|
"""
|
||||||
|
텍스트의 토큰 수를 추정합니다.
|
||||||
|
|
||||||
|
정확한 토큰 계산을 위해서는 실제 토크나이저가 필요하지만,
|
||||||
|
API 호출 전 빠른 검증을 위해 추정값을 사용합니다.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 토큰 수를 계산할 텍스트
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 추정된 토큰 수
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 텍스트 정리
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 한글과 영어 문자 분리
|
||||||
|
korean_chars = len(re.findall(r'[가-힣]', text))
|
||||||
|
english_chars = len(re.findall(r'[a-zA-Z]', text))
|
||||||
|
other_chars = len(text) - korean_chars - english_chars
|
||||||
|
|
||||||
|
# 토큰 수 추정
|
||||||
|
korean_tokens = korean_chars / KOREAN_CHARS_PER_TOKEN
|
||||||
|
english_tokens = english_chars / AVERAGE_CHARS_PER_TOKEN
|
||||||
|
other_tokens = other_chars / AVERAGE_CHARS_PER_TOKEN
|
||||||
|
|
||||||
|
estimated_tokens = int(korean_tokens + english_tokens + other_tokens)
|
||||||
|
|
||||||
|
# 최소 1토큰은 보장
|
||||||
|
return max(1, estimated_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_prompt_length(prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> tuple[bool, int, Optional[str]]:
|
||||||
|
"""
|
||||||
|
프롬프트 길이를 검증합니다.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 검증할 프롬프트
|
||||||
|
max_tokens: 최대 허용 토큰 수
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (유효성, 토큰 수, 오류 메시지)
|
||||||
|
"""
|
||||||
|
if not prompt:
|
||||||
|
return False, 0, "프롬프트가 비어있습니다."
|
||||||
|
|
||||||
|
token_count = estimate_token_count(prompt)
|
||||||
|
|
||||||
|
if token_count > max_tokens:
|
||||||
|
error_msg = (
|
||||||
|
f"프롬프트가 너무 깁니다. "
|
||||||
|
f"현재: {token_count}토큰, 최대: {max_tokens}토큰. "
|
||||||
|
f"프롬프트를 {token_count - max_tokens}토큰 줄여주세요."
|
||||||
|
)
|
||||||
|
return False, token_count, error_msg
|
||||||
|
|
||||||
|
return True, token_count, None
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_prompt(prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> str:
|
||||||
|
"""
|
||||||
|
프롬프트를 지정된 토큰 수로 자릅니다.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 자를 프롬프트
|
||||||
|
max_tokens: 최대 토큰 수
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 잘린 프롬프트
|
||||||
|
"""
|
||||||
|
if not prompt:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
current_tokens = estimate_token_count(prompt)
|
||||||
|
if current_tokens <= max_tokens:
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
# 대략적인 비율로 텍스트 자르기
|
||||||
|
ratio = max_tokens / current_tokens
|
||||||
|
target_length = int(len(prompt) * ratio * 0.9) # 여유분 10%
|
||||||
|
|
||||||
|
truncated = prompt[:target_length]
|
||||||
|
|
||||||
|
# 단어/문장 경계에서 자르기
|
||||||
|
if len(truncated) < len(prompt):
|
||||||
|
# 마지막 완전한 단어까지만 유지
|
||||||
|
last_space = truncated.rfind(' ')
|
||||||
|
last_korean = truncated.rfind('다') # 한글 어미
|
||||||
|
last_punct = max(truncated.rfind('.'), truncated.rfind(','), truncated.rfind('!'))
|
||||||
|
|
||||||
|
cut_point = max(last_space, last_korean, last_punct)
|
||||||
|
if cut_point > target_length * 0.8: # 너무 많이 잘리지 않도록
|
||||||
|
truncated = truncated[:cut_point]
|
||||||
|
|
||||||
|
# 최종 검증
|
||||||
|
final_tokens = estimate_token_count(truncated)
|
||||||
|
if final_tokens > max_tokens:
|
||||||
|
# 강제로 문자 단위로 자르기
|
||||||
|
chars_per_token = len(truncated) / final_tokens
|
||||||
|
target_chars = int(max_tokens * chars_per_token * 0.95)
|
||||||
|
truncated = truncated[:target_chars]
|
||||||
|
|
||||||
|
return truncated.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_stats(prompt: str) -> dict:
|
||||||
|
"""
|
||||||
|
프롬프트 통계 정보를 반환합니다.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 분석할 프롬프트
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 프롬프트 통계
|
||||||
|
"""
|
||||||
|
if not prompt:
|
||||||
|
return {
|
||||||
|
"character_count": 0,
|
||||||
|
"estimated_tokens": 0,
|
||||||
|
"korean_chars": 0,
|
||||||
|
"english_chars": 0,
|
||||||
|
"other_chars": 0,
|
||||||
|
"is_valid": False,
|
||||||
|
"remaining_tokens": MAX_PROMPT_TOKENS
|
||||||
|
}
|
||||||
|
|
||||||
|
char_count = len(prompt)
|
||||||
|
korean_chars = len(re.findall(r'[가-힣]', prompt))
|
||||||
|
english_chars = len(re.findall(r'[a-zA-Z]', prompt))
|
||||||
|
other_chars = char_count - korean_chars - english_chars
|
||||||
|
estimated_tokens = estimate_token_count(prompt)
|
||||||
|
is_valid = estimated_tokens <= MAX_PROMPT_TOKENS
|
||||||
|
remaining_tokens = MAX_PROMPT_TOKENS - estimated_tokens
|
||||||
|
|
||||||
|
return {
|
||||||
|
"character_count": char_count,
|
||||||
|
"estimated_tokens": estimated_tokens,
|
||||||
|
"korean_chars": korean_chars,
|
||||||
|
"english_chars": english_chars,
|
||||||
|
"other_chars": other_chars,
|
||||||
|
"is_valid": is_valid,
|
||||||
|
"remaining_tokens": remaining_tokens,
|
||||||
|
"max_tokens": MAX_PROMPT_TOKENS
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 실제 토크나이저 사용 시 대체할 수 있는 인터페이스
|
||||||
|
class TokenCounter:
|
||||||
|
"""토큰 카운터 인터페이스"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer_name: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
토큰 카운터 초기화
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer_name: 사용할 토크나이저 이름 (향후 확장용)
|
||||||
|
"""
|
||||||
|
self.tokenizer_name = tokenizer_name or "estimate"
|
||||||
|
logger.info(f"토큰 카운터 초기화: {self.tokenizer_name}")
|
||||||
|
|
||||||
|
def count_tokens(self, text: str) -> int:
|
||||||
|
"""텍스트의 토큰 수 계산"""
|
||||||
|
if self.tokenizer_name == "estimate":
|
||||||
|
return estimate_token_count(text)
|
||||||
|
else:
|
||||||
|
# 향후 실제 토크나이저 구현
|
||||||
|
raise NotImplementedError(f"토크나이저 '{self.tokenizer_name}'는 아직 구현되지 않았습니다.")
|
||||||
|
|
||||||
|
def validate_prompt(self, prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> tuple[bool, int, Optional[str]]:
|
||||||
|
"""프롬프트 검증"""
|
||||||
|
return validate_prompt_length(prompt, max_tokens)
|
||||||
|
|
||||||
|
def truncate_prompt(self, prompt: str, max_tokens: int = MAX_PROMPT_TOKENS) -> str:
|
||||||
|
"""프롬프트 자르기"""
|
||||||
|
return truncate_prompt(prompt, max_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
# 전역 토큰 카운터 인스턴스
|
||||||
|
default_token_counter = TokenCounter()
|
||||||
108
test_main.py
Normal file
108
test_main.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple test for the consolidated main.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
import base64
|
||||||
|
|
||||||
|
# Add current directory to PYTHONPATH
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.insert(0, current_dir)
|
||||||
|
|
||||||
|
def create_test_image():
|
||||||
|
"""Create a simple test image"""
|
||||||
|
img = Image.new('RGB', (2048, 2048), color='lightblue')
|
||||||
|
from PIL import ImageDraw
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
draw.rectangle([500, 500, 1500, 1500], fill='red')
|
||||||
|
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
img.save(buffer, format='PNG')
|
||||||
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
def test_preview_function():
|
||||||
|
"""Test the preview image creation"""
|
||||||
|
# Import the function from main.py
|
||||||
|
from main import create_preview_image_b64, get_image_info
|
||||||
|
|
||||||
|
# Create test image
|
||||||
|
test_png = create_test_image()
|
||||||
|
print(f"Created test PNG: {len(test_png)} bytes")
|
||||||
|
|
||||||
|
# Get image info
|
||||||
|
info = get_image_info(test_png)
|
||||||
|
print(f"Image info: {info}")
|
||||||
|
|
||||||
|
# Create preview
|
||||||
|
preview_b64 = create_preview_image_b64(test_png, target_size=512, quality=85)
|
||||||
|
|
||||||
|
if preview_b64:
|
||||||
|
print(f"✅ Preview created: {len(preview_b64)} chars")
|
||||||
|
|
||||||
|
# Save preview to verify
|
||||||
|
preview_bytes = base64.b64decode(preview_b64)
|
||||||
|
with open('test_preview.jpg', 'wb') as f:
|
||||||
|
f.write(preview_bytes)
|
||||||
|
|
||||||
|
ratio = (len(preview_bytes) / len(test_png)) * 100
|
||||||
|
print(f"Compression: {ratio:.1f}% (original: {len(test_png)}, preview: {len(preview_bytes)})")
|
||||||
|
print("Preview saved as test_preview.jpg")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Failed to create preview")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_imports():
|
||||||
|
"""Test all necessary imports"""
|
||||||
|
try:
|
||||||
|
from main import (
|
||||||
|
Imagen4MCPServer,
|
||||||
|
Imagen4ToolHandlers,
|
||||||
|
ImageGenerationResult,
|
||||||
|
get_tools
|
||||||
|
)
|
||||||
|
print("✅ All classes imported successfully")
|
||||||
|
|
||||||
|
# Test tool definitions
|
||||||
|
tools = get_tools()
|
||||||
|
print(f"✅ {len(tools)} tools defined")
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Import error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main test function"""
|
||||||
|
print("=== Testing Consolidated main.py ===\n")
|
||||||
|
|
||||||
|
success = True
|
||||||
|
|
||||||
|
# Test imports
|
||||||
|
if not test_imports():
|
||||||
|
success = False
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test preview function
|
||||||
|
if not test_preview_function():
|
||||||
|
success = False
|
||||||
|
|
||||||
|
print(f"\n=== Test Results ===")
|
||||||
|
if success:
|
||||||
|
print("✅ All tests passed!")
|
||||||
|
print("\nTo run the server:")
|
||||||
|
print("1. Make sure .env is configured")
|
||||||
|
print("2. Run: python main.py")
|
||||||
|
print("3. Or use: run.bat")
|
||||||
|
else:
|
||||||
|
print("❌ Some tests failed")
|
||||||
|
|
||||||
|
return 0 if success else 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
178
tests/test_connector.py
Normal file
178
tests/test_connector.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
"""
|
||||||
|
Tests for Imagen4 Connector
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from src.connector import Config, Imagen4Client
|
||||||
|
from src.connector.imagen4_client import ImageGenerationRequest
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfig:
|
||||||
|
"""Config 클래스 테스트"""
|
||||||
|
|
||||||
|
def test_config_creation(self):
|
||||||
|
"""설정 생성 테스트"""
|
||||||
|
config = Config(
|
||||||
|
project_id="test-project",
|
||||||
|
location="us-central1",
|
||||||
|
output_path="./test_images"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.project_id == "test-project"
|
||||||
|
assert config.location == "us-central1"
|
||||||
|
assert config.output_path == "./test_images"
|
||||||
|
|
||||||
|
def test_config_validation(self):
|
||||||
|
"""설정 유효성 검사 테스트"""
|
||||||
|
# 유효한 설정
|
||||||
|
valid_config = Config(project_id="test-project")
|
||||||
|
assert valid_config.validate() is True
|
||||||
|
|
||||||
|
# 무효한 설정
|
||||||
|
invalid_config = Config(project_id="")
|
||||||
|
assert invalid_config.validate() is False
|
||||||
|
|
||||||
|
@patch.dict('os.environ', {
|
||||||
|
'GOOGLE_CLOUD_PROJECT_ID': 'test-project',
|
||||||
|
'GOOGLE_CLOUD_LOCATION': 'us-west1',
|
||||||
|
'GENERATED_IMAGES_PATH': './custom_path'
|
||||||
|
})
|
||||||
|
def test_config_from_env(self):
|
||||||
|
"""환경 변수에서 설정 로드 테스트"""
|
||||||
|
config = Config.from_env()
|
||||||
|
|
||||||
|
assert config.project_id == "test-project"
|
||||||
|
assert config.location == "us-west1"
|
||||||
|
assert config.output_path == "./custom_path"
|
||||||
|
|
||||||
|
@patch.dict('os.environ', {}, clear=True)
|
||||||
|
def test_config_from_env_missing_project_id(self):
|
||||||
|
"""필수 환경 변수 누락 테스트"""
|
||||||
|
with pytest.raises(ValueError, match="GOOGLE_CLOUD_PROJECT_ID environment variable is required"):
|
||||||
|
Config.from_env()
|
||||||
|
|
||||||
|
|
||||||
|
class TestImageGenerationRequest:
|
||||||
|
"""ImageGenerationRequest 테스트"""
|
||||||
|
|
||||||
|
def test_valid_request(self):
|
||||||
|
"""유효한 요청 테스트"""
|
||||||
|
request = ImageGenerationRequest(
|
||||||
|
prompt="A beautiful landscape",
|
||||||
|
seed=12345
|
||||||
|
)
|
||||||
|
|
||||||
|
# 유효성 검사가 예외 없이 완료되어야 함
|
||||||
|
request.validate()
|
||||||
|
|
||||||
|
def test_invalid_prompt(self):
|
||||||
|
"""무효한 프롬프트 테스트"""
|
||||||
|
request = ImageGenerationRequest(
|
||||||
|
prompt="",
|
||||||
|
seed=12345
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Prompt is required"):
|
||||||
|
request.validate()
|
||||||
|
|
||||||
|
def test_invalid_seed(self):
|
||||||
|
"""무효한 시드값 테스트"""
|
||||||
|
request = ImageGenerationRequest(
|
||||||
|
prompt="Test prompt",
|
||||||
|
seed=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Seed value must be an integer between 0 and 4294967295"):
|
||||||
|
request.validate()
|
||||||
|
|
||||||
|
def test_invalid_number_of_images(self):
|
||||||
|
"""무효한 이미지 개수 테스트"""
|
||||||
|
request = ImageGenerationRequest(
|
||||||
|
prompt="Test prompt",
|
||||||
|
seed=12345,
|
||||||
|
number_of_images=3
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Number of images must be 1 or 2 only"):
|
||||||
|
request.validate()
|
||||||
|
|
||||||
|
|
||||||
|
class TestImagen4Client:
|
||||||
|
"""Imagen4Client 테스트"""
|
||||||
|
|
||||||
|
def test_client_initialization(self):
|
||||||
|
"""클라이언트 초기화 테스트"""
|
||||||
|
config = Config(project_id="test-project")
|
||||||
|
|
||||||
|
with patch('src.connector.imagen4_client.genai.Client'):
|
||||||
|
client = Imagen4Client(config)
|
||||||
|
assert client.config == config
|
||||||
|
|
||||||
|
def test_client_invalid_config(self):
|
||||||
|
"""무효한 설정으로 클라이언트 초기화 테스트"""
|
||||||
|
invalid_config = Config(project_id="")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid configuration"):
|
||||||
|
Imagen4Client(invalid_config)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_image_success(self):
|
||||||
|
"""이미지 생성 성공 테스트"""
|
||||||
|
config = Config(project_id="test-project")
|
||||||
|
|
||||||
|
# Create mock response
|
||||||
|
mock_image_data = b"fake_image_data"
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.generated_images = [Mock()]
|
||||||
|
mock_response.generated_images[0].image = Mock()
|
||||||
|
mock_response.generated_images[0].image.image_bytes = mock_image_data
|
||||||
|
|
||||||
|
with patch('src.connector.imagen4_client.genai.Client') as mock_client_class:
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
mock_client.models.generate_images = Mock(return_value=mock_response)
|
||||||
|
|
||||||
|
client = Imagen4Client(config)
|
||||||
|
|
||||||
|
request = ImageGenerationRequest(
|
||||||
|
prompt="Test prompt",
|
||||||
|
seed=12345
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch('asyncio.to_thread', return_value=mock_response):
|
||||||
|
response = await client.generate_image(request)
|
||||||
|
|
||||||
|
assert response.success is True
|
||||||
|
assert len(response.images_data) == 1
|
||||||
|
assert response.images_data[0] == mock_image_data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_image_failure(self):
|
||||||
|
"""Image generation failure test"""
|
||||||
|
config = Config(project_id="test-project")
|
||||||
|
|
||||||
|
with patch('src.connector.imagen4_client.genai.Client') as mock_client_class:
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
client = Imagen4Client(config)
|
||||||
|
|
||||||
|
request = ImageGenerationRequest(
|
||||||
|
prompt="Test prompt",
|
||||||
|
seed=12345
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simulate API call failure
|
||||||
|
with patch('asyncio.to_thread', side_effect=Exception("API Error")):
|
||||||
|
response = await client.generate_image(request)
|
||||||
|
|
||||||
|
assert response.success is False
|
||||||
|
assert "API Error" in response.error_message
|
||||||
|
assert len(response.images_data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
136
tests/test_server.py
Normal file
136
tests/test_server.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Tests for Imagen4 Server
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, AsyncMock, patch
|
||||||
|
|
||||||
|
from src.connector import Config
|
||||||
|
from src.server import Imagen4MCPServer, ToolHandlers, MCPToolDefinitions
|
||||||
|
from src.server.models import MCPToolDefinitions
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolDefinitions:
|
||||||
|
"""MCPToolDefinitions tests"""
|
||||||
|
|
||||||
|
def test_get_generate_image_tool(self):
|
||||||
|
"""Image generation tool definition test"""
|
||||||
|
tool = MCPToolDefinitions.get_generate_image_tool()
|
||||||
|
|
||||||
|
assert tool.name == "generate_image"
|
||||||
|
assert "Google Imagen 4" in tool.description
|
||||||
|
assert "prompt" in tool.inputSchema["properties"]
|
||||||
|
assert "seed" in tool.inputSchema["required"]
|
||||||
|
|
||||||
|
def test_get_regenerate_from_json_tool(self):
|
||||||
|
"""JSON regeneration tool definition test"""
|
||||||
|
tool = MCPToolDefinitions.get_regenerate_from_json_tool()
|
||||||
|
|
||||||
|
assert tool.name == "regenerate_from_json"
|
||||||
|
assert "JSON file" in tool.description
|
||||||
|
assert "json_file_path" in tool.inputSchema["properties"]
|
||||||
|
|
||||||
|
def test_get_generate_random_seed_tool(self):
|
||||||
|
"""Random seed generation tool definition test"""
|
||||||
|
tool = MCPToolDefinitions.get_generate_random_seed_tool()
|
||||||
|
|
||||||
|
assert tool.name == "generate_random_seed"
|
||||||
|
assert "random seed" in tool.description
|
||||||
|
|
||||||
|
def test_get_all_tools(self):
|
||||||
|
"""All tools return test"""
|
||||||
|
tools = MCPToolDefinitions.get_all_tools()
|
||||||
|
|
||||||
|
assert len(tools) == 3
|
||||||
|
tool_names = [tool.name for tool in tools]
|
||||||
|
assert "generate_image" in tool_names
|
||||||
|
assert "regenerate_from_json" in tool_names
|
||||||
|
assert "generate_random_seed" in tool_names
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolHandlers:
|
||||||
|
"""ToolHandlers tests"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config(self):
|
||||||
|
"""Test configuration"""
|
||||||
|
return Config(project_id="test-project")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def handlers(self, config):
|
||||||
|
"""Test handlers"""
|
||||||
|
with patch('src.server.handlers.Imagen4Client'):
|
||||||
|
return ToolHandlers(config)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_generate_random_seed(self, handlers):
|
||||||
|
"""Random seed generation handler test"""
|
||||||
|
result = await handlers.handle_generate_random_seed({})
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].type == "text"
|
||||||
|
assert "Generated random seed:" in result[0].text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_generate_image_missing_prompt(self, handlers):
|
||||||
|
"""Missing prompt test"""
|
||||||
|
arguments = {"seed": 12345}
|
||||||
|
|
||||||
|
result = await handlers.handle_generate_image(arguments)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].type == "text"
|
||||||
|
assert "Prompt is required" in result[0].text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_generate_image_missing_seed(self, handlers):
|
||||||
|
"""Missing seed value test"""
|
||||||
|
arguments = {"prompt": "Test prompt"}
|
||||||
|
|
||||||
|
result = await handlers.handle_generate_image(arguments)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].type == "text"
|
||||||
|
assert "Seed value is required" in result[0].text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_regenerate_from_json_missing_path(self, handlers):
|
||||||
|
"""Missing JSON file path test"""
|
||||||
|
arguments = {}
|
||||||
|
|
||||||
|
result = await handlers.handle_regenerate_from_json(arguments)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].type == "text"
|
||||||
|
assert "JSON file path is required" in result[0].text
|
||||||
|
|
||||||
|
|
||||||
|
class TestImagen4MCPServer:
|
||||||
|
"""Imagen4MCPServer tests"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config(self):
|
||||||
|
"""Test configuration"""
|
||||||
|
return Config(project_id="test-project")
|
||||||
|
|
||||||
|
def test_server_initialization(self, config):
|
||||||
|
"""Server initialization test"""
|
||||||
|
with patch('src.server.mcp_server.ToolHandlers'):
|
||||||
|
server = Imagen4MCPServer(config)
|
||||||
|
|
||||||
|
assert server.config == config
|
||||||
|
assert server.server is not None
|
||||||
|
assert server.handlers is not None
|
||||||
|
|
||||||
|
def test_get_server(self, config):
|
||||||
|
"""Server instance return test"""
|
||||||
|
with patch('src.server.mcp_server.ToolHandlers'):
|
||||||
|
mcp_server = Imagen4MCPServer(config)
|
||||||
|
server = mcp_server.get_server()
|
||||||
|
|
||||||
|
assert server is not None
|
||||||
|
assert server.name == "imagen4-mcp-server"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
Reference in New Issue
Block a user