Skip to content

Commit

Permalink
Refactoring (qubvel#528)
Browse files Browse the repository at this point in the history
* Reorganize decoders, add deprecation for utils, add dataset
* Fix imports
  • Loading branch information
qubvel authored Dec 27, 2021
1 parent 4f94380 commit bf1d2bf
Show file tree
Hide file tree
Showing 31 changed files with 174 additions and 43 deletions.
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ torchvision>=0.5.0
pretrainedmodels==0.7.4
efficientnet-pytorch==0.6.3
timm==0.4.12

tqdm
opencv-python-headless
18 changes: 9 additions & 9 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from .unet import Unet
from .unetplusplus import UnetPlusPlus
from .manet import MAnet
from .linknet import Linknet
from .fpn import FPN
from .pspnet import PSPNet
from .deeplabv3 import DeepLabV3, DeepLabV3Plus
from .pan import PAN
from .decoders.unet import Unet
from .decoders.unetplusplus import UnetPlusPlus
from .decoders.manet import MAnet
from .decoders.linknet import Linknet
from .decoders.fpn import FPN
from .decoders.pspnet import PSPNet
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
from .decoders.pan import PAN

from . import encoders
from . import utils
from . import decoders
from . import losses

from .__version__ import __version__
Expand Down
4 changes: 2 additions & 2 deletions segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def forward(self, x):

return masks

@torch.no_grad()
def predict(self, x):
"""Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`
Expand All @@ -36,7 +37,6 @@ def predict(self, x):
if self.training:
self.eval()

with torch.no_grad():
x = self.forward(x)
x = self.forward(x)

return x
1 change: 1 addition & 0 deletions segmentation_models_pytorch/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .oxford_pet import OxfordPetDataset, SimpleOxfordPetDataset
125 changes: 125 additions & 0 deletions segmentation_models_pytorch/datasets/oxford_pet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import cv2
import shutil
import numpy as np
from tqdm import tqdm
from urllib.request import urlretrieve


class OxfordPetDataset(torch.utils.data.Dataset):

def __init__(self, root, mode="train", transform=None):

assert mode in {"train", "valid", "test"}

self.root = root
self.mode = mode
self.transform = transform

self._download_dataset() # download only if it does not exist

self.images_directory = os.path.join(self.root, "images")
self.masks_directory = os.path.join(self.root, "annotations", "trimaps")

self.filenames = self._read_split() # read train/valid/test splits

def __len__(self):
return len(self.filenames)

def __getitem__(self, idx):

filename = self.filenames[idx]
image_path = os.path.join(self.images_directory, filename + ".jpg")
mask_path = os.path.join(self.masks_directory, filename + ".png")

image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

trimap = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
mask = self._preprocess_mask(trimap)

sample = dict(image=image, mask=mask, trimap=trimap)
if self.transform is not None:
sample = self.transform(**sample)

return sample

@staticmethod
def _preprocess_mask(mask):
mask = mask.astype(np.float32)
mask[mask == 2.0] = 0.0
mask[(mask == 1.0) | (mask == 3.0)] = 1.0
return mask

def _read_split(self):
split_filename = "test.txt" if self.mode == "test" else "trainval.txt"
split_filepath = os.path.join(self.root, "annotations", split_filename)
with open(split_filepath) as f:
split_data = f.read().strip("\n").split("\n")
filenames = [x.split(" ")[0] for x in split_data]
if self.mode == "train": # 90% for train
filenames = [x for i, x in enumerate(filenames) if i % 10 != 0]
elif self.mode == "valid": # 10% for validation
filenames = [x for i, x in enumerate(filenames) if i % 10 == 0]
return filenames

def _download_dataset(self):

# load images
filepath = os.path.join(self.root, "images.tar.gz")
download_url(
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", filepath=filepath,
)
extract_archive(filepath)

# load annotations
filepath = os.path.join(self.root, "annotations.tar.gz")
download_url(
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", filepath=filepath,
)
extract_archive(filepath)


class SimpleOxfordPetDataset(OxfordPetDataset):
"""Dataset for example without augmentations and transforms"""

def __getitem__(self, *args, **kwargs):

sample = super().__getitem__(*args, **kwargs)

# resize images
image = cv2.resize(sample["image"], (256, 256), cv2.INTER_LINEAR)
mask = cv2.resize(sample["mask"], (256, 256), cv2.INTER_NEAREST)
trimap = cv2.resize(sample["trimap"], (256, 256), cv2.INTER_NEAREST)

# convert to other format HWC -> CHW
sample["image"] = np.moveaxis(image, -1, 0)
sample["mask"] = np.expand_dims(mask, 0)
sample["trimap"] = np.expand_dims(trimap, 0)

return sample


class TqdmUpTo(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)


def download_url(url, filepath):
directory = os.path.dirname(os.path.abspath(filepath))
os.makedirs(directory, exist_ok=True)
if os.path.exists(filepath):
return

with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=os.path.basename(filepath)) as t:
urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None)
t.total = t.n


def extract_archive(filepath):
extract_dir = os.path.dirname(os.path.abspath(filepath))
dst_dir = os.path.splitext(filepath)[0]
if not os.path.exists(dst_dir):
shutil.unpack_archive(filepath, extract_dir)
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import torch.nn as nn

from torch import nn
from typing import Optional
from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
from ..base import SegmentationModel, SegmentationHead, ClassificationHead
from ..encoders import get_encoder

from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
from segmentation_models_pytorch.encoders import get_encoder
from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder

class DeepLabV3(SegmentationModel):
"""DeepLabV3_ implementation from "Rethinking Atrous Convolution for Semantic Image Segmentation"
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional, Union
from .decoder import FPNDecoder
from ..base import SegmentationModel, SegmentationHead, ClassificationHead
from ..encoders import get_encoder

from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
from segmentation_models_pytorch.encoders import get_encoder
from .decoder import FPNDecoder

class FPN(SegmentationModel):
"""FPN_ is a fully convolution neural network for image semantic segmentation.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.nn as nn

from ..base import modules
from segmentation_models_pytorch.base import modules


class TransposeX2(nn.Sequential):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional, Union

from segmentation_models_pytorch.base import SegmentationHead, SegmentationModel, ClassificationHead
from segmentation_models_pytorch.encoders import get_encoder
from .decoder import LinknetDecoder
from ..base import SegmentationHead, SegmentationModel, ClassificationHead
from ..encoders import get_encoder


class Linknet(SegmentationModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base import modules as md

from segmentation_models_pytorch.base import modules as md


class PAB(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional, Union, List

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
from .decoder import MAnetDecoder
from ..encoders import get_encoder
from ..base import SegmentationModel
from ..base import SegmentationHead, ClassificationHead


class MAnet(SegmentationModel):
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
from .decoder import PANDecoder
from ..encoders import get_encoder
from ..base import SegmentationModel
from ..base import SegmentationHead, ClassificationHead


class PAN(SegmentationModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F

from ..base import modules
from segmentation_models_pytorch.base import modules


class PSPBlock(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
from .decoder import PSPDecoder
from ..encoders import get_encoder

from ..base import SegmentationModel
from ..base import SegmentationHead, ClassificationHead


class PSPNet(SegmentationModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F

from ..base import modules as md
from segmentation_models_pytorch.base import modules as md


class DecoderBlock(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional, Union, List

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
from .decoder import UnetDecoder
from ..encoders import get_encoder
from ..base import SegmentationModel
from ..base import SegmentationHead, ClassificationHead


class Unet(SegmentationModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F

from ..base import modules as md
from segmentation_models_pytorch.base import modules as md


class DecoderBlock(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional, Union, List

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
from .decoder import UnetPlusPlusDecoder
from ..encoders import get_encoder
from ..base import SegmentationModel
from ..base import SegmentationHead, ClassificationHead


class UnetPlusPlus(SegmentationModel):
Expand Down
Empty file.
5 changes: 4 additions & 1 deletion segmentation_models_pytorch/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import warnings
warnings.warn("`smp.utils` module is deprecated and will be removed in future releases.", DeprecationWarning)

from . import train
from . import losses
from . import metrics
from . import metrics

0 comments on commit bf1d2bf

Please sign in to comment.