Skip to content

Commit

Permalink
Implemented CLI and UI.
Browse files Browse the repository at this point in the history
  • Loading branch information
Klus3kk committed Dec 15, 2024
1 parent 730320d commit 95de7b6
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 17 deletions.
Empty file added api/FastAPIHandler.py
Empty file.
89 changes: 81 additions & 8 deletions core/StyleTransferModel.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,95 @@
import torch
import torch
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from torchvision.models import vgg19

from PIL import Image

class StyleTransferModel:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self._load_pretrained_model()

def _load_pretrained_model(self):
'''Using VGG-19 for feature extraction'''
# Load VGG-19 pretrained model for feature extraction
model = vgg19(pretrained=True).features
for param in model.parameters():
param.requires_grad = False
param.requires_grad = False
return model.to(self.device)

def apply_style(self, content_image, style_image):
'''Style transfer logic'''
...
def apply_style(self, content_image, style_image, iterations=300, style_weight=1e6, content_weight=1):
# Preprocess images
content_tensor = self._image_to_tensor(content_image).to(self.device)
style_tensor = self._image_to_tensor(style_image).to(self.device)

# Initialize target image (clone of content image)
target = content_tensor.clone().requires_grad_(True)

# Define optimizer
optimizer = optim.Adam([target], lr=0.003)

# Feature maps for content and style
style_features = self._extract_features(style_tensor)
content_features = self._extract_features(content_tensor)

# Compute style and content losses
for i in range(iterations):
target_features = self._extract_features(target)
content_loss = self._calculate_content_loss(content_features, target_features)
style_loss = self._calculate_style_loss(style_features, target_features)

total_loss = style_weight * style_loss + content_weight * content_loss

optimizer.zero_grad()
total_loss.backward()
optimizer.step()

return self._tensor_to_image(target)

def _image_to_tensor(self, image):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform(image).unsqueeze(0)

def _tensor_to_image(self, tensor):
unnormalize = transforms.Normalize(
mean=[-2.12, -2.04, -1.8],
std=[4.37, 4.46, 4.44],
)
tensor = unnormalize(tensor.squeeze(0))
return transforms.ToPILImage()(tensor)

def _extract_features(self, tensor):
layers = {
"0": "conv1_1",
"5": "conv2_1",
"10": "conv3_1",
"19": "conv4_1",
"21": "conv4_2", # Content representation
"28": "conv5_1",
}
features = {}
x = tensor
for name, layer in self.model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features

def _calculate_content_loss(self, content_features, target_features):
return torch.mean((target_features["conv4_2"] - content_features["conv4_2"]) ** 2)

def _calculate_style_loss(self, style_features, target_features):
style_loss = 0
for layer in style_features:
target_gram = self._gram_matrix(target_features[layer])
style_gram = self._gram_matrix(style_features[layer])
_, d, h, w = target_features[layer].size()
style_loss += torch.mean((target_gram - style_gram) ** 2) / (d * h * w)
return style_loss

def _gram_matrix(self, tensor):
_, d, h, w = tensor.size()
tensor = tensor.view(d, h * w)
return torch.mm(tensor, tensor.t())
25 changes: 25 additions & 0 deletions interface/CLIHandler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import argparse
from core.StyleTransferModel import StyleTransferModel
from core.ImageProcessor import ImageProcessor

def main():
parser = argparse.ArgumentParser(description="Artify: Apply artistic styles to images.")
parser.add_argument("--content", required=True,help="Path to content image")
parser.add_argument("--style", required=True, help="Path to style image")
parser.add_argument("--output", required=True, help="Path to save the styled image")

args = parser.parse_args()

processor = ImageProcessor()
model = StyleTransferModel()

content_image = processor.preprocess_image(args.content)
style_image = processor.preprocess_image(args.style)

styled_image = model.apply_style(content_image, style_image)
processor.save_image(styled_image, args.output)

print(f"Styled image saved to {args.output}")

if __name__ == "__main__":
main()
28 changes: 28 additions & 0 deletions interface/UIHandler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import streamlit as st
from core.StyleTransferModel import StyleTransferModel
from core.ImageProcessor import ImageProcessor
from PIL import Image

def main():
st.title("Artify: AI-Powered Image Style Transfer")
st.write("Upload your content and style images to generate a styled result!")

content_file = st.file_uploader("Upload Content Image", type=["jpg", "png"])
style_file = st.file_uploader("Upload Style Image", type=["jpg", "png"])

if content_file and style_file:
processor = ImageProcessor()
model = StyleTransferModel()

content_image = Image.open(content_file)
style_image = Image.open(style_file)

st.image(content_image, caption="Content Image", use_column_width=True)
st.image(style_image, caption="Style Image", use_column_width=True)

if st.button("Generate Styled Image"):
styled_image = model.apply_style(content_image, style_image)
st.image(styled_image, caption="Styled Image", use_column_width=True)

if __name__ == "__main__":
main()
15 changes: 6 additions & 9 deletions tests/test_artify.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,17 @@
from core.ImageProcessor import ImageProcessor

if __name__ == "__main__":
content_image_path = ""
style_image_path = ""
output_image_path = ""
content_image_path = "path/to/content.jpg"
style_image_path = "path/to/style.jpg"
output_image_path = "path/to/output.jpg"

processor = ImageProcessor()
model = StyleTransferModel()

content_image = processor.preprocess_image(content_image_path)
style_image = processor.preprocess_image(style_image_path)

# Apply style (placeholder for now)
# styled_image = model.apply_style(content_image, style_image)
styled_image = model.apply_style(content_image, style_image)
processor.save_image(styled_image, output_image_path)

# Save the output (placeholder for now)
# processor.save_image(styled_image, output_image_path)

print("Pipeline tested successfully")
print("Style transfer completed!")

0 comments on commit 95de7b6

Please sign in to comment.