Skip to content

Commit

Permalink
Added: Updated tests, Cloud and Docker implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Klus3kk committed Dec 18, 2024
1 parent 192e31f commit c4d1304
Show file tree
Hide file tree
Showing 11 changed files with 223 additions and 89 deletions.
21 changes: 21 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Base image
FROM python:3.10-slim

# Set environment variables
ENV PYTHONUNBUFFERED=1

# Set working directory
WORKDIR /app

# Copy requirements and install dependencies
COPY requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir -r requirements.txt

# Copy the app code
COPY . /app

# Expose FastAPI port
EXPOSE 8000

# Command to run FastAPI
CMD ["uvicorn", "api.FastAPIHandler:app", "--host", "0.0.0.0", "--port", "8000"]
2 changes: 1 addition & 1 deletion deployment/upload_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
registry = StyleRegistry()

# Google Cloud Storage bucket name
bucket_name = os.getenv("GCS_BUCKET_NAME", "your-bucket-name")
bucket_name = os.getenv("GCS_BUCKET_NAME", "artify-models")

# Upload each model to the cloud
for category in registry.styles.keys():
Expand Down
9 changes: 9 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
version: "3.9"
services:
artify-api:
build: .
ports:
- "8000:8000"
volumes:
- ./models:/app/models
- ./images:/app/images
24 changes: 21 additions & 3 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,30 @@ def test_root_endpoint():
assert response.status_code == 200
assert response.json() == {"message": "Welcome to Artify! Use /apply_style to stylize your images."}

def test_apply_style_endpoint():
with open("images/content/sample_content.jpg", "rb") as content, open("images/style/impressionism/sample_style.jpg", "rb") as style:
def test_apply_style_endpoint_success():
with open("images/content/sample_content.jpg", "rb") as content:
response = client.post(
"/apply_style/",
files={"content": ("sample_content.jpg", content, "image/jpeg")},
data={"style_category": "impressionism"}
)
assert response.status_code == 200
assert "output_path" in response.json(), "API should return the output path."
assert "output_path" in response.json()
assert response.json()["message"] == "Style applied successfully!"

def test_apply_style_endpoint_invalid_style():
with open("images/content/sample_content.jpg", "rb") as content:
response = client.post(
"/apply_style/",
files={"content": ("sample_content.jpg", content, "image/jpeg")},
data={"style_category": "invalid_style"}
)
assert response.status_code == 400
assert "error" in response.json()

def test_apply_style_endpoint_missing_content():
response = client.post(
"/apply_style/",
data={"style_category": "impressionism"}
)
assert response.status_code == 422
28 changes: 21 additions & 7 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,27 @@
from PIL import Image
from io import BytesIO

def test_style_transfer_model_load():
model = StyleTransferModel()
assert model.model is not None, "Model should be loaded successfully."

def test_image_preprocessor_preprocess():
processor = ImageProcessor()
@pytest.fixture
def mock_image():
image = Image.new("RGB", (800, 600), color="red")
image_data = BytesIO()
image.save(image_data, format="JPEG")
image_data.seek(0)
return image_data

def test_style_transfer_model_load():
model = StyleTransferModel()
assert model.model is not None, "Model should be loaded successfully."

def test_style_transfer_model_apply(mock_image):
model = StyleTransferModel()
content_image = Image.open(mock_image)
styled_image = model.apply_style(content_image, content_image)
assert styled_image is not None, "Styled image should be generated successfully."

processed_image = processor.preprocess_image(image_data)
def test_image_preprocessor_preprocess(mock_image):
processor = ImageProcessor()
processed_image = processor.preprocess_image(mock_image)
assert processed_image.size == (512, 512), "Image should be resized to 512x512."

def test_image_preprocessor_save(tmp_path):
Expand All @@ -24,3 +33,8 @@ def test_image_preprocessor_save(tmp_path):
save_path = tmp_path / "test_image.jpg"
processor.save_image(image, save_path)
assert save_path.exists(), "Image should be saved successfully."

def test_image_preprocessor_invalid_input():
processor = ImageProcessor()
with pytest.raises(ValueError):
processor.preprocess_image(None)
37 changes: 34 additions & 3 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import subprocess
import os

def test_cli_handler():
def test_cli_handler_success():
command = [
"python",
"interface/CLIHandler.py",
Expand All @@ -12,5 +13,35 @@ def test_cli_handler():
"images/output/test_output.jpg"
]
result = subprocess.run(command, capture_output=True, text=True)
assert result.returncode == 0, "CLI should execute without errors."
assert "Styled image saved to:" in result.stdout, "CLI should output success message."
assert result.returncode == 0
assert "Styled image saved to:" in result.stdout

def test_cli_handler_invalid_content():
command = [
"python",
"interface/CLIHandler.py",
"--content",
"invalid_path.jpg",
"--style_category",
"impressionism",
"--output",
"images/output/test_output.jpg"
]
result = subprocess.run(command, capture_output=True, text=True)
assert result.returncode != 0
assert "Error" in result.stderr

def test_cli_handler_invalid_style():
command = [
"python",
"interface/CLIHandler.py",
"--content",
"images/content/sample_content.jpg",
"--style_category",
"invalid_style",
"--output",
"images/output/test_output.jpg"
]
result = subprocess.run(command, capture_output=True, text=True)
assert result.returncode != 0
assert "Error" in result.stderr
7 changes: 6 additions & 1 deletion tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ def test_style_registry_random_selection():
registry = StyleRegistry()
random_style = registry.get_random_style_image("impressionism")
assert random_style is not None, "Random style image should be selected."
assert "impressionism" in random_style, "Selected image should belong to the requested category."
assert "impressionism" in random_style

def test_style_registry_invalid_category():
registry = StyleRegistry()
with pytest.raises(ValueError):
registry.get_random_style_image("invalid_category")

def test_config_manager(tmp_path):
config_path = tmp_path / "config.json"
Expand Down
37 changes: 27 additions & 10 deletions train_models.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,44 @@
from PIL import Image
from core.StyleTransferModel import StyleTransferModel
from utilities.StyleRegistry import StyleRegistry
from utilities.Logger import Logger
import os
import logging

# Setup logger
logger = Logger.setup_logger(log_file="training.log", log_level=logging.INFO)

# Initialize the style transfer model and registry
model = StyleTransferModel()
registry = StyleRegistry()

# Define the content image for training
content_image = Image.open("images/content/sample_content.jpg")
content_image_path = "images/content/sample_content.jpg"
if not os.path.exists(content_image_path):
logger.error(f"Content image not found at {content_image_path}. Exiting.")
raise FileNotFoundError(f"Content image not found at {content_image_path}.")
content_image = Image.open(content_image_path)

# Ensure models directory exists
os.makedirs("models", exist_ok=True)
output_dir = "models"
os.makedirs(output_dir, exist_ok=True)

# Train a model for each style category
for category in registry.styles.keys():
# Get a random style image from the category
style_image_path = registry.get_random_style_image(category)
style_image = Image.open(style_image_path)
try:
# Get a random style image from the category
style_image_path = registry.get_random_style_image(category)
if not os.path.exists(style_image_path):
logger.warning(f"Style image not found: {style_image_path}. Skipping category {category}.")
continue

style_image = Image.open(style_image_path)

# Define the output path for the trained model
output_path = f"models/{category}_model.pth"
# Define the output path for the trained model
output_path = os.path.join(output_dir, f"{category}_model.pth")

print(f"Training model for {category} using style: {style_image_path}...")
model.train_model(content_image, style_image, output_path)
print(f"Model saved to {output_path}")
logger.info(f"Training model for '{category}' using style: {style_image_path}...")
model.train_model(content_image, style_image, output_path)
logger.info(f"Model saved to {output_path}")
except Exception as e:
logger.error(f"Failed to train model for category '{category}': {e}")
40 changes: 34 additions & 6 deletions utilities/ConfigManager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,40 @@
import json
import os


class ConfigManager:
@staticmethod
def load_config(config_path):
with open(config_path, "r") as f:
return json.load(f)

def load_config(config_path, default_config=None):
"""
Load configuration from a file. If the file does not exist and default_config is provided,
create the file with the default configuration.
:param config_path: Path to the configuration file.
:param default_config: A dictionary with default configuration values.
:return: Loaded configuration as a dictionary.
"""
if not os.path.exists(config_path):
if default_config is not None:
ConfigManager.save_config(default_config, config_path)
return default_config
raise FileNotFoundError(f"Config file '{config_path}' not found.")

try:
with open(config_path, "r") as f:
return json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in config file '{config_path}': {e}")

@staticmethod
def save_config(config, config_path):
with open(config_path, "w") as f:
json.dump(config, f, indent=4)
"""
Save configuration to a file.
:param config: Dictionary containing configuration values.
:param config_path: Path to save the configuration file.
"""
try:
with open(config_path, "w") as f:
json.dump(config, f, indent=4)
except Exception as e:
raise IOError(f"Failed to save config to '{config_path}': {e}")
20 changes: 16 additions & 4 deletions utilities/Logger.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import logging


class Logger:
@staticmethod
def setup_logger():
def setup_logger(log_file=None, log_level=logging.INFO):
"""
Set up the logger with console and optional file logging.
:param log_file: Path to the log file (optional).
:param log_level: Logging level (default: INFO).
:return: Configured logger instance.
"""
handlers = [logging.StreamHandler()]
if log_file:
handlers.append(logging.FileHandler(log_file))

logging.basicConfig(
level=logging.INFO,
level=log_level,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
handlers=handlers,
)
return logging.getLogger("Artify")
return logging.getLogger("Artify")
Loading

0 comments on commit c4d1304

Please sign in to comment.