Skip to content

Commit

Permalink
visualize results
Browse files Browse the repository at this point in the history
  • Loading branch information
KGallyamov committed Aug 27, 2024
1 parent 8c676a9 commit b47da9c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,23 @@ def predict_dataloader(self):
def setup_infer(self):
self.predict_dataset = AnomaliesDataset(self.predict_source, self.get_aug(self.aug, 'test'))

def save_preds(self, preds, stage: Stages, dst_path: pathlib.Path):
def save_preds(self, out_batches, stage: Stages, dst_path: pathlib.Path):
out_file_path = dst_path / "results"
os.mkdir(out_file_path)
n = 0
for preds_batch in preds:
for i, pred in enumerate(preds_batch):
for batch in out_batches:
for img, pred in zip(batch[0], batch[1]):
img = img.cpu().numpy()
pred = pred.numpy() * 255 # shape - (1024, 1024)
if pred.dtype != np.uint8:
pred = pred.astype(np.uint8)
filename = out_file_path / f"out_{n}.png"
n += 1
cv2.imwrite(filename, pred)
mask_vis = np.zeros_like(img)
mask_vis[1, :, :] = pred / 255
mask_vis = mask_vis
img_with_mask = (img * 255 * 0.75 + mask_vis * 255 * 0.25).astype(np.uint8).transpose((1, 2, 0))
img_with_mask = cv2.cvtColor(img_with_mask, cv2.COLOR_BGR2RGB)
cv2.imwrite(str(filename).replace('out_', 'vis_'), img_with_mask)
logging.info(f"Saved result to: {out_file_path}")
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_step(self, batch, batch_idx):
return {"loss": loss}

def predict_step(self, x, batch_idx, **kwargs):
return self.compute_anomaly_mask(x)
return (x, self.compute_anomaly_mask(x))

def compute_anomaly_mask(self, x):
x_rec = self.forward(x) # (B, C, W, H)
Expand Down

0 comments on commit b47da9c

Please sign in to comment.