-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added: Updated tests, Cloud and Docker implementation.
- Loading branch information
Showing
11 changed files
with
223 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Base image | ||
FROM python:3.10-slim | ||
|
||
# Set environment variables | ||
ENV PYTHONUNBUFFERED=1 | ||
|
||
# Set working directory | ||
WORKDIR /app | ||
|
||
# Copy requirements and install dependencies | ||
COPY requirements.txt /app/requirements.txt | ||
RUN pip install --no-cache-dir -r requirements.txt | ||
|
||
# Copy the app code | ||
COPY . /app | ||
|
||
# Expose FastAPI port | ||
EXPOSE 8000 | ||
|
||
# Command to run FastAPI | ||
CMD ["uvicorn", "api.FastAPIHandler:app", "--host", "0.0.0.0", "--port", "8000"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
version: "3.9" | ||
services: | ||
artify-api: | ||
build: . | ||
ports: | ||
- "8000:8000" | ||
volumes: | ||
- ./models:/app/models | ||
- ./images:/app/images |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,44 @@ | ||
from PIL import Image | ||
from core.StyleTransferModel import StyleTransferModel | ||
from utilities.StyleRegistry import StyleRegistry | ||
from utilities.Logger import Logger | ||
import os | ||
import logging | ||
|
||
# Setup logger | ||
logger = Logger.setup_logger(log_file="training.log", log_level=logging.INFO) | ||
|
||
# Initialize the style transfer model and registry | ||
model = StyleTransferModel() | ||
registry = StyleRegistry() | ||
|
||
# Define the content image for training | ||
content_image = Image.open("images/content/sample_content.jpg") | ||
content_image_path = "images/content/sample_content.jpg" | ||
if not os.path.exists(content_image_path): | ||
logger.error(f"Content image not found at {content_image_path}. Exiting.") | ||
raise FileNotFoundError(f"Content image not found at {content_image_path}.") | ||
content_image = Image.open(content_image_path) | ||
|
||
# Ensure models directory exists | ||
os.makedirs("models", exist_ok=True) | ||
output_dir = "models" | ||
os.makedirs(output_dir, exist_ok=True) | ||
|
||
# Train a model for each style category | ||
for category in registry.styles.keys(): | ||
# Get a random style image from the category | ||
style_image_path = registry.get_random_style_image(category) | ||
style_image = Image.open(style_image_path) | ||
try: | ||
# Get a random style image from the category | ||
style_image_path = registry.get_random_style_image(category) | ||
if not os.path.exists(style_image_path): | ||
logger.warning(f"Style image not found: {style_image_path}. Skipping category {category}.") | ||
continue | ||
|
||
style_image = Image.open(style_image_path) | ||
|
||
# Define the output path for the trained model | ||
output_path = f"models/{category}_model.pth" | ||
# Define the output path for the trained model | ||
output_path = os.path.join(output_dir, f"{category}_model.pth") | ||
|
||
print(f"Training model for {category} using style: {style_image_path}...") | ||
model.train_model(content_image, style_image, output_path) | ||
print(f"Model saved to {output_path}") | ||
logger.info(f"Training model for '{category}' using style: {style_image_path}...") | ||
model.train_model(content_image, style_image, output_path) | ||
logger.info(f"Model saved to {output_path}") | ||
except Exception as e: | ||
logger.error(f"Failed to train model for category '{category}': {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,40 @@ | ||
import json | ||
import os | ||
|
||
|
||
class ConfigManager: | ||
@staticmethod | ||
def load_config(config_path): | ||
with open(config_path, "r") as f: | ||
return json.load(f) | ||
|
||
def load_config(config_path, default_config=None): | ||
""" | ||
Load configuration from a file. If the file does not exist and default_config is provided, | ||
create the file with the default configuration. | ||
:param config_path: Path to the configuration file. | ||
:param default_config: A dictionary with default configuration values. | ||
:return: Loaded configuration as a dictionary. | ||
""" | ||
if not os.path.exists(config_path): | ||
if default_config is not None: | ||
ConfigManager.save_config(default_config, config_path) | ||
return default_config | ||
raise FileNotFoundError(f"Config file '{config_path}' not found.") | ||
|
||
try: | ||
with open(config_path, "r") as f: | ||
return json.load(f) | ||
except json.JSONDecodeError as e: | ||
raise ValueError(f"Invalid JSON in config file '{config_path}': {e}") | ||
|
||
@staticmethod | ||
def save_config(config, config_path): | ||
with open(config_path, "w") as f: | ||
json.dump(config, f, indent=4) | ||
""" | ||
Save configuration to a file. | ||
:param config: Dictionary containing configuration values. | ||
:param config_path: Path to save the configuration file. | ||
""" | ||
try: | ||
with open(config_path, "w") as f: | ||
json.dump(config, f, indent=4) | ||
except Exception as e: | ||
raise IOError(f"Failed to save config to '{config_path}': {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,23 @@ | ||
import logging | ||
|
||
|
||
class Logger: | ||
@staticmethod | ||
def setup_logger(): | ||
def setup_logger(log_file=None, log_level=logging.INFO): | ||
""" | ||
Set up the logger with console and optional file logging. | ||
:param log_file: Path to the log file (optional). | ||
:param log_level: Logging level (default: INFO). | ||
:return: Configured logger instance. | ||
""" | ||
handlers = [logging.StreamHandler()] | ||
if log_file: | ||
handlers.append(logging.FileHandler(log_file)) | ||
|
||
logging.basicConfig( | ||
level=logging.INFO, | ||
level=log_level, | ||
format="%(asctime)s - %(levelname)s - %(message)s", | ||
handlers=[logging.StreamHandler()], | ||
handlers=handlers, | ||
) | ||
return logging.getLogger("Artify") | ||
return logging.getLogger("Artify") |
Oops, something went wrong.