219 lines
7.8 KiB
Python
219 lines
7.8 KiB
Python
"""Unit tests for validation utilities"""
|
|
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
# Add src to path
|
|
import sys
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
|
|
|
from src.utils.validation import (
|
|
validate_edit_parameters,
|
|
validate_file_parameters,
|
|
validate_move_file_parameters,
|
|
validate_image_path_parameter,
|
|
sanitize_prompt,
|
|
validate_aspect_ratio_format,
|
|
validate_seed_range,
|
|
validate_filename_safety
|
|
)
|
|
|
|
|
|
class TestValidation(unittest.TestCase):
|
|
"""Test cases for validation utilities"""
|
|
|
|
def test_validate_edit_parameters_success(self):
|
|
"""Test successful validation of edit parameters"""
|
|
valid_args = {
|
|
'input_image_b64': 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==',
|
|
'prompt': 'Make the sky blue',
|
|
'seed': 12345,
|
|
'aspect_ratio': '16:9',
|
|
'save_to_file': True
|
|
}
|
|
|
|
is_valid, error = validate_edit_parameters(valid_args)
|
|
self.assertTrue(is_valid)
|
|
self.assertIsNone(error)
|
|
|
|
def test_validate_edit_parameters_missing_required(self):
|
|
"""Test validation fails with missing required parameters"""
|
|
# Missing input_image_b64
|
|
args = {'prompt': 'test', 'seed': 12345}
|
|
is_valid, error = validate_edit_parameters(args)
|
|
self.assertFalse(is_valid)
|
|
self.assertIn('input_image_b64 is required', error)
|
|
|
|
# Missing prompt
|
|
args = {'input_image_b64': 'test_b64', 'seed': 12345}
|
|
is_valid, error = validate_edit_parameters(args)
|
|
self.assertFalse(is_valid)
|
|
self.assertIn('prompt is required', error)
|
|
|
|
# Missing seed
|
|
args = {'input_image_b64': 'test_b64', 'prompt': 'test'}
|
|
is_valid, error = validate_edit_parameters(args)
|
|
self.assertFalse(is_valid)
|
|
self.assertIn('seed is required', error)
|
|
|
|
def test_validate_edit_parameters_invalid_types(self):
|
|
"""Test validation fails with invalid parameter types"""
|
|
# Invalid seed type
|
|
args = {
|
|
'input_image_b64': 'valid_b64',
|
|
'prompt': 'test prompt',
|
|
'seed': 'not_a_number'
|
|
}
|
|
is_valid, error = validate_edit_parameters(args)
|
|
self.assertFalse(is_valid)
|
|
self.assertIn('seed must be an integer', error)
|
|
|
|
# Invalid aspect ratio
|
|
args = {
|
|
'input_image_b64': 'valid_b64',
|
|
'prompt': 'test prompt',
|
|
'seed': 12345,
|
|
'aspect_ratio': 'invalid_ratio'
|
|
}
|
|
is_valid, error = validate_edit_parameters(args)
|
|
self.assertFalse(is_valid)
|
|
self.assertIn('aspect_ratio must be one of', error)
|
|
|
|
def test_validate_edit_parameters_invalid_ranges(self):
|
|
"""Test validation fails with invalid parameter ranges"""
|
|
# Seed out of range
|
|
args = {
|
|
'input_image_b64': 'valid_b64',
|
|
'prompt': 'test prompt',
|
|
'seed': -1
|
|
}
|
|
is_valid, error = validate_edit_parameters(args)
|
|
self.assertFalse(is_valid)
|
|
self.assertIn('seed must be between', error)
|
|
|
|
# Prompt too long
|
|
args = {
|
|
'input_image_b64': 'valid_b64',
|
|
'prompt': 'x' * 10001, # Too long
|
|
'seed': 12345
|
|
}
|
|
is_valid, error = validate_edit_parameters(args)
|
|
self.assertFalse(is_valid)
|
|
self.assertIn('prompt is too long', error)
|
|
|
|
def test_validate_file_parameters_success(self):
|
|
"""Test successful validation of file parameters"""
|
|
valid_args = {
|
|
'input_image_name': 'test.png',
|
|
'prompt': 'Edit this image',
|
|
'seed': 54321,
|
|
'aspect_ratio': '1:1',
|
|
'save_to_file': False
|
|
}
|
|
|
|
is_valid, error = validate_file_parameters(valid_args)
|
|
self.assertTrue(is_valid)
|
|
self.assertIsNone(error)
|
|
|
|
def test_validate_file_parameters_invalid_filename(self):
|
|
"""Test validation fails with invalid filename"""
|
|
# Path traversal attempt
|
|
args = {
|
|
'input_image_name': '../../../etc/passwd',
|
|
'prompt': 'test',
|
|
'seed': 12345
|
|
}
|
|
is_valid, error = validate_file_parameters(args)
|
|
self.assertFalse(is_valid)
|
|
self.assertIn('cannot contain path separators', error)
|
|
|
|
# Invalid extension
|
|
args = {
|
|
'input_image_name': 'test.exe',
|
|
'prompt': 'test',
|
|
'seed': 12345
|
|
}
|
|
is_valid, error = validate_file_parameters(args)
|
|
self.assertFalse(is_valid)
|
|
self.assertIn('must have a valid image extension', error)
|
|
|
|
def test_validate_move_file_parameters_success(self):
|
|
"""Test successful validation of move file parameters"""
|
|
valid_args = {
|
|
'temp_file_name': 'temp_image.png',
|
|
'output_file_name': 'output.png',
|
|
'copy_only': True
|
|
}
|
|
|
|
is_valid, error = validate_move_file_parameters(valid_args)
|
|
self.assertTrue(is_valid)
|
|
self.assertIsNone(error)
|
|
|
|
def test_validate_image_path_parameter_success(self):
|
|
"""Test successful validation of image path parameter"""
|
|
valid_args = {'image_path': '/path/to/image.png'}
|
|
|
|
is_valid, error = validate_image_path_parameter(valid_args)
|
|
self.assertTrue(is_valid)
|
|
self.assertIsNone(error)
|
|
|
|
def test_sanitize_prompt(self):
|
|
"""Test prompt sanitization"""
|
|
# Test whitespace normalization
|
|
prompt = " This has extra whitespace "
|
|
sanitized = sanitize_prompt(prompt)
|
|
self.assertEqual(sanitized, "This has extra whitespace")
|
|
|
|
# Test null byte removal
|
|
prompt = "Test\x00with\x00nulls"
|
|
sanitized = sanitize_prompt(prompt)
|
|
self.assertEqual(sanitized, "Testwithulls")
|
|
|
|
# Test length limiting
|
|
long_prompt = "x" * 10001
|
|
sanitized = sanitize_prompt(long_prompt)
|
|
self.assertEqual(len(sanitized), 10000)
|
|
|
|
def test_validate_aspect_ratio_format(self):
|
|
"""Test aspect ratio format validation"""
|
|
# Valid formats
|
|
self.assertTrue(validate_aspect_ratio_format("16:9"))
|
|
self.assertTrue(validate_aspect_ratio_format("1:1"))
|
|
self.assertTrue(validate_aspect_ratio_format("4:3"))
|
|
|
|
# Invalid formats
|
|
self.assertFalse(validate_aspect_ratio_format("16-9"))
|
|
self.assertFalse(validate_aspect_ratio_format("16:9:1"))
|
|
self.assertFalse(validate_aspect_ratio_format("a:b"))
|
|
self.assertFalse(validate_aspect_ratio_format("0:1"))
|
|
|
|
def test_validate_seed_range(self):
|
|
"""Test seed range validation"""
|
|
# Valid seeds
|
|
self.assertTrue(validate_seed_range(0))
|
|
self.assertTrue(validate_seed_range(12345))
|
|
self.assertTrue(validate_seed_range(2**32 - 1))
|
|
|
|
# Invalid seeds
|
|
self.assertFalse(validate_seed_range(-1))
|
|
self.assertFalse(validate_seed_range(2**32))
|
|
self.assertFalse(validate_seed_range("not_a_number"))
|
|
|
|
def test_validate_filename_safety(self):
|
|
"""Test filename safety validation"""
|
|
# Safe filenames
|
|
self.assertTrue(validate_filename_safety("image.png"))
|
|
self.assertTrue(validate_filename_safety("my_image_123.jpg"))
|
|
self.assertTrue(validate_filename_safety("test-file.png"))
|
|
|
|
# Unsafe filenames
|
|
self.assertFalse(validate_filename_safety("../image.png"))
|
|
self.assertFalse(validate_filename_safety("path/to/image.png"))
|
|
self.assertFalse(validate_filename_safety("image<>.png"))
|
|
self.assertFalse(validate_filename_safety("CON.png")) # Windows reserved
|
|
self.assertFalse(validate_filename_safety("x" * 256)) # Too long
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|