diff --git a/api/FastAPIHandler.py b/api/FastAPIHandler.py new file mode 100644 index 0000000..e69de29 diff --git a/core/StyleTransferModel.py b/core/StyleTransferModel.py index 0727de6..168868d 100644 --- a/core/StyleTransferModel.py +++ b/core/StyleTransferModel.py @@ -1,8 +1,8 @@ -import torch +import torch +import torch.optim as optim import torchvision.transforms as transforms -from PIL import Image from torchvision.models import vgg19 - +from PIL import Image class StyleTransferModel: def __init__(self): @@ -10,13 +10,86 @@ def __init__(self): self.model = self._load_pretrained_model() def _load_pretrained_model(self): - '''Using VGG-19 for feature extraction''' + # Load VGG-19 pretrained model for feature extraction model = vgg19(pretrained=True).features for param in model.parameters(): - param.requires_grad = False + param.requires_grad = False return model.to(self.device) - def apply_style(self, content_image, style_image): - '''Style transfer logic''' - ... + def apply_style(self, content_image, style_image, iterations=300, style_weight=1e6, content_weight=1): + # Preprocess images + content_tensor = self._image_to_tensor(content_image).to(self.device) + style_tensor = self._image_to_tensor(style_image).to(self.device) + + # Initialize target image (clone of content image) + target = content_tensor.clone().requires_grad_(True) + + # Define optimizer + optimizer = optim.Adam([target], lr=0.003) + + # Feature maps for content and style + style_features = self._extract_features(style_tensor) + content_features = self._extract_features(content_tensor) + + # Compute style and content losses + for i in range(iterations): + target_features = self._extract_features(target) + content_loss = self._calculate_content_loss(content_features, target_features) + style_loss = self._calculate_style_loss(style_features, target_features) + + total_loss = style_weight * style_loss + content_weight * content_loss + + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + return self._tensor_to_image(target) + + def _image_to_tensor(self, image): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + return transform(image).unsqueeze(0) + + def _tensor_to_image(self, tensor): + unnormalize = transforms.Normalize( + mean=[-2.12, -2.04, -1.8], + std=[4.37, 4.46, 4.44], + ) + tensor = unnormalize(tensor.squeeze(0)) + return transforms.ToPILImage()(tensor) + + def _extract_features(self, tensor): + layers = { + "0": "conv1_1", + "5": "conv2_1", + "10": "conv3_1", + "19": "conv4_1", + "21": "conv4_2", # Content representation + "28": "conv5_1", + } + features = {} + x = tensor + for name, layer in self.model._modules.items(): + x = layer(x) + if name in layers: + features[layers[name]] = x + return features + + def _calculate_content_loss(self, content_features, target_features): + return torch.mean((target_features["conv4_2"] - content_features["conv4_2"]) ** 2) + + def _calculate_style_loss(self, style_features, target_features): + style_loss = 0 + for layer in style_features: + target_gram = self._gram_matrix(target_features[layer]) + style_gram = self._gram_matrix(style_features[layer]) + _, d, h, w = target_features[layer].size() + style_loss += torch.mean((target_gram - style_gram) ** 2) / (d * h * w) + return style_loss + def _gram_matrix(self, tensor): + _, d, h, w = tensor.size() + tensor = tensor.view(d, h * w) + return torch.mm(tensor, tensor.t()) diff --git a/interface/CLIHandler.py b/interface/CLIHandler.py new file mode 100644 index 0000000..d944abe --- /dev/null +++ b/interface/CLIHandler.py @@ -0,0 +1,25 @@ +import argparse +from core.StyleTransferModel import StyleTransferModel +from core.ImageProcessor import ImageProcessor + +def main(): + parser = argparse.ArgumentParser(description="Artify: Apply artistic styles to images.") + parser.add_argument("--content", required=True,help="Path to content image") + parser.add_argument("--style", required=True, help="Path to style image") + parser.add_argument("--output", required=True, help="Path to save the styled image") + + args = parser.parse_args() + + processor = ImageProcessor() + model = StyleTransferModel() + + content_image = processor.preprocess_image(args.content) + style_image = processor.preprocess_image(args.style) + + styled_image = model.apply_style(content_image, style_image) + processor.save_image(styled_image, args.output) + + print(f"Styled image saved to {args.output}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/interface/UIHandler.py b/interface/UIHandler.py new file mode 100644 index 0000000..116d262 --- /dev/null +++ b/interface/UIHandler.py @@ -0,0 +1,28 @@ +import streamlit as st +from core.StyleTransferModel import StyleTransferModel +from core.ImageProcessor import ImageProcessor +from PIL import Image + +def main(): + st.title("Artify: AI-Powered Image Style Transfer") + st.write("Upload your content and style images to generate a styled result!") + + content_file = st.file_uploader("Upload Content Image", type=["jpg", "png"]) + style_file = st.file_uploader("Upload Style Image", type=["jpg", "png"]) + + if content_file and style_file: + processor = ImageProcessor() + model = StyleTransferModel() + + content_image = Image.open(content_file) + style_image = Image.open(style_file) + + st.image(content_image, caption="Content Image", use_column_width=True) + st.image(style_image, caption="Style Image", use_column_width=True) + + if st.button("Generate Styled Image"): + styled_image = model.apply_style(content_image, style_image) + st.image(styled_image, caption="Styled Image", use_column_width=True) + +if __name__ == "__main__": + main() diff --git a/tests/test_artify.py b/tests/test_artify.py index 8d8c8f2..a431030 100644 --- a/tests/test_artify.py +++ b/tests/test_artify.py @@ -2,9 +2,9 @@ from core.ImageProcessor import ImageProcessor if __name__ == "__main__": - content_image_path = "" - style_image_path = "" - output_image_path = "" + content_image_path = "path/to/content.jpg" + style_image_path = "path/to/style.jpg" + output_image_path = "path/to/output.jpg" processor = ImageProcessor() model = StyleTransferModel() @@ -12,10 +12,7 @@ content_image = processor.preprocess_image(content_image_path) style_image = processor.preprocess_image(style_image_path) - # Apply style (placeholder for now) - # styled_image = model.apply_style(content_image, style_image) + styled_image = model.apply_style(content_image, style_image) + processor.save_image(styled_image, output_image_path) - # Save the output (placeholder for now) - # processor.save_image(styled_image, output_image_path) - - print("Pipeline tested successfully") + print("Style transfer completed!")