206 lines
7.0 KiB
Python
206 lines
7.0 KiB
Python
"""Unit tests for token utilities"""
|
|
|
|
import pytest
|
|
from src.utils.token_utils import (
|
|
estimate_token_count,
|
|
get_token_limit_for_size,
|
|
determine_optimal_size_for_aspect_ratio,
|
|
validate_prompt_length,
|
|
get_prompt_stats,
|
|
truncate_prompt_to_fit,
|
|
suggest_quality_for_prompt,
|
|
TOKEN_LIMITS
|
|
)
|
|
|
|
|
|
def test_estimate_token_count():
|
|
"""Test token count estimation"""
|
|
# Empty string
|
|
assert estimate_token_count("") == 0
|
|
|
|
# Short text
|
|
count = estimate_token_count("Hello world")
|
|
assert 2 <= count <= 4 # Reasonable range
|
|
|
|
# Longer text (approximately 4 chars per token)
|
|
long_text = "a" * 100
|
|
count = estimate_token_count(long_text)
|
|
assert 20 <= count <= 30
|
|
|
|
# With spaces and punctuation
|
|
text = "This is a test. With multiple sentences!"
|
|
count = estimate_token_count(text)
|
|
assert 8 <= count <= 12
|
|
|
|
|
|
def test_get_token_limit_for_size():
|
|
"""Test getting token limits for different sizes"""
|
|
# High quality limits
|
|
assert get_token_limit_for_size("1024x1024", "high") == 4160
|
|
assert get_token_limit_for_size("1024x1536", "high") == 6240
|
|
assert get_token_limit_for_size("1536x1024", "high") == 6208
|
|
|
|
# Medium quality limits
|
|
assert get_token_limit_for_size("1024x1024", "medium") == 1056
|
|
assert get_token_limit_for_size("1024x1536", "medium") == 1584
|
|
|
|
# Low quality limits
|
|
assert get_token_limit_for_size("1024x1024", "low") == 272
|
|
|
|
# Unknown size should default to square
|
|
assert get_token_limit_for_size("999x999", "high") == 4160
|
|
|
|
# Invalid quality should default to high
|
|
assert get_token_limit_for_size("1024x1024", "invalid") == 4160
|
|
|
|
|
|
def test_determine_optimal_size_for_aspect_ratio():
|
|
"""Test optimal size determination based on aspect ratio"""
|
|
# Small square image
|
|
size, aspect = determine_optimal_size_for_aspect_ratio(100, 100)
|
|
assert size == "256x256"
|
|
assert aspect == "square"
|
|
|
|
# Medium square image
|
|
size, aspect = determine_optimal_size_for_aspect_ratio(400, 400)
|
|
assert size == "512x512"
|
|
assert aspect == "square"
|
|
|
|
# Large square image
|
|
size, aspect = determine_optimal_size_for_aspect_ratio(1000, 1000)
|
|
assert size == "1024x1024"
|
|
assert aspect == "square"
|
|
|
|
# Landscape image
|
|
size, aspect = determine_optimal_size_for_aspect_ratio(1600, 900)
|
|
assert size == "1536x1024"
|
|
assert aspect == "landscape"
|
|
|
|
# Portrait image
|
|
size, aspect = determine_optimal_size_for_aspect_ratio(900, 1600)
|
|
assert size == "1024x1536"
|
|
assert aspect == "portrait"
|
|
|
|
|
|
def test_validate_prompt_length():
|
|
"""Test prompt length validation"""
|
|
# Short prompt - should be valid
|
|
is_valid, tokens, error = validate_prompt_length("Edit this image", "1024x1024", "high")
|
|
assert is_valid is True
|
|
assert tokens > 0
|
|
assert error == ""
|
|
|
|
# Very long prompt - should be invalid
|
|
long_prompt = "word " * 2000 # Way over limit
|
|
is_valid, tokens, error = validate_prompt_length(long_prompt, "1024x1024", "high")
|
|
assert is_valid is False
|
|
assert tokens > 4160 # Should exceed high quality limit
|
|
assert "too long" in error.lower()
|
|
|
|
# Edge case - close to limit (should pass but might warn)
|
|
# For 1024x1024 high quality, limit is 4160 tokens
|
|
# Approximately 16,640 characters (4 chars per token)
|
|
edge_prompt = "a" * 16000
|
|
is_valid, tokens, error = validate_prompt_length(edge_prompt, "1024x1024", "high")
|
|
# Should be close to limit
|
|
assert tokens > 3000
|
|
|
|
|
|
def test_get_prompt_stats():
|
|
"""Test getting prompt statistics"""
|
|
prompt = "Make the sky blue and add some clouds"
|
|
stats = get_prompt_stats(prompt, "1024x1024", "high")
|
|
|
|
assert "estimated_tokens" in stats
|
|
assert "token_limit" in stats
|
|
assert "usage_percentage" in stats
|
|
assert "remaining_tokens" in stats
|
|
assert "quality" in stats
|
|
assert "size" in stats
|
|
assert "is_valid" in stats
|
|
|
|
assert stats["token_limit"] == 4160
|
|
assert stats["quality"] == "high"
|
|
assert stats["size"] == "1024x1024"
|
|
assert stats["is_valid"] is True
|
|
assert stats["usage_percentage"] < 10 # Short prompt
|
|
|
|
|
|
def test_truncate_prompt_to_fit():
|
|
"""Test prompt truncation"""
|
|
# Short prompt - should not be truncated
|
|
short_prompt = "Edit this image"
|
|
truncated = truncate_prompt_to_fit(short_prompt, "1024x1024", "high")
|
|
assert truncated == short_prompt
|
|
|
|
# Long prompt - should be truncated
|
|
long_prompt = " ".join([f"word{i}" for i in range(5000)])
|
|
truncated = truncate_prompt_to_fit(long_prompt, "1024x1024", "high", buffer=0.95)
|
|
|
|
# Check that truncated version is shorter
|
|
assert len(truncated) < len(long_prompt)
|
|
|
|
# Check that truncated version fits within limits
|
|
is_valid, tokens, _ = validate_prompt_length(truncated, "1024x1024", "high")
|
|
assert is_valid is True
|
|
assert tokens < 4160 * 0.95 # Should be within buffer
|
|
|
|
|
|
def test_truncate_prompt_with_low_quality():
|
|
"""Test prompt truncation with low quality (strict limits)"""
|
|
# For low quality square, limit is only 272 tokens
|
|
medium_prompt = " ".join([f"word{i}" for i in range(200)])
|
|
truncated = truncate_prompt_to_fit(medium_prompt, "1024x1024", "low")
|
|
|
|
# Should be significantly truncated
|
|
assert len(truncated) < len(medium_prompt)
|
|
|
|
# Verify it fits
|
|
is_valid, tokens, _ = validate_prompt_length(truncated, "1024x1024", "low")
|
|
assert is_valid is True
|
|
assert tokens <= 272 * 0.95
|
|
|
|
|
|
def test_suggest_quality_for_prompt():
|
|
"""Test quality suggestion based on prompt length"""
|
|
# Very short prompt - should suggest low
|
|
short_prompt = "blue sky"
|
|
suggested = suggest_quality_for_prompt(short_prompt, "1024x1024")
|
|
assert suggested == "low"
|
|
|
|
# Medium prompt - should suggest medium
|
|
medium_prompt = " ".join([f"word{i}" for i in range(100)])
|
|
suggested = suggest_quality_for_prompt(medium_prompt, "1024x1024")
|
|
assert suggested in ["low", "medium"]
|
|
|
|
# Long prompt - should suggest high
|
|
long_prompt = " ".join([f"word{i}" for i in range(1000)])
|
|
suggested = suggest_quality_for_prompt(long_prompt, "1024x1024")
|
|
assert suggested == "high"
|
|
|
|
# Very long prompt - still suggests high (will need truncation)
|
|
very_long_prompt = " ".join([f"word{i}" for i in range(5000)])
|
|
suggested = suggest_quality_for_prompt(very_long_prompt, "1024x1024")
|
|
assert suggested == "high"
|
|
|
|
|
|
def test_token_limits_structure():
|
|
"""Test that TOKEN_LIMITS has the expected structure"""
|
|
assert "low" in TOKEN_LIMITS
|
|
assert "medium" in TOKEN_LIMITS
|
|
assert "high" in TOKEN_LIMITS
|
|
|
|
for quality in TOKEN_LIMITS:
|
|
assert "1024x1024" in TOKEN_LIMITS[quality]
|
|
assert "1024x1536" in TOKEN_LIMITS[quality]
|
|
assert "1536x1024" in TOKEN_LIMITS[quality]
|
|
|
|
# Verify high quality limits match documentation
|
|
assert TOKEN_LIMITS["high"]["1024x1024"] == 4160
|
|
assert TOKEN_LIMITS["high"]["1024x1536"] == 6240
|
|
assert TOKEN_LIMITS["high"]["1536x1024"] == 6208
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|