diff --git a/docs/source/images/supersimplenet/architecture.png b/docs/source/images/supersimplenet/architecture.png new file mode 100644 index 0000000000..f8822ba753 Binary files /dev/null and b/docs/source/images/supersimplenet/architecture.png differ diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index 26f8695ab6..6eec0032bb 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -29,6 +29,7 @@ ReverseDistillation, Rkde, Stfpm, + SuperSimpleNet, Uflow, VlmAd, WinClip, @@ -57,6 +58,7 @@ class UnknownModelError(ModuleNotFoundError): "ReverseDistillation", "Rkde", "Stfpm", + "SuperSimpleNet", "Uflow", "VlmAd", "WinClip", diff --git a/src/anomalib/models/components/feature_extractors/torchfx.py b/src/anomalib/models/components/feature_extractors/torchfx.py index 600f2a961d..485441fa50 100644 --- a/src/anomalib/models/components/feature_extractors/torchfx.py +++ b/src/anomalib/models/components/feature_extractors/torchfx.py @@ -166,7 +166,11 @@ class can be provided and it will try to load the weights from the provided weig backbone_class = backbone.class_path backbone_model = backbone_class(**backbone.init_args) - if isinstance(weights, WeightsEnum): # torchvision models + if isinstance(weights, WeightsEnum) or weights in { + "IMAGENET1K_V1", + "IMAGENET1K_V2", + "DEFAULT", + }: # torchvision models feature_extractor = create_feature_extractor(model=backbone_model, return_nodes=return_nodes) elif weights is not None: if not isinstance(weights, str): diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index b09da8b07b..f161bbd4ab 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -19,6 +19,7 @@ from .reverse_distillation import ReverseDistillation from .rkde import Rkde from .stfpm import Stfpm +from .supersimplenet import SuperSimpleNet from .uflow import Uflow from .vlm_ad import VlmAd from .winclip import WinClip @@ -40,6 +41,7 @@ "ReverseDistillation", "Rkde", "Stfpm", + "SuperSimpleNet", "Uflow", "VlmAd", "WinClip", diff --git a/src/anomalib/models/image/supersimplenet/LICENSE b/src/anomalib/models/image/supersimplenet/LICENSE new file mode 100644 index 0000000000..7578a7064b --- /dev/null +++ b/src/anomalib/models/image/supersimplenet/LICENSE @@ -0,0 +1,29 @@ +Copyright (c) 2024 Intel Corporation +SPDX-License-Identifier: Apache-2.0 + +Some files in this folder are based on the original SuperSimpleNet implementation by Blaž Rolih + +Original license: +----------------- + + MIT License + + Copyright (c) 2024 Blaž Rolih + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/src/anomalib/models/image/supersimplenet/README.md b/src/anomalib/models/image/supersimplenet/README.md new file mode 100644 index 0000000000..fb9ab776c8 --- /dev/null +++ b/src/anomalib/models/image/supersimplenet/README.md @@ -0,0 +1,57 @@ +# SuperSimpleNet: Unifying Unsupervised and Supervised Learning for Fast and Reliable Surface Defect Detection + +This is an implementation of the [SuperSimpleNet](https://arxiv.org/pdf/2408.03143) paper, based on the [official code](https://github.com/blaz-r/SuperSimpleNet). + +Model Type: Segmentation + +## Description + +**SuperSimpleNet** is a simple yet strong discriminative defect / anomaly detection model evolved from the SimpleNet architecture. It consists of four components: +feature extractor with upscaling, feature adaptor, synthetic feature-level anomaly generation module, and +segmentation-detection module. + +A ResNet-like feature extractor first extracts features, which are then upscaled and +average-pooled to capture neighboring context. Features are further refined for anomaly detection task in the adaptor module. +During training, synthetic anomalies are generated at the feature level by adding Gaussian noise to regions defined by the +binary Perlin noise mask. The perturbed features are then fed into the segmentation-detection +module, which produces the anomaly map and the anomaly score. During inference, anomaly generation is skipped, and the model +directly predicts the anomaly map and score. The predicted anomaly map is upscaled to match the input image size +and refined with a Gaussian filter. + +This implementation supports both unsupervised and supervised setting, but Anomalib currently supports only unsupervised learning. + +## Architecture + +![SuperSimpleNet architecture](/docs/source/images/supersimplenet/architecture.png "SuperSimpleNet architecture") + +## Usage + +`anomalib train --model SuperSimpleNet --data MVTec --data.category ` + +> It is recommended to train the model for 300 epochs with batch size of 32 to achieve stable training with random anomaly generation. Training with lower parameter values will still work, but might not yield the optimal results. +> +> For supervised learning, refer to the [official code](https://github.com/blaz-r/SuperSimpleNet). + +## MVTec AD results + +The following results were obtained using this Anomalib implementation trained for 300 epochs with seed 42, default params, and batch size 32. +| | **Image AUROC** | **Pixel AUPRO** | +| ----------- | :-------------: | :-------------: | +| Bottle | 1.000 | 0.914 | +| Cable | 0.981 | 0.895 | +| Capsule | 0.990 | 0.926 | +| Carpet | 0.987 | 0.936 | +| Grid | 0.998 | 0.935 | +| Hazelnut | 0.999 | 0.946 | +| Leather | 1.000 | 0.972 | +| Metal_nut | 0.996 | 0.923 | +| Pill | 0.960 | 0.942 | +| Screw | 0.903 | 0.952 | +| Tile | 0.989 | 0.817 | +| Toothbrush | 0.917 | 0.861 | +| Transistor | 1.000 | 0.909 | +| Wood | 0.996 | 0.868 | +| Zipper | 0.996 | 0.944 | +| **Average** | 0.981 | 0.916 | + +For other results on VisA, SensumSODF, and KSDD2, refer to the [paper](https://arxiv.org/pdf/2408.03143). diff --git a/src/anomalib/models/image/supersimplenet/__init__.py b/src/anomalib/models/image/supersimplenet/__init__.py new file mode 100644 index 0000000000..2f1e7d990b --- /dev/null +++ b/src/anomalib/models/image/supersimplenet/__init__.py @@ -0,0 +1,8 @@ +"""SuperSimpleNet model.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import SuperSimpleNet + +__all__ = ["SuperSimpleNet"] diff --git a/src/anomalib/models/image/supersimplenet/anomaly_generator.py b/src/anomalib/models/image/supersimplenet/anomaly_generator.py new file mode 100644 index 0000000000..0ed8ed3992 --- /dev/null +++ b/src/anomalib/models/image/supersimplenet/anomaly_generator.py @@ -0,0 +1,163 @@ +"""Anomaly generator for the SuperSimplenet model implementation.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import nn + +from anomalib.data.utils.generators.perlin import _rand_perlin_2d + + +class SSNAnomalyGenerator(nn.Module): + """Anomaly generator of the SuperSimpleNet model.""" + + def __init__( + self, + noise_mean: float, + noise_std: float, + threshold: float, + perlin_range: tuple[int, int] = (0, 6), + ) -> None: + super().__init__() + + self.noise_mean = noise_mean + self.noise_std = noise_std + + self.threshold = threshold + + self.min_perlin_scale = perlin_range[0] + self.max_perlin_scale = perlin_range[1] + + @staticmethod + def next_power_2(num: int) -> int: + """Get the next power of 2 for given number. + + Args: + num (int): value of interest + + Returns: + next power of 2 value for given number + """ + return 1 << (num - 1).bit_length() + + def generate_perlin(self, batches: int, height: int, width: int) -> torch.Tensor: + """Generate 2d perlin noise masks with dims [b, 1, h, w]. + + Args: + batches (int): number of batches (different masks) + height (int): height of features + width (int): width of features + + Returns: + tensor with b perlin binarized masks + """ + perlin = [] + for _ in range(batches): + # get scale of perlin in x and y direction + perlin_scalex = 2 ** ( + torch.randint( + self.min_perlin_scale, + self.max_perlin_scale, + (1,), + ).item() + ) + perlin_scaley = 2 ** ( + torch.randint( + self.min_perlin_scale, + self.max_perlin_scale, + (1,), + ).item() + ) + + perlin_height = self.next_power_2(height) + perlin_width = self.next_power_2(width) + + perlin_noise = _rand_perlin_2d( + (perlin_height, perlin_width), + (perlin_scalex, perlin_scaley), + ) + # original is power of 2 scale, so fit to our size + perlin_noise = F.interpolate( + perlin_noise.reshape(1, 1, perlin_height, perlin_width), + size=(height, width), + mode="bilinear", + ) + # binarize + perlin_thr = torch.where(perlin_noise > self.threshold, 1, 0) + + # 50% of anomaly + if torch.rand(1).item() > 0.5: + perlin_thr = torch.zeros_like(perlin_thr) + + perlin.append(perlin_thr) + return torch.cat(perlin) + + def forward( + self, + features: torch.Tensor, + mask: torch.Tensor, + labels: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate anomaly on features using thresholded perlin noise and Gaussian noise. + + Also update GT masks and labels with new anomaly information. + + Args: + features: input features. + mask: GT masks. + labels: GT labels. + + Returns: + perturbed features, updated GT masks and labels. + """ + b, _, h, w = features.shape + + # duplicate + features = torch.cat((features, features)) + mask = torch.cat((mask, mask)) + labels = torch.cat((labels, labels)) + + noise = torch.normal( + mean=self.noise_mean, + std=self.noise_std, + size=features.shape, + device=features.device, + requires_grad=False, + ) + + # mask indicating which regions will have noise applied + # [B * 2, 1, H, W] initial all masked as anomalous + noise_mask = torch.ones( + b * 2, + 1, + h, + w, + device=features.device, + requires_grad=False, + ) + + # no overlap: don't apply to already anomalous regions (mask=1 -> bad) + noise_mask = noise_mask * (1 - mask) + + # shape of noise is [B * 2, 1, H, W] + perlin_mask = self.generate_perlin(b * 2, h, w).to(features.device) + # only apply where perlin mask is 1 + noise_mask = noise_mask * perlin_mask + + # update gt mask + mask = mask + noise_mask + # binarize + mask = torch.where(mask > 0, torch.ones_like(mask), torch.zeros_like(mask)) + + # make new labels. 1 if any part of mask is 1, 0 otherwise + new_anomalous = noise_mask.reshape(b * 2, -1).any(dim=1).type(torch.float32) + labels = labels + new_anomalous + # binarize + labels = torch.where(labels > 0, torch.ones_like(labels), torch.zeros_like(labels)) + + # apply masked noise + perturbed = features + noise * noise_mask + + return perturbed, mask, labels diff --git a/src/anomalib/models/image/supersimplenet/lightning_model.py b/src/anomalib/models/image/supersimplenet/lightning_model.py new file mode 100644 index 0000000000..cd6f58073e --- /dev/null +++ b/src/anomalib/models/image/supersimplenet/lightning_model.py @@ -0,0 +1,139 @@ +"""SuperSimpleNet: Unifying Unsupervised and Supervised Learning for Fast and Reliable Surface Defect Detection. + +Paper https://arxiv.org/pdf/2408.03143 +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler +from torch.optim import AdamW +from torch.optim.lr_scheduler import MultiStepLR + +from anomalib import LearningType +from anomalib.data import Batch +from anomalib.metrics import Evaluator +from anomalib.models import AnomalibModule +from anomalib.post_processing import PostProcessor +from anomalib.pre_processing import PreProcessor + +from .loss import SSNLoss +from .torch_model import SuperSimpleNetModel + + +class SuperSimpleNet(AnomalibModule): + """PL Lightning Module for the SuperSimpleNet model. + + Args: + perlin_threshold (float): threshold value for Perlin noise thresholding during anomaly generation. + backbone (str): backbone name + layers (list[str]): backbone layers utilised + supervised (bool): whether the model will be trained in supervised mode. False by default (unsupervised). + """ + + def __init__( + self, + perlin_threshold: float = 0.2, + backbone: str = "wide_resnet50_2", + layers: list[str] = ["layer2", "layer3"], # noqa: B006 + supervised: bool = False, + pre_processor: PreProcessor | bool = True, + post_processor: PostProcessor | None = None, + evaluator: Evaluator | bool = True, + ) -> None: + super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + self.supervised = supervised + # stop grad in unsupervised + if supervised: + stop_grad = False + self.norm_clip_val = 1 + else: + stop_grad = True + self.norm_clip_val = 0 + + self.model = SuperSimpleNetModel( + perlin_threshold=perlin_threshold, + backbone=backbone, + layers=layers, + stop_grad=stop_grad, + ) + self.loss = SSNLoss() + + def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: + """Perform the training step input and return the loss. + + Args: + batch (batch: dict[str, str | torch.Tensor]): Input batch + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + STEP_OUTPUT: Dictionary containing the loss value. + """ + del args, kwargs # These variables are not used. + + anomaly_map, anomaly_score, masks, labels = self.model( + images=batch.image, + masks=batch.gt_mask, + labels=batch.gt_label, + ) + loss = self.loss(pred_map=anomaly_map, pred_score=anomaly_score, target_mask=masks, target_label=labels) + self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) + return {"loss": loss} + + def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: + """Perform the validation step and return the anomaly map and anomaly score. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + STEP_OUTPUT | None: batch dictionary containing anomaly-maps. + """ + # These variables are not used. + del args, kwargs + + # Get anomaly maps and predicted scores from the model. + predictions = self.model(batch.image) + + return batch.update(**predictions._asdict()) + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return SuperSimpleNet trainer arguments.""" + return {"gradient_clip_val": self.norm_clip_val, "num_sanity_val_steps": 0} + + def configure_optimizers(self) -> OptimizerLRScheduler: + """Configure AdamW optimizer and MultiStepLR scheduler.""" + optim = AdamW( + [ + { + "params": self.model.adaptor.parameters(), + "lr": 0.0001, + }, + { + "params": self.model.segdec.parameters(), + "lr": 0.0002, + "weight_decay": 0.00001, + }, + ], + ) + sched = MultiStepLR( + optim, + milestones=[int(self.trainer.max_epochs * 0.8), int(self.trainer.max_epochs * 0.9)], + gamma=0.4, + ) + return [optim], [sched] + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/src/anomalib/models/image/supersimplenet/loss.py b/src/anomalib/models/image/supersimplenet/loss.py new file mode 100644 index 0000000000..984e5299cb --- /dev/null +++ b/src/anomalib/models/image/supersimplenet/loss.py @@ -0,0 +1,74 @@ +"""Loss function for the SuperSimpleNet model implementation.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from functools import partial + +import torch +from torch import nn +from torchvision.ops.focal_loss import sigmoid_focal_loss + + +class SSNLoss(nn.Module): + """SuperSimpleNet loss function. + + Args: + truncation_term (float): L1 loss truncation term preventing overfitting. + """ + + def __init__(self, truncation_term: float = 0.5) -> None: + super().__init__() + self.focal_loss = partial(sigmoid_focal_loss, alpha=-1, gamma=4.0, reduction="mean") + self.th = truncation_term + + def trunc_l1_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Calculate the truncated L1 loss between `pred` and `target`. + + Args: + pred (torch.Tensor): predicted values. + target (torch.Tensor): target GT values. + + Returns: + torch.Tensor: L1 truncated loss value. + """ + normal_scores = pred[target == 0] + anomalous_scores = pred[target > 0] + # push normal towards negative numbers + true_loss = torch.clip(normal_scores + self.th, min=0) + # push anomalous towards positive numbers + fake_loss = torch.clip(-anomalous_scores + self.th, min=0) + + true_loss = true_loss.mean() if len(true_loss) else torch.Tensor(0) + fake_loss = fake_loss.mean() if len(fake_loss) else torch.Tensor(0) + + return true_loss + fake_loss + + def forward( + self, + pred_map: torch.Tensor, + pred_score: torch.Tensor, + target_mask: torch.Tensor, + target_label: torch.Tensor, + ) -> torch.Tensor: + """Calculate loss based on predicted anomaly maps and scores. + + Total loss = Lseg and Lcls + where + Lseg = Lfocal(map) + Ltruncl1(map) + Lcls = Lfocal(score) + + Args: + pred_map: predicted anomaly maps. + pred_score: predicted anomaly scores. + target_mask: GT anomaly masks. + target_label: GT anomaly labels. + + Returns: + torch.Tensor: loss value. + """ + map_focal = self.focal_loss(pred_map, target_mask) + map_trunc_l1 = self.trunc_l1_loss(pred_map, target_mask) + score_focal = self.focal_loss(pred_score, target_label) + + return map_focal + map_trunc_l1 + score_focal diff --git a/src/anomalib/models/image/supersimplenet/torch_model.py b/src/anomalib/models/image/supersimplenet/torch_model.py new file mode 100644 index 0000000000..0f61575f63 --- /dev/null +++ b/src/anomalib/models/image/supersimplenet/torch_model.py @@ -0,0 +1,367 @@ +"""PyTorch model for the SuperSimpleNet model implementation.""" + +# Original Code +# Copyright (c) 2024 Blaž Rolih +# https://github.com/blaz-r/SuperSimpleNet. +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import nn +from torch.nn import Parameter + +from anomalib.data import InferenceBatch +from anomalib.models.components import GaussianBlur2d, TorchFXFeatureExtractor +from anomalib.models.image.supersimplenet.anomaly_generator import SSNAnomalyGenerator + + +class SuperSimpleNetModel(nn.Module): + """SuperSimpleNet Pytorch model. + + It consists of feature extractor, feature adaptor, anomaly generation mechanism and segmentation-detection module. + + Args: + perlin_threshold (float): threshold value for Perlin noise thresholding during anomaly generation. + backbone (str): backbone name + layers (list[str]): backbone layers utilised + stop_grad (bool): whether to stop gradient from class. to seg. head. + """ + + def __init__( + self, + perlin_threshold: float = 0.2, + backbone: str = "wide_resnet50_2", + layers: list[str] = ["layer2", "layer3"], # noqa: B006 + stop_grad: bool = True, + ) -> None: + super().__init__() + self.feature_extractor = FeatureExtractor(backbone=backbone, layers=layers) + + channels = self.feature_extractor.get_channels_dim() + self.adaptor = FeatureAdaptor(channels) + self.segdec = SegmentationDetectionModule(channel_dim=channels, stop_grad=stop_grad) + self.anomaly_generator = SSNAnomalyGenerator(noise_mean=0, noise_std=0.015, threshold=perlin_threshold) + + self.anomaly_map_generator = AnomalyMapGenerator(sigma=4) + + def forward( + self, + images: torch.Tensor, + masks: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + ) -> InferenceBatch | tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """SuperSimpleNet forward pass. + + Extract and process features, adapt them, generate anomalies (train only) and predict anomaly map and score. + + Args: + images (torch.Tensor): Input images. + masks (torch.Tensor): GT masks. + labels (torch.Tensor): GT labels. + + Returns: + inference: anomaly map and score + training: anomaly map, score and GT masks and labels + """ + output_size = images.shape[-2:] + + features = self.feature_extractor(images) + adapted = self.adaptor(features) + + if self.training: + masks = self.downsample_mask(masks, *features.shape[-2:]) + # make linter happy :) + if labels is not None: + labels = labels.type(torch.float32) + + features, masks, labels = self.anomaly_generator( + adapted, + masks, + labels, + ) + + anomaly_map, anomaly_score = self.segdec(features) + return anomaly_map, anomaly_score, masks, labels + + anomaly_map, anomaly_score = self.segdec(adapted) + anomaly_map = self.anomaly_map_generator(anomaly_map, final_size=output_size) + + return InferenceBatch(anomaly_map=anomaly_map, pred_score=anomaly_score) + + @staticmethod + def downsample_mask(masks: torch.Tensor, feat_h: int, feat_w: int) -> torch.Tensor: + """Downsample the masks according to the feature dimensions. + + Primarily used in supervised setting. + + Args: + masks (torch.Tensor): input GT masks + feat_h (int): feature height. + feat_w (int): feature width. + + Returns: + (torch.Tensor): downsampled masks. + """ + masks = masks.type(torch.float32) + # best downsampling proposed by DestSeg + masks = F.interpolate( + masks.unsqueeze(1), + size=(feat_h, feat_w), + mode="bilinear", + ) + return torch.where( + masks < 0.5, + torch.zeros_like(masks), + torch.ones_like(masks), + ) + + +def init_weights(m: nn.Module) -> None: + """Init weight of the model. + + Args: + m (nn.Module): torch module. + """ + if isinstance(m, nn.Linear | nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.BatchNorm1d | nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + + +class FeatureExtractor(nn.Module): + """Feature extractor module. + + Args: + backbone (str): backbone name. + layers (list[str]): list of layers used for extraction. + """ + + def __init__(self, backbone: str, layers: list[str], patch_size: int = 3) -> None: + super().__init__() + + self.feature_extractor = TorchFXFeatureExtractor( + backbone=backbone, + return_nodes=layers, + weights="IMAGENET1K_V1", + ) + self.pooler = nn.AvgPool2d( + kernel_size=patch_size, + stride=1, + padding=patch_size // 2, + ) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Extract features from input tensor. + + Args: + input_tensor: input tensor (images) + + Returns: + (torch.Tensor): extracted feature map. + """ + # extract features + self.feature_extractor.eval() + with torch.no_grad(): + features = self.feature_extractor(input_tensor) + + features = list(features.values()) + + _, _, h, w = features[0].shape + feature_map = [] + for layer in features: + # upscale all to 2x the size of the first (largest) + resized = F.interpolate( + layer, + size=(h * 2, w * 2), + mode="bilinear", + ) + feature_map.append(resized) + # channel-wise concat + feature_map = torch.cat(feature_map, dim=1) + + # neighboring patch aggregation + return self.pooler(feature_map) + + def get_channels_dim(self) -> int: + """Get feature channel dimension. + + Returns: + (int): feature channel dimension. + """ + # dryrun + self.feature_extractor.eval() + with torch.no_grad(): + features = self.feature_extractor(torch.rand(1, 3, 256, 256)) + # sum channels + return sum(feature.shape[1] for feature in features.values()) + + +class FeatureAdaptor(nn.Module): + """Feature adaptor used to adapt raw features for the task of anomaly detection. + + Args: + channel_dim (int): channel dimension of features. + """ + + def __init__(self, channel_dim: int) -> None: + super().__init__() + # linear layer equivalent + self.projection = nn.Conv2d( + in_channels=channel_dim, + out_channels=channel_dim, + kernel_size=1, + stride=1, + ) + self.apply(init_weights) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Adapt features. + + Args: + features (torch.Tensor): input features + + Returns: + (torch.Tensor) adapted features + """ + return self.projection(features) + + +class SegmentationDetectionModule(nn.Module): + """SegmentationDetection module responsible for prediction of anomaly map and score. + + Args: + channel_dim (int): channel dimension of features. + stop_grad (bool): whether to stop gradient from class. head to seg. head. + """ + + def __init__( + self, + channel_dim: int, + stop_grad: bool = False, + ) -> None: + super().__init__() + self.stop_grad = stop_grad + + # 1x1 conv - linear layer equivalent + self.seg_head = nn.Sequential( + nn.Conv2d( + in_channels=channel_dim, + out_channels=1024, + kernel_size=1, + stride=1, + ), + nn.BatchNorm2d(1024), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d( + in_channels=1024, + out_channels=1, + kernel_size=1, + stride=1, + bias=False, + ), + ) + + # pooling for cls. conv out and map + self.map_avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) + self.map_max_pool = nn.AdaptiveMaxPool2d(output_size=(1, 1)) + + self.dec_avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) + self.dec_max_pool = nn.AdaptiveMaxPool2d(output_size=(1, 1)) + + # cls. head conv block + self.cls_conv = nn.Sequential( + nn.Conv2d( + in_channels=channel_dim + 1, + out_channels=128, + kernel_size=5, + padding="same", + ), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + ) + + # cls. head fc block: 128 from dec and 2 from map, * 2 due to max and avg pool + self.cls_fc = nn.Linear(in_features=128 * 2 + 2, out_features=1) + + self.apply(init_weights) + + def get_params(self) -> tuple[list[Parameter], list[Parameter]]: + """Get segmentation and classification head parameters. + + Returns: + seg. head parameters and class. head parameters. + """ + seg_params = list(self.seg_head.parameters()) + dec_params = list(self.cls_conv.parameters()) + list(self.cls_fc.parameters()) + return seg_params, dec_params + + def forward(self, features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Predict anomaly map and anomaly score. + + Args: + features: adapted features. + + Returns: + predicted anomaly map and score. + """ + # get anomaly map from seg head + ano_map = self.seg_head(features) + + map_dec_copy = ano_map + if self.stop_grad: + map_dec_copy = map_dec_copy.detach() + # dec conv layer takes feat + map + mask_cat = torch.cat((features, map_dec_copy), dim=1) + dec_out = self.cls_conv(mask_cat) + + # conv block result pooling + dec_max = self.dec_max_pool(dec_out) + dec_avg = self.dec_avg_pool(dec_out) + + # predicted map pooling (and stop grad) + map_max = self.map_max_pool(ano_map) + if self.stop_grad: + map_max = map_max.detach() + + map_avg = self.map_avg_pool(ano_map) + if self.stop_grad: + map_avg = map_avg.detach() + + # final dec layer: conv channel max and avg and map max and avg + dec_cat = torch.cat((dec_max, dec_avg, map_max, map_avg), dim=1).squeeze() + ano_score = self.cls_fc(dec_cat).squeeze() + + return ano_map, ano_score + + +class AnomalyMapGenerator(nn.Module): + """Final anomaly map generator, responsible for upscaling and smoothing. + + Args: + sigma (float) Gaussian kernel sigma value. + """ + + def __init__(self, sigma: float) -> None: + super().__init__() + kernel_size = 2 * math.ceil(3 * sigma) + 1 + self.smoothing = GaussianBlur2d(kernel_size=kernel_size, sigma=4) + + def forward(self, out_map: torch.Tensor, final_size: tuple[int, int]) -> torch.Tensor: + """Upscale and smooth anomaly map to get final anomaly map of same size as input image. + + Args: + out_map (torch.Tensor): output anomaly map from seg. head. + final_size (tuple[int, int]): size (h, w) of final anomaly map. + + Returns: + torch.Tensor: final anomaly map. + """ + # upscale & smooth + anomaly_map = F.interpolate(out_map, size=final_size, mode="bilinear") + return self.smoothing(anomaly_map)