Skip to content

Commit

Permalink
fix for inference results saving
Browse files Browse the repository at this point in the history
  • Loading branch information
KGallyamov committed Aug 19, 2024
1 parent a26c9d5 commit 8c676a9
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import logging
import pathlib

import pandas as pd
import torch
import cv2
import numpy as np
from torch.utils.data import random_split

from innofw.constants import Frameworks
Expand Down Expand Up @@ -79,11 +82,18 @@ def predict_dataloader(self):
return test_dataloader

def setup_infer(self):
self.predict_dataset = AnomaliesDataset(self.predict_source, self.aug)
self.predict_dataset = AnomaliesDataset(self.predict_source, self.get_aug(self.aug, 'test'))

def save_preds(self, preds, stage: Stages, dst_path: pathlib.Path):
dst_path = pathlib.Path(dst_path)
df = pd.DataFrame(list(preds), columns=["prediction"])
dst_filepath = dst_path / "prediction.csv"
df.to_csv(dst_filepath)
logging.info(f"Saved results to: {dst_filepath}")
out_file_path = dst_path / "results"
os.mkdir(out_file_path)
n = 0
for preds_batch in preds:
for i, pred in enumerate(preds_batch):
pred = pred.numpy() * 255 # shape - (1024, 1024)
if pred.dtype != np.uint8:
pred = pred.astype(np.uint8)
filename = out_file_path / f"out_{n}.png"
n += 1
cv2.imwrite(filename, pred)
logging.info(f"Saved result to: {out_file_path}")
3 changes: 1 addition & 2 deletions innofw/core/datasets/anomalies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,5 @@ def __getitem__(self, idx):
return self.augmentations(image) if self.augmentations is not None else image
mask = cv2.imread(str(self.labels[idx]), 0)
if self.augmentations is not None:
image, mask = self.augmentations(image=image, mask=mask)

image, mask = self.augmentations(image, mask)
return image, mask
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torchmetrics.classification import BinaryJaccardIndex, BinaryF1Score, BinaryPrecision, \
BinaryRecall
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
from lovely_numpy import lo

from innofw.core.models.torch.lightning_modules.base import BaseLightningModule

Expand Down Expand Up @@ -77,20 +78,23 @@ def training_step(self, x, batch_idx):

def validation_step(self, batch, batch_idx):
x, y = batch
y = y.bool()
x_rec = self.forward(x)
loss = self.loss_fn(x, x_rec)
mask = self.compute_anomaly_mask(x)
metrics = self.compute_metrics('val', mask, y)
self.log_metrics('val', metrics)
print(mask.float().mean(), y.float().mean())
self.log("val_loss", loss, on_step=False, on_epoch=True)
return {"loss": loss}

def test_step(self, x, batch_idx):
def test_step(self, batch, batch_idx):
x, y = batch
x_rec = self.forward(x)
loss = self.loss_fn(x, x_rec)
mask = self.compute_anomaly_mask(x)
metrics = self.compute_metrics('val', mask, y)
self.log_metrics('val', metrics)
metrics = self.compute_metrics('test', mask, y)
self.log_metrics('test', metrics)
self.log("test_loss", loss, on_step=False, on_epoch=True)
return {"loss": loss}

Expand Down
2 changes: 1 addition & 1 deletion innofw/core/models/torch_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def predict(self, datamodule, ckpt_path=None):
def train(self, data_module, ckpt_path=None):
self.trainer.fit(self.pl_module, data_module, ckpt_path=ckpt_path)

def test(self, data_module):
def test(self, data_module, ckpt_path=None):
outputs = self.trainer.test(self.pl_module, data_module)
return outputs

Expand Down

0 comments on commit 8c676a9

Please sign in to comment.