From c4d1304f64791f6849937db1ba838e6fafe1aa85 Mon Sep 17 00:00:00 2001 From: Klus3kk Date: Wed, 18 Dec 2024 13:54:29 +0100 Subject: [PATCH] Added: Updated tests, Cloud and Docker implementation. --- Dockerfile | 21 +++++++++ deployment/upload_models.py | 2 +- docker-compose.yml | 9 ++++ tests/test_api.py | 24 ++++++++-- tests/test_core.py | 28 +++++++++--- tests/test_interface.py | 37 ++++++++++++++-- tests/test_utilities.py | 7 ++- train_models.py | 37 +++++++++++----- utilities/ConfigManager.py | 40 ++++++++++++++--- utilities/Logger.py | 20 +++++++-- utilities/StyleRegistry.py | 87 ++++++++++++++----------------------- 11 files changed, 223 insertions(+), 89 deletions(-) create mode 100644 docker-compose.yml diff --git a/Dockerfile b/Dockerfile index e69de29..f1e6882 100644 --- a/Dockerfile +++ b/Dockerfile @@ -0,0 +1,21 @@ +# Base image +FROM python:3.10-slim + +# Set environment variables +ENV PYTHONUNBUFFERED=1 + +# Set working directory +WORKDIR /app + +# Copy requirements and install dependencies +COPY requirements.txt /app/requirements.txt +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the app code +COPY . /app + +# Expose FastAPI port +EXPOSE 8000 + +# Command to run FastAPI +CMD ["uvicorn", "api.FastAPIHandler:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/deployment/upload_models.py b/deployment/upload_models.py index 50af617..a6d899d 100644 --- a/deployment/upload_models.py +++ b/deployment/upload_models.py @@ -10,7 +10,7 @@ registry = StyleRegistry() # Google Cloud Storage bucket name -bucket_name = os.getenv("GCS_BUCKET_NAME", "your-bucket-name") +bucket_name = os.getenv("GCS_BUCKET_NAME", "artify-models") # Upload each model to the cloud for category in registry.styles.keys(): diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..b5cf065 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,9 @@ +version: "3.9" +services: + artify-api: + build: . + ports: + - "8000:8000" + volumes: + - ./models:/app/models + - ./images:/app/images \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index e97852d..0cb69ab 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -8,12 +8,30 @@ def test_root_endpoint(): assert response.status_code == 200 assert response.json() == {"message": "Welcome to Artify! Use /apply_style to stylize your images."} -def test_apply_style_endpoint(): - with open("images/content/sample_content.jpg", "rb") as content, open("images/style/impressionism/sample_style.jpg", "rb") as style: +def test_apply_style_endpoint_success(): + with open("images/content/sample_content.jpg", "rb") as content: response = client.post( "/apply_style/", files={"content": ("sample_content.jpg", content, "image/jpeg")}, data={"style_category": "impressionism"} ) assert response.status_code == 200 - assert "output_path" in response.json(), "API should return the output path." + assert "output_path" in response.json() + assert response.json()["message"] == "Style applied successfully!" + +def test_apply_style_endpoint_invalid_style(): + with open("images/content/sample_content.jpg", "rb") as content: + response = client.post( + "/apply_style/", + files={"content": ("sample_content.jpg", content, "image/jpeg")}, + data={"style_category": "invalid_style"} + ) + assert response.status_code == 400 + assert "error" in response.json() + +def test_apply_style_endpoint_missing_content(): + response = client.post( + "/apply_style/", + data={"style_category": "impressionism"} + ) + assert response.status_code == 422 diff --git a/tests/test_core.py b/tests/test_core.py index e792264..88813b3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -4,18 +4,27 @@ from PIL import Image from io import BytesIO -def test_style_transfer_model_load(): - model = StyleTransferModel() - assert model.model is not None, "Model should be loaded successfully." - -def test_image_preprocessor_preprocess(): - processor = ImageProcessor() +@pytest.fixture +def mock_image(): image = Image.new("RGB", (800, 600), color="red") image_data = BytesIO() image.save(image_data, format="JPEG") image_data.seek(0) + return image_data + +def test_style_transfer_model_load(): + model = StyleTransferModel() + assert model.model is not None, "Model should be loaded successfully." + +def test_style_transfer_model_apply(mock_image): + model = StyleTransferModel() + content_image = Image.open(mock_image) + styled_image = model.apply_style(content_image, content_image) + assert styled_image is not None, "Styled image should be generated successfully." - processed_image = processor.preprocess_image(image_data) +def test_image_preprocessor_preprocess(mock_image): + processor = ImageProcessor() + processed_image = processor.preprocess_image(mock_image) assert processed_image.size == (512, 512), "Image should be resized to 512x512." def test_image_preprocessor_save(tmp_path): @@ -24,3 +33,8 @@ def test_image_preprocessor_save(tmp_path): save_path = tmp_path / "test_image.jpg" processor.save_image(image, save_path) assert save_path.exists(), "Image should be saved successfully." + +def test_image_preprocessor_invalid_input(): + processor = ImageProcessor() + with pytest.raises(ValueError): + processor.preprocess_image(None) diff --git a/tests/test_interface.py b/tests/test_interface.py index 042f9b2..3d7b875 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -1,6 +1,7 @@ import subprocess +import os -def test_cli_handler(): +def test_cli_handler_success(): command = [ "python", "interface/CLIHandler.py", @@ -12,5 +13,35 @@ def test_cli_handler(): "images/output/test_output.jpg" ] result = subprocess.run(command, capture_output=True, text=True) - assert result.returncode == 0, "CLI should execute without errors." - assert "Styled image saved to:" in result.stdout, "CLI should output success message." + assert result.returncode == 0 + assert "Styled image saved to:" in result.stdout + +def test_cli_handler_invalid_content(): + command = [ + "python", + "interface/CLIHandler.py", + "--content", + "invalid_path.jpg", + "--style_category", + "impressionism", + "--output", + "images/output/test_output.jpg" + ] + result = subprocess.run(command, capture_output=True, text=True) + assert result.returncode != 0 + assert "Error" in result.stderr + +def test_cli_handler_invalid_style(): + command = [ + "python", + "interface/CLIHandler.py", + "--content", + "images/content/sample_content.jpg", + "--style_category", + "invalid_style", + "--output", + "images/output/test_output.jpg" + ] + result = subprocess.run(command, capture_output=True, text=True) + assert result.returncode != 0 + assert "Error" in result.stderr diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 6d21eeb..bd2822e 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -8,7 +8,12 @@ def test_style_registry_random_selection(): registry = StyleRegistry() random_style = registry.get_random_style_image("impressionism") assert random_style is not None, "Random style image should be selected." - assert "impressionism" in random_style, "Selected image should belong to the requested category." + assert "impressionism" in random_style + +def test_style_registry_invalid_category(): + registry = StyleRegistry() + with pytest.raises(ValueError): + registry.get_random_style_image("invalid_category") def test_config_manager(tmp_path): config_path = tmp_path / "config.json" diff --git a/train_models.py b/train_models.py index 27ead23..eac5ea7 100644 --- a/train_models.py +++ b/train_models.py @@ -1,27 +1,44 @@ from PIL import Image from core.StyleTransferModel import StyleTransferModel from utilities.StyleRegistry import StyleRegistry +from utilities.Logger import Logger import os +import logging + +# Setup logger +logger = Logger.setup_logger(log_file="training.log", log_level=logging.INFO) # Initialize the style transfer model and registry model = StyleTransferModel() registry = StyleRegistry() # Define the content image for training -content_image = Image.open("images/content/sample_content.jpg") +content_image_path = "images/content/sample_content.jpg" +if not os.path.exists(content_image_path): + logger.error(f"Content image not found at {content_image_path}. Exiting.") + raise FileNotFoundError(f"Content image not found at {content_image_path}.") +content_image = Image.open(content_image_path) # Ensure models directory exists -os.makedirs("models", exist_ok=True) +output_dir = "models" +os.makedirs(output_dir, exist_ok=True) # Train a model for each style category for category in registry.styles.keys(): - # Get a random style image from the category - style_image_path = registry.get_random_style_image(category) - style_image = Image.open(style_image_path) + try: + # Get a random style image from the category + style_image_path = registry.get_random_style_image(category) + if not os.path.exists(style_image_path): + logger.warning(f"Style image not found: {style_image_path}. Skipping category {category}.") + continue + + style_image = Image.open(style_image_path) - # Define the output path for the trained model - output_path = f"models/{category}_model.pth" + # Define the output path for the trained model + output_path = os.path.join(output_dir, f"{category}_model.pth") - print(f"Training model for {category} using style: {style_image_path}...") - model.train_model(content_image, style_image, output_path) - print(f"Model saved to {output_path}") + logger.info(f"Training model for '{category}' using style: {style_image_path}...") + model.train_model(content_image, style_image, output_path) + logger.info(f"Model saved to {output_path}") + except Exception as e: + logger.error(f"Failed to train model for category '{category}': {e}") diff --git a/utilities/ConfigManager.py b/utilities/ConfigManager.py index 2c11635..ff8d360 100644 --- a/utilities/ConfigManager.py +++ b/utilities/ConfigManager.py @@ -1,12 +1,40 @@ import json +import os + class ConfigManager: @staticmethod - def load_config(config_path): - with open(config_path, "r") as f: - return json.load(f) - + def load_config(config_path, default_config=None): + """ + Load configuration from a file. If the file does not exist and default_config is provided, + create the file with the default configuration. + + :param config_path: Path to the configuration file. + :param default_config: A dictionary with default configuration values. + :return: Loaded configuration as a dictionary. + """ + if not os.path.exists(config_path): + if default_config is not None: + ConfigManager.save_config(default_config, config_path) + return default_config + raise FileNotFoundError(f"Config file '{config_path}' not found.") + + try: + with open(config_path, "r") as f: + return json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in config file '{config_path}': {e}") + @staticmethod def save_config(config, config_path): - with open(config_path, "w") as f: - json.dump(config, f, indent=4) \ No newline at end of file + """ + Save configuration to a file. + + :param config: Dictionary containing configuration values. + :param config_path: Path to save the configuration file. + """ + try: + with open(config_path, "w") as f: + json.dump(config, f, indent=4) + except Exception as e: + raise IOError(f"Failed to save config to '{config_path}': {e}") diff --git a/utilities/Logger.py b/utilities/Logger.py index e650d30..469d1a5 100644 --- a/utilities/Logger.py +++ b/utilities/Logger.py @@ -1,11 +1,23 @@ import logging + class Logger: @staticmethod - def setup_logger(): + def setup_logger(log_file=None, log_level=logging.INFO): + """ + Set up the logger with console and optional file logging. + + :param log_file: Path to the log file (optional). + :param log_level: Logging level (default: INFO). + :return: Configured logger instance. + """ + handlers = [logging.StreamHandler()] + if log_file: + handlers.append(logging.FileHandler(log_file)) + logging.basicConfig( - level=logging.INFO, + level=log_level, format="%(asctime)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler()], + handlers=handlers, ) - return logging.getLogger("Artify") \ No newline at end of file + return logging.getLogger("Artify") diff --git a/utilities/StyleRegistry.py b/utilities/StyleRegistry.py index 34fc036..4ffe968 100644 --- a/utilities/StyleRegistry.py +++ b/utilities/StyleRegistry.py @@ -1,62 +1,41 @@ import random +import os + class StyleRegistry: - def __init__(self): - self.styles = { - "abstract": [ - "images/style/abstract/Frankenthaler_Helen_Mountains_and_Sea.jpg", - "images/style/abstract/Jackson_Pollock_Autumn_Rhythm.jpg", - "images/style/abstract/Jackson_Pollock_Full-Fathom-Five.jpg", - "images/style/abstract/Joan_Miro_Triptych_Bleu.jpg", - "images/style/abstract/Paul_Klee_Senecio.jpg", - "images/style/abstract/Piet_Mondrian_Broadway_Boogie_Woogie.jpg", - "images/style/abstract/Theo_van_Doesburg_Peinture_Pure.jpg", - "images/style/abstract/Vassily_Kandinsky_Composition_VIII.jpg" - ], - "baroque": [ - "images/style/baroque/Annibale_Carracci_Lamentation_of_Christ.jpg", - "images/style/baroque/Caravaggio_Death_of_the_Virgin.jpg", - "images/style/baroque/Caravaggio_Judith_Beheading_Holofernes.jpg", - "images/style/baroque/Caravaggio_The_Calling_of_Saint_Matthew.jpg", - "images/style/baroque/Diego_Velázquez_Las_Meninas.jpg", - "images/style/baroque/Johannes_Vermeer_Girl_with_a_Pearl_Earring.jpg", - "images/style/baroque/Rembrandt_The_Anatomy_Lesson_of_Dr_Nicolaes_Tulp.jpg", - "images/style/baroque/Rembrandt_van_Rijn_The_Night_Watch.jpg" - ], - "cubism": [ - "images/style/cubism/Georges_Braque_Natura_morta_con clarinetto.jpg", - "images/style/cubism/Georges_Braque_Violin_and_Candlestick.jpg", - "images/style/cubism/Juan_Gris_Guitar_on_the_table.jpg", - "images/style/cubism/Juan_Gris_Portrait_of_Pablo_Picasso.jpg", - "images/style/cubism/Marc_Chagall_I_am_The_Village.jpg", - "images/style/cubism/Pablo_Picasso_Guernica.jpg", - "images/style/cubism/Pablo-Picasso-Panny-Z-Awinionu.jpg", - "images/style/cubism/Paul_Cézanne_Les_Grandes_Baigneuses.jpg" - ], - "expressionism": [ - "images/style/expressionism/David_Alfaro_Siqueiros_Birth_of_Fascism.jpg", - "images/style/expressionism/Edvard_Munch_The_Scream.jpg", - "images/style/expressionism/Emil_Nolde_Dance_Around_the_Golden Calf.jpg", - "images/style/expressionism/Ernst_Ludwig_Kirchner_Street_Berlin.jpg", - "images/style/expressionism/Marc_Franz_Blaues_Pferdchen_Saarlandmuseum.jpg", - "images/style/expressionism/Max_Beckmann_The_Night.jpg", - "images/style/expressionism/Pablo_Picasso_Stary_Gitarzysta.jpg", - "images/style/expressionism/Wassily_Kandinsky_The_Blue_Rider.jpg" - ], - "impressionism": [ - "images/style/impressionism/Auguste_Renoir_Dance_at_Le_Moulin_de_la_Galette.jpg", - "images/style/impressionism/Claude_Monet_Impression_Sunrise.jpg", - "images/style/impressionism/Claude_Monet_Lilies.jpg", - "images/style/impressionism/Claude_Monet_Woman_with_a_Parasol.jpg", - "images/style/impressionism/Edgar_Degas_The_Dance_Class.jpg", - "images/style/impressionism/Edouard_Manet_Woman_Reading.jpg", - "images/style/impressionism/Gustave_Caillebotte_LHomme_au_balcon_boulevard.jpg", - "images/style/impressionism/Gustave_Caillebotte_Paris_Street_Rainy_Day.jpg" - ] - } + def __init__(self, styles_dir="images/style"): + """ + Initialize the StyleRegistry and dynamically load styles from the directory. + + :param styles_dir: Path to the directory containing style categories and images. + """ + self.styles_dir = styles_dir + self.styles = self._load_styles() + + def _load_styles(self): + """ + Load style categories and image paths dynamically from the directory. + + :return: A dictionary with style categories as keys and lists of image paths as values. + """ + styles = {} + for category in os.listdir(self.styles_dir): + category_path = os.path.join(self.styles_dir, category) + if os.path.isdir(category_path): + styles[category] = [ + os.path.join(category_path, file) + for file in os.listdir(category_path) + if file.endswith(('.jpg', '.png')) + ] + return styles def get_random_style_image(self, category): + """ + Get a random image path from the specified style category. + + :param category: Style category name. + :return: Path to a random image in the category. + """ if category in self.styles: return random.choice(self.styles[category]) raise ValueError(f"Style category '{category}' not found!") -