Skip to content

Commit

Permalink
Updated core modules, made some unit tests for them.
Browse files Browse the repository at this point in the history
  • Loading branch information
Klus3kk committed Dec 1, 2024
1 parent 50c42b0 commit 730320d
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 4 deletions.
2 changes: 2 additions & 0 deletions core/ImageProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
class ImageProcessor:
@staticmethod
def preprocess_image(image_path, size=512):
'''Resize and normalize image for input.'''
image = Image.open(image_path).convert("RGB")
image = image.resize((size, size))
return image

@staticmethod
def save_image(image, output_path):
'''Save the processed image.'''
image.save(output_path)


13 changes: 11 additions & 2 deletions core/StyleTransferModel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import torch
import torchvision.transforms as transforms
from PIL import Image
from torchvision.models import vgg19


class StyleTransferModel:
def __init__(self, model_path):
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = torch.load(model_path).to(self.device)
self.model = self._load_pretrained_model()

def _load_pretrained_model(self):
'''Using VGG-19 for feature extraction'''
model = vgg19(pretrained=True).features
for param in model.parameters():
param.requires_grad = False
return model.to(self.device)

def apply_style(self, content_image, style_image):
'''Style transfer logic'''
Expand Down
Binary file not shown.
21 changes: 21 additions & 0 deletions tests/test_artify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from core.StyleTransferModel import StyleTransferModel
from core.ImageProcessor import ImageProcessor

if __name__ == "__main__":
content_image_path = ""
style_image_path = ""
output_image_path = ""

processor = ImageProcessor()
model = StyleTransferModel()

content_image = processor.preprocess_image(content_image_path)
style_image = processor.preprocess_image(style_image_path)

# Apply style (placeholder for now)
# styled_image = model.apply_style(content_image, style_image)

# Save the output (placeholder for now)
# processor.save_image(styled_image, output_image_path)

print("Pipeline tested successfully")
13 changes: 13 additions & 0 deletions tests/test_image_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os
from core.ImageProcessor import ImageProcessor

def test_preprocess_image():
processor = ImageProcessor()
processed_image = processor.preprocess_image("")
assert processed_image.size == (512, 512)

def test_save_image():
processor = ImageProcessor()
sample_image = processor.preprocess_image("")
processor.save_image(sample_image, "")
assert os.path.exists("")
2 changes: 0 additions & 2 deletions tests/test_sample.py

This file was deleted.

5 changes: 5 additions & 0 deletions tests/test_style_transfer_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from core.StyleTransferModel import StyleTransferModel

def test_model_initialization():
model = StyleTransferModel()
assert model.model is not None

0 comments on commit 730320d

Please sign in to comment.