Skip to content

Commit

Permalink
Implemented: Model training, TV Loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Klus3kk committed Dec 18, 2024
1 parent 389ec91 commit 22ad9f2
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 4 deletions.
11 changes: 8 additions & 3 deletions api/FastAPIHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@

@app.post("/apply_style/")
async def apply_style(content: UploadFile = File(...), style_category: str = "impressionism"):
# Dynamically load the model for the selected style
model_path = f"models/{style_category}_model.pth"
model.load_model_from_gcloud("your-bucket-name", model_path)

# Read content image
content_image = Image.open(BytesIO(await content.read())).convert("RGB")
style_image_path = registry.get_random_style_image(style_category)
style_image = Image.open(style_image_path).convert("RGB")

styled_image = model.apply_style(content_image, style_image)
# Apply the style
styled_image = model.apply_style(content_image, None)

# Save and return the styled image
output_path = os.path.join(OUTPUT_DIR, f"styled_{content.filename}")
processor.save_image(styled_image, output_path)

Expand Down
89 changes: 88 additions & 1 deletion core/StyleTransferModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,15 @@ def apply_style(self, content_image, style_image, iterations=300, style_weight=1

return self._tensor_to_image(target)

def _image_to_tensor(self, image):
def _image_to_tensor(self, image, target_size=(512, 512)):
"""
Convert a PIL image to a normalized tensor and resize it.
:param image: PIL Image object.
:param target_size: Target size to resize the image (default: 512x512).
:return: Normalized tensor.
"""
transform = transforms.Compose([
transforms.Resize(target_size), # Resize to target size
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
Expand Down Expand Up @@ -88,8 +95,88 @@ def _calculate_style_loss(self, style_features, target_features):
_, d, h, w = target_features[layer].size()
style_loss += torch.mean((target_gram - style_gram) ** 2) / (d * h * w)
return style_loss

def _calculate_tv_loss(self, tensor):
diff_x = torch.sum(torch.abs(tensor[:, :, :, :-1] - tensor[:, :, :, 1:]))
diff_y = torch.sum(torch.abs(tensor[:, :, :-1, :] - tensor[:, :, 1:, :]))
return diff_x + diff_y


def _gram_matrix(self, tensor):
_, d, h, w = tensor.size()
tensor = tensor.view(d, h * w)
return torch.mm(tensor, tensor.t())

def train_model(self, content_image, style_image, output_path, iterations=300, style_weight=1e6, content_weight=5):
"""
Train a style transfer model and save it to the specified path.
:param content_image: PIL Image object for content.
:param style_image: PIL Image object for style.
:param output_path: Path to save the trained model.
:param iterations: Number of training iterations.
:param style_weight: Weight for style loss.
:param content_weight: Weight for content loss.
"""
# Preprocess images
content_tensor = self._image_to_tensor(content_image).to(self.device)
style_tensor = self._image_to_tensor(style_image).to(self.device)

# Initialize target image (clone of content image)
target = content_tensor.clone().requires_grad_(True)

# Define optimizer
optimizer = optim.Adam([target], lr=0.005)

# Extract features
style_features = self._extract_features(style_tensor)
content_features = self._extract_features(content_tensor)

# Training loop
print(f"Starting training for {iterations} iterations...")
for i in range(iterations):
target_features = self._extract_features(target)
content_loss = self._calculate_content_loss(content_features, target_features)
style_loss = self._calculate_style_loss(style_features, target_features)
tv_loss = self._calculate_tv_loss(target)

total_loss = style_weight * style_loss + content_weight * content_loss + 1e-5 * tv_loss

optimizer.zero_grad()
total_loss.backward()
optimizer.step()

if (i + 1) % 10 == 0 or i == 0:
print(f"\nIteration {i + 1}/{iterations}: \n"
f"Content Loss = {content_loss.item():.4f} \n"
f"Style Loss = {style_loss.item():.4f} \n"
f"TV Loss = {tv_loss.item():.4f} \n"
f"Total Loss = {total_loss.item():.4f}")

torch.save(target, output_path)
print(f"Model saved to {output_path}")



def load_model(self, model_path):
"""
Load a pretrained model from the local filesystem.
:param model_path: Path to the model file.
"""
self.model = torch.load(model_path, map_location=self.device)
print(f"Model loaded from {model_path}")

def load_model_from_gcloud(self, bucket_name, model_name):
"""
Load a model dynamically from Google Cloud Storage.
:param bucket_name: Name of the GCS bucket.
:param model_name: Name of the model file in GCS.
"""
from google.cloud import storage
client = storage.Client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(model_name)

local_path = f"/tmp/{model_name}"
blob.download_to_filename(local_path)
self.load_model(local_path)
print(f"Model loaded from GCS bucket {bucket_name}, file {model_name}")
18 changes: 18 additions & 0 deletions deployment/upload_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from deployment.CloudUploader import CloudUploader
from utilities.StyleRegistry import StyleRegistry
import os

# Initialize StyleRegistry
registry = StyleRegistry()

# Google Cloud Storage bucket name
bucket_name = "your-bucket-name"

# Upload each model to the cloud
for category in registry.styles.keys():
model_path = f"models/{category}_model.pth"
remote_path = f"models/{os.path.basename(model_path)}"
print(f"Uploading {model_path} to {bucket_name}/{remote_path}...")
CloudUploader.upload_to_cloud(bucket_name, model_path, remote_path)

print("All models uploaded successfully!")
27 changes: 27 additions & 0 deletions train_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from PIL import Image
from core.StyleTransferModel import StyleTransferModel
from utilities.StyleRegistry import StyleRegistry
import os

# Initialize the style transfer model and registry
model = StyleTransferModel()
registry = StyleRegistry()

# Define the content image for training
content_image = Image.open("images/content/sample_content.jpg")

# Ensure models directory exists
os.makedirs("models", exist_ok=True)

# Train a model for each style category
for category in registry.styles.keys():
# Get a random style image from the category
style_image_path = registry.get_random_style_image(category)
style_image = Image.open(style_image_path)

# Define the output path for the trained model
output_path = f"models/{category}_model.pth"

print(f"Training model for {category} using style: {style_image_path}...")
model.train_model(content_image, style_image, output_path)
print(f"Model saved to {output_path}")

0 comments on commit 22ad9f2

Please sign in to comment.