diff --git a/mapmaster/__init__.py b/mapmaster/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/mapmaster/dataset/nuscenes_bemapnet.py b/mapmaster/dataset/nuscenes_bemapnet.py
new file mode 100644
index 0000000..2498965
--- /dev/null
+++ b/mapmaster/dataset/nuscenes_bemapnet.py
@@ -0,0 +1,141 @@
+import os
+import torch
+import numpy as np
+from PIL import Image
+from copy import deepcopy
+from skimage import io as skimage_io
+from torch.utils.data import Dataset
+
+
+class NuScenesMapDataset(Dataset):
+ def __init__(self, img_key_list, map_conf, ida_conf, bezier_conf, transforms, data_split="training"):
+ super().__init__()
+ self.img_key_list = img_key_list
+ self.map_conf = map_conf
+ self.ida_conf = ida_conf
+ self.bez_conf = bezier_conf
+ self.ego_size = map_conf["ego_size"]
+ self.mask_key = map_conf["mask_key"]
+ self.nusc_root = map_conf["nusc_root"]
+ self.anno_root = map_conf["anno_root"]
+ self.split_dir = map_conf["split_dir"]
+ self.num_degree = bezier_conf["num_degree"]
+ self.max_pieces = bezier_conf["max_pieces"]
+ self.max_instances = bezier_conf["max_instances"]
+ self.split_mode = 'train' if data_split == "training" else 'val'
+ split_path = os.path.join(self.split_dir, f'{self.split_mode}.txt')
+ self.tokens = [token.strip() for token in open(split_path).readlines()]
+ self.transforms = transforms
+
+ def __getitem__(self, idx: int):
+ token = self.tokens[idx]
+ sample = np.load(os.path.join(self.anno_root, f'{token}.npz'), allow_pickle=True)
+ resize_dims, crop, flip, rotate = self.sample_ida_augmentation()
+ images, ida_mats = [], []
+ for im_view in self.img_key_list:
+ for im_path in sample['image_paths']:
+ if im_path.startswith(f'samples/{im_view}/'):
+ im_path = os.path.join(self.nusc_root, im_path)
+ img = skimage_io.imread(im_path)
+ img, ida_mat = self.img_transform(img, resize_dims, crop, flip, rotate)
+ images.append(img)
+ ida_mats.append(ida_mat)
+ extrinsic = np.stack([np.eye(4) for _ in range(sample["trans"].shape[0])], axis=0)
+ extrinsic[:, :3, :3] = sample["rots"]
+ extrinsic[:, :3, 3] = sample["trans"]
+ intrinsic = sample['intrins']
+ ctr_points = np.zeros((self.max_instances, max(self.max_pieces) * max(self.num_degree) + 1, 2), dtype=np.float)
+ ins_labels = np.zeros((self.max_instances, 3), dtype=np.int16) - 1
+ for ins_id, ctr_info in enumerate(sample['ctr_points']):
+ cls_id = int(ctr_info['type'])
+ ctr_pts_raw = np.array(ctr_info['pts'])
+ max_points = self.max_pieces[cls_id] * self.num_degree[cls_id] + 1
+ num_points = max_points if max_points <= ctr_pts_raw.shape[0] else ctr_pts_raw.shape[0]
+ assert num_points >= self.num_degree[cls_id] + 1
+ ctr_points[ins_id][:num_points] = np.array(ctr_pts_raw[:num_points])
+ ins_labels[ins_id] = [cls_id, (num_points - 1) // self.num_degree[cls_id] - 1, num_points]
+ masks = sample[self.mask_key]
+ if flip:
+ new_order = [2, 1, 0, 5, 4, 3]
+ img_key_list = [self.img_key_list[i] for i in new_order]
+ images = [images[i] for i in new_order]
+ ida_mats = [ida_mats[i] for i in new_order]
+ extrinsic = [extrinsic[i] for i in new_order]
+ intrinsic = [intrinsic[i] for i in new_order]
+ masks = [np.flip(mask, axis=1) for mask in masks]
+ ctr_points = self.point_flip(ctr_points, ins_labels, self.ego_size)
+ item = dict(
+ images=images, targets=dict(masks=masks, points=ctr_points, labels=ins_labels),
+ extrinsic=np.stack(extrinsic), intrinsic=np.stack(intrinsic), ida_mats=np.stack(ida_mats),
+ extra_infos=dict(token=token, img_key_list=self.img_key_list, map_size=self.ego_size, do_flip=flip)
+ )
+ if self.transforms is not None:
+ item = self.transforms(item)
+ return item
+
+ def __len__(self):
+ return len(self.tokens)
+
+ def sample_ida_augmentation(self):
+ """Generate ida augmentation values based on ida_config."""
+ resize_dims = w, h = self.ida_conf["resize_dims"]
+ crop = (0, 0, w, h)
+ if self.ida_conf["up_crop_ratio"] > 0:
+ crop = (0, int(self.ida_conf["up_crop_ratio"] * h), w, h)
+ flip, color, rotate_ida = False, False, 0
+ if self.split_mode == "train":
+ if self.ida_conf["rand_flip"] and np.random.choice([0, 1]):
+ flip = True
+ if self.ida_conf["rot_lim"]:
+ assert isinstance(self.ida_conf["rot_lim"], (tuple, list))
+ rotate_ida = np.random.uniform(*self.ida_conf["rot_lim"])
+ return resize_dims, crop, flip, rotate_ida
+
+ def img_transform(self, img, resize_dims, crop, flip, rotate):
+ img = Image.fromarray(img)
+ ida_rot = torch.eye(2)
+ ida_tran = torch.zeros(2)
+ W, H = img.size
+ img = img.resize(resize_dims)
+ img = img.crop(crop)
+ if flip:
+ img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
+ img = img.rotate(rotate)
+
+ # post-homography transformation
+ scales = torch.tensor([resize_dims[0] / W, resize_dims[1] / H])
+ ida_rot *= torch.Tensor(scales)
+ ida_tran -= torch.Tensor(crop[:2])
+ if flip:
+ A = torch.Tensor([[-1, 0], [0, 1]])
+ b = torch.Tensor([crop[2] - crop[0], 0])
+ ida_rot = A.matmul(ida_rot)
+ ida_tran = A.matmul(ida_tran) + b
+ A = self.get_rot(rotate / 180 * np.pi)
+ b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2
+ b = A.matmul(-b) + b
+ ida_rot = A.matmul(ida_rot)
+ ida_tran = A.matmul(ida_tran) + b
+ ida_mat = ida_rot.new_zeros(3, 3)
+ ida_mat[2, 2] = 1
+ ida_mat[:2, :2] = ida_rot
+ ida_mat[:2, 2] = ida_tran
+ return np.asarray(img), ida_mat
+
+ @staticmethod
+ def point_flip(points, labels, map_shape):
+
+ def _flip(pts):
+ pts[:, 0] = map_shape[1] - pts[:, 0]
+ return pts.copy()
+
+ points_ret = deepcopy(points)
+ for ins_id in range(points.shape[0]):
+ end = labels[ins_id, 2]
+ points_ret[ins_id][:end] = _flip(points[ins_id][:end])
+
+ return points_ret
+
+ @staticmethod
+ def get_rot(h):
+ return torch.Tensor([[np.cos(h), np.sin(h)], [-np.sin(h), np.cos(h)]])
diff --git a/mapmaster/dataset/nuscenes_pivotnet.py b/mapmaster/dataset/nuscenes_pivotnet.py
new file mode 100644
index 0000000..270db9b
--- /dev/null
+++ b/mapmaster/dataset/nuscenes_pivotnet.py
@@ -0,0 +1,56 @@
+import os
+import numpy as np
+import pickle as pkl
+from PIL import Image
+from torch.utils.data import Dataset
+
+class NuScenesMapDataset(Dataset):
+ def __init__(self, img_key_list, map_conf, transforms, data_split="training"):
+ super().__init__()
+ self.img_key_list = img_key_list
+ self.map_conf = map_conf
+
+ self.ego_size = map_conf["ego_size"]
+ self.mask_key = map_conf["mask_key"]
+ self.nusc_root = map_conf["nusc_root"]
+ self.anno_root = map_conf["anno_root"]
+ self.split_dir = map_conf["split_dir"] # instance_mask/instance_mask8
+
+ self.split_mode = 'train' if data_split == "training" else 'val'
+ split_path = os.path.join(self.split_dir, f'{self.split_mode}.txt')
+ self.tokens = [token.strip() for token in open(split_path).readlines()]
+ self.transforms = transforms
+
+ def __getitem__(self, idx: int):
+ token = self.tokens[idx]
+ sample = np.load(os.path.join(self.anno_root, f'{token}.npz'), allow_pickle=True)
+ # images
+ images = []
+ for im_view in self.img_key_list:
+ for im_path in sample['image_paths']:
+ if im_path.startswith(f'samples/{im_view}/'):
+ im_path = os.path.join(self.nusc_root, im_path)
+ img = np.asarray(Image.open(im_path))
+ images.append(img)
+ # pivot pts
+ pivot_pts = sample["pivot_pts"].item()
+ valid_length = sample["pivot_length"].item()
+ # targets
+ masks=sample[self.mask_key]
+ targets = dict(masks=masks, points=pivot_pts, valid_len=valid_length)
+ # pose
+ extrinsic = np.stack([np.eye(4) for _ in range(sample["trans"].shape[0])], axis=0)
+ extrinsic[:, :3, :3] = sample["rots"]
+ extrinsic[:, :3, 3] = sample["trans"]
+ intrinsic = sample['intrins']
+ # transform
+ item = dict(images=images, targets=targets,
+ extra_infos=dict(token=token, map_size=self.ego_size),
+ extrinsic=np.stack(extrinsic, axis=0), intrinsic=np.stack(intrinsic, axis=0))
+ if self.transforms is not None:
+ item = self.transforms(item)
+
+ return item
+
+ def __len__(self):
+ return len(self.tokens)
diff --git a/mapmaster/dataset/sampler.py b/mapmaster/dataset/sampler.py
new file mode 100644
index 0000000..23cd2cd
--- /dev/null
+++ b/mapmaster/dataset/sampler.py
@@ -0,0 +1,61 @@
+import torch
+import itertools
+import torch.distributed as dist
+from typing import Optional
+from torch.utils.data.sampler import Sampler
+
+
+class InfiniteSampler(Sampler):
+ """
+ In training, we only care about the "infinite stream" of training data.
+ So this sampler produces an infinite stream of indices and
+ all workers cooperate to correctly shuffle the indices and sample different indices.
+ The samplers in each worker effectively produces `indices[worker_id::num_workers]`
+ where `indices` is an infinite stream of indices consisting of
+ `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
+ or `range(size) + range(size) + ...` (if shuffle is False)
+ """
+
+ def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = 0, rank=0, world_size=1, drop_last=False):
+ """
+ Args:
+ size (int): the total number of data of the underlying dataset to sample from
+ shuffle (bool): whether to shuffle the indices or not
+ seed (int): the initial seed of the shuffle. Must be the same
+ across all workers. If None, will use a random seed shared
+ among workers (require synchronization among all workers).
+ """
+ self._size = size
+ assert size > 0
+ self._shuffle = shuffle
+ self._seed = int(seed)
+ self.drop_last = drop_last
+
+ if dist.is_available() and dist.is_initialized():
+ self._rank = dist.get_rank()
+ self._world_size = dist.get_world_size()
+ else:
+ self._rank = rank
+ self._world_size = world_size
+
+ def set_epoch(self, epoch):
+ pass
+
+ def __iter__(self):
+ start = self._rank
+ yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
+
+ def _infinite_indices(self):
+ g = torch.Generator()
+ g.manual_seed(self._seed)
+ while True:
+ if self._shuffle:
+ yield from torch.randperm(self._size, generator=g).tolist()
+ else:
+ yield from list(range(self._size))
+
+ def __len__(self):
+ if self.drop_last:
+ return self._size // self._world_size
+ else:
+ return (self._size + self._world_size - 1) // self._world_size
diff --git a/mapmaster/dataset/transform.py b/mapmaster/dataset/transform.py
new file mode 100644
index 0000000..0d8d2f0
--- /dev/null
+++ b/mapmaster/dataset/transform.py
@@ -0,0 +1,274 @@
+import cv2
+import mmcv
+import torch
+import numpy as np
+from PIL import Image
+from collections.abc import Sequence
+
+class Resize(object):
+ def __init__(self, img_scale=None, backend="cv2", interpolation="bilinear"):
+ self.size = img_scale
+ self.backend = backend
+ self.interpolation = interpolation
+ self.cv2_interp_codes = {
+ "nearest": cv2.INTER_NEAREST,
+ "bilinear": cv2.INTER_LINEAR,
+ "bicubic": cv2.INTER_CUBIC,
+ "area": cv2.INTER_AREA,
+ "lanczos": cv2.INTER_LANCZOS4,
+ }
+ self.pillow_interp_codes = {
+ "nearest": Image.NEAREST,
+ "bilinear": Image.BILINEAR,
+ "bicubic": Image.BICUBIC,
+ "box": Image.BOX,
+ "lanczos": Image.LANCZOS,
+ "hamming": Image.HAMMING,
+ }
+
+ def __call__(self, data_dict):
+ """Call function to resize images.
+
+ Args:
+ data_dict (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Resized data_dict, 'scale_factor' keys are added into result dict.
+ """
+
+ imgs = []
+ for img in data_dict["images"]:
+ img = self.im_resize(img, self.size, backend=self.backend)
+ imgs.append(img)
+ data_dict["images"] = imgs
+
+ new_h, new_w = imgs[0].shape[:2]
+ h, w = data_dict["images"][0].shape[:2]
+ w_scale = new_w / w
+ h_scale = new_h / h
+ scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
+ data_dict["extra_infos"].update({"scale_factor": scale_factor})
+
+ return data_dict
+
+ def im_resize(self, img, size, return_scale=False, interpolation="bilinear", out=None, backend="cv2"):
+ """Resize image to a given size.
+ Args:
+ img (ndarray): The input image.
+ size (tuple[int]): Target size (w, h).
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`.
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if backend not in ["cv2", "pillow"]:
+ raise ValueError(
+ f"backend: {backend} is not supported for resize." f"Supported backends are 'cv2', 'pillow'"
+ )
+
+ if backend == "pillow":
+ assert img.dtype == np.uint8, "Pillow backend only support uint8 type"
+ pil_image = Image.fromarray(img)
+ pil_image = pil_image.resize(size, self.pillow_interp_codes[interpolation])
+ resized_img = np.array(pil_image)
+ else:
+ resized_img = cv2.resize(img, size, dst=out, interpolation=self.cv2_interp_codes[interpolation])
+ if not return_scale:
+ return resized_img
+ else:
+ w_scale = size[0] / w
+ h_scale = size[1] / h
+ return resized_img, w_scale, h_scale
+
+class Normalize(object):
+ """Normalize the image.
+
+ Added key is "img_norm_cfg".
+
+ Args:
+ mean (sequence): Mean values of 3 channels.
+ std (sequence): Std values of 3 channels.
+ to_rgb (bool): Whether to convert the image from BGR to RGB,
+ default is true.
+ """
+
+ def __init__(self, mean, std, to_rgb=True):
+ self.mean = np.array(mean, dtype=np.float32)
+ self.std = np.array(std, dtype=np.float32)
+ self.to_rgb = to_rgb
+
+ def __call__(self, data_dict):
+ imgs = []
+ for img in data_dict["images"]:
+ if self.to_rgb:
+ img = img.astype(np.float32) / 255.0
+ img = self.im_normalize(img, self.mean, self.std, self.to_rgb)
+ imgs.append(img)
+ data_dict["images"] = imgs
+ data_dict["extra_infos"]["img_norm_cfg"] = dict(mean=self.mean, std=self.std, to_rgb=self.to_rgb)
+ return data_dict
+
+ @staticmethod
+ def im_normalize(img, mean, std, to_rgb=True):
+ img = img.copy().astype(np.float32)
+ assert img.dtype != np.uint8 # cv2 inplace normalization does not accept uint8
+ mean = np.float64(mean.reshape(1, -1))
+ stdinv = 1 / np.float64(std.reshape(1, -1))
+ if to_rgb:
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
+ cv2.subtract(img, mean, img) # inplace
+ cv2.multiply(img, stdinv, img) # inplace
+ return img
+
+
+class ToTensor(object):
+ """Default formatting bundle."""
+
+ def __call__(self, data_dict):
+ """Call function to transform and format common fields in data_dict.
+
+ Args:
+ data_dict (dict): Data dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data that is formatted with default bundle.
+ """
+
+ for k in ["images", "extrinsic", "intrinsic", "ida_mats"]:
+ if k == "images":
+ data_dict[k] = np.stack([img.transpose(2, 0, 1) for img in data_dict[k]], axis=0)
+ data_dict[k] = self.to_tensor(np.ascontiguousarray(data_dict[k]))
+
+ for k in ["masks", "points", "labels"]:
+ data_dict["targets"][k] = self.to_tensor(np.ascontiguousarray(data_dict["targets"][k]))
+
+ return data_dict
+
+ @staticmethod
+ def to_tensor(data):
+ if isinstance(data, torch.Tensor):
+ return data
+ elif isinstance(data, np.ndarray):
+ return torch.from_numpy(data)
+ elif isinstance(data, Sequence) and not mmcv.is_str(data):
+ return torch.tensor(data)
+ elif isinstance(data, int):
+ return torch.LongTensor([data])
+ elif isinstance(data, float):
+ return torch.FloatTensor([data])
+ else:
+ raise TypeError(f"type {type(data)} cannot be converted to tensor.")
+
+class ToTensor_Pivot(object):
+ """Default formatting bundle."""
+
+ def __call__(self, data_dict):
+ """Call function to transform and format common fields in data_dict.
+
+ Args:
+ data_dict (dict): Data dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data that is formatted with default bundle.
+ """
+ if "images" in data_dict:
+ if isinstance(data_dict["images"], list):
+ # process multiple imgs in single frame
+ imgs = [img.transpose(2, 0, 1) for img in data_dict["images"]]
+ imgs = np.ascontiguousarray(np.stack(imgs, axis=0))
+ data_dict["images"] = self.to_tensor(imgs)
+ else:
+ img = np.ascontiguousarray(data_dict["img"].transpose(2, 0, 1))
+ data_dict["images"] = self.to_tensor(img)
+
+ for k in ["masks"]:
+ data_dict["targets"][k] = self.to_tensor(np.ascontiguousarray(data_dict["targets"][k]))
+
+ return data_dict
+
+ @staticmethod
+ def to_tensor(data):
+ """Convert objects of various python types to :obj:`torch.Tensor`.
+ Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
+ :class:`Sequence`, :class:`int` and :class:`float`.
+ Args:
+ data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
+ be converted.
+ """
+
+ if isinstance(data, torch.Tensor):
+ return data
+ elif isinstance(data, np.ndarray):
+ return torch.from_numpy(data)
+ elif isinstance(data, Sequence) and not mmcv.is_str(data):
+ return torch.tensor(data)
+ elif isinstance(data, int):
+ return torch.LongTensor([data])
+ elif isinstance(data, float):
+ return torch.FloatTensor([data])
+ else:
+ raise TypeError(f"type {type(data)} cannot be converted to tensor.")
+
+
+
+class Pad(object):
+ """Pad the image & mask.
+
+ There are two padding modes: (1) pad to a fixed size and (2) pad to the
+ minimum size that is divisible by some number.
+ Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
+
+ Args:
+ size (tuple, optional): Fixed padding size.
+ size_divisor (int, optional): The divisor of padded size.
+ pad_val (float, optional): Padding value, 0 by default.
+ """
+
+ def __init__(self, size_divisor=None, pad_val=0):
+ self.size_divisor = size_divisor
+ self.pad_val = pad_val
+ # only one of size and size_divisor should be valid
+ assert size_divisor is not None
+
+ def __call__(self, data_dict):
+ """Call function to pad images, masks, semantic segmentation maps.
+
+ Args:
+ data_dict (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Updated result dict.
+ """
+ padded_img = None
+ padded_imgs = []
+ for img in data_dict["images"]:
+ padded_img = self.im_pad_to_multiple(img, self.size_divisor, pad_val=self.pad_val)
+ padded_imgs.append(padded_img)
+ data_dict["images"] = padded_imgs
+ data_dict["extra_infos"].update(
+ {
+ "pad_shape": padded_img.shape,
+ "pad_size_divisor": self.size_divisor if self.size_divisor is not None else "None",
+ }
+ )
+ return data_dict
+
+ def im_pad_to_multiple(self, img, divisor, pad_val=0):
+ """Pad an image to ensure each edge to be multiple to some number.
+ Args:
+ img (ndarray): Image to be padded.
+ divisor (int): Padded image edges will be multiple to divisor.
+ pad_val (Number | Sequence[Number]): Same as :func:`impad`.
+ Returns:
+ ndarray: The padded image.
+ """
+ pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor
+ pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor
+ return self.im_pad(img, shape=(pad_h, pad_w), pad_val=pad_val)
diff --git a/mapmaster/engine/callbacks.py b/mapmaster/engine/callbacks.py
new file mode 100644
index 0000000..f410930
--- /dev/null
+++ b/mapmaster/engine/callbacks.py
@@ -0,0 +1,299 @@
+import os
+import glob
+import tqdm
+import torch
+import pickle
+from clearml import Task
+from loguru import logger
+from typing import Callable, Optional
+from tensorboardX import SummaryWriter
+from torch.nn.utils import clip_grad_norm_
+from mapmaster.engine.executor import Callback, BaseExecutor, Trainer
+from mapmaster.utils.misc import AvgMeter
+
+
+__all__ = ["Callback", "MasterOnlyCallback", "CheckPointSaver", "CheckPointLoader", "CheckPointC2Loader",
+ "ClearMLCallback", "EvalResultsSaver", "LamdaCallback", "ClipGrad", "ProgressBar",
+ "LearningRateMonitor", "TextMonitor", "TensorBoardMonitor"]
+
+
+class MasterOnlyCallback(Callback):
+ enabled_rank = [0]
+
+
+class CheckPointSaver(MasterOnlyCallback):
+ def __init__(
+ self,
+ local_path,
+ filename=r"checkpoint_epoch_{epoch}.pth",
+ remote_path=None,
+ save_interval: int = 1,
+ num_keep_latest=None,
+ ):
+ self.local_path = local_path
+ self.filename = filename
+ self.remote_path = remote_path
+ self.save_interval = save_interval
+ self.num_keep_latest = num_keep_latest
+ os.makedirs(local_path, exist_ok=True)
+
+ def _make_checkpoint(self, trainer: Trainer):
+ model_state = None
+ if hasattr(trainer, "ema_model"):
+ model = trainer.ema_model.ema
+ else:
+ model = trainer.model
+ if model:
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
+ model_state = model.module.state_dict()
+ model_state_cpu = type(model_state)()
+ for key, val in model_state.items():
+ model_state_cpu[key] = val.cpu()
+ model_state = model_state_cpu
+ else:
+ model_state = model.state_dict()
+
+ optim_state = trainer.optimizer.state_dict() if trainer.optimizer else None
+
+ callback_states = {}
+ for cb in trainer.callbacks:
+ if hasattr(cb, "state_dict"):
+ cls_name = cb.__class__.__name__
+ callback_states[cls_name] = cb.state_dict()
+
+ ckpt = {
+ "epoch": trainer.epoch,
+ "it": trainer.global_step,
+ "global_step": trainer.global_step,
+ "model_state": model_state,
+ "optimizer_state": optim_state,
+ "callback": callback_states,
+ }
+
+ # save grad_scaler
+ if hasattr(trainer, "grad_scaler"):
+ ckpt["grad_scaler_state"] = trainer.grad_scaler.state_dict()
+
+ return ckpt
+
+ def after_epoch(self, trainer: Trainer, epoch: int, update_best_ckpt: bool = False):
+ if (epoch + 1) % self.save_interval != 0:
+ return
+ filename = self.filename.format(epoch=epoch)
+ save_path = os.path.join(self.local_path, filename)
+ torch.save(self._make_checkpoint(trainer), save_path)
+ if update_best_ckpt:
+ torch.save(self._make_checkpoint(trainer), os.path.join(self.local_path, f"checkpoint_best.pth"))
+ self._remove_out_of_date_ckpt()
+
+ def _remove_out_of_date_ckpt(self):
+ if not self.num_keep_latest:
+ return
+
+ ckpt_list = glob.glob(os.path.join(self.local_path, self.filename.format(epoch="*")))
+ ckpt_list.sort(key=os.path.getmtime)
+ if len(ckpt_list) > self.num_keep_latest:
+ for cur_file_idx in range(0, len(ckpt_list) - self.num_keep_latest):
+ os.remove(ckpt_list[cur_file_idx])
+
+
+class CheckPointLoader(Callback):
+ def __init__(
+ self,
+ path,
+ weight_only=False,
+ ):
+ self.path = path
+ self.weight_only = weight_only
+
+ def load_checkpoint(self, trainer: Trainer):
+ logger.info(f"Loading parameters from checkpoint {self.path}")
+ with open(self.path, "rb") as f:
+ checkpoint = torch.load(f, map_location=torch.device("cpu"))
+
+ # TODO bulid model finetune callback
+ model_state_dict = trainer.model.state_dict()
+ checkpoint_state_dict = checkpoint["model_state"]
+ for k in list(checkpoint_state_dict.keys()):
+ if k in model_state_dict:
+ shape_model = tuple(model_state_dict[k].shape)
+ shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+ if shape_model != shape_checkpoint:
+ logger.info(
+ "'{}' has shape {} in the checkpoint but {} in the "
+ "model! Skipped.".format(k, shape_checkpoint, shape_model)
+ )
+ checkpoint_state_dict.pop(k)
+ trainer.model.load_state_dict(checkpoint_state_dict, strict=False)
+
+ if self.weight_only:
+ return
+
+ trainer.epoch = checkpoint.get("epoch", -1) + 1
+ trainer.global_step = checkpoint.get("global_step", -1) + 1
+ if "optimizer_state" in checkpoint:
+ trainer.optimizer.load_state_dict(checkpoint["optimizer_state"])
+ # resume callback
+ for cb in trainer.callbacks:
+ if hasattr(cb, "state_dict"):
+ cls_name = cb.__class__.__name__
+ if cls_name in checkpoint["callback"]:
+ cb.load_state_dict(checkpoint["callback"][cls_name])
+ # resume grad_scaler
+ if hasattr(trainer, "grad_scaler") and "grad_scaler_state" in checkpoint:
+ trainer.grad_scaler.load_state_dict(checkpoint["grad_scaler_state"])
+
+
+class ClearMLCallback(MasterOnlyCallback):
+ def __init__(self):
+ super().__init__()
+ self.task_id = None
+
+ def after_init(self, executor: BaseExecutor):
+ if self.task_id is None:
+ self.task = Task.init(
+ project_name="det3d",
+ task_name=executor.exp.exp_name,
+ auto_connect_frameworks={"pytorch": False},
+ reuse_last_task_id=False,
+ continue_last_task=False,
+ )
+ else:
+ self.task = Task.get_task(task_id=self.task_id)
+ self.task.add_tags(["resume"])
+ logger.info(f"continue from clearml task {self.task_id}")
+ self.task.connect(executor.exp)
+ if hasattr(executor.exp, "get_pcdet_cfg"):
+ self.task.connect(executor.exp.get_pcdet_cfg(), "pcdet_config")
+
+ def state_dict(self):
+ return {"task_id": self.task.task_id}
+
+ def load_state_dict(self, state_dict):
+ self.task_id = state_dict["task_id"]
+
+
+class EvalResultsSaver(MasterOnlyCallback):
+ def __init__(self, out_dir: str):
+ self.out_dir = out_dir
+
+ def after_eval(self, executor, det_annos: list):
+ out_file = os.path.join(self.out_dir, "result.pkl")
+ pickle.dump(det_annos, open(out_file, "wb"))
+
+
+class LamdaCallback:
+ def __init__(
+ self,
+ setup: Optional[Callable] = None,
+ load_checkpoint: Optional[Callable] = None,
+ after_init: Optional[Callable] = None,
+ before_train: Optional[Callable] = None,
+ before_epoch: Optional[Callable] = None,
+ before_step: Optional[Callable] = None,
+ before_backward: Optional[Callable] = None,
+ before_optimize: Optional[Callable] = None,
+ after_step: Optional[Callable] = None,
+ after_epoch: Optional[Callable] = None,
+ after_train: Optional[Callable] = None,
+ ) -> None:
+ for k, v in locals().items():
+ if k == "self":
+ continue
+ if v is not None:
+ setattr(self, k, v)
+
+
+class ClipGrad(Callback):
+ def __init__(self, max_norm: float):
+ self.max_norm = max_norm
+
+ def before_optimize(self, trainer):
+ clip_grad_norm_(trainer.model.parameters(), self.max_norm)
+
+
+class ProgressBar(MasterOnlyCallback):
+ def __init__(self, logger=None) -> None:
+ self.epoch_bar = None
+ self.step_bar = None
+ self.logger = logger
+
+ def setup(self, trainer: Trainer):
+ self.epoch_bar = tqdm.tqdm(initial=0, total=trainer.exp.max_epoch, desc="[Epoch]", dynamic_ncols=True)
+ self.step_bar = tqdm.tqdm(initial=0, desc="[Step]", dynamic_ncols=True, leave=False)
+ if self.logger:
+ self.logger.remove(0)
+ self.logger.add(lambda msg: self.step_bar.write(msg, end=""))
+
+ def before_epoch(self, trainer: Trainer, epoch: int):
+ self.epoch_bar.update(epoch - self.epoch_bar.n)
+ self.step_bar.reset(len(trainer.train_dataloader))
+
+ def after_step(self, trainer: Trainer, step, data_dict, *args, **kwargs):
+ self.step_bar.update()
+
+ def after_train(self, trainer: Trainer):
+ if self.step_bar:
+ self.step_bar.close()
+ if self.epoch_bar:
+ self.epoch_bar.close()
+
+
+class LearningRateMonitor:
+ def _get_learning_rate(self, optimizer):
+ if hasattr(optimizer, "lr"):
+ lr = float(optimizer.lr)
+ else:
+ lr = optimizer.param_groups[0]["lr"]
+ return lr
+
+
+class TextMonitor(MasterOnlyCallback, LearningRateMonitor):
+ def __init__(self, interval=10):
+ self.interval = interval
+ self.avg_loss = AvgMeter()
+ self.ext_dict = None
+
+ def after_step(self, trainer: Trainer, step, data_dict, *args, **kwargs):
+ self.avg_loss.update(kwargs["loss"])
+
+ lr = self._get_learning_rate(trainer.optimizer)
+
+ ext_info = ""
+ if kwargs["extra"] is not None:
+ if self.ext_dict is None:
+ self.ext_dict = {k: AvgMeter() for k in kwargs["extra"]}
+ for key, val in kwargs["extra"].items():
+ self.ext_dict[key].update(val)
+ ext_info = "".join([f" {k}={v.window_avg :.4f}" for k, v in self.ext_dict.items()])
+
+ if step % self.interval != 0:
+ return
+
+ trainer.logger.info(
+ f"e:{trainer.epoch}[{step}/{self.total_step}] lr={lr :.6f} loss={self.avg_loss.window_avg :.4f}{ext_info}"
+ )
+
+ def before_epoch(self, trainer: Trainer, epoch: int):
+ lr = trainer.optimizer.param_groups[0]["lr"]
+ trainer.logger.info(f"e:{epoch} learning rate = {lr :.6f}")
+ self.total_step = len(trainer.train_dataloader)
+
+
+class TensorBoardMonitor(MasterOnlyCallback, LearningRateMonitor):
+ def __init__(self, log_dir, interval=10):
+ os.makedirs(log_dir, exist_ok=True)
+ self.tb_log = SummaryWriter(log_dir=log_dir)
+ self.interval = interval
+
+ def after_step(self, trainer: Trainer, step, data_dict, *args, **kwargs):
+ accumulated_iter = trainer.global_step
+ if accumulated_iter % self.interval != 0:
+ return
+ lr = self._get_learning_rate(trainer.optimizer)
+ self.tb_log.add_scalar("epoch", trainer.epoch, accumulated_iter)
+ self.tb_log.add_scalar("train/loss", kwargs["loss"], accumulated_iter)
+ self.tb_log.add_scalar("meta_data/learning_rate", lr, accumulated_iter)
+ if kwargs["extra"] is not None:
+ for key, val in kwargs["extra"].items():
+ self.tb_log.add_scalar(f"train/{key}", val, accumulated_iter)
diff --git a/mapmaster/engine/core.py b/mapmaster/engine/core.py
new file mode 100644
index 0000000..4ecc69f
--- /dev/null
+++ b/mapmaster/engine/core.py
@@ -0,0 +1,191 @@
+import os
+import sys
+import argparse
+import datetime
+import warnings
+import subprocess
+from mapmaster.engine.executor import Trainer, BeMapNetEvaluator
+from mapmaster.engine.environ import ShareFSUUIDNameServer, RlaunchReplicaEnv
+from mapmaster.engine.callbacks import CheckPointLoader, CheckPointSaver, ClearMLCallback, ProgressBar
+from mapmaster.engine.callbacks import TensorBoardMonitor, TextMonitor, ClipGrad
+from mapmaster.utils.env import collect_env_info, get_root_dir
+from mapmaster.utils.misc import setup_logger, sanitize_filename, PyDecorator, all_gather_object
+
+
+__all__ = ["BaseCli", "BeMapNetCli"]
+
+
+class BaseCli:
+ """Command line tools for any exp."""
+
+ def __init__(self, Exp):
+ """Make sure the order of initialization is: build_args --> build_env --> build_exp,
+ since experiments depend on the environment and the environment depends on args.
+
+ Args:
+ Exp : experiment description class
+ """
+ self.ExpCls = Exp
+ self.args = self._get_parser(Exp).parse_args()
+ self.env = RlaunchReplicaEnv(self.args.sync_bn, self.args.devices, self.args.find_unused_parameters)
+
+ @property
+ def exp(self):
+ if not hasattr(self, "_exp"):
+ exp = self.ExpCls(
+ **{x if y is not None else "none": y for (x, y) in vars(self.args).items()},
+ total_devices=self.env.world_size(),
+ )
+ self.exp_updated_cfg_msg = exp.update_attr(self.args.exp_options)
+ self._exp = exp
+ return self._exp
+
+ def _get_parser(self, Exp):
+ parser = argparse.ArgumentParser()
+ parser = Exp.add_argparse_args(parser)
+ parser = self.add_argparse_args(parser)
+ return parser
+
+ @staticmethod
+ def add_argparse_args(parser: argparse.ArgumentParser):
+ parser.add_argument("--eval", dest="eval", action="store_true", help="conduct evaluation only")
+ parser.add_argument("-te", "--train_and_eval", dest="train_and_eval", action="store_true", help="train+eval")
+ parser.add_argument("--find_unused_parameters", dest="find_unused_parameters", action="store_true")
+ parser.add_argument("-d", "--devices", default="0-7", type=str, help="device for training")
+ parser.add_argument("--ckpt", type=str, default=None, help="checkpoint to start from or be evaluated")
+ parser.add_argument("--pretrained_model", type=str, default=None, help="pretrained_model used by training")
+ parser.add_argument("--sync_bn", type=int, default=0, help="0-> disable sync_bn, 1-> whole world")
+ clearml_parser = parser.add_mutually_exclusive_group(required=False)
+ clearml_parser.add_argument("--clearml", dest="clearml", action="store_true", help="enabel clearml for train")
+ clearml_parser.add_argument("--no-clearml", dest="clearml", action="store_false", help="disable clearml")
+ parser.set_defaults(clearml=True)
+ return parser
+
+ def _get_exp_output_dir(self):
+ exp_dir = os.path.join(os.path.join(get_root_dir(), "outputs"), sanitize_filename(self.exp.exp_name))
+ os.makedirs(exp_dir, exist_ok=True)
+ output_dir = None
+ if self.args.ckpt:
+ output_dir = os.path.dirname(os.path.dirname(os.path.abspath(self.args.ckpt)))
+ elif self.env.global_rank() == 0:
+ output_dir = os.path.join(exp_dir, datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S"))
+ os.makedirs(output_dir, exist_ok=True)
+ # make a symlink "latest"
+ symlink, symlink_tmp = os.path.join(exp_dir, "latest"), os.path.join(exp_dir, "latest_tmp")
+ if os.path.exists(symlink_tmp):
+ os.remove(symlink_tmp)
+ os.symlink(os.path.relpath(output_dir, exp_dir), symlink_tmp)
+ os.rename(symlink_tmp, symlink)
+ output_dir = all_gather_object(output_dir)[0]
+ return output_dir
+
+ def get_evaluator(self, callbacks=None):
+ exp = self.exp
+ if self.args.ckpt is None:
+ warnings.warn("No checkpoint is specified for evaluation")
+ if exp.eval_executor_class is None:
+ sys.exit("No evaluator is specified for evaluation")
+
+ output_dir = self._get_exp_output_dir()
+ logger = setup_logger(output_dir, distributed_rank=self.env.global_rank(), filename="eval.log")
+ self._set_basic_log_message(logger)
+ if callbacks is None:
+ callbacks = [self.env, CheckPointLoader(self.args.ckpt)]
+ evaluator = exp.eval_executor_class(exp=exp, callbacks=callbacks, logger=logger)
+ return evaluator
+
+ def _set_basic_log_message(self, logger):
+ logger.opt(ansi=True).info("Cli arguments:\n{}".format(self.args))
+ logger.info(f"exp_name: {self.exp.exp_name}")
+ logger.opt(ansi=True).info(
+ "Used experiment configs:\n{}".format(self.exp.get_cfg_as_str())
+ )
+ if self.exp_updated_cfg_msg:
+ logger.opt(ansi=True).info(
+ "List of override configs:\n{}".format(self.exp_updated_cfg_msg)
+ )
+ logger.opt(ansi=True).info("Environment info:\n{}".format(collect_env_info()))
+
+ def get_trainer(self, callbacks=None, evaluator=None):
+ args = self.args
+ exp = self.exp
+ if evaluator is not None:
+ output_dir = self.exp.output_dir
+ else:
+ output_dir = self._get_exp_output_dir()
+
+ logger = setup_logger(output_dir, distributed_rank=self.env.global_rank(), filename="train.log")
+ self._set_basic_log_message(logger)
+
+ if callbacks is None:
+ callbacks = [
+ self.env,
+ ProgressBar(logger=logger),
+ TextMonitor(interval=exp.print_interval),
+ TensorBoardMonitor(os.path.join(output_dir, "tensorboard"), interval=exp.print_interval),
+ CheckPointSaver(
+ local_path=os.path.join(output_dir, "dump_model"),
+ remote_path=exp.ckpt_oss_save_dir,
+ save_interval=exp.dump_interval,
+ num_keep_latest=exp.num_keep_latest_ckpt,
+ ),
+ ]
+ if "grad_clip_value" in exp.__dict__:
+ callbacks.append(ClipGrad(exp.grad_clip_value))
+ if args.clearml:
+ callbacks.append(ClearMLCallback())
+ if args.ckpt:
+ callbacks.append(CheckPointLoader(args.ckpt))
+ if args.pretrained_model:
+ callbacks.append(CheckPointLoader(args.pretrained_model, weight_only=True))
+ callbacks.extend(exp.callbacks)
+
+ trainer = Trainer(exp=exp, callbacks=callbacks, logger=logger, evaluator=evaluator)
+ return trainer
+
+ def executor(self):
+ if self.args.eval:
+ self.get_evaluator().eval()
+ elif self.args.train_and_eval:
+ evaluator = self.get_evaluator(callbacks=[])
+ self.get_trainer(evaluator=evaluator).train()
+ else:
+ self.get_trainer().train()
+
+ def dispatch(self, executor_func):
+ is_master = self.env.global_rank() == 0
+ with ShareFSUUIDNameServer(is_master) as ns:
+ self.env.set_master_uri(ns)
+ self.env.setup_nccl()
+ if self.env.local_rank() == 0:
+ command = sys.argv.copy()
+ command[0] = os.path.abspath(command[0])
+ command = [sys.executable] + command
+ for local_rank in range(1, self.env.nr_gpus):
+ env_copy = os.environ.copy()
+ env_copy["LOCAL_RANK"] = f"{local_rank}"
+ subprocess.Popen(command, env=env_copy)
+ self.env.init_dist()
+ executor_func()
+
+ def run(self):
+ self.dispatch(self.executor)
+
+
+class MapMasterCli(BaseCli):
+ @PyDecorator.overrides(BaseCli)
+ def get_evaluator(self, callbacks=None):
+ exp = self.exp
+
+ output_dir = self._get_exp_output_dir()
+ self.exp.output_dir = output_dir
+ logger = setup_logger(output_dir, distributed_rank=self.env.global_rank(), filename="eval.log")
+ self._set_basic_log_message(logger)
+ if callbacks is None:
+ callbacks = [
+ self.env,
+ CheckPointLoader(self.args.ckpt),
+ ]
+
+ evaluator = BeMapNetEvaluator(exp=exp, callbacks=callbacks, logger=logger)
+ return evaluator
diff --git a/mapmaster/engine/environ.py b/mapmaster/engine/environ.py
new file mode 100644
index 0000000..d6a2757
--- /dev/null
+++ b/mapmaster/engine/environ.py
@@ -0,0 +1,151 @@
+import os
+import time
+import uuid
+import torch
+import subprocess
+import numpy as np
+from torch import nn
+from loguru import logger
+import torch.distributed as dist
+from mapmaster.utils.env import get_root_dir
+from mapmaster.utils.misc import parse_devices
+from mapmaster.engine.callbacks import Callback
+
+
+__all__ = ["ShareFSUUIDNameServer", "RlaunchReplicaEnv"]
+output_root_dir = os.path.join(get_root_dir(), "outputs")
+
+
+class ShareFSUUIDNameServer:
+ def __init__(self, is_master):
+ self.exp_id = self._get_exp_id()
+ self.is_master = is_master
+ os.makedirs(os.path.dirname(self.filepath), exist_ok=True)
+
+ def _get_exp_id(self):
+ if "DET3D_EXPID" not in os.environ:
+ if int(os.environ.get("RLAUNCH_REPLICA_TOTAL", 1)) == 1:
+ return str(uuid.uuid4())
+ msg = """cannot find DET3D_EXPID in environ please use following
+ command DET3D_EXPID=$(cat /proc/sys/kernel/random/uuid) rlaunch ...
+ """
+ logger.error(msg)
+ raise RuntimeError
+ return str(os.environ["DET3D_EXPID"])
+
+ @property
+ def filepath(self):
+ return os.path.join(output_root_dir, f"master_ip_{self.exp_id}.txt")
+
+ def __enter__(self):
+ if self.is_master:
+ self.set_master()
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_tb):
+ if self.is_master:
+ os.remove(self.filepath)
+
+ def set_master(self):
+ assert not os.path.exists(self.filepath)
+ hostname = "Host"
+ with open(self.filepath, "w") as f:
+ f.write(hostname)
+
+ def get_master(self):
+ while True:
+ if os.path.exists(self.filepath):
+ with open(self.filepath, "r") as f:
+ return f.read()
+ else:
+ time.sleep(5)
+
+
+class _DDPEnv(Callback):
+ def __init__(self, sync_bn=0, devices=None, find_unused_parameters=False):
+ if devices:
+ devices = parse_devices(devices)
+ os.environ["CUDA_VISIBLE_DEVICES"] = devices
+ self.nr_gpus = torch.cuda.device_count()
+ self.sync_bn = sync_bn
+ self.find_unused_parameters = find_unused_parameters
+
+ @staticmethod
+ def setup_nccl():
+ ifname = filter(lambda x: x not in ("lo",), os.listdir("/sys/class/net/"))
+ os.environ["NCCL_SOCKET_IFNAME"] = ",".join(ifname)
+ os.environ["NCCL_IB_DISABLE"] = "1"
+
+ # os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
+ os.environ["NCCL_IB_HCA"] = subprocess.getoutput(
+ "cd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; "
+ "do cat $i/ports/1/gid_attrs/types/* 2>/dev/null "
+ "| grep v >/dev/null && echo $i ; done; > /dev/null"
+ )
+ os.environ["NCCL_IB_GID_INDEX"] = "3"
+ os.environ["NCCL_IB_TC"] = "106"
+
+ def after_init(self, trainer):
+ trainer.model.cuda()
+ if int(self.sync_bn) > 1:
+ ranks = np.arange(self.world_size()).reshape(-1, self.sync_bn)
+ process_groups = [torch.distributed.new_group(list(pids)) for pids in ranks]
+ trainer.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
+ trainer.model, process_groups[self.global_rank() // self.sync_bn]
+ )
+ elif int(self.sync_bn) == 1:
+ trainer.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(trainer.model)
+ trainer.model = nn.parallel.DistributedDataParallel(
+ trainer.model, device_ids=[self.local_rank()], find_unused_parameters=self.find_unused_parameters
+ )
+
+ def cleanup(self):
+ dist.destroy_process_group()
+
+ def init_dist(self):
+ torch.cuda.set_device(self.local_rank())
+ dist.init_process_group(
+ backend="nccl",
+ init_method=self._master_uri,
+ rank=self.global_rank(),
+ world_size=self.world_size(),
+ )
+ dist.barrier()
+
+
+class RlaunchReplicaEnv(_DDPEnv):
+ def __init__(self, sync_bn=0, devices=None, find_unused_parameters=False):
+ super().__init__(sync_bn, devices, find_unused_parameters)
+
+ def set_master_uri(self, ns):
+ self._master_uri = f"tcp://{self.master_address(ns)}:{self.master_port()}"
+ logger.info(self._master_uri)
+
+ @staticmethod
+ def is_brainpp_mm_env():
+ return int(os.environ.get("RLAUNCH_REPLICA_TOTAL", 1)) > 1
+
+ def master_address(self, ns) -> str:
+ if self.node_rank() == 0:
+ root_node = "localhost"
+ else:
+ root_node = ns.get_master()
+ os.environ["MASTER_ADDR"] = root_node
+ return root_node
+
+ def master_port(self) -> int:
+ port = os.environ.get("MASTER_PORT", 12345)
+ os.environ["MASTER_PORT"] = str(port)
+ return int(port)
+
+ def world_size(self) -> int:
+ return int(os.environ.get("RLAUNCH_REPLICA_TOTAL", 1)) * int(self.nr_gpus)
+
+ def global_rank(self) -> int:
+ return int(self.nr_gpus) * self.node_rank() + self.local_rank()
+
+ def local_rank(self) -> int:
+ return int(os.environ.get("LOCAL_RANK", 0))
+
+ def node_rank(self) -> int:
+ return int(os.environ.get("RLAUNCH_REPLICA", 0))
diff --git a/mapmaster/engine/executor.py b/mapmaster/engine/executor.py
new file mode 100644
index 0000000..a7bcf4c
--- /dev/null
+++ b/mapmaster/engine/executor.py
@@ -0,0 +1,227 @@
+import os
+import torch
+from tqdm import tqdm
+from typing import Sequence
+from mapmaster.engine.experiment import BaseExp
+from mapmaster.utils.misc import get_rank, synchronize
+
+
+__all__ = ["Callback", "BaseExecutor", "Trainer", "BeMapNetEvaluator"]
+
+
+class Callback:
+
+ # callback enabled rank list
+ # None means callback is always enabled
+ enabled_rank = None
+
+ def setup(self, executor):
+ pass
+
+ def load_checkpoint(self, executor):
+ pass
+
+ def after_init(self, executor):
+ pass
+
+ def before_train(self, executor):
+ pass
+
+ def before_epoch(self, executor, epoch: int):
+ pass
+
+ def before_step(self, executor, step, data_dict):
+ pass
+
+ def before_backward(self, executor):
+ pass
+
+ def before_optimize(self, executor):
+ pass
+
+ def after_step(self, executor, step, data_dict, *args, **kwargs):
+ pass
+
+ def after_epoch(self, executor, epoch: int, update_best_ckpt: bool = False):
+ pass
+
+ def after_train(self, executor):
+ pass
+
+
+class BaseExecutor:
+ def __init__(self, exp: BaseExp, callbacks: Sequence["Callback"], logger=None) -> None:
+ self.exp = exp
+ self.logger = logger
+ self.callbacks = callbacks
+ self._invoke_callback("setup")
+
+ self.epoch = 0
+ self.global_step = 0
+ self._invoke_callback("load_checkpoint")
+ self._invoke_callback("after_init")
+
+ @property
+ def train_dataloader(self):
+ return self.exp.train_dataloader
+
+ @property
+ def val_dataloader(self):
+ return self.exp.val_dataloader
+
+ @property
+ def model(self):
+ return self.exp.model
+
+ @model.setter
+ def model(self, value):
+ self.exp.model = value
+
+ @property
+ def optimizer(self):
+ return self.exp.optimizer
+
+ @property
+ def lr_scheduler(self):
+ return self.exp.lr_scheduler
+
+ def _invoke_callback(self, callback_name, *args, **kwargs):
+ for cb in self.callbacks:
+ if cb.enabled_rank is None or self.global_rank in cb.enabled_rank:
+ func = getattr(cb, callback_name, None)
+ if func:
+ func(self, *args, **kwargs)
+
+ @property
+ def global_rank(self):
+ return get_rank()
+
+
+class Trainer(BaseExecutor):
+ def __init__(
+ self, exp: BaseExp, callbacks: Sequence["Callback"], logger=None, use_amp=False, evaluator=None
+ ) -> None:
+ super(Trainer, self).__init__(exp, callbacks, logger)
+ self.use_amp = use_amp
+ self.evaluator = evaluator
+ if self.use_amp:
+ self.grad_scaler = torch.cuda.amp.GradScaler()
+
+ def train(self):
+ self.train_iter = iter(self.train_dataloader)
+ self._invoke_callback("before_train")
+ self.model.cuda()
+ self.model.train()
+ self.optimizer_to(self.optimizer, next(self.model.parameters()).device)
+ start_epoch = self.epoch
+ for epoch in range(start_epoch, self.exp.max_epoch):
+ self.epoch = epoch
+ self.model.train()
+ self.train_epoch(epoch)
+ self._invoke_callback("after_train")
+
+ def train_epoch(self, epoch):
+ self._invoke_callback("before_epoch", epoch)
+ sampler = self.train_dataloader.sampler
+ if hasattr(sampler, "set_epoch"):
+ sampler.set_epoch(epoch)
+ for step in range(len(self.train_dataloader)):
+ try:
+ data = next(self.train_iter)
+ except StopIteration:
+ self.train_iter = iter(self.train_dataloader)
+ data = next(self.train_iter)
+ self.train_step(data, step)
+ if self.evaluator is not None:
+ self.evaluator.eval()
+ self._invoke_callback("after_epoch", epoch, update_best_ckpt=False)
+
+ def train_step(self, data, step):
+ self._invoke_callback("before_step", step, data)
+ self.lr_scheduler.step(self.global_step)
+ self.model.train()
+ self.optimizer.zero_grad()
+ if not self.use_amp:
+ ret = self.exp.training_step(data)
+ else:
+ with torch.cuda.amp.autocast():
+ ret = self.exp.training_step(data)
+ if isinstance(ret, torch.Tensor):
+ loss = ret
+ ext_dict = None
+ elif isinstance(ret, tuple):
+ loss, ext_dict = ret
+ ext_dict = {k: v.detach() if isinstance(v, torch.Tensor) else v for k, v in ext_dict.items()}
+ else:
+ raise TypeError
+ self._invoke_callback("before_backward")
+ if not self.use_amp:
+ loss.backward()
+ self._invoke_callback("before_optimize")
+ self.optimizer.step()
+ else:
+ self.grad_scaler.scale(loss).backward()
+ self.grad_scaler.unscale_(self.optimizer) # NOTE: grads are unscaled before "before_optimize" callbacks
+ self._invoke_callback("before_optimize")
+ self.grad_scaler.step(self.optimizer)
+ self.grad_scaler.update()
+ self._invoke_callback("after_step", step, data, loss=loss.detach(), extra=ext_dict)
+ self.global_step += 1
+
+ # refer to: https://github.com/pytorch/pytorch/issues/8741
+ @staticmethod
+ def optimizer_to(optim, device):
+ for param in optim.state.values():
+ # Not sure there are any global tensors in the state dict
+ if isinstance(param, torch.Tensor):
+ param.data = param.data.to(device)
+ if param._grad is not None:
+ param._grad.data = param._grad.data.to(device)
+ elif isinstance(param, dict):
+ for subparam in param.values():
+ if isinstance(subparam, torch.Tensor):
+ subparam.data = subparam.data.to(device)
+ if subparam._grad is not None:
+ subparam._grad.data = subparam._grad.data.to(device)
+
+
+class BeMapNetEvaluator(BaseExecutor):
+ def __init__(self, exp: BaseExp, callbacks: Sequence["Callback"], logger=None) -> None:
+ super(BeMapNetEvaluator, self).__init__(exp, callbacks, logger)
+
+ def eval(self, ckpt_name=None):
+
+ exp = self.exp
+ val_iter = iter(self.val_dataloader)
+
+ self._invoke_callback("before_eval")
+
+ if ckpt_name is not None:
+ if get_rank() == 0:
+ self.logger.info("Eval with best checkpoint!")
+ path = os.path.join(exp.output_dir, 'dump_model', ckpt_name)
+ checkpoint = torch.load(open(path, "rb"), map_location=torch.device("cpu"))
+ self.model.load_state_dict(checkpoint["model_state"], strict=False)
+
+ self.model.cuda()
+ self.model.eval()
+
+ for step in tqdm(range(len(self.val_dataloader))):
+ batch_data = next(val_iter)
+ with torch.no_grad():
+ exp.test_step(batch_data)
+ self._invoke_callback("after_step", step, {})
+
+ synchronize()
+
+ if get_rank() == 0:
+ self.logger.info("Done with inference, start evaluation later!")
+ gt_dir = exp.exp_config.map_conf['anno_root']
+ dt_dir = exp.evaluation_save_dir
+ val_txts = exp.exp_config.VAL_TXT
+
+ for val_txt in val_txts:
+ ap_table = "".join(os.popen(f"python3 tools/evaluation/eval.py {gt_dir} {dt_dir} {val_txt}").readlines())
+ self.logger.info(" AP-Performance with HDMapNetAPI: \n" + val_txt + "\n" + ap_table)
+
+ self._invoke_callback("after_eval")
diff --git a/mapmaster/engine/experiment.py b/mapmaster/engine/experiment.py
new file mode 100644
index 0000000..af5aee2
--- /dev/null
+++ b/mapmaster/engine/experiment.py
@@ -0,0 +1,187 @@
+import os
+import sys
+import torch
+import functools
+import numpy as np
+from torch.nn import Module
+from tabulate import tabulate
+from abc import ABCMeta, abstractmethod
+from mapmaster.utils.misc import DictAction
+
+
+class BaseExp(metaclass=ABCMeta):
+ """Basic class for any experiment in Perceptron.
+
+ Args:
+ batch_size_per_device (int):
+ batch_size of each device
+
+ total_devices (int):
+ number of devices to use
+
+ max_epoch (int):
+ total training epochs, the reason why we need to give max_epoch
+ is that lr_scheduler may need to be adapted according to max_epoch
+ """
+
+ def __init__(self, batch_size_per_device, total_devices, max_epoch):
+ self._batch_size_per_device = batch_size_per_device
+ self._max_epoch = max_epoch
+ self._total_devices = total_devices
+ # ----------------------------------------------- extra configure ------------------------- #
+ self.seed = None
+ self.exp_name = os.path.splitext(os.path.basename(sys.argv.copy()[0]))[0] # entrypoint filename as exp_name
+ self.print_interval = 100
+ self.dump_interval = 10
+ self.eval_interval = 10
+ self.num_keep_latest_ckpt = 10
+ self.ckpt_oss_save_dir = None
+ self.enable_tensorboard = False
+ self.eval_executor_class = None
+
+ @property
+ def train_dataloader(self):
+ if "_train_dataloader" not in self.__dict__:
+ self._train_dataloader = self._configure_train_dataloader()
+ return self._train_dataloader
+
+ @property
+ def val_dataloader(self):
+ if "_val_dataloader" not in self.__dict__:
+ self._val_dataloader = self._configure_val_dataloader()
+ return self._val_dataloader
+
+ @property
+ def test_dataloader(self):
+ if "_test_dataloader" not in self.__dict__:
+ self._test_dataloader = self._configure_test_dataloader()
+ return self._test_dataloader
+
+ @property
+ def model(self):
+ if "_model" not in self.__dict__:
+ self._model = self._configure_model()
+ return self._model
+
+ @model.setter
+ def model(self, value):
+ self._model = value
+
+ @property
+ def callbacks(self):
+ if not hasattr(self, "_callbacks"):
+ self._callbacks = self._configure_callbacks()
+ return self._callbacks
+
+ @property
+ def optimizer(self):
+ if "_optimizer" not in self.__dict__:
+ self._optimizer = self._configure_optimizer()
+ return self._optimizer
+
+ @property
+ def lr_scheduler(self):
+ if "_lr_scheduler" not in self.__dict__:
+ self._lr_scheduler = self._configure_lr_scheduler()
+ return self._lr_scheduler
+
+ @property
+ def batch_size_per_device(self):
+ return self._batch_size_per_device
+
+ @property
+ def max_epoch(self):
+ return self._max_epoch
+
+ @property
+ def total_devices(self):
+ return self._total_devices
+
+ @abstractmethod
+ def _configure_model(self) -> Module:
+ pass
+
+ @abstractmethod
+ def _configure_train_dataloader(self):
+ """"""
+
+ def _configure_callbacks(self):
+ return []
+
+ @abstractmethod
+ def _configure_val_dataloader(self):
+ """"""
+
+ @abstractmethod
+ def _configure_test_dataloader(self):
+ """"""
+
+ def training_step(self, *args, **kwargs):
+ pass
+
+ @abstractmethod
+ def _configure_optimizer(self) -> torch.optim.Optimizer:
+ pass
+
+ @abstractmethod
+ def _configure_lr_scheduler(self, **kwargs):
+ pass
+
+ def update_attr(self, options: dict) -> str:
+ if options is None:
+ return ""
+ assert isinstance(options, dict)
+ msg = ""
+ for k, v in options.items():
+ if k in self.__dict__:
+ old_v = self.__getattribute__(k)
+ if not v == old_v:
+ self.__setattr__(k, v)
+ msg = "{}\n'{}' is overriden from '{}' to '{}'".format(msg, k, old_v, v)
+ else:
+ self.__setattr__(k, v)
+ msg = "{}\n'{}' is set to '{}'".format(msg, k, v)
+
+ # update exp_name
+ exp_name_suffix = "-".join(sorted([f"{k}-{v}" for k, v in options.items()]))
+ self.exp_name = f"{self.exp_name}--{exp_name_suffix}"
+ return msg
+
+ def get_cfg_as_str(self) -> str:
+ config_table = []
+ for c, v in self.__dict__.items():
+ if not isinstance(v, (int, float, str, list, tuple, dict, np.ndarray)):
+ if hasattr(v, "__name__"):
+ v = v.__name__
+ elif hasattr(v, "__class__"):
+ v = v.__class__
+ elif type(v) == functools.partial:
+ v = v.func.__name__
+ if c[0] == "_":
+ c = c[1:]
+ config_table.append((str(c), str(v)))
+
+ headers = ["config key", "value"]
+ config_table = tabulate(config_table, headers, tablefmt="plain")
+ return config_table
+
+ def __str__(self):
+ return self.get_cfg_as_str()
+
+ def to_onnx(self):
+ pass
+
+ @classmethod
+ def add_argparse_args(cls, parser): # pragma: no-cover
+ parser.add_argument(
+ "--exp_options",
+ nargs="+",
+ action=DictAction,
+ help="override some settings in the exp, the key-value pair in xxx=yyy format will be merged into exp. "
+ 'If the value to be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+ "Note that the quotation marks are necessary and that no white space is allowed.",
+ )
+ parser.add_argument("-b", "--batch-size-per-device", type=int, default=None)
+ parser.add_argument("-e", "--max-epoch", type=int, default=None)
+ return parser
diff --git a/mapmaster/models/__init__.py b/mapmaster/models/__init__.py
new file mode 100644
index 0000000..f8f8380
--- /dev/null
+++ b/mapmaster/models/__init__.py
@@ -0,0 +1 @@
+from .network import MapMaster
diff --git a/mapmaster/models/backbone/__init__.py b/mapmaster/models/backbone/__init__.py
new file mode 100644
index 0000000..e9a30cc
--- /dev/null
+++ b/mapmaster/models/backbone/__init__.py
@@ -0,0 +1 @@
+from .model import ResNetBackbone, EfficientNetBackbone, SwinTRBackbone
diff --git a/mapmaster/models/backbone/bifpn/__init__.py b/mapmaster/models/backbone/bifpn/__init__.py
new file mode 100644
index 0000000..0ed49b5
--- /dev/null
+++ b/mapmaster/models/backbone/bifpn/__init__.py
@@ -0,0 +1 @@
+from .model import BiFPN
diff --git a/mapmaster/models/backbone/bifpn/model.py b/mapmaster/models/backbone/bifpn/model.py
new file mode 100644
index 0000000..0615998
--- /dev/null
+++ b/mapmaster/models/backbone/bifpn/model.py
@@ -0,0 +1,372 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from .utils import Swish, Conv2dStaticSamePadding, MaxPool2dStaticSamePadding
+
+
+class SeparableConvBlock(nn.Module):
+ """
+ created by Zylo117
+ """
+
+ def __init__(self, in_channels, out_channels=None, norm=True, activation=False, norm_layer=nn.BatchNorm2d):
+ super(SeparableConvBlock, self).__init__()
+ if out_channels is None:
+ out_channels = in_channels
+
+ # Q: whether separate conv
+ # share bias between depthwise_conv and pointwise_conv
+ # or just pointwise_conv apply bias.
+ # A: Confirmed, just pointwise_conv applies bias, depthwise_conv has no bias.
+
+ self.depthwise_conv = Conv2dStaticSamePadding(
+ in_channels, in_channels, kernel_size=3, stride=1, groups=in_channels, bias=False
+ )
+ self.pointwise_conv = Conv2dStaticSamePadding(in_channels, out_channels, kernel_size=1, stride=1)
+
+ self.norm = norm
+ if self.norm:
+ # Warning: pytorch momentum is different from tensorflow's, momentum_pytorch = 1 - momentum_tensorflow
+ self.bn = norm_layer(num_features=out_channels, momentum=0.01, eps=1e-3)
+
+ self.activation = activation
+ if self.activation:
+ self.swish = Swish()
+
+ def forward(self, x):
+ x = self.depthwise_conv(x)
+ x = self.pointwise_conv(x)
+
+ if self.norm:
+ x = self.bn(x)
+
+ if self.activation:
+ x = self.swish(x)
+
+ return x
+
+
+class BiFPNLayer(nn.Module):
+ """
+ modified by Zylo117
+ """
+
+ def __init__(
+ self,
+ num_channels,
+ conv_channels,
+ first_time=False,
+ epsilon=1e-4,
+ attention=True,
+ use_p8=False,
+ norm_layer=nn.BatchNorm2d,
+ ):
+ """
+ Args:
+ num_channels:
+ conv_channels:
+ first_time: whether the input comes directly from the efficientnet,
+ if True, downchannel it first, and downsample P5 to generate P6 then P7
+ epsilon: epsilon of fast weighted attention sum of BiFPN, not the BN's epsilon
+ onnx_export: if True, use Swish instead of MemoryEfficientSwish
+ """
+ super(BiFPNLayer, self).__init__()
+ self.epsilon = epsilon
+ self.use_p8 = use_p8
+
+ # Conv layers
+ self.conv6_up = SeparableConvBlock(num_channels, norm_layer=norm_layer)
+ self.conv5_up = SeparableConvBlock(num_channels, norm_layer=norm_layer)
+ self.conv4_up = SeparableConvBlock(num_channels, norm_layer=norm_layer)
+ self.conv3_up = SeparableConvBlock(num_channels, norm_layer=norm_layer)
+ self.conv4_down = SeparableConvBlock(num_channels, norm_layer=norm_layer)
+ self.conv5_down = SeparableConvBlock(num_channels, norm_layer=norm_layer)
+ self.conv6_down = SeparableConvBlock(num_channels, norm_layer=norm_layer)
+ self.conv7_down = SeparableConvBlock(num_channels, norm_layer=norm_layer)
+ if use_p8:
+ self.conv7_up = SeparableConvBlock(num_channels, norm_layer=norm_layer)
+ self.conv8_down = SeparableConvBlock(num_channels, norm_layer=norm_layer)
+
+ # Feature scaling layers
+ self.p6_upsample = nn.Upsample(scale_factor=2, mode="nearest")
+ self.p5_upsample = nn.Upsample(scale_factor=2, mode="nearest")
+ self.p4_upsample = nn.Upsample(scale_factor=2, mode="nearest")
+ self.p3_upsample = nn.Upsample(scale_factor=2, mode="nearest")
+
+ self.p4_downsample = MaxPool2dStaticSamePadding(3, 2)
+ self.p5_downsample = MaxPool2dStaticSamePadding(3, 2)
+ self.p6_downsample = MaxPool2dStaticSamePadding(3, 2)
+ self.p7_downsample = MaxPool2dStaticSamePadding(3, 2)
+ if use_p8:
+ self.p7_upsample = nn.Upsample(scale_factor=2, mode="nearest")
+ self.p8_downsample = MaxPool2dStaticSamePadding(3, 2)
+
+ self.swish = Swish()
+
+ self.first_time = first_time
+ if self.first_time:
+ self.p5_down_channel = nn.Sequential(
+ Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
+ norm_layer(num_channels, momentum=0.01, eps=1e-3),
+ )
+ self.p4_down_channel = nn.Sequential(
+ Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
+ norm_layer(num_channels, momentum=0.01, eps=1e-3),
+ )
+ self.p3_down_channel = nn.Sequential(
+ Conv2dStaticSamePadding(conv_channels[0], num_channels, 1),
+ norm_layer(num_channels, momentum=0.01, eps=1e-3),
+ )
+
+ self.p5_to_p6 = nn.Sequential(
+ Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
+ norm_layer(num_channels, momentum=0.01, eps=1e-3),
+ MaxPool2dStaticSamePadding(3, 2),
+ )
+ self.p6_to_p7 = nn.Sequential(MaxPool2dStaticSamePadding(3, 2))
+ if use_p8:
+ self.p7_to_p8 = nn.Sequential(MaxPool2dStaticSamePadding(3, 2))
+
+ self.p4_down_channel_2 = nn.Sequential(
+ Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
+ norm_layer(num_channels, momentum=0.01, eps=1e-3),
+ )
+ self.p5_down_channel_2 = nn.Sequential(
+ Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
+ norm_layer(num_channels, momentum=0.01, eps=1e-3),
+ )
+
+ # Weight
+ self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+ self.p6_w1_relu = nn.ReLU()
+ self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+ self.p5_w1_relu = nn.ReLU()
+ self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+ self.p4_w1_relu = nn.ReLU()
+ self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+ self.p3_w1_relu = nn.ReLU()
+
+ self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
+ self.p4_w2_relu = nn.ReLU()
+ self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
+ self.p5_w2_relu = nn.ReLU()
+ self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
+ self.p6_w2_relu = nn.ReLU()
+ self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+ self.p7_w2_relu = nn.ReLU()
+
+ self.attention = attention
+
+ def forward(self, inputs):
+ """
+ illustration of a minimal bifpn unit
+ P7_0 -------------------------> P7_2 -------->
+ |-------------| ↑
+ ↓ |
+ P6_0 ---------> P6_1 ---------> P6_2 -------->
+ |-------------|--------------↑ ↑
+ ↓ |
+ P5_0 ---------> P5_1 ---------> P5_2 -------->
+ |-------------|--------------↑ ↑
+ ↓ |
+ P4_0 ---------> P4_1 ---------> P4_2 -------->
+ |-------------|--------------↑ ↑
+ |--------------↓ |
+ P3_0 -------------------------> P3_2 -------->
+ """
+
+ # downsample channels using same-padding conv2d to target phase's if not the same
+ # judge: same phase as target,
+ # if same, pass;
+ # elif earlier phase, downsample to target phase's by pooling
+ # elif later phase, upsample to target phase's by nearest interpolation
+
+ if self.attention:
+ outs = self._forward_fast_attention(inputs)
+ else:
+ outs = self._forward(inputs)
+
+ return outs
+
+ def _forward_fast_attention(self, inputs):
+ if self.first_time:
+ p3, p4, p5 = inputs
+
+ p6_in = self.p5_to_p6(p5)
+ p7_in = self.p6_to_p7(p6_in)
+
+ p3_in = self.p3_down_channel(p3)
+ p4_in = self.p4_down_channel(p4)
+ p5_in = self.p5_down_channel(p5)
+
+ else:
+ # P3_0, P4_0, P5_0, P6_0 and P7_0
+ p3_in, p4_in, p5_in, p6_in, p7_in = inputs
+
+ # P7_0 to P7_2
+
+ # Weights for P6_0 and P7_0 to P6_1
+ p6_w1 = self.p6_w1_relu(self.p6_w1)
+ weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
+ # Connections for P6_0 and P7_0 to P6_1 respectively
+ p6_up = self.conv6_up.forward(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in)))
+
+ # Weights for P5_0 and P6_1 to P5_1
+ p5_w1 = self.p5_w1_relu(self.p5_w1)
+ weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
+ # Connections for P5_0 and P6_1 to P5_1 respectively
+ p5_up = self.conv5_up.forward(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up)))
+
+ # Weights for P4_0 and P5_1 to P4_1
+ p4_w1 = self.p4_w1_relu(self.p4_w1)
+ weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
+ # Connections for P4_0 and P5_1 to P4_1 respectively
+ p4_up = self.conv4_up.forward(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up)))
+
+ # Weights for P3_0 and P4_1 to P3_2
+ p3_w1 = self.p3_w1_relu(self.p3_w1)
+ weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
+ # Connections for P3_0 and P4_1 to P3_2 respectively
+ p3_out = self.conv3_up.forward(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up)))
+
+ if self.first_time:
+ p4_in = self.p4_down_channel_2(p4)
+ p5_in = self.p5_down_channel_2(p5)
+
+ # Weights for P4_0, P4_1 and P3_2 to P4_2
+ p4_w2 = self.p4_w2_relu(self.p4_w2)
+ weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon)
+ # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
+ p4_out = self.conv4_down.forward(
+ self.swish(weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out))
+ )
+
+ # Weights for P5_0, P5_1 and P4_2 to P5_2
+ p5_w2 = self.p5_w2_relu(self.p5_w2)
+ weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon)
+ # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
+ p5_out = self.conv5_down.forward(
+ self.swish(weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out))
+ )
+
+ # Weights for P6_0, P6_1 and P5_2 to P6_2
+ p6_w2 = self.p6_w2_relu(self.p6_w2)
+ weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon)
+ # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
+ p6_out = self.conv6_down.forward(
+ self.swish(weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out))
+ )
+
+ # Weights for P7_0 and P6_2 to P7_2
+ p7_w2 = self.p7_w2_relu(self.p7_w2)
+ weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon)
+ # Connections for P7_0 and P6_2 to P7_2
+ p7_out = self.conv7_down.forward(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out)))
+
+ return p3_out, p4_out, p5_out, p6_out, p7_out
+
+ def _forward(self, inputs):
+ if self.first_time:
+ p3, p4, p5 = inputs
+
+ p6_in = self.p5_to_p6(p5)
+ p7_in = self.p6_to_p7(p6_in)
+ if self.use_p8:
+ p8_in = self.p7_to_p8(p7_in)
+
+ p3_in = self.p3_down_channel(p3)
+ p4_in = self.p4_down_channel(p4)
+ p5_in = self.p5_down_channel(p5)
+
+ else:
+ if self.use_p8:
+ # P3_0, P4_0, P5_0, P6_0, P7_0 and P8_0
+ p3_in, p4_in, p5_in, p6_in, p7_in, p8_in = inputs
+ else:
+ # P3_0, P4_0, P5_0, P6_0 and P7_0
+ p3_in, p4_in, p5_in, p6_in, p7_in = inputs
+
+ if self.use_p8:
+ # P8_0 to P8_2
+
+ # Connections for P7_0 and P8_0 to P7_1 respectively
+ p7_up = self.conv7_up.forward(self.swish(p7_in + self.p7_upsample(p8_in)))
+
+ # Connections for P6_0 and P7_0 to P6_1 respectively
+ p6_up = self.conv6_up.forward(self.swish(p6_in + self.p6_upsample(p7_up)))
+ else:
+ # P7_0 to P7_2
+
+ # Connections for P6_0 and P7_0 to P6_1 respectively
+ p6_up = self.conv6_up.forward(self.swish(p6_in + self.p6_upsample(p7_in)))
+
+ # Connections for P5_0 and P6_1 to P5_1 respectively
+ p5_up = self.conv5_up.forward(self.swish(p5_in + self.p5_upsample(p6_up)))
+
+ # Connections for P4_0 and P5_1 to P4_1 respectively
+ p4_up = self.conv4_up.forward(self.swish(p4_in + self.p4_upsample(p5_up)))
+
+ # Connections for P3_0 and P4_1 to P3_2 respectively
+ p3_out = self.conv3_up.forward(self.swish(p3_in + self.p3_upsample(p4_up)))
+
+ if self.first_time:
+ p4_in = self.p4_down_channel_2(p4)
+ p5_in = self.p5_down_channel_2(p5)
+
+ # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
+ p4_out = self.conv4_down.forward(self.swish(p4_in + p4_up + self.p4_downsample(p3_out)))
+
+ # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
+ p5_out = self.conv5_down.forward(self.swish(p5_in + p5_up + self.p5_downsample(p4_out)))
+
+ # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
+ p6_out = self.conv6_down.forward(self.swish(p6_in + p6_up + self.p6_downsample(p5_out)))
+
+ if self.use_p8:
+ # Connections for P7_0, P7_1 and P6_2 to P7_2 respectively
+ p7_out = self.conv7_down.forward(self.swish(p7_in + p7_up + self.p7_downsample(p6_out)))
+
+ # Connections for P8_0 and P7_2 to P8_2
+ p8_out = self.conv8_down.forward(self.swish(p8_in + self.p8_downsample(p7_out)))
+
+ return p3_out, p4_out, p5_out, p6_out, p7_out, p8_out
+ else:
+ # Connections for P7_0 and P6_2 to P7_2
+ p7_out = self.conv7_down.forward(self.swish(p7_in + self.p7_downsample(p6_out)))
+
+ return p3_out, p4_out, p5_out, p6_out, p7_out
+
+
+class BiFPN(nn.Module):
+ def __init__(
+ self, conv_channels, fpn_cell_repeat=2, fpn_num_filters=64, norm_layer=nn.BatchNorm2d, use_checkpoint=False,
+ tgt_shape=(12, 28*6),
+ ):
+ super(BiFPN, self).__init__()
+ self.model = nn.Sequential(
+ *[
+ BiFPNLayer(fpn_num_filters, conv_channels, True if i == 0 else False, norm_layer=norm_layer)
+ for i in range(fpn_cell_repeat)
+ ]
+ )
+ self.tgt_shape = tgt_shape
+ self.use_checkpoint = use_checkpoint
+
+ def forward(self, im_bkb_features):
+ if self.use_checkpoint and self.training:
+ im_nek_features = checkpoint.checkpoint(self._forward, *im_bkb_features)
+ else:
+ im_nek_features = self._forward(*im_bkb_features)
+ im_nek_features = [torch.cat([self.up_sample(x, tgt_shape=self.tgt_shape) for x in im_nek_features], dim=1)]
+ return im_nek_features
+
+ def _forward(self, *inputs):
+ outputs = self.model(inputs[-3:])
+ return outputs
+
+ def up_sample(self, x, tgt_shape=None):
+ tgt_shape = self.tgt_shape if tgt_shape is None else tgt_shape
+ if tuple(x.shape[-2:]) == tuple(tgt_shape):
+ return x
+ return F.interpolate(x, size=tgt_shape, mode="bilinear", align_corners=True)
diff --git a/mapmaster/models/backbone/bifpn/utils.py b/mapmaster/models/backbone/bifpn/utils.py
new file mode 100644
index 0000000..0f7bb4b
--- /dev/null
+++ b/mapmaster/models/backbone/bifpn/utils.py
@@ -0,0 +1,90 @@
+# Author: Zylo117
+
+import math
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class Swish(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class Conv2dStaticSamePadding(nn.Module):
+ """
+ created by Zylo117
+ The real keras/tensorflow conv2d with same padding
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, bias=bias, groups=groups)
+ self.stride = self.conv.stride
+ self.kernel_size = self.conv.kernel_size
+ self.dilation = self.conv.dilation
+
+ if isinstance(self.stride, int):
+ self.stride = [self.stride] * 2
+ elif len(self.stride) == 1:
+ self.stride = [self.stride[0]] * 2
+
+ if isinstance(self.kernel_size, int):
+ self.kernel_size = [self.kernel_size] * 2
+ elif len(self.kernel_size) == 1:
+ self.kernel_size = [self.kernel_size[0]] * 2
+
+ def forward(self, x):
+ h, w = x.shape[-2:]
+
+ extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1]
+ extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0]
+
+ left = extra_h // 2
+ right = extra_h - left
+ top = extra_v // 2
+ bottom = extra_v - top
+
+ x = F.pad(x, [left, right, top, bottom])
+
+ x = self.conv(x)
+ return x
+
+
+class MaxPool2dStaticSamePadding(nn.Module):
+ """
+ created by Zylo117
+ The real keras/tensorflow MaxPool2d with same padding
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.pool = nn.MaxPool2d(*args, **kwargs)
+ self.stride = self.pool.stride
+ self.kernel_size = self.pool.kernel_size
+
+ if isinstance(self.stride, int):
+ self.stride = [self.stride] * 2
+ elif len(self.stride) == 1:
+ self.stride = [self.stride[0]] * 2
+
+ if isinstance(self.kernel_size, int):
+ self.kernel_size = [self.kernel_size] * 2
+ elif len(self.kernel_size) == 1:
+ self.kernel_size = [self.kernel_size[0]] * 2
+
+ def forward(self, x):
+ h, w = x.shape[-2:]
+
+ extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1]
+ extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0]
+
+ left = extra_h // 2
+ right = extra_h - left
+ top = extra_v // 2
+ bottom = extra_v - top
+
+ x = F.pad(x, [left, right, top, bottom])
+
+ x = self.pool(x)
+ return x
diff --git a/mapmaster/models/backbone/efficientnet/__init__.py b/mapmaster/models/backbone/efficientnet/__init__.py
new file mode 100644
index 0000000..464fc1b
--- /dev/null
+++ b/mapmaster/models/backbone/efficientnet/__init__.py
@@ -0,0 +1 @@
+from .model import EfficientNet
diff --git a/mapmaster/models/backbone/efficientnet/model.py b/mapmaster/models/backbone/efficientnet/model.py
new file mode 100644
index 0000000..63d869d
--- /dev/null
+++ b/mapmaster/models/backbone/efficientnet/model.py
@@ -0,0 +1,468 @@
+"""model.py - Model and module class for EfficientNet.
+ They are built to mirror those in the official TensorFlow implementation.
+"""
+
+# Author: lukemelas (github username)
+# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
+# With adjustments and added comments by workingcoder (github username).
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.utils import checkpoint as cp
+from mapmaster.models.backbone.efficientnet.utils import (
+ round_filters,
+ round_repeats,
+ drop_connect,
+ get_same_padding_conv2d,
+ get_model_params,
+ efficientnet_params,
+ load_pretrained_weights,
+ Swish,
+ MemoryEfficientSwish,
+ calculate_output_image_size,
+)
+
+
+VALID_MODELS = (
+ "efficientnet-b0",
+ "efficientnet-b1",
+ "efficientnet-b2",
+ "efficientnet-b3",
+ "efficientnet-b4",
+ "efficientnet-b5",
+ "efficientnet-b6",
+ "efficientnet-b7",
+ "efficientnet-b8",
+ # Support the construction of 'efficientnet-l2' without pretrained weights
+ "efficientnet-l2",
+)
+
+
+class MBConvBlock(nn.Module):
+ """Mobile Inverted Residual Bottleneck Block.
+
+ Args:
+ block_args (namedtuple): BlockArgs, defined in utils.py.
+ global_params (namedtuple): GlobalParam, defined in utils.py.
+ image_size (tuple or list): [image_height, image_width].
+
+ References:
+ [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
+ [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
+ [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
+ """
+
+ def __init__(self, block_args, global_params, image_size=None, norm_layer=nn.BatchNorm2d):
+ super().__init__()
+ self._block_args = block_args
+ self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
+ self._bn_eps = global_params.batch_norm_epsilon
+ self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
+ self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
+
+ # Expansion phase (Inverted Bottleneck)
+ inp = self._block_args.input_filters # number of input channels
+ oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
+ if self._block_args.expand_ratio != 1:
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
+ self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
+ self._bn0 = norm_layer(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
+ # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
+
+ # Depthwise convolution phase
+ k = self._block_args.kernel_size
+ s = self._block_args.stride
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
+ self._depthwise_conv = Conv2d(
+ in_channels=oup,
+ out_channels=oup,
+ groups=oup, # groups makes it depthwise
+ kernel_size=k,
+ stride=s,
+ bias=False,
+ )
+ self._bn1 = norm_layer(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
+ image_size = calculate_output_image_size(image_size, s)
+
+ # Squeeze and Excitation layer, if desired
+ if self.has_se:
+ Conv2d = get_same_padding_conv2d(image_size=(1, 1))
+ num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
+ self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
+ self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
+
+ # Pointwise convolution phase
+ final_oup = self._block_args.output_filters
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
+ self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
+ self._bn2 = norm_layer(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
+ self._swish = MemoryEfficientSwish()
+
+ def forward(self, inputs, drop_connect_rate=None):
+ """MBConvBlock's forward function.
+
+ Args:
+ inputs (tensor): Input tensor.
+ drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
+
+ Returns:
+ Output of this block after processing.
+ """
+
+ # Expansion and Depthwise Convolution
+ x = inputs
+ if self._block_args.expand_ratio != 1:
+ x = self._expand_conv(inputs)
+ x = self._bn0(x)
+ x = self._swish(x)
+
+ x = self._depthwise_conv(x)
+ x = self._bn1(x)
+ x = self._swish(x)
+
+ # Squeeze and Excitation
+ if self.has_se:
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
+ x_squeezed = self._se_reduce(x_squeezed)
+ x_squeezed = self._swish(x_squeezed)
+ x_squeezed = self._se_expand(x_squeezed)
+ x = torch.sigmoid(x_squeezed) * x
+
+ # Pointwise Convolution
+ x = self._project_conv(x)
+ x = self._bn2(x)
+
+ # Skip connection and drop connect
+ input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
+ if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
+ # The combination of skip connection and drop connect brings about stochastic depth.
+ if drop_connect_rate:
+ x = drop_connect(x, p=drop_connect_rate, training=self.training)
+ x = x + inputs # skip connection
+ return x
+
+ def set_swish(self, memory_efficient=True):
+ """Sets swish function as memory efficient (for training) or standard (for export).
+
+ Args:
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
+ """
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+
+
+class EfficientNet(nn.Module):
+ """EfficientNet model.
+ Most easily loaded with the .from_name or .from_pretrained methods.
+
+ Args:
+ blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
+ global_params (namedtuple): A set of GlobalParams shared between blocks.
+
+ References:
+ [1] https://arxiv.org/abs/1905.11946 (EfficientNet)
+
+ Example:
+ >>> import torch
+ >>> from efficientnet.model import EfficientNet
+ >>> inputs = torch.rand(1, 3, 224, 224)
+ >>> model = EfficientNet.from_pretrained('efficientnet-b0')
+ >>> model.eval()
+ >>> outputs = model(inputs)
+ """
+
+ def __init__(self, blocks_args=None, global_params=None, with_head=True, with_cp=False, norm_layer=nn.BatchNorm2d):
+ super().__init__()
+ assert isinstance(blocks_args, list), "blocks_args should be a list"
+ assert len(blocks_args) > 0, "block args must be greater than 0"
+ self._global_params = global_params
+ self._blocks_args = blocks_args
+ self.with_head = with_head
+ self.with_cp = with_cp
+
+ # Batch norm parameters
+ bn_mom = 1 - self._global_params.batch_norm_momentum
+ bn_eps = self._global_params.batch_norm_epsilon
+
+ # Get stem static or dynamic convolution depending on image size
+ image_size = global_params.image_size
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
+
+ # Stem
+ in_channels = 3 # rgb
+ out_channels = round_filters(32, self._global_params) # number of output channels
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
+ self._bn0 = norm_layer(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+ image_size = calculate_output_image_size(image_size, 2)
+
+ # Build blocks
+ self._blocks = nn.ModuleList([])
+
+ for block_args in self._blocks_args:
+
+ # Update block input and output filters based on depth multiplier.
+ block_args = block_args._replace(
+ input_filters=round_filters(block_args.input_filters, self._global_params),
+ output_filters=round_filters(block_args.output_filters, self._global_params),
+ num_repeat=round_repeats(block_args.num_repeat, self._global_params),
+ )
+
+ # The first block needs to take care of stride and filter size increase.
+ self._blocks.append(
+ MBConvBlock(block_args, self._global_params, image_size=image_size, norm_layer=norm_layer)
+ )
+ image_size = calculate_output_image_size(image_size, block_args.stride)
+ if block_args.num_repeat > 1: # modify block_args to keep same output size
+ block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
+ for _ in range(block_args.num_repeat - 1):
+ self._blocks.append(
+ MBConvBlock(block_args, self._global_params, image_size=image_size, norm_layer=norm_layer)
+ )
+ # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
+
+ # Head
+ in_channels = block_args.output_filters # output of final block
+ out_channels = round_filters(1280, self._global_params)
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
+ self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+ self._bn1 = norm_layer(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+
+ # Final linear layer
+ self._avg_pooling = nn.AdaptiveAvgPool2d(1)
+ if self._global_params.include_top:
+ self._dropout = nn.Dropout(self._global_params.dropout_rate)
+ self._fc = nn.Linear(out_channels, self._global_params.num_classes)
+
+ # set activation to memory efficient swish by default
+ self._swish = MemoryEfficientSwish()
+
+ def set_swish(self, memory_efficient=True):
+ """Sets swish function as memory efficient (for training) or standard (for export).
+
+ Args:
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
+ """
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+ for block in self._blocks:
+ block.set_swish(memory_efficient)
+
+ def extract_endpoints(self, inputs):
+ """Use convolution layer to extract features
+ from reduction levels i in [1, 2, 3, 4, 5].
+
+ Args:
+ inputs (tensor): Input tensor.
+
+ Returns:
+ Dictionary of last intermediate features
+ with reduction levels i in [1, 2, 3, 4, 5].
+ Example:
+ >>> import torch
+ >>> from efficientnet.model import EfficientNet
+ >>> inputs = torch.rand(1, 3, 224, 224)
+ >>> model = EfficientNet.from_pretrained('efficientnet-b0')
+ >>> endpoints = model.extract_endpoints(inputs)
+ >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
+ >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
+ >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
+ >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
+ >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
+ >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
+ """
+ endpoints = dict()
+
+ # Stem
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
+ prev_x = x
+
+ # Blocks
+ for idx, block in enumerate(self._blocks):
+ drop_connect_rate = self._global_params.drop_connect_rate
+ if drop_connect_rate:
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
+
+ # x = block(x, drop_connect_rate=drop_connect_rate)
+ if self.with_cp and x.requires_grad:
+ x = cp.checkpoint(block, x, drop_connect_rate)
+ # x = block(x, drop_connect_rate=drop_connect_rate)
+ else:
+ x = block(x, drop_connect_rate=drop_connect_rate)
+
+ if prev_x.size(2) > x.size(2):
+ endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x
+ elif idx == len(self._blocks) - 1:
+ endpoints["reduction_{}".format(len(endpoints) + 1)] = x
+ prev_x = x
+
+ if self.with_head:
+ # Head
+ x = self._swish(self._bn1(self._conv_head(x)))
+ endpoints["reduction_{}".format(len(endpoints) + 1)] = x
+
+ return endpoints
+
+ def extract_features(self, inputs):
+ """use convolution layer to extract feature .
+
+ Args:
+ inputs (tensor): Input tensor.
+
+ Returns:
+ Output of the final convolution
+ layer in the efficientnet model.
+ """
+ # Stem
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
+
+ # Blocks
+ for idx, block in enumerate(self._blocks):
+ drop_connect_rate = self._global_params.drop_connect_rate
+ if drop_connect_rate:
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
+ x = block(x, drop_connect_rate=drop_connect_rate)
+
+ # Head
+ x = self._swish(self._bn1(self._conv_head(x)))
+
+ return x
+
+ def forward(self, inputs):
+ """EfficientNet's forward function.
+ Calls extract_features to extract features, applies final linear layer, and returns logits.
+
+ Args:
+ inputs (tensor): Input tensor.
+
+ Returns:
+ Output of this model after processing.
+ """
+ # Convolution layers
+ x = self.extract_features(inputs)
+ # Pooling and final linear layer
+ x = self._avg_pooling(x)
+ if self._global_params.include_top:
+ x = x.flatten(start_dim=1)
+ x = self._dropout(x)
+ x = self._fc(x)
+ return x
+
+ @classmethod
+ def from_name(
+ cls, model_name, in_channels=3, out_stride=32, with_head=True, with_cp=False, norm_layer=nn.BatchNorm2d,
+ **override_params
+ ):
+ """Create an efficientnet model according to name.
+
+ Args:
+ model_name (str): Name for efficientnet.
+ in_channels (int): Input data's channel number.
+ override_params (other key word params):
+ Params to override model's global_params.
+ Optional key:
+ 'width_coefficient', 'depth_coefficient',
+ 'image_size', 'dropout_rate',
+ 'num_classes', 'batch_norm_momentum',
+ 'batch_norm_epsilon', 'drop_connect_rate',
+ 'depth_divisor', 'min_depth'
+
+ Returns:
+ An efficientnet model.
+ """
+ cls._check_model_name_is_valid(model_name)
+ blocks_args, global_params = get_model_params(model_name, out_stride, override_params)
+ model = cls(blocks_args, global_params, with_head=with_head, with_cp=with_cp, norm_layer=norm_layer)
+ model._change_in_channels(in_channels)
+ return model
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_name,
+ weights_path=None,
+ advprop=False,
+ in_channels=3,
+ num_classes=100,
+ out_stride=32,
+ with_head=True,
+ with_cp=False,
+ norm_layer=nn.BatchNorm2d,
+ **override_params
+ ):
+ """Create an efficientnet model according to name.
+
+ Args:
+ model_name (str): Name for efficientnet.
+ weights_path (None or str):
+ str: path to pretrained weights file on the local disk.
+ None: use pretrained weights downloaded from the Internet.
+ advprop (bool):
+ Whether to load pretrained weights
+ trained with advprop (valid when weights_path is None).
+ in_channels (int): Input data's channel number.
+ num_classes (int):
+ Number of categories for classification.
+ It controls the output size for final linear layer.
+ override_params (other key word params):
+ Params to override model's global_params.
+ Optional key:
+ 'width_coefficient', 'depth_coefficient',
+ 'image_size', 'dropout_rate',
+ 'batch_norm_momentum',
+ 'batch_norm_epsilon', 'drop_connect_rate',
+ 'depth_divisor', 'min_depth'
+
+ Returns:
+ A pretrained efficientnet model.
+ """
+ model = cls.from_name(
+ model_name,
+ num_classes=num_classes,
+ out_stride=out_stride,
+ with_head=with_head,
+ with_cp=with_cp,
+ norm_layer=norm_layer,
+ **override_params
+ )
+ load_pretrained_weights(
+ model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop
+ )
+ model._change_in_channels(in_channels)
+ return model
+
+ @classmethod
+ def get_image_size(cls, model_name):
+ """Get the input image size for a given efficientnet model.
+
+ Args:
+ model_name (str): Name for efficientnet.
+
+ Returns:
+ Input image size (resolution).
+ """
+ cls._check_model_name_is_valid(model_name)
+ _, _, res, _ = efficientnet_params(model_name)
+ return res
+
+ @classmethod
+ def _check_model_name_is_valid(cls, model_name):
+ """Validates model name.
+
+ Args:
+ model_name (str): Name for efficientnet.
+
+ Returns:
+ bool: Is a valid name or not.
+ """
+ if model_name not in VALID_MODELS:
+ raise ValueError("model_name should be one of: " + ", ".join(VALID_MODELS))
+
+ def _change_in_channels(self, in_channels):
+ """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
+
+ Args:
+ in_channels (int): Input data's channel number.
+ """
+ if in_channels != 3:
+ Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
+ out_channels = round_filters(32, self._global_params)
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
diff --git a/mapmaster/models/backbone/efficientnet/utils.py b/mapmaster/models/backbone/efficientnet/utils.py
new file mode 100644
index 0000000..c6b715f
--- /dev/null
+++ b/mapmaster/models/backbone/efficientnet/utils.py
@@ -0,0 +1,656 @@
+"""utils.py - Helper functions for building the model and for loading model parameters.
+ These helper functions are built to mirror those in the official TensorFlow implementation.
+"""
+
+# Author: lukemelas (github username)
+# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
+# With adjustments and added comments by workingcoder (github username).
+
+import re
+import math
+import collections
+from functools import partial
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.utils import model_zoo
+
+
+################################################################################
+# Help functions for model architecture
+################################################################################
+
+# GlobalParams and BlockArgs: Two namedtuples
+# Swish and MemoryEfficientSwish: Two implementations of the method
+# round_filters and round_repeats:
+# Functions to calculate params for scaling model width and depth ! ! !
+# get_width_and_height_from_size and calculate_output_image_size
+# drop_connect: A structural design
+# get_same_padding_conv2d:
+# Conv2dDynamicSamePadding
+# Conv2dStaticSamePadding
+# get_same_padding_maxPool2d:
+# MaxPool2dDynamicSamePadding
+# MaxPool2dStaticSamePadding
+# It's an additional function, not used in EfficientNet,
+# but can be used in other model (such as EfficientDet).
+
+# Parameters for the entire model (stem, all blocks, and head)
+GlobalParams = collections.namedtuple(
+ "GlobalParams",
+ [
+ "width_coefficient",
+ "depth_coefficient",
+ "image_size",
+ "dropout_rate",
+ "num_classes",
+ "batch_norm_momentum",
+ "batch_norm_epsilon",
+ "drop_connect_rate",
+ "depth_divisor",
+ "min_depth",
+ "include_top",
+ ],
+)
+
+# Parameters for an individual model block
+BlockArgs = collections.namedtuple(
+ "BlockArgs",
+ ["num_repeat", "kernel_size", "stride", "expand_ratio", "input_filters", "output_filters", "se_ratio", "id_skip"],
+)
+
+# Set GlobalParams and BlockArgs's defaults
+GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
+BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
+
+# Swish activation function
+if hasattr(nn, "SiLU"):
+ Swish = nn.SiLU
+else:
+ # For compatibility with old PyTorch versions
+ class Swish(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+# A memory-efficient implementation of Swish function
+class SwishImplementation(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, i):
+ result = i * torch.sigmoid(i)
+ ctx.save_for_backward(i)
+ return result
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ i = ctx.saved_tensors[0]
+ sigmoid_i = torch.sigmoid(i)
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
+
+
+class MemoryEfficientSwish(nn.Module):
+ def forward(self, x):
+ return SwishImplementation.apply(x)
+
+
+def round_filters(filters, global_params):
+ """Calculate and round number of filters based on width multiplier.
+ Use width_coefficient, depth_divisor and min_depth of global_params.
+
+ Args:
+ filters (int): Filters number to be calculated.
+ global_params (namedtuple): Global params of the model.
+
+ Returns:
+ new_filters: New filters number after calculating.
+ """
+ multiplier = global_params.width_coefficient
+ if not multiplier:
+ return filters
+ # TODO: modify the params names.
+ # maybe the names (width_divisor,min_width)
+ # are more suitable than (depth_divisor,min_depth).
+ divisor = global_params.depth_divisor
+ min_depth = global_params.min_depth
+ filters *= multiplier
+ min_depth = min_depth or divisor # pay attention to this line when using min_depth
+ # follow the formula transferred from official TensorFlow implementation
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
+ if new_filters < 0.9 * filters: # prevent rounding by more than 10%
+ new_filters += divisor
+ return int(new_filters)
+
+
+def round_repeats(repeats, global_params):
+ """Calculate module's repeat number of a block based on depth multiplier.
+ Use depth_coefficient of global_params.
+
+ Args:
+ repeats (int): num_repeat to be calculated.
+ global_params (namedtuple): Global params of the model.
+
+ Returns:
+ new repeat: New repeat number after calculating.
+ """
+ multiplier = global_params.depth_coefficient
+ if not multiplier:
+ return repeats
+ # follow the formula transferred from official TensorFlow implementation
+ return int(math.ceil(multiplier * repeats))
+
+
+def drop_connect(inputs, p, training):
+ """Drop connect.
+
+ Args:
+ input (tensor: BCWH): Input of this structure.
+ p (float: 0.0~1.0): Probability of drop connection.
+ training (bool): The running mode.
+
+ Returns:
+ output: Output after drop connection.
+ """
+ assert 0 <= p <= 1, "p must be in range of [0,1]"
+
+ if not training:
+ return inputs
+
+ batch_size = inputs.shape[0]
+ keep_prob = 1 - p
+
+ # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
+ random_tensor = keep_prob
+ random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
+ binary_tensor = torch.floor(random_tensor)
+
+ output = inputs / keep_prob * binary_tensor
+ return output
+
+
+def get_width_and_height_from_size(x):
+ """Obtain height and width from x.
+
+ Args:
+ x (int, tuple or list): Data size.
+
+ Returns:
+ size: A tuple or list (H,W).
+ """
+ if isinstance(x, int):
+ return x, x
+ if isinstance(x, list) or isinstance(x, tuple):
+ return x
+ else:
+ raise TypeError()
+
+
+def calculate_output_image_size(input_image_size, stride):
+ """Calculates the output image size when using Conv2dSamePadding with a stride.
+ Necessary for static padding. Thanks to mannatsingh for pointing this out.
+
+ Args:
+ input_image_size (int, tuple or list): Size of input image.
+ stride (int, tuple or list): Conv2d operation's stride.
+
+ Returns:
+ output_image_size: A list [H,W].
+ """
+ if input_image_size is None:
+ return None
+ image_height, image_width = get_width_and_height_from_size(input_image_size)
+ stride = stride if isinstance(stride, int) else stride[0]
+ image_height = int(math.ceil(image_height / stride))
+ image_width = int(math.ceil(image_width / stride))
+ return [image_height, image_width]
+
+
+# Note:
+# The following 'SamePadding' functions make output size equal ceil(input size/stride).
+# Only when stride equals 1, can the output size be the same as input size.
+# Don't be confused by their function names ! ! !
+
+
+def get_same_padding_conv2d(image_size=None):
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
+ Static padding is necessary for ONNX exporting of models.
+
+ Args:
+ image_size (int or tuple): Size of the image.
+
+ Returns:
+ Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
+ """
+ if image_size is None:
+ return Conv2dDynamicSamePadding
+ else:
+ return partial(Conv2dStaticSamePadding, image_size=image_size)
+
+
+class Conv2dDynamicSamePadding(nn.Conv2d):
+ """2D Convolutions like TensorFlow, for a dynamic image size.
+ The padding is operated in forward function by calculating dynamically.
+ """
+
+ # Tips for 'SAME' mode padding.
+ # Given the following:
+ # i: width or height
+ # s: stride
+ # k: kernel size
+ # d: dilation
+ # p: padding
+ # Output after Conv2d:
+ # o = floor((i+p-((k-1)*d+1))/s+1)
+ # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
+ # => p = (i-1)*s+((k-1)*d+1)-i
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
+
+ def forward(self, x):
+ ih, iw = x.size()[-2:]
+ kh, kw = self.weight.size()[-2:]
+ sh, sw = self.stride
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+
+class Conv2dStaticSamePadding(nn.Conv2d):
+ """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
+ The padding mudule is calculated in construction function, then used in forward.
+ """
+
+ # With the same calculation as Conv2dDynamicSamePadding
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
+ super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
+
+ # Calculate padding based on image size and save it
+ assert image_size is not None
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
+ kh, kw = self.weight.size()[-2:]
+ sh, sw = self.stride
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+ if pad_h > 0 or pad_w > 0:
+ self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
+ else:
+ self.static_padding = nn.Identity()
+
+ def forward(self, x):
+ x = self.static_padding(x)
+ x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+ return x
+
+
+def get_same_padding_maxPool2d(image_size=None):
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
+ Static padding is necessary for ONNX exporting of models.
+
+ Args:
+ image_size (int or tuple): Size of the image.
+
+ Returns:
+ MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
+ """
+ if image_size is None:
+ return MaxPool2dDynamicSamePadding
+ else:
+ return partial(MaxPool2dStaticSamePadding, image_size=image_size)
+
+
+class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
+ The padding is operated in forward function by calculating dynamically.
+ """
+
+ def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False):
+ super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
+ self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
+ self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
+
+ def forward(self, x):
+ ih, iw = x.size()[-2:]
+ kh, kw = self.kernel_size
+ sh, sw = self.stride
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
+ return F.max_pool2d(
+ x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode, self.return_indices
+ )
+
+
+class MaxPool2dStaticSamePadding(nn.MaxPool2d):
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
+ The padding mudule is calculated in construction function, then used in forward.
+ """
+
+ def __init__(self, kernel_size, stride, image_size=None, **kwargs):
+ super().__init__(kernel_size, stride, **kwargs)
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
+ self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
+ self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
+
+ # Calculate padding based on image size and save it
+ assert image_size is not None
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
+ kh, kw = self.kernel_size
+ sh, sw = self.stride
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+ if pad_h > 0 or pad_w > 0:
+ self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
+ else:
+ self.static_padding = nn.Identity()
+
+ def forward(self, x):
+ x = self.static_padding(x)
+ x = F.max_pool2d(
+ x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode, self.return_indices
+ )
+ return x
+
+
+################################################################################
+# Helper functions for loading model params
+################################################################################
+
+# BlockDecoder: A Class for encoding and decoding BlockArgs
+# efficientnet_params: A function to query compound coefficient
+# get_model_params and efficientnet:
+# Functions to get BlockArgs and GlobalParams for efficientnet
+# url_map and url_map_advprop: Dicts of url_map for pretrained weights
+# load_pretrained_weights: A function to load pretrained weights
+
+
+class BlockDecoder(object):
+ """Block Decoder for readability,
+ straight from the official TensorFlow repository.
+ """
+
+ @staticmethod
+ def _decode_block_string(block_string):
+ """Get a block through a string notation of arguments.
+
+ Args:
+ block_string (str): A string notation of arguments.
+ Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
+
+ Returns:
+ BlockArgs: The namedtuple defined at the top of this file.
+ """
+ assert isinstance(block_string, str)
+
+ ops = block_string.split("_")
+ options = {}
+ for op in ops:
+ splits = re.split(r"(\d.*)", op)
+ if len(splits) >= 2:
+ key, value = splits[:2]
+ options[key] = value
+
+ # Check stride
+ assert ("s" in options and len(options["s"]) == 1) or (
+ len(options["s"]) == 2 and options["s"][0] == options["s"][1]
+ )
+
+ return BlockArgs(
+ num_repeat=int(options["r"]),
+ kernel_size=int(options["k"]),
+ stride=[int(options["s"][0])],
+ expand_ratio=int(options["e"]),
+ input_filters=int(options["i"]),
+ output_filters=int(options["o"]),
+ se_ratio=float(options["se"]) if "se" in options else None,
+ id_skip=("noskip" not in block_string),
+ )
+
+ @staticmethod
+ def _encode_block_string(block):
+ """Encode a block to a string.
+
+ Args:
+ block (namedtuple): A BlockArgs type argument.
+
+ Returns:
+ block_string: A String form of BlockArgs.
+ """
+ args = [
+ "r%d" % block.num_repeat,
+ "k%d" % block.kernel_size,
+ "s%d%d" % (block.strides[0], block.strides[1]),
+ "e%s" % block.expand_ratio,
+ "i%d" % block.input_filters,
+ "o%d" % block.output_filters,
+ ]
+ if 0 < block.se_ratio <= 1:
+ args.append("se%s" % block.se_ratio)
+ if block.id_skip is False:
+ args.append("noskip")
+ return "_".join(args)
+
+ @staticmethod
+ def decode(string_list):
+ """Decode a list of string notations to specify blocks inside the network.
+
+ Args:
+ string_list (list[str]): A list of strings, each string is a notation of block.
+
+ Returns:
+ blocks_args: A list of BlockArgs namedtuples of block args.
+ """
+ assert isinstance(string_list, list)
+ blocks_args = []
+ for block_string in string_list:
+ blocks_args.append(BlockDecoder._decode_block_string(block_string))
+ return blocks_args
+
+ @staticmethod
+ def encode(blocks_args):
+ """Encode a list of BlockArgs to a list of strings.
+
+ Args:
+ blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
+
+ Returns:
+ block_strings: A list of strings, each string is a notation of block.
+ """
+ block_strings = []
+ for block in blocks_args:
+ block_strings.append(BlockDecoder._encode_block_string(block))
+ return block_strings
+
+
+def efficientnet_params(model_name):
+ """Map EfficientNet model name to parameter coefficients.
+
+ Args:
+ model_name (str): Model name to be queried.
+
+ Returns:
+ params_dict[model_name]: A (width,depth,res,dropout) tuple.
+ """
+ params_dict = {
+ # Coefficients: width,depth,res,dropout
+ "efficientnet-b0": (1.0, 1.0, 224, 0.2),
+ "efficientnet-b1": (1.0, 1.1, 240, 0.2),
+ "efficientnet-b2": (1.1, 1.2, 260, 0.3),
+ "efficientnet-b3": (1.2, 1.4, 300, 0.3),
+ "efficientnet-b4": (1.4, 1.8, 380, 0.4),
+ "efficientnet-b5": (1.6, 2.2, 456, 0.4),
+ "efficientnet-b6": (1.8, 2.6, 528, 0.5),
+ "efficientnet-b7": (2.0, 3.1, 600, 0.5),
+ "efficientnet-b8": (2.2, 3.6, 672, 0.5),
+ "efficientnet-l2": (4.3, 5.3, 800, 0.5),
+ }
+ return params_dict[model_name]
+
+
+def efficientnet(
+ width_coefficient=None,
+ depth_coefficient=None,
+ image_size=None,
+ dropout_rate=0.2,
+ drop_connect_rate=0.2,
+ num_classes=1000,
+ include_top=True,
+ out_stride=32,
+):
+ """Create BlockArgs and GlobalParams for efficientnet model.
+
+ Args:
+ width_coefficient (float)
+ depth_coefficient (float)
+ image_size (int)
+ dropout_rate (float)
+ drop_connect_rate (float)
+ num_classes (int)
+
+ Meaning as the name suggests.
+
+ Returns:
+ blocks_args, global_params.
+ """
+
+ # Blocks args for the whole model(efficientnet-b0 by default)
+ # It will be modified in the construction of EfficientNet Class according to model
+ if out_stride == 32:
+ blocks_args = [
+ "r1_k3_s11_e1_i32_o16_se0.25",
+ "r2_k3_s22_e6_i16_o24_se0.25",
+ "r2_k5_s22_e6_i24_o40_se0.25",
+ "r3_k3_s22_e6_i40_o80_se0.25",
+ "r3_k5_s11_e6_i80_o112_se0.25",
+ "r4_k5_s22_e6_i112_o192_se0.25",
+ "r1_k3_s11_e6_i192_o320_se0.25",
+ ]
+ elif out_stride == 16:
+ blocks_args = [
+ "r1_k3_s11_e1_i32_o16_se0.25",
+ "r2_k3_s22_e6_i16_o24_se0.25",
+ "r2_k5_s22_e6_i24_o40_se0.25",
+ "r3_k3_s22_e6_i40_o80_se0.25",
+ "r3_k5_s11_e6_i80_o112_se0.25",
+ "r4_k5_s11_e6_i112_o192_se0.25",
+ "r1_k3_s11_e6_i192_o320_se0.25",
+ ]
+ else:
+ raise NotImplementedError
+ blocks_args = BlockDecoder.decode(blocks_args)
+
+ global_params = GlobalParams(
+ width_coefficient=width_coefficient,
+ depth_coefficient=depth_coefficient,
+ image_size=image_size,
+ dropout_rate=dropout_rate,
+ num_classes=num_classes,
+ batch_norm_momentum=0.99,
+ batch_norm_epsilon=1e-3,
+ drop_connect_rate=drop_connect_rate,
+ depth_divisor=8,
+ min_depth=None,
+ include_top=include_top,
+ )
+
+ return blocks_args, global_params
+
+
+def get_model_params(model_name, out_stride, override_params):
+ """Get the block args and global params for a given model name.
+
+ Args:
+ model_name (str): Model's name.
+ override_params (dict): A dict to modify global_params.
+
+ Returns:
+ blocks_args, global_params
+ """
+ if model_name.startswith("efficientnet"):
+ w, d, s, p = efficientnet_params(model_name)
+ # note: all models have drop connect rate = 0.2
+ blocks_args, global_params = efficientnet(
+ width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s, out_stride=out_stride
+ )
+ else:
+ raise NotImplementedError("model name is not pre-defined: {}".format(model_name))
+ if override_params:
+ # ValueError will be raised here if override_params has fields not included in global_params.
+ global_params = global_params._replace(**override_params)
+ return blocks_args, global_params
+
+
+# train with Standard methods
+# check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks)
+url_map = {
+ "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
+ "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
+ "efficientnet-b2": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
+ "efficientnet-b3": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
+ "efficientnet-b4": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
+ "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
+ "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
+ "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth",
+}
+
+# train with Adversarial Examples(AdvProp)
+# check more details in paper(Adversarial Examples Improve Image Recognition)
+url_map_advprop = {
+ "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth",
+ "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth",
+ "efficientnet-b2": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth",
+ "efficientnet-b3": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth",
+ "efficientnet-b4": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth",
+ "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth",
+ "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth",
+ "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth",
+ "efficientnet-b8": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth",
+}
+
+# TODO: add the petrained weights url map of 'efficientnet-l2'
+
+
+def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True):
+ """Loads pretrained weights from weights path or download using url.
+
+ Args:
+ model (Module): The whole model of efficientnet.
+ model_name (str): Model name of efficientnet.
+ weights_path (None or str):
+ str: path to pretrained weights file on the local disk.
+ None: use pretrained weights downloaded from the Internet.
+ load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
+ advprop (bool): Whether to load pretrained weights
+ trained with advprop (valid when weights_path is None).
+ """
+ if isinstance(weights_path, str):
+ state_dict = torch.load(weights_path)
+ else:
+ # AutoAugment or Advprop (different preprocessing)
+ url_map_ = url_map_advprop if advprop else url_map
+ state_dict = model_zoo.load_url(url_map_[model_name])
+
+ if load_fc:
+ ret = model.load_state_dict(state_dict, strict=False)
+ assert not ret.missing_keys, "Missing keys when loading pretrained weights: {}".format(ret.missing_keys)
+ else:
+ state_dict.pop("_fc.weight")
+ state_dict.pop("_fc.bias")
+ ret = model.load_state_dict(state_dict, strict=False)
+ assert set(ret.missing_keys) == set(
+ ["_fc.weight", "_fc.bias"]
+ ), "Missing keys when loading pretrained weights: {}".format(ret.missing_keys)
+ assert not ret.unexpected_keys, "Missing keys when loading pretrained weights: {}".format(ret.unexpected_keys)
+
+ if verbose:
+ print("Loaded pretrained weights for {}".format(model_name))
diff --git a/mapmaster/models/backbone/model.py b/mapmaster/models/backbone/model.py
new file mode 100644
index 0000000..003124b
--- /dev/null
+++ b/mapmaster/models/backbone/model.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mapmaster.models.backbone.resnet import ResNet
+from mapmaster.models.backbone.efficientnet import EfficientNet
+from mapmaster.models.backbone.swin_transformer import SwinTransformer
+from mapmaster.models.backbone.bifpn import BiFPN
+
+
+class ResNetBackbone(nn.Module):
+ def __init__(self, bkb_kwargs, fpn_kwarg=None, up_shape=None, ret_layers=1):
+ super(ResNetBackbone, self).__init__()
+ assert 0 < ret_layers < 4
+ self.ret_layers = ret_layers
+ self.bkb = ResNet(**bkb_kwargs)
+ self.fpn = None if fpn_kwarg is None else BiFPN(**fpn_kwarg)
+ self.up_shape = None if up_shape is None else up_shape
+ self.bkb.init_weights()
+
+ def forward(self, inputs):
+ images = inputs["images"]
+ images = images.view(-1, *images.shape[-3:])
+ bkb_features = list(self.bkb(images)[-self.ret_layers:])
+ nek_features = self.fpn(bkb_features) if self.fpn is not None else None
+ return {"im_bkb_features": bkb_features, "im_nek_features": nek_features}
+
+
+class EfficientNetBackbone(nn.Module):
+ def __init__(self, bkb_kwargs, fpn_kwarg=None, up_shape=None, ret_layers=1):
+ super(EfficientNetBackbone, self).__init__()
+ assert 0 < ret_layers < 4
+ self.ret_layers = ret_layers
+ self.bkb = EfficientNet.from_pretrained(**bkb_kwargs)
+ self.fpn = None if fpn_kwarg is None else BiFPN(**fpn_kwarg)
+ self.up_shape = None if up_shape is None else up_shape
+ del self.bkb._conv_head
+ del self.bkb._bn1
+ del self.bkb._avg_pooling
+ del self.bkb._dropout
+ del self.bkb._fc
+
+ def forward(self, inputs):
+ images = inputs["images"]
+ images = images.view(-1, *images.shape[-3:])
+ endpoints = self.bkb.extract_endpoints(images)
+ bkb_features = []
+ for i, (key, value) in enumerate(endpoints.items()):
+ if i > 0:
+ bkb_features.append(value)
+ bkb_features = list(bkb_features[-self.ret_layers:])
+ nek_features = self.fpn(bkb_features) if self.fpn is not None else None
+ return {"im_bkb_features": bkb_features, "im_nek_features": nek_features}
+
+
+class SwinTRBackbone(nn.Module):
+ def __init__(self, bkb_kwargs, fpn_kwarg=None, up_shape=None, ret_layers=1):
+ super(SwinTRBackbone, self).__init__()
+ assert 0 < ret_layers < 4
+ self.ret_layers = ret_layers
+ self.bkb = SwinTransformer(**bkb_kwargs)
+ self.fpn = None if fpn_kwarg is None else BiFPN(**fpn_kwarg)
+ self.up_shape = None if up_shape is None else up_shape
+
+ def forward(self, inputs):
+ images = inputs["images"]
+ images = images.view(-1, *images.shape[-3:])
+ bkb_features = list(self.bkb(images)[-self.ret_layers:])
+ nek_features = None
+ if self.fpn is not None:
+ nek_features = self.fpn(bkb_features)
+ else:
+ if self.up_shape is not None:
+ nek_features = [torch.cat([self.up_sample(x, self.up_shape) for x in bkb_features], dim=1)]
+
+ return {"im_bkb_features": bkb_features, "im_nek_features": nek_features}
+
+ def up_sample(self, x, tgt_shape=None):
+ tgt_shape = self.tgt_shape if tgt_shape is None else tgt_shape
+ if tuple(x.shape[-2:]) == tuple(tgt_shape):
+ return x
+ return F.interpolate(x, size=tgt_shape, mode="bilinear", align_corners=True)
diff --git a/mapmaster/models/backbone/resnet/__init__.py b/mapmaster/models/backbone/resnet/__init__.py
new file mode 100644
index 0000000..4f73423
--- /dev/null
+++ b/mapmaster/models/backbone/resnet/__init__.py
@@ -0,0 +1 @@
+from .resnet import ResNet
diff --git a/mapmaster/models/backbone/resnet/resnet.py b/mapmaster/models/backbone/resnet/resnet.py
new file mode 100644
index 0000000..3762e25
--- /dev/null
+++ b/mapmaster/models/backbone/resnet/resnet.py
@@ -0,0 +1,596 @@
+import warnings
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from torch.nn.modules.batchnorm import _BatchNorm
+from mmcv.runner import BaseModule
+from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
+from .utils import ResLayer
+
+
+class BasicBlock(BaseModule):
+ expansion = 1
+
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style="pytorch",
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type="BN"),
+ dcn=None,
+ plugins=None,
+ init_cfg=None,
+ ):
+ super(BasicBlock, self).__init__(init_cfg)
+ assert dcn is None, "Not implemented yet."
+ assert plugins is None, "Not implemented yet."
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg, inplanes, planes, 3, stride=stride, padding=dilation, dilation=dilation, bias=False
+ )
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(conv_cfg, planes, planes, 3, padding=1, bias=False)
+ self.add_module(self.norm2_name, norm2)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(BaseModule):
+ expansion = 4
+
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style="pytorch",
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type="BN"),
+ dcn=None,
+ plugins=None,
+ init_cfg=None,
+ ):
+ """Bottleneck block for ResNet.
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottleneck, self).__init__(init_cfg)
+ assert style in ["pytorch", "caffe"]
+ assert dcn is None or isinstance(dcn, dict)
+ assert plugins is None or isinstance(plugins, list)
+ if plugins is not None:
+ allowed_position = ["after_conv1", "after_conv2", "after_conv3"]
+ assert all(p["position"] in allowed_position for p in plugins)
+
+ self.inplanes = inplanes
+ self.planes = planes
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.dcn = dcn
+ self.with_dcn = dcn is not None
+ self.plugins = plugins
+ self.with_plugins = plugins is not None
+
+ if self.with_plugins:
+ # collect plugins for conv1/conv2/conv3
+ self.after_conv1_plugins = [plugin["cfg"] for plugin in plugins if plugin["position"] == "after_conv1"]
+ self.after_conv2_plugins = [plugin["cfg"] for plugin in plugins if plugin["position"] == "after_conv2"]
+ self.after_conv3_plugins = [plugin["cfg"] for plugin in plugins if plugin["position"] == "after_conv3"]
+
+ if self.style == "pytorch":
+ self.conv1_stride = 1
+ self.conv2_stride = stride
+ else:
+ self.conv1_stride = stride
+ self.conv2_stride = 1
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(norm_cfg, planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(conv_cfg, inplanes, planes, kernel_size=1, stride=self.conv1_stride, bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = dcn.pop("fallback_on_stride", False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False,
+ )
+ else:
+ assert self.conv_cfg is None, "conv_cfg must be None for DCN"
+ self.conv2 = build_conv_layer(
+ dcn,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False,
+ )
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(conv_cfg, planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ if self.with_plugins:
+ self.after_conv1_plugin_names = self.make_block_plugins(planes, self.after_conv1_plugins)
+ self.after_conv2_plugin_names = self.make_block_plugins(planes, self.after_conv2_plugins)
+ self.after_conv3_plugin_names = self.make_block_plugins(planes * self.expansion, self.after_conv3_plugins)
+
+ def make_block_plugins(self, in_channels, plugins):
+ """make plugins for block.
+ Args:
+ in_channels (int): Input channels of plugin.
+ plugins (list[dict]): List of plugins cfg to build.
+ Returns:
+ list[str]: List of the names of plugin.
+ """
+ assert isinstance(plugins, list)
+ plugin_names = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ name, layer = build_plugin_layer(plugin, in_channels=in_channels, postfix=plugin.pop("postfix", ""))
+ assert not hasattr(self, name), f"duplicate plugin {name}"
+ self.add_module(name, layer)
+ plugin_names.append(name)
+ return plugin_names
+
+ def forward_plugin(self, x, plugin_names):
+ out = x
+ for name in plugin_names:
+ out = getattr(self, name)(x)
+ return out
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ @property
+ def norm3(self):
+ """nn.Module: normalization layer after the third convolution layer"""
+ return getattr(self, self.norm3_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(BaseModule):
+ """ResNet backbone.
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ stem_channels (int | None): Number of stem channels. If not specified,
+ it will be the same as `base_channels`. Default: None.
+ base_channels (int): Number of base channels of res layer. Default: 64.
+ in_channels (int): Number of input image channels. Default: 3.
+ num_stages (int): Resnet stages. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ plugins (list[dict]): List of plugins for stages, each dict contains:
+ - cfg (dict, required): Cfg dict to build plugin.
+ - position (str, required): Position inside block to insert
+ plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'.
+ - stages (tuple[bool], optional): Stages to apply plugin, length
+ should be same as 'num_stages'.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+ pretrained (str, optional): model pretrained path. Default: None
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ Example:
+ >>> from mmdet.models import ResNet
+ >>> import torch
+ >>> self = ResNet(depth=18)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 64, 8, 8)
+ (1, 128, 4, 4)
+ (1, 256, 2, 2)
+ (1, 512, 1, 1)
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3)),
+ }
+
+ def __init__(
+ self,
+ depth,
+ in_channels=3,
+ stem_channels=None,
+ base_channels=64,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style="pytorch",
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type="BN", requires_grad=True),
+ norm_eval=True,
+ dcn=None,
+ stage_with_dcn=(False, False, False, False),
+ plugins=None,
+ with_cp=False,
+ zero_init_residual=True,
+ pretrained=None,
+ init_cfg=None,
+ ):
+ super(ResNet, self).__init__(init_cfg)
+ self.zero_init_residual = zero_init_residual
+ if depth not in self.arch_settings:
+ raise KeyError(f"invalid depth {depth} for resnet")
+
+ block_init_cfg = None
+ assert not (init_cfg and pretrained), "init_cfg and pretrained cannot be specified at the same time"
+ if isinstance(pretrained, str):
+ warnings.warn("DeprecationWarning: pretrained is deprecated, " 'please use "init_cfg" instead')
+ self.init_cfg = dict(type="Pretrained", checkpoint=pretrained)
+ elif pretrained is None:
+ if init_cfg is None:
+ self.init_cfg = [
+ dict(type="Kaiming", layer="Conv2d"),
+ dict(type="Constant", val=1, layer=["_BatchNorm", "GroupNorm"]),
+ ]
+ block = self.arch_settings[depth][0]
+ if self.zero_init_residual:
+ if block is BasicBlock:
+ block_init_cfg = dict(type="Constant", val=0, override=dict(name="norm2"))
+ elif block is Bottleneck:
+ block_init_cfg = dict(type="Constant", val=0, override=dict(name="norm3"))
+ else:
+ raise TypeError("pretrained must be a str or None")
+
+ self.depth = depth
+ if stem_channels is None:
+ stem_channels = base_channels
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert num_stages >= 1 and num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.dcn = dcn
+ self.stage_with_dcn = stage_with_dcn
+ if dcn is not None:
+ assert len(stage_with_dcn) == num_stages
+ self.plugins = plugins
+ self.block, stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ self.inplanes = stem_channels
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ dcn = self.dcn if self.stage_with_dcn[i] else None
+ if plugins is not None:
+ stage_plugins = self.make_stage_plugins(plugins, i)
+ else:
+ stage_plugins = None
+ planes = base_channels * 2 ** i
+ res_layer = self.make_res_layer(
+ block=self.block,
+ inplanes=self.inplanes,
+ planes=planes,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ dcn=dcn,
+ plugins=stage_plugins,
+ init_cfg=block_init_cfg,
+ )
+ self.inplanes = planes * self.block.expansion
+ layer_name = f"layer{i + 1}"
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = self.block.expansion * base_channels * 2 ** (len(self.stage_blocks) - 1)
+
+ def make_stage_plugins(self, plugins, stage_idx):
+ """Make plugins for ResNet ``stage_idx`` th stage.
+ Currently we support to insert ``context_block``,
+ ``empirical_attention_block``, ``nonlocal_block`` into the backbone
+ like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
+ Bottleneck.
+ An example of plugins format could be:
+ Examples:
+ >>> plugins=[
+ ... dict(cfg=dict(type='xxx', arg1='xxx'),
+ ... stages=(False, True, True, True),
+ ... position='after_conv2'),
+ ... dict(cfg=dict(type='yyy'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='1'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='2'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3')
+ ... ]
+ >>> self = ResNet(depth=18)
+ >>> stage_plugins = self.make_stage_plugins(plugins, 0)
+ >>> assert len(stage_plugins) == 3
+ Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
+ .. code-block:: none
+ conv1-> conv2->conv3->yyy->zzz1->zzz2
+ Suppose 'stage_idx=1', the structure of blocks in the stage would be:
+ .. code-block:: none
+ conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
+ If stages is missing, the plugin would be applied to all stages.
+ Args:
+ plugins (list[dict]): List of plugins cfg to build. The postfix is
+ required if multiple same type plugins are inserted.
+ stage_idx (int): Index of stage to build
+ Returns:
+ list[dict]: Plugins for current stage
+ """
+ stage_plugins = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ stages = plugin.pop("stages", None)
+ assert stages is None or len(stages) == self.num_stages
+ # whether to insert plugin into current stage
+ if stages is None or stages[stage_idx]:
+ stage_plugins.append(plugin)
+
+ return stage_plugins
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(**kwargs)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def _make_stem_layer(self, in_channels, stem_channels):
+ if self.deep_stem:
+ self.stem = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg, in_channels, stem_channels // 2, kernel_size=3, stride=2, padding=1, bias=False
+ ),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg,
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg, stem_channels // 2, stem_channels, kernel_size=3, stride=1, padding=1, bias=False
+ ),
+ build_norm_layer(self.norm_cfg, stem_channels)[1],
+ nn.ReLU(inplace=True),
+ )
+ else:
+ self.conv1 = build_conv_layer(
+ self.conv_cfg, in_channels, stem_channels, kernel_size=7, stride=2, padding=3, bias=False
+ )
+ self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, stem_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ if self.deep_stem:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ else:
+ self.norm1.eval()
+ for m in [self.conv1, self.norm1]:
+ for param in m.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = getattr(self, f"layer{i}")
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ """Forward function."""
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super(ResNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+
+class ResNetV1d(ResNet):
+ r"""ResNetV1d variant described in `Bag of Tricks
+ `_.
+ Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
+ the input stem with three 3x3 convs. And in the downsampling block, a 2x2
+ avg_pool with stride 2 is added before conv, whose stride is changed to 1.
+ """
+
+ def __init__(self, **kwargs):
+ super(ResNetV1d, self).__init__(deep_stem=True, avg_down=True, **kwargs)
diff --git a/mapmaster/models/backbone/resnet/utils.py b/mapmaster/models/backbone/resnet/utils.py
new file mode 100644
index 0000000..affe659
--- /dev/null
+++ b/mapmaster/models/backbone/resnet/utils.py
@@ -0,0 +1,93 @@
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmcv.runner import Sequential
+from torch import nn as nn
+
+
+class ResLayer(Sequential):
+ """ResLayer to build ResNet style backbone.
+ Args:
+ block (nn.Module): block used to build ResLayer.
+ inplanes (int): inplanes of block.
+ planes (int): planes of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ downsample_first (bool): Downsample at the first block or last block.
+ False for Hourglass, True for ResNet. Default: True
+ """
+
+ def __init__(
+ self,
+ block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type="BN"),
+ downsample_first=True,
+ **kwargs
+ ):
+ self.block = block
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = []
+ conv_stride = stride
+ if avg_down:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
+ )
+ downsample.extend(
+ [
+ build_conv_layer(
+ conv_cfg, inplanes, planes * block.expansion, kernel_size=1, stride=conv_stride, bias=False
+ ),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1],
+ ]
+ )
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ if downsample_first:
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs
+ )
+ )
+ inplanes = planes * block.expansion
+ for _ in range(1, num_blocks):
+ layers.append(
+ block(inplanes=inplanes, planes=planes, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)
+ )
+
+ else: # downsample_first=False is for HourglassModule
+ for _ in range(num_blocks - 1):
+ layers.append(
+ block(inplanes=inplanes, planes=inplanes, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)
+ )
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs
+ )
+ )
+ super(ResLayer, self).__init__(*layers)
diff --git a/mapmaster/models/backbone/swin_transformer/__init__.py b/mapmaster/models/backbone/swin_transformer/__init__.py
new file mode 100644
index 0000000..55af267
--- /dev/null
+++ b/mapmaster/models/backbone/swin_transformer/__init__.py
@@ -0,0 +1,85 @@
+import os
+import torch
+from .model import SwinTransformer as _SwinTransformer
+from torch.utils import model_zoo
+
+model_urls = {
+ "tiny": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_tiny_patch4_window7_512x512.pth",
+ "base": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_base_patch4_window7_512x512.pth",
+}
+
+
+class SwinTransformer(_SwinTransformer):
+ def __init__(
+ self,
+ arch="tiny",
+ pretrained=False,
+ window_size=7,
+ shift_mode=1,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.3,
+ ape=False,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ use_checkpoint=False,
+ **kwargs
+ ):
+ if arch == "tiny":
+ embed_dim = 96
+ depths = (2, 2, 6, 2)
+ num_heads = (3, 6, 12, 24)
+ elif arch == "small":
+ embed_dim = 96
+ depths = (2, 2, 18, 2)
+ num_heads = (3, 6, 12, 24)
+ elif arch == "base":
+ embed_dim = 128
+ depths = (2, 2, 18, 2)
+ num_heads = (4, 8, 16, 32)
+ else:
+ raise NotImplementedError
+
+ super(SwinTransformer, self).__init__(
+ embed_dim=embed_dim,
+ depths=depths,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_mode=shift_mode,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=drop_path_rate,
+ ape=ape,
+ patch_norm=patch_norm,
+ out_indices=out_indices,
+ use_checkpoint=use_checkpoint,
+ **kwargs
+ )
+ if isinstance(pretrained, bool):
+ assert pretrained is True
+ print(model_urls[arch])
+ state_dict = model_zoo.load_url(model_urls[arch])["state_dict"]
+ elif isinstance(pretrained, str):
+ assert os.path.exists(pretrained)
+ print(pretrained)
+ state_dict = torch.load(pretrained)["state_dict"]
+ else:
+ raise NotImplementedError
+
+ self.arch = arch
+ self.init_weights(state_dict=state_dict)
+
+ def init_weights(self, state_dict):
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if "backbone" in key:
+ new_state_dict[key.replace("backbone.", "")] = value
+ ret = self.load_state_dict(new_state_dict, strict=False)
+ print("Backbone missing_keys: {}".format(ret.missing_keys))
+ print("Backbone unexpected_keys: {}".format(ret.unexpected_keys))
diff --git a/mapmaster/models/backbone/swin_transformer/model.py b/mapmaster/models/backbone/swin_transformer/model.py
new file mode 100644
index 0000000..f2e0da2
--- /dev/null
+++ b/mapmaster/models/backbone/swin_transformer/model.py
@@ -0,0 +1,670 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu, Yutong Lin, Yixuan Wei
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+import numpy as np
+from .utils import DropPath, to_2tuple, trunc_normal_, get_root_logger, load_checkpoint
+
+
+class Mlp(nn.Module):
+ """ Multilayer perceptron."""
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """Forward function.
+
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ """Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.H = None
+ self.W = None
+
+ def forward(self, x, mask_matrix):
+ """Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ mask_matrix: Attention mask for cyclic shift.
+ """
+ B, L, C = x.shape
+ H, W = self.H, self.W
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ attn_mask = mask_matrix
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ """Patch Merging Layer
+
+ Args:
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x, H, W):
+ """Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+
+class BasicLayer(nn.Module):
+ """A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of feature channels
+ depth (int): Depths of this stage.
+ num_heads (int): Number of attention head.
+ window_size (int): Local window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(
+ self,
+ dim,
+ depth,
+ num_heads,
+ window_size=7,
+ shift_mode=0,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+ self.window_size = window_size
+ self.shift_size = window_size // 2
+ self.shift_mode = shift_mode
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList(
+ [
+ SwinTransformerBlock(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ """Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+
+ # calculate attention mask for SW-MSA
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
+ h_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ w_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ cnt = 0
+ for i, h in enumerate(h_slices):
+ for j, w in enumerate(w_slices):
+ img_mask[:, h, w, :] = cnt
+ if self.shift_mode == 1 and j == 1:
+ continue
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ if self.use_checkpoint and self.training:
+ x = checkpoint.checkpoint(blk, x, attn_mask)
+ else:
+ x = blk(x, attn_mask)
+ if self.downsample is not None:
+ x_down = self.downsample(x, H, W)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding
+
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ # padding
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+class SwinTransformer(nn.Module):
+ """Swin Transformer backbone.
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ pretrain_img_size (int): Input image size for training the pretrained model,
+ used in absolute postion embedding. Default 224.
+ patch_size (int | tuple(int)): Patch size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ num_heads (tuple[int]): Number of attention head of each stage.
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): Dropout rate.
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(
+ self,
+ pretrain_img_size=224,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ shift_mode=0,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.2,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None,
+ )
+
+ # absolute position embedding
+ if self.ape:
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
+
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
+ )
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(embed_dim * 2 ** i_layer),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ shift_mode=shift_mode,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint,
+ )
+ self.layers.append(layer)
+
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ self.num_features = num_features
+
+ # add a norm layer for each output
+ for i_layer in out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f"norm{i_layer}"
+ self.add_module(layer_name, layer)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 1 and self.ape:
+ self.absolute_pos_embed.requires_grad = False
+
+ if self.frozen_stages >= 2:
+ self.pos_drop.eval()
+ for i in range(0, self.frozen_stages - 1):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ if isinstance(pretrained, str):
+ self.apply(_init_weights)
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ self.apply(_init_weights)
+ else:
+ raise TypeError("pretrained must be a str or None")
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.patch_embed(x)
+
+ Wh, Ww = x.size(2), x.size(3)
+ if self.ape:
+ # interpolate the position embedding to the corresponding size
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
+ else:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ outs = []
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+
+ if i in self.out_indices:
+ norm_layer = getattr(self, f"norm{i}")
+ x_out = norm_layer(x_out)
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+ outs.append(out)
+
+ return tuple(outs)
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(SwinTransformer, self).train(mode)
+ self._freeze_stages()
diff --git a/mapmaster/models/backbone/swin_transformer/utils.py b/mapmaster/models/backbone/swin_transformer/utils.py
new file mode 100644
index 0000000..4d239b9
--- /dev/null
+++ b/mapmaster/models/backbone/swin_transformer/utils.py
@@ -0,0 +1,695 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+import io
+import os
+import os.path as osp
+import pkgutil
+import time
+import math
+import logging
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+from tempfile import TemporaryDirectory
+from itertools import repeat
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+import torchvision
+from torch import Tensor
+from torch.optim import Optimizer
+from torch.utils import model_zoo
+from torch.nn import functional as F
+
+# from torch._six import container_abcs
+import collections.abc as container_abcs
+
+import mmcv
+from mmcv.fileio import FileClient
+from mmcv.fileio import load as load_file
+from mmcv.parallel import is_module_wrapper
+from mmcv.utils import mkdir_or_exist
+from mmcv.runner import get_dist_info
+
+ENV_MMCV_HOME = "MMCV_HOME"
+ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
+DEFAULT_CACHE_DIR = "~/.cache"
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, container_abcs.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+def _get_mmcv_home():
+ mmcv_home = os.path.expanduser(
+ os.getenv(ENV_MMCV_HOME, os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "mmcv"))
+ )
+
+ mkdir_or_exist(mmcv_home)
+ return mmcv_home
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2,
+ )
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+logger_initialized = {}
+
+
+def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="w"):
+ """Initialize and get a logger by name.
+
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
+ will also be added.
+
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level. Note that only the process of
+ rank 0 is affected, and other processes will set the level to
+ "Error" thus be silent most of the time.
+ file_mode (str): The file mode used in opening log file.
+ Defaults to 'w'.
+
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ logger = logging.getLogger(name)
+ if name in logger_initialized:
+ return logger
+ # handle hierarchical names
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
+ # initialization since it is a child of "a".
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+
+ # handle duplicate logs to the console
+ # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
+ # to the root logger. As logger.propagate is True by default, this root
+ # level handler causes logging messages from rank>0 processes to
+ # unexpectedly show up on the console, creating much unwanted clutter.
+ # To fix this issue, we set the root logger's StreamHandler, if any, to log
+ # at the ERROR level.
+ for handler in logger.root.handlers:
+ if type(handler) is logging.StreamHandler:
+ handler.setLevel(logging.ERROR)
+
+ stream_handler = logging.StreamHandler()
+ handlers = [stream_handler]
+
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ else:
+ rank = 0
+
+ # only rank 0 will add a FileHandler
+ if rank == 0 and log_file is not None:
+ # Here, the default behaviour of the official logger is 'a'. Thus, we
+ # provide an interface to change the file mode to the default
+ # behaviour.
+ file_handler = logging.FileHandler(log_file, file_mode)
+ handlers.append(file_handler)
+
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
+ for handler in handlers:
+ handler.setFormatter(formatter)
+ handler.setLevel(log_level)
+ logger.addHandler(handler)
+
+ if rank == 0:
+ logger.setLevel(log_level)
+ else:
+ logger.setLevel(logging.ERROR)
+
+ logger_initialized[name] = True
+
+ return logger
+
+
+def get_root_logger(log_file=None, log_level=logging.INFO):
+ """Get the root logger.
+
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added. The name of the root logger is the top-level package name,
+ e.g., "mmseg".
+
+ Args:
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+
+ Returns:
+ logging.Logger: The root logger.
+ """
+
+ logger = get_logger(name="mmseg", log_file=log_file, log_level=log_level)
+
+ return logger
+
+
+def load_state_dict(module, state_dict, strict=False, logger=None):
+ """Load state_dict to a module.
+
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
+ Default value for ``strict`` is set to ``False`` and the message for
+ param mismatch will be shown even if strict is False.
+
+ Args:
+ module (Module): Module that receives the state_dict.
+ state_dict (OrderedDict): Weights.
+ strict (bool): whether to strictly enforce that the keys
+ in :attr:`state_dict` match the keys returned by this module's
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
+ message. If not specified, print function will be used.
+ """
+ unexpected_keys = []
+ all_missing_keys = []
+ err_msg = []
+
+ metadata = getattr(state_dict, "_metadata", None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ # use _load_from_state_dict to enable checkpoint version control
+ def load(module, prefix=""):
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+ module._load_from_state_dict(
+ state_dict, prefix, local_metadata, True, all_missing_keys, unexpected_keys, err_msg
+ )
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ load(module)
+ load = None # break load->load reference cycle
+
+ # ignore "num_batches_tracked" of BN layers
+ missing_keys = [key for key in all_missing_keys if "num_batches_tracked" not in key]
+
+ if unexpected_keys:
+ err_msg.append("unexpected key in source " f'state_dict: {", ".join(unexpected_keys)}\n')
+ if missing_keys:
+ err_msg.append(f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
+
+ rank, _ = get_dist_info()
+ if len(err_msg) > 0 and rank == 0:
+ err_msg.insert(0, "The model and loaded state dict do not match exactly\n")
+ err_msg = "\n".join(err_msg)
+ if strict:
+ raise RuntimeError(err_msg)
+ elif logger is not None:
+ logger.warning(err_msg)
+ else:
+ print(err_msg)
+
+
+def load_url_dist(url, model_dir=None):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get("LOCAL_RANK", rank))
+ if rank == 0:
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+ return checkpoint
+
+
+def load_pavimodel_dist(model_path, map_location=None):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ try:
+ from pavi import modelcloud
+ except ImportError:
+ raise ImportError("Please install pavi to load checkpoint from modelcloud.")
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get("LOCAL_RANK", rank))
+ if rank == 0:
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
+ return checkpoint
+
+
+def load_fileclient_dist(filename, backend, map_location):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get("LOCAL_RANK", rank))
+ allowed_backends = ["ceph"]
+ if backend not in allowed_backends:
+ raise ValueError(f"Load from Backend {backend} is not supported.")
+ if rank == 0:
+ fileclient = FileClient(backend=backend)
+ buffer = io.BytesIO(fileclient.get(filename))
+ checkpoint = torch.load(buffer, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ fileclient = FileClient(backend=backend)
+ buffer = io.BytesIO(fileclient.get(filename))
+ checkpoint = torch.load(buffer, map_location=map_location)
+ return checkpoint
+
+
+def get_torchvision_models():
+ model_urls = dict()
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+ if ispkg:
+ continue
+ _zoo = import_module(f"torchvision.models.{name}")
+ if hasattr(_zoo, "model_urls"):
+ _urls = getattr(_zoo, "model_urls")
+ model_urls.update(_urls)
+ return model_urls
+
+
+def get_external_models():
+ mmcv_home = _get_mmcv_home()
+ default_json_path = osp.join(mmcv.__path__[0], "model_zoo/open_mmlab.json")
+ default_urls = load_file(default_json_path)
+ assert isinstance(default_urls, dict)
+ external_json_path = osp.join(mmcv_home, "open_mmlab.json")
+ if osp.exists(external_json_path):
+ external_urls = load_file(external_json_path)
+ assert isinstance(external_urls, dict)
+ default_urls.update(external_urls)
+
+ return default_urls
+
+
+def get_mmcls_models():
+ mmcls_json_path = osp.join(mmcv.__path__[0], "model_zoo/mmcls.json")
+ mmcls_urls = load_file(mmcls_json_path)
+
+ return mmcls_urls
+
+
+def get_deprecated_model_names():
+ deprecate_json_path = osp.join(mmcv.__path__[0], "model_zoo/deprecated.json")
+ deprecate_urls = load_file(deprecate_json_path)
+ assert isinstance(deprecate_urls, dict)
+
+ return deprecate_urls
+
+
+def _process_mmcls_checkpoint(checkpoint):
+ state_dict = checkpoint["state_dict"]
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k.startswith("backbone."):
+ new_state_dict[k[9:]] = v
+ new_checkpoint = dict(state_dict=new_state_dict)
+
+ return new_checkpoint
+
+
+def _load_checkpoint(filename, map_location=None):
+ """Load checkpoint from somewhere (modelzoo, file, url).
+
+ Args:
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
+
+ Returns:
+ dict | OrderedDict: The loaded checkpoint. It can be either an
+ OrderedDict storing model weights or a dict containing other
+ information, which depends on the checkpoint.
+ """
+ if filename.startswith("modelzoo://"):
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' 'use "torchvision://" instead')
+ model_urls = get_torchvision_models()
+ model_name = filename[11:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ elif filename.startswith("torchvision://"):
+ model_urls = get_torchvision_models()
+ model_name = filename[14:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ elif filename.startswith("open-mmlab://"):
+ model_urls = get_external_models()
+ model_name = filename[13:]
+ deprecated_urls = get_deprecated_model_names()
+ if model_name in deprecated_urls:
+ warnings.warn(
+ f"open-mmlab://{model_name} is deprecated in favor " f"of open-mmlab://{deprecated_urls[model_name]}"
+ )
+ model_name = deprecated_urls[model_name]
+ model_url = model_urls[model_name]
+ # check if is url
+ if model_url.startswith(("http://", "https://")):
+ checkpoint = load_url_dist(model_url)
+ else:
+ filename = osp.join(_get_mmcv_home(), model_url)
+ if not osp.isfile(filename):
+ raise IOError(f"{filename} is not a checkpoint file")
+ checkpoint = torch.load(filename, map_location=map_location)
+ elif filename.startswith("mmcls://"):
+ model_urls = get_mmcls_models()
+ model_name = filename[8:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
+ elif filename.startswith(("http://", "https://")):
+ checkpoint = load_url_dist(filename)
+ elif filename.startswith("pavi://"):
+ model_path = filename[7:]
+ checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
+ elif filename.startswith("s3://"):
+ checkpoint = load_fileclient_dist(filename, backend="ceph", map_location=map_location)
+ else:
+ if not osp.isfile(filename):
+ raise IOError(f"{filename} is not a checkpoint file")
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+
+
+def load_checkpoint(model, filename, map_location="cpu", strict=False, logger=None):
+ """Load checkpoint from a file or URI.
+
+ Args:
+ model (Module): Module to load checkpoint.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str): Same as :func:`torch.load`.
+ strict (bool): Whether to allow different params for the model and
+ checkpoint.
+ logger (:mod:`logging.Logger` or None): The logger for error message.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ checkpoint = _load_checkpoint(filename, map_location)
+ # OrderedDict is a subclass of dict
+ if not isinstance(checkpoint, dict):
+ raise RuntimeError(f"No state_dict found in checkpoint file {filename}")
+ # get state_dict from checkpoint
+ if "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ elif "model" in checkpoint:
+ state_dict = checkpoint["model"]
+ else:
+ state_dict = checkpoint
+ # strip prefix of state_dict
+ if list(state_dict.keys())[0].startswith("module."):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+
+ # for MoBY, load model of online branch
+ if sorted(list(state_dict.keys()))[0].startswith("encoder"):
+ state_dict = {k.replace("encoder.", ""): v for k, v in state_dict.items() if k.startswith("encoder.")}
+
+ # reshape absolute position embedding
+ if state_dict.get("absolute_pos_embed") is not None:
+ absolute_pos_embed = state_dict["absolute_pos_embed"]
+ N1, L, C1 = absolute_pos_embed.size()
+ N2, C2, H, W = model.absolute_pos_embed.size()
+ if N1 != N2 or C1 != C2 or L != H * W:
+ logger.warning("Error in loading absolute_pos_embed, pass")
+ else:
+ state_dict["absolute_pos_embed"] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
+
+ # interpolate position bias table if needed
+ relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
+ for table_key in relative_position_bias_table_keys:
+ table_pretrained = state_dict[table_key]
+ table_current = model.state_dict()[table_key]
+ L1, nH1 = table_pretrained.size()
+ L2, nH2 = table_current.size()
+ if nH1 != nH2:
+ logger.warning(f"Error in loading {table_key}, pass")
+ else:
+ if L1 != L2:
+ S1 = int(L1 ** 0.5)
+ S2 = int(L2 ** 0.5)
+ table_pretrained_resized = F.interpolate(
+ table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode="bicubic"
+ )
+ state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
+
+ # load state_dict
+ load_state_dict(model, state_dict, strict, logger)
+ return checkpoint
+
+
+def weights_to_cpu(state_dict):
+ """Copy a model state_dict to cpu.
+
+ Args:
+ state_dict (OrderedDict): Model weights on GPU.
+
+ Returns:
+ OrderedDict: Model weights on GPU.
+ """
+ state_dict_cpu = OrderedDict()
+ for key, val in state_dict.items():
+ state_dict_cpu[key] = val.cpu()
+ return state_dict_cpu
+
+
+def _save_to_state_dict(module, destination, prefix, keep_vars):
+ """Saves module state to `destination` dictionary.
+
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (dict): A dict where state will be stored.
+ prefix (str): The prefix for parameters and buffers used in this
+ module.
+ """
+ for name, param in module._parameters.items():
+ if param is not None:
+ destination[prefix + name] = param if keep_vars else param.detach()
+ for name, buf in module._buffers.items():
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
+ if buf is not None:
+ destination[prefix + name] = buf if keep_vars else buf.detach()
+
+
+def get_state_dict(module, destination=None, prefix="", keep_vars=False):
+ """Returns a dictionary containing a whole state of the module.
+
+ Both parameters and persistent buffers (e.g. running averages) are
+ included. Keys are corresponding parameter and buffer names.
+
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
+ recursively check parallel module in case that the model has a complicated
+ structure, e.g., nn.Module(nn.Module(DDP)).
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (OrderedDict): Returned dict for the state of the
+ module.
+ prefix (str): Prefix of the key.
+ keep_vars (bool): Whether to keep the variable property of the
+ parameters. Default: False.
+
+ Returns:
+ dict: A dictionary containing a whole state of the module.
+ """
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+
+ # below is the same as torch.nn.Module.state_dict()
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(version=module._version)
+ _save_to_state_dict(module, destination, prefix, keep_vars)
+ for name, child in module._modules.items():
+ if child is not None:
+ get_state_dict(child, destination, prefix + name + ".", keep_vars=keep_vars)
+ for hook in module._state_dict_hooks.values():
+ hook_result = hook(module, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ return destination
+
+
+def save_checkpoint(model, filename, optimizer=None, meta=None):
+ """Save checkpoint to file.
+
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+ ``optimizer``. By default ``meta`` will contain version and time info.
+
+ Args:
+ model (Module): Module whose params are to be saved.
+ filename (str): Checkpoint filename.
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(f"meta must be a dict or None, but got {type(meta)}")
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+
+ if is_module_wrapper(model):
+ model = model.module
+
+ if hasattr(model, "CLASSES") and model.CLASSES is not None:
+ # save class name to the meta
+ meta.update(CLASSES=model.CLASSES)
+
+ checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))}
+ # save optimizer state dict in the checkpoint
+ if isinstance(optimizer, Optimizer):
+ checkpoint["optimizer"] = optimizer.state_dict()
+ elif isinstance(optimizer, dict):
+ checkpoint["optimizer"] = {}
+ for name, optim in optimizer.items():
+ checkpoint["optimizer"][name] = optim.state_dict()
+
+ if filename.startswith("pavi://"):
+ try:
+ from pavi import modelcloud
+ from pavi.exception import NodeNotFoundError
+ except ImportError:
+ raise ImportError("Please install pavi to load checkpoint from modelcloud.")
+ model_path = filename[7:]
+ root = modelcloud.Folder()
+ model_dir, model_name = osp.split(model_path)
+ try:
+ model = modelcloud.get(model_dir)
+ except NodeNotFoundError:
+ model = root.create_training_model(model_dir)
+ with TemporaryDirectory() as tmp_dir:
+ checkpoint_file = osp.join(tmp_dir, model_name)
+ with open(checkpoint_file, "wb") as f:
+ torch.save(checkpoint, f)
+ f.flush()
+ model.create_file(checkpoint_file, name=model_name)
+ else:
+ mmcv.mkdir_or_exist(osp.dirname(filename))
+ # immediately flush buffer
+ with open(filename, "wb") as f:
+ torch.save(checkpoint, f)
+ f.flush()
diff --git a/mapmaster/models/bev_decoder/__init__.py b/mapmaster/models/bev_decoder/__init__.py
new file mode 100644
index 0000000..cd39500
--- /dev/null
+++ b/mapmaster/models/bev_decoder/__init__.py
@@ -0,0 +1 @@
+from .model import TransformerBEVDecoder, DeformTransformerBEVEncoder
diff --git a/mapmaster/models/bev_decoder/deform_transformer/__init__.py b/mapmaster/models/bev_decoder/deform_transformer/__init__.py
new file mode 100644
index 0000000..810662e
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/__init__.py
@@ -0,0 +1 @@
+from .deform_transformer import DeformTransformer
diff --git a/mapmaster/models/bev_decoder/deform_transformer/deform_transformer.py b/mapmaster/models/bev_decoder/deform_transformer/deform_transformer.py
new file mode 100644
index 0000000..28f9ea1
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/deform_transformer.py
@@ -0,0 +1,672 @@
+import copy
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from .ops import MSDeformAttn
+from .position_encoding import PositionEmbeddingSine
+from .position_encoding import PositionEmbeddingLearned
+
+class DeformTransformer(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ src_shape=(16, 168),
+ tgt_shape=(32, 32),
+ d_model=256,
+ n_heads=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=1024,
+ dropout=0.1,
+ activation="relu",
+ return_intermediate_dec=False,
+ dec_n_points=4,
+ enc_n_points=4,
+ src_pos_encode="sine",
+ tgt_pos_encode="learned",
+ norm_layer=nn.BatchNorm2d,
+ use_checkpoint=False,
+ use_projection=False,
+ map_size=(400, 200),
+ image_shape=(900, 1600),
+ map_resolution=0.15,
+ image_order=(2, 1, 0, 5, 4, 3)
+ ):
+ super().__init__()
+
+ if isinstance(in_channels, int):
+ in_channels = [in_channels]
+ if isinstance(src_shape[0], int):
+ src_shape = [src_shape]
+ assert len(src_shape) == len(in_channels)
+ n_levels = len(in_channels)
+
+ self.input_proj = nn.ModuleList()
+ for i in range(len(in_channels)):
+ self.input_proj.append(
+ nn.Sequential(
+ nn.Conv2d(in_channels[i], d_model, kernel_size=1, bias=False),
+ norm_layer(d_model),
+ )
+ )
+
+ encoder_layer = DeformTransformerEncoderLayer(
+ d_model, dim_feedforward, dropout, activation, n_levels, n_heads, enc_n_points, use_checkpoint
+ )
+ self.encoder = DeformTransformerEncoder(encoder_layer, num_encoder_layers)
+
+ decoder_layer = DeformTransformerDecoderLayer(
+ d_model, dim_feedforward, dropout, activation, n_levels, n_heads, dec_n_points, use_checkpoint
+ )
+ self.decoder = DeformTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec)
+
+ self.dropout = nn.Dropout(dropout)
+
+ self.t2s_reference_points = nn.Linear(d_model, 2)
+
+ self._reset_parameters()
+
+ if src_pos_encode == "sine":
+ self.src_pos_embed = PositionEmbeddingSine(d_model, normalize=True)
+ self.src_lvl_embed = nn.Embedding(n_levels, d_model)
+ elif src_pos_encode == "learned":
+ self.src_pos_embed = nn.ModuleList(
+ [PositionEmbeddingLearned(shape, d_model) for shape in src_shape],
+ )
+ else:
+ raise NotImplementedError
+
+ if tgt_pos_encode == "sine":
+ self.tgt_pos_embed = PositionEmbeddingSine(d_model, normalize=True)
+ elif tgt_pos_encode == "learned":
+ self.tgt_pos_embed = PositionEmbeddingLearned(tgt_shape, d_model)
+ else:
+ raise NotImplementedError
+
+ self.tgt_embed = PositionEmbeddingLearned(tgt_shape, d_model)
+
+ self.src_shape = src_shape
+ self.tgt_shape = tgt_shape
+ self.src_pos_encode = src_pos_encode
+ self.tgt_pos_encode = tgt_pos_encode
+
+ """
+ use_projection: bool / whether to use IPM as the reference points
+ map_size: (x_width, y_width) shape of the original Map (400, 200)
+ image_shape: (Height, Width)
+ map_resolution: map resolution (m / pixel)
+ """
+ self.use_projection = use_projection # Use IPM Projection to get reference points
+ self.map_size = map_size
+ self.map_resolution = map_resolution
+ self.image_shape = image_shape
+ image_order = torch.tensor(image_order, dtype=torch.long)
+ self.register_buffer("image_order", image_order)
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MSDeformAttn):
+ m._reset_parameters()
+ nn.init.xavier_uniform_(self.t2s_reference_points.weight, gain=1.0)
+ nn.init.constant_(self.t2s_reference_points.bias, 0.0)
+
+ @staticmethod
+ def get_valid_ratio(mask):
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def get_projection_points(self, extrinsic, intrinsic, flip=False):
+ """
+ extrinsic:
+ torch.Tensor (6, 4, 4)
+ intrinsic:
+ torch.Tensor (6, 3, 3)
+ flip:
+ flip or not
+
+ Return
+ reference points (N, L, 2)
+ mask (N, )
+ """
+ map_forward_ratio = self.tgt_shape[0] / self.map_size[0]
+ map_lateral_ratio = self.tgt_shape[1] / self.map_size[1]
+
+ map_forward_res = self.map_resolution / map_forward_ratio
+ map_lateral_res = self.map_resolution / map_lateral_ratio
+
+ X = (torch.arange(self.tgt_shape[0] - 1, -1, -1, device=extrinsic.device) + 0.5 - self.tgt_shape[0] / 2) * map_forward_res
+ Y = (torch.arange(self.tgt_shape[1] - 1, -1, -1, device=extrinsic.device) + 0.5 - self.tgt_shape[1] / 2) * map_lateral_res
+ if flip:
+ Y = -1 * Y # Flip the Y axis
+
+ Z = torch.zeros(self.tgt_shape, device=extrinsic.device)
+ grid_X, grid_Y = torch.meshgrid(X, Y)
+ coords = torch.stack([grid_X, grid_Y, Z, torch.ones(self.tgt_shape, device=extrinsic.device)], dim=-1) # (H, W, 4) homogeneous coordinates
+ coords_flatten = coords.reshape(-1, 4) # (N, 4)
+
+ cams = []
+ for cam in extrinsic:
+ cam_coords = torch.linalg.inv(cam) @ coords_flatten.T # (4, N)
+ cam_coords = cam_coords[:3, :] # (3, N) -- x, y, z
+ cams.append(cam_coords)
+ cams = torch.stack(cams, dim=0) # (6, 3, N) Coordinates in Camera Frame
+ normed_coors = F.normalize(cams, p=1, dim=0) # (6, 3, N) Normalized Coordinates in Camera Frame
+
+ cams_z = normed_coors[:, 2, :] # (6, N) z coord
+ cam_id = torch.argmax(cams_z, dim=0) # (N, ) -- bev to img idx, Choose the camera with the smallest angle of view
+
+ max_z = cams_z[cam_id, torch.arange(cams.shape[-1])]
+ valid_mask = max_z > 0
+
+ intrinsic_percam = intrinsic[cam_id] # (N, 3, 3)
+
+ coords_percam = cams[cam_id, :, torch.arange(cams.shape[2])] # (N, 3)
+ pixel_coord = (intrinsic_percam @ coords_percam[:, :, None]).squeeze() # (N, 3)
+ pixel_coord = pixel_coord[:, :2] / pixel_coord[:, [2]] # divided by Z / (N, 2)
+
+ if not isinstance(self.image_shape, list):
+ image_shape = torch.tensor([self.image_shape for _ in range(len(extrinsic))], device=extrinsic.device)[cam_id]
+ else:
+ image_shape = torch.tensor(self.image_shape, device=extrinsic.device)[cam_id]
+
+ valid_pixelx = torch.bitwise_and(pixel_coord[:, 0] < image_shape[:,1], pixel_coord[:, 0] >= 0)
+ valid_pixely = torch.bitwise_and(pixel_coord[:, 1] < image_shape[:,0], pixel_coord[:, 1] >= 0)
+ valid_mask = valid_mask * valid_pixelx * valid_pixely
+
+ # cast to levels
+ reference_points = []
+ for level_shape in self.src_shape:
+ level_h, level_w = level_shape
+ level_w /= 6
+ image_h, image_w = image_shape.T
+
+ ratio_h = image_h / level_h
+ ratio_w = image_w / level_w
+
+ if flip:
+ img_x = level_w - pixel_coord[:, 0] / ratio_w
+ cam_id_ = self.image_order[cam_id]
+ x = cam_id_ * level_w + img_x
+
+ else:
+ x = cam_id * level_w + pixel_coord[:, 0] / ratio_w
+
+ y = pixel_coord[:, 1] / ratio_h
+
+ x /= (level_w * 6) # Normalize to [0 ~ 1]
+ y /= level_h
+
+ level_point = torch.stack([x, y], dim=-1)
+ reference_points.append(level_point)
+
+ reference_points = torch.stack(reference_points, dim=-2)
+ return reference_points, valid_mask
+
+ def forward(self, srcs, src_masks=None, cameras_info=None):
+
+ if not isinstance(srcs, (list, tuple)):
+ srcs = [srcs]
+ if not isinstance(src_masks, (list, tuple)):
+ src_masks = [src_masks]
+
+ if src_masks[0] is None:
+ src_masks = []
+
+ src_flatten = []
+ src_mask_flatten = []
+ src_pos_embed_flatten = []
+ src_spatial_shapes = []
+ for i, src in enumerate(srcs):
+ bs, c, h, w = src.shape
+ spatial_shape = (h, w)
+ src_spatial_shapes.append(spatial_shape)
+
+ if len(src_masks) < i + 1:
+ src_mask = torch.zeros((bs, h, w), dtype=torch.bool, device=src.device) # (N, H, W)
+ src_masks.append(src_mask)
+ else:
+ src_mask = src_masks[i]
+
+ if self.src_pos_encode == "sine":
+ src_pos_embed = self.src_pos_embed(src_mask) # (N, C, H, W)
+ src_pos_embed = src_pos_embed + self.src_lvl_embed.weight[i].view(-1, 1, 1) # (N, C, H, W)
+ else:
+ src_pos_embed = self.src_pos_embed[i](src_mask) # (N, C, H, W)
+
+ src = self.input_proj[i](src) # (N, C, H, W)
+ src = src + src_pos_embed # (N, C, H, W)
+
+ src = src.flatten(2).transpose(1, 2) # (N, H * W, C)
+ src_mask = src_mask.flatten(1) # (N, H * W)
+ src_pos_embed = src_pos_embed.flatten(2).transpose(1, 2) # (N, H * W, C)
+
+ src_flatten.append(src)
+ src_mask_flatten.append(src_mask)
+ src_pos_embed_flatten.append(src_pos_embed)
+
+ src = torch.cat(src_flatten, 1) # (N, L * H * W, C)
+ src_mask = torch.cat(src_mask_flatten, 1) # (N, L * H * W)
+ src_pos_embed = torch.cat(src_pos_embed_flatten, 1) # (N, L * H * W, C)
+ src_spatial_shapes = torch.as_tensor(src_spatial_shapes, dtype=torch.long, device=src.device) # (L, 2)
+ src_level_start_index = torch.cat(
+ (src_spatial_shapes.new_zeros((1,)), src_spatial_shapes.prod(1).cumsum(0)[:-1])
+ ) # (L,)
+ src_valid_ratios = torch.stack([self.get_valid_ratio(m) for m in src_masks], 1) # (N, L, 2)
+
+ tgt_mask = torch.zeros((srcs[0].size(0), *self.tgt_shape), dtype=torch.bool, device=srcs[0].device) # (N, H, W)
+ tgt_pos_embed = self.tgt_pos_embed(tgt_mask) # (N, C, H, W)
+ tgt_pos_embed = tgt_pos_embed.flatten(2).transpose(1, 2) # (N, H * W, C)
+ # tgt = tgt_pos_embed # (N, H * W, C)
+ tgt = self.tgt_embed(tgt_mask).flatten(2).transpose(1, 2)
+
+ tgt_spatial_shapes = torch.as_tensor(self.tgt_shape, dtype=torch.long, device=tgt.device).unsqueeze(0) # (1, 2)
+ tgt_valid_ratios = self.get_valid_ratio(tgt_mask).unsqueeze(1) # (N, 1, 2)
+ tgt_level_start_index = tgt_spatial_shapes.new_zeros((1,)) # (1,)
+ tgt_mask = tgt_mask.flatten(1) # (N, 1 * H * W)
+
+ t2s_reference_points = self.t2s_reference_points(tgt_pos_embed).sigmoid() # (N, H * W, 2)
+
+ if self.use_projection:
+ t2s_reference_points = t2s_reference_points.unsqueeze(-2).repeat(1, 1, len(self.src_shape), 1) # (N, H * W, L, 2)
+ bs = srcs[0].shape[0]
+
+ do_flip = cameras_info['do_flip']
+ if do_flip is None:
+ do_flip = torch.zeros((bs,), dtype=torch.bool)
+
+ for i in range(bs):
+ flip = do_flip[i].item()
+ extrinsic = cameras_info['extrinsic'][i].float()
+ intrinsic = cameras_info['intrinsic'][i].float()
+
+ # Use IPM to generate reference points, Original Size (900, 1600)
+ ipm_reference_points, valid_mask = self.get_projection_points(extrinsic, intrinsic, flip) # (N, L, 2)
+ loc = torch.where(valid_mask > 0)[0]
+
+ # Change the embeddings to reference point coordinate
+ t2s_reference_points[i, loc, :, :] = ipm_reference_points[loc, :, :]
+ else:
+ t2s_reference_points = t2s_reference_points[:, :, None]
+
+ # encoder
+ memory = self.encoder(
+ src, src_spatial_shapes, src_level_start_index, src_valid_ratios, src_pos_embed, src_mask
+ ) # (N, H * W, C)
+ # decoder
+ hs = self.decoder(
+ tgt,
+ memory,
+ tgt_pos_embed,
+ t2s_reference_points,
+ tgt_spatial_shapes,
+ src_spatial_shapes,
+ tgt_level_start_index,
+ src_level_start_index,
+ tgt_valid_ratios,
+ src_valid_ratios,
+ tgt_mask,
+ src_mask,
+ ) # (M, N, H * W, C)
+ ys = hs.transpose(2, 3) # (M, N, C, H * W)
+ ys = ys.reshape(*ys.shape[:-1], *self.tgt_shape).contiguous() # (M, N, C, H, W)
+ return [memory, hs, ys]
+
+
+class DeformTransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model=256,
+ d_ffn=1024,
+ dropout=0.1,
+ activation="relu",
+ n_levels=4,
+ n_heads=8,
+ n_points=4,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+
+ # self attention
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation)
+ self.dropout2 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout3 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ self.use_checkpoint = use_checkpoint
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, src):
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+ src = src + self.dropout3(src2)
+ src = self.norm2(src)
+ return src
+
+ def _forward(
+ self, src, src_pos_embed, src_reference_points, src_spatial_shapes, src_level_start_index, src_key_padding_mask
+ ):
+ # self attention
+ src2 = self.self_attn(
+ self.with_pos_embed(src, src_pos_embed),
+ src_reference_points,
+ src,
+ src_spatial_shapes,
+ src_level_start_index,
+ src_key_padding_mask,
+ )
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+
+ # ffn
+ src = self.forward_ffn(src)
+ return src
+
+ def forward(
+ self, src, src_pos_embed, src_reference_points, src_spatial_shapes, src_level_start_index, src_key_padding_mask
+ ):
+ if self.use_checkpoint and self.training:
+ src = checkpoint.checkpoint(
+ self._forward,
+ src,
+ src_pos_embed,
+ src_reference_points,
+ src_spatial_shapes,
+ src_level_start_index,
+ src_key_padding_mask,
+ )
+ else:
+ src = self._forward(
+ src,
+ src_pos_embed,
+ src_reference_points,
+ src_spatial_shapes,
+ src_level_start_index,
+ src_key_padding_mask,
+ )
+ return src
+
+
+class DeformTransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+
+ @staticmethod
+ def get_reference_points(spatial_shapes, valid_ratios, device):
+ reference_points_list = []
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+ ref_y, ref_x = torch.meshgrid(
+ torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
+ )
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+ def forward(
+ self,
+ src,
+ src_spatial_shapes,
+ src_level_start_index,
+ src_valid_ratios,
+ src_pos_embed=None,
+ src_key_padding_mask=None,
+ ):
+
+ src_reference_points = self.get_reference_points(src_spatial_shapes, src_valid_ratios, device=src.device)
+
+ output = src
+ for _, layer in enumerate(self.layers):
+ output = layer(
+ output,
+ src_pos_embed,
+ src_reference_points,
+ src_spatial_shapes,
+ src_level_start_index,
+ src_key_padding_mask,
+ )
+ return output
+
+
+class DeformTransformerDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model=256,
+ d_ffn=1024,
+ dropout=0.1,
+ activation="relu",
+ n_levels=4,
+ n_heads=8,
+ n_points=4,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+
+ # cross attention
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # self attention
+ # self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
+
+ self.self_attn = MSDeformAttn(d_model, 1, n_heads, n_points)
+ self.dropout2 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation)
+ self.dropout3 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout4 = nn.Dropout(dropout)
+ self.norm3 = nn.LayerNorm(d_model)
+
+ self.use_checkpoint = use_checkpoint
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, tgt):
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout4(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def _forward(
+ self,
+ tgt,
+ src,
+ tgt_pos_embed,
+ tgt_reference_points,
+ t2s_reference_points,
+ tgt_spatial_shapes,
+ src_spatial_shapes,
+ tgt_level_start_index,
+ src_level_start_index,
+ tgt_key_padding_mask,
+ src_key_padding_mask,
+ ):
+ # self attention
+ # q = k = self.with_pos_embed(tgt, tgt_pos_embed)
+ # tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
+
+ tgt2 = self.self_attn(
+ self.with_pos_embed(tgt, tgt_pos_embed),
+ tgt_reference_points,
+ tgt,
+ tgt_spatial_shapes,
+ tgt_level_start_index,
+ tgt_key_padding_mask,
+ )
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+
+ # cross attention
+ tgt2 = self.cross_attn(
+ self.with_pos_embed(tgt, tgt_pos_embed),
+ t2s_reference_points,
+ src,
+ src_spatial_shapes,
+ src_level_start_index,
+ src_key_padding_mask,
+ )
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+
+ # ffn
+ tgt = self.forward_ffn(tgt)
+ return tgt
+
+ def forward(
+ self,
+ tgt,
+ src,
+ tgt_pos_embed,
+ tgt_reference_points,
+ t2s_reference_points,
+ tgt_spatial_shapes,
+ src_spatial_shapes,
+ tgt_level_start_index,
+ src_level_start_index,
+ tgt_key_padding_mask,
+ src_key_padding_mask,
+ ):
+ if self.use_checkpoint and self.training:
+ tgt = checkpoint.checkpoint(
+ self._forward,
+ tgt,
+ src,
+ tgt_pos_embed,
+ tgt_reference_points,
+ t2s_reference_points,
+ tgt_spatial_shapes,
+ src_spatial_shapes,
+ tgt_level_start_index,
+ src_level_start_index,
+ tgt_key_padding_mask,
+ src_key_padding_mask,
+ )
+ else:
+ tgt = self._forward(
+ tgt,
+ src,
+ tgt_pos_embed,
+ tgt_reference_points,
+ t2s_reference_points,
+ tgt_spatial_shapes,
+ src_spatial_shapes,
+ tgt_level_start_index,
+ src_level_start_index,
+ tgt_key_padding_mask,
+ src_key_padding_mask,
+ )
+ return tgt
+
+
+class DeformTransformerDecoder(nn.Module):
+ def __init__(self, decoder_layer, num_layers, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.return_intermediate = return_intermediate
+
+ @staticmethod
+ def get_reference_points(spatial_shapes, valid_ratios, device):
+ reference_points_list = []
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+ ref_y, ref_x = torch.meshgrid(
+ torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
+ )
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+ def forward(
+ self,
+ tgt,
+ src,
+ tgt_pos_embed,
+ t2s_reference_points,
+ tgt_spatial_shapes,
+ src_spatial_shapes,
+ tgt_level_start_index,
+ src_level_start_index,
+ tgt_valid_ratios,
+ src_valid_ratios,
+ tgt_key_padding_mask=None,
+ src_key_padding_mask=None,
+ ):
+
+ tgt_reference_points = self.get_reference_points(tgt_spatial_shapes, tgt_valid_ratios, device=tgt.device)
+ t2s_reference_points = t2s_reference_points * src_valid_ratios[:, None]
+ # t2s_reference_points = t2s_reference_points[:, :, None] * src_valid_ratios[:, None]
+
+ intermediate = []
+ output = tgt
+ for _, layer in enumerate(self.layers):
+ output = layer(
+ output,
+ src,
+ tgt_pos_embed,
+ tgt_reference_points,
+ t2s_reference_points,
+ tgt_spatial_shapes,
+ src_spatial_shapes,
+ tgt_level_start_index,
+ src_level_start_index,
+ tgt_key_padding_mask,
+ src_key_padding_mask,
+ )
+
+ if self.return_intermediate:
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output.unsqueeze(0)
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/__init__.py b/mapmaster/models/bev_decoder/deform_transformer/ops/__init__.py
new file mode 100644
index 0000000..d8f3e3c
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/__init__.py
@@ -0,0 +1 @@
+from .modules import MSDeformAttn
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/functions/__init__.py b/mapmaster/models/bev_decoder/deform_transformer/ops/functions/__init__.py
new file mode 100644
index 0000000..06ebc91
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/functions/__init__.py
@@ -0,0 +1,9 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from .ms_deform_attn_func import MSDeformAttnFunction
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/functions/ms_deform_attn_func.py b/mapmaster/models/bev_decoder/deform_transformer/ops/functions/ms_deform_attn_func.py
new file mode 100644
index 0000000..7cd8d12
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/functions/ms_deform_attn_func.py
@@ -0,0 +1,73 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+import MultiScaleDeformableAttention as MSDA
+
+
+class MSDeformAttnFunction(Function):
+ @staticmethod
+ def forward(
+ ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step
+ ):
+ ctx.im2col_step = im2col_step
+ output = MSDA.ms_deform_attn_forward(
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step
+ )
+ ctx.save_for_backward(
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
+ )
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
+ grad_value, grad_sampling_loc, grad_attn_weight = MSDA.ms_deform_attn_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ grad_output,
+ ctx.im2col_step,
+ )
+
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
+
+
+def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
+ # for debug and test only,
+ # need to use cuda version instead
+ N_, S_, M_, D_ = value.shape
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_)
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
+ # N_*M_, D_, Lq_, P_
+ sampling_value_l_ = F.grid_sample(
+ value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
+ )
+ sampling_value_list.append(sampling_value_l_)
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_)
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_)
+ return output.transpose(1, 2).contiguous()
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/make.sh b/mapmaster/models/bev_decoder/deform_transformer/ops/make.sh
new file mode 100644
index 0000000..106b685
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/make.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+python setup.py build install
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/modules/__init__.py b/mapmaster/models/bev_decoder/deform_transformer/ops/modules/__init__.py
new file mode 100644
index 0000000..ff5ce83
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/modules/__init__.py
@@ -0,0 +1,9 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from .ms_deform_attn import MSDeformAttn
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/modules/ms_deform_attn.py b/mapmaster/models/bev_decoder/deform_transformer/ops/modules/ms_deform_attn.py
new file mode 100644
index 0000000..c091ed6
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/modules/ms_deform_attn.py
@@ -0,0 +1,140 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import warnings
+import math
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn.init import xavier_uniform_, constant_
+
+from ..functions import MSDeformAttnFunction
+
+
+def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
+ return (n & (n - 1) == 0) and n != 0
+
+
+class MSDeformAttn(nn.Module):
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
+ """
+ Multi-Scale Deformable Attention Module
+ :param d_model hidden dimension
+ :param n_levels number of feature levels
+ :param n_heads number of attention heads
+ :param n_points number of sampling points per attention head per feature level
+ """
+ super().__init__()
+ if d_model % n_heads != 0:
+ raise ValueError("d_model must be divisible by n_heads, but got {} and {}".format(d_model, n_heads))
+ _d_per_head = d_model // n_heads
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
+ if not _is_power_of_2(_d_per_head):
+ warnings.warn(
+ "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
+ "which is more efficient in our CUDA implementation."
+ )
+
+ self.im2col_step = 64
+
+ self.d_model = d_model
+ self.n_levels = n_levels
+ self.n_heads = n_heads
+ self.n_points = n_points
+
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
+ self.value_proj = nn.Linear(d_model, d_model)
+ self.output_proj = nn.Linear(d_model, d_model)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ constant_(self.sampling_offsets.weight.data, 0.0)
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(self.n_heads, 1, 1, 2)
+ .repeat(1, self.n_levels, self.n_points, 1)
+ )
+ for i in range(self.n_points):
+ grid_init[:, :, i, :] *= i + 1
+ with torch.no_grad():
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+ constant_(self.attention_weights.weight.data, 0.0)
+ constant_(self.attention_weights.bias.data, 0.0)
+ xavier_uniform_(self.value_proj.weight.data)
+ constant_(self.value_proj.bias.data, 0.0)
+ xavier_uniform_(self.output_proj.weight.data)
+ constant_(self.output_proj.bias.data, 0.0)
+
+ def forward(
+ self,
+ query,
+ reference_points,
+ input_flatten,
+ input_spatial_shapes,
+ input_level_start_index,
+ input_padding_mask=None,
+ ):
+ """
+ :param query (N, Length_{query}, C)
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
+
+ :return output (N, Length_{query}, C)
+ """
+ N, Len_q, _ = query.shape
+ N, Len_in, _ = input_flatten.shape
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
+
+ value = self.value_proj(input_flatten)
+ if input_padding_mask is not None:
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
+ # N, Len_q, n_heads, n_levels, n_points, 2
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :]
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+ )
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :2]
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
+ )
+ else:
+ raise ValueError(
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1])
+ )
+ output = MSDeformAttnFunction.apply(
+ value,
+ input_spatial_shapes,
+ input_level_start_index,
+ sampling_locations,
+ attention_weights,
+ self.im2col_step,
+ )
+ output = self.output_proj(output)
+ return output
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/setup.py b/mapmaster/models/bev_decoder/deform_transformer/ops/setup.py
new file mode 100644
index 0000000..148dc05
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/setup.py
@@ -0,0 +1,84 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+import os
+import glob
+
+import torch
+
+from torch.utils.cpp_extension import CUDA_HOME
+from torch.utils.cpp_extension import CppExtension
+from torch.utils.cpp_extension import CUDAExtension
+
+from setuptools import find_packages
+from setuptools import setup
+
+requirements = ["torch", "torchvision"]
+
+
+def get_extensions():
+ this_dir = os.path.dirname(os.path.abspath(__file__))
+ extensions_dir = os.path.join(this_dir, "src")
+
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
+
+ sources = main_file + source_cpu
+ extension = CppExtension
+ extra_compile_args = {"cxx": []}
+ define_macros = []
+
+ if torch.cuda.is_available() and CUDA_HOME is not None:
+ extension = CUDAExtension
+ sources += source_cuda
+ define_macros += [("WITH_CUDA", None)]
+ extra_compile_args["nvcc"] = [
+ "-DCUDA_HAS_FP16=1",
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ "-arch=sm_60",
+ "-gencode=arch=compute_60,code=sm_60",
+ "-gencode=arch=compute_61,code=sm_61",
+ "-gencode=arch=compute_70,code=sm_70",
+ "-gencode=arch=compute_75,code=sm_75",
+ # "-gencode=arch=compute_80,code=sm_80",
+ ]
+ else:
+ raise NotImplementedError("Cuda is not availabel")
+
+ sources = [os.path.join(extensions_dir, s) for s in sources]
+ include_dirs = [extensions_dir]
+ ext_modules = [
+ extension(
+ "MultiScaleDeformableAttention",
+ sources,
+ include_dirs=include_dirs,
+ define_macros=define_macros,
+ extra_compile_args=extra_compile_args,
+ )
+ ]
+ return ext_modules
+
+
+setup(
+ name="MultiScaleDeformableAttention",
+ version="1.0",
+ author="Weijie Su",
+ url="https://github.com/fundamentalvision/Deformable-DETR",
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
+ packages=find_packages(
+ exclude=(
+ "configs",
+ "tests",
+ )
+ ),
+ ext_modules=get_extensions(),
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
+)
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/cpu/ms_deform_attn_cpu.cpp b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cpu/ms_deform_attn_cpu.cpp
new file mode 100644
index 0000000..dc1c0a1
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cpu/ms_deform_attn_cpu.cpp
@@ -0,0 +1,40 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+
+#include
+#include
+
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/cpu/ms_deform_attn_cpu.h b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cpu/ms_deform_attn_cpu.h
new file mode 100644
index 0000000..67422ce
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cpu/ms_deform_attn_cpu.h
@@ -0,0 +1,31 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_attn_cuda.cu b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_attn_cuda.cu
new file mode 100644
index 0000000..aa9ee57
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_attn_cuda.cu
@@ -0,0 +1,153 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+#include "cuda/ms_deform_im2col_cuda.cuh"
+
+#include
+#include
+#include
+#include
+
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+ const int batch_n = im2col_step_;
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto columns = output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ columns.data());
+
+ }));
+ }
+
+ output = output.view({batch, num_query, num_heads*channels});
+
+ return output;
+}
+
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto grad_value = at::zeros_like(value);
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
+ auto grad_attn_weight = at::zeros_like(attn_weight);
+
+ const int batch_n = im2col_step_;
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto grad_output_g = grad_output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
+ grad_output_g.data(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ grad_value.data() + n * im2col_step_ * per_value_size,
+ grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
+
+ }));
+ }
+
+ return {
+ grad_value, grad_sampling_loc, grad_attn_weight
+ };
+}
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_attn_cuda.h b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_attn_cuda.h
new file mode 100644
index 0000000..2d03704
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_attn_cuda.h
@@ -0,0 +1,29 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_im2col_cuda.cuh b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_im2col_cuda.cuh
new file mode 100644
index 0000000..3d86a01
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_im2col_cuda.cuh
@@ -0,0 +1,1327 @@
+/*!
+**************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************
+* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
+* Copyright (c) 2018 Microsoft
+**************************************************************************
+*/
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+ const scalar_t* data_value,
+ const int64_t* data_spatial_shapes,
+ const int64_t* data_level_start_index,
+ const scalar_t* data_sampling_loc,
+ const scalar_t* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* data_col)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+template
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+ const scalar_t* grad_col,
+ const scalar_t* data_value,
+ const int64_t * data_spatial_shapes,
+ const int64_t * data_level_start_index,
+ const scalar_t * data_sampling_loc,
+ const scalar_t * data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024)
+ {
+ if ((channels & 1023) == 0)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ else{
+ switch(channels)
+ {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ default:
+ if (channels < 64)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/ms_deform_attn.h b/mapmaster/models/bev_decoder/deform_transformer/ops/src/ms_deform_attn.h
new file mode 100644
index 0000000..c5c6166
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/ms_deform_attn.h
@@ -0,0 +1,61 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+
+#include "cpu/ms_deform_attn_cpu.h"
+
+#ifdef WITH_CUDA
+#include "cuda/ms_deform_attn_cuda.h"
+#endif
+
+
+at::Tensor
+ms_deform_attn_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_forward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::vector
+ms_deform_attn_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_backward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/vision.cpp b/mapmaster/models/bev_decoder/deform_transformer/ops/src/vision.cpp
new file mode 100644
index 0000000..d37eafb
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/vision.cpp
@@ -0,0 +1,16 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include "ms_deform_attn.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
+}
diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/test.py b/mapmaster/models/bev_decoder/deform_transformer/ops/test.py
new file mode 100644
index 0000000..50bb67e
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/ops/test.py
@@ -0,0 +1,117 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import torch
+from torch.autograd import gradcheck
+
+from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
+
+
+N, M, D = 1, 2, 2
+Lq, L, P = 2, 2, 2
+shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
+level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1]))
+S = sum([(H * W).item() for H, W in shapes])
+
+
+torch.manual_seed(3)
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_double():
+ value = torch.rand(N, S, M, D).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ output_pytorch = (
+ ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double())
+ .detach()
+ .cpu()
+ )
+ output_cuda = (
+ MSDeformAttnFunction.apply(
+ value.double(),
+ shapes,
+ level_start_index,
+ sampling_locations.double(),
+ attention_weights.double(),
+ im2col_step,
+ )
+ .detach()
+ .cpu()
+ )
+ fwdok = torch.allclose(output_cuda, output_pytorch)
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+ print(
+ f"* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}"
+ )
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_float():
+ value = torch.rand(N, S, M, D).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
+ output_cuda = (
+ MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step)
+ .detach()
+ .cpu()
+ )
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+ print(
+ f"* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}"
+ )
+
+
+def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
+
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ func = MSDeformAttnFunction.apply
+
+ value.requires_grad = grad_value
+ sampling_locations.requires_grad = grad_sampling_loc
+ attention_weights.requires_grad = grad_attn_weight
+
+ gradok = gradcheck(
+ func,
+ (
+ value.double(),
+ shapes,
+ level_start_index,
+ sampling_locations.double(),
+ attention_weights.double(),
+ im2col_step,
+ ),
+ )
+
+ print(f"* {gradok} check_gradient_numerical(D={channels})")
+
+
+if __name__ == "__main__":
+ check_forward_equal_with_pytorch_double()
+ check_forward_equal_with_pytorch_float()
+
+ for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
+ check_gradient_numerical(channels, True, True, True)
diff --git a/mapmaster/models/bev_decoder/deform_transformer/position_encoding.py b/mapmaster/models/bev_decoder/deform_transformer/position_encoding.py
new file mode 100644
index 0000000..7ecdcaf
--- /dev/null
+++ b/mapmaster/models/bev_decoder/deform_transformer/position_encoding.py
@@ -0,0 +1,63 @@
+"""
+Various positional encodings for the transformer.
+"""
+import math
+import torch
+from torch import nn
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, mask):
+ assert mask is not None
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats // 2, dtype=torch.float32, device=mask.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / (self.num_pos_feats // 2))
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+
+ def __init__(self, num_pos=(50, 50), num_pos_feats=256):
+ super().__init__()
+ self.num_pos = num_pos
+ self.pos_embed = nn.Embedding(num_pos[0] * num_pos[1], num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.normal_(self.pos_embed.weight)
+
+ def forward(self, mask):
+ h, w = mask.shape[-2:]
+ pos = self.pos_embed.weight.view(*self.num_pos, -1)[:h, :w]
+ pos = pos.permute(2, 0, 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1)
+ return pos
diff --git a/mapmaster/models/bev_decoder/model.py b/mapmaster/models/bev_decoder/model.py
new file mode 100644
index 0000000..daddb64
--- /dev/null
+++ b/mapmaster/models/bev_decoder/model.py
@@ -0,0 +1,53 @@
+import torch
+import numpy as np
+import torch.nn as nn
+from mapmaster.models.bev_decoder.transformer import Transformer
+from mapmaster.models.bev_decoder.deform_transformer import DeformTransformer
+
+class TransformerBEVDecoder(nn.Module):
+ def __init__(self, key='im_bkb_features', **kwargs):
+ super(TransformerBEVDecoder, self).__init__()
+ self.bev_encoder = Transformer(**kwargs)
+ self.key = key
+
+ def forward(self, inputs):
+ assert self.key in inputs
+ feats = inputs[self.key]
+ fuse_feats = feats[-1]
+ fuse_feats = fuse_feats.reshape(*inputs['images'].shape[:2], *fuse_feats.shape[-3:])
+ fuse_feats = torch.cat(torch.unbind(fuse_feats, dim=1), dim=-1)
+
+ cameras_info = {
+ 'extrinsic': inputs.get('extrinsic', None),
+ 'intrinsic': inputs.get('intrinsic', None),
+ 'ida_mats': inputs.get('ida_mats', None),
+ 'do_flip': inputs['extra_infos'].get('do_flip', None)
+ }
+
+ _, _, bev_feats = self.bev_encoder(fuse_feats, cameras_info=cameras_info)
+
+ return {"bev_enc_features": list(bev_feats)}
+
+class DeformTransformerBEVEncoder(nn.Module):
+ def __init__(self, **kwargs):
+ super(DeformTransformerBEVEncoder, self).__init__()
+ self.bev_encoder = DeformTransformer(**kwargs)
+
+ def forward(self, inputs):
+ assert "im_bkb_features" in inputs
+ feats = inputs["im_bkb_features"]
+ for i in range(len(feats)):
+ feats[i] = feats[i].reshape(*inputs["images"].shape[:2], *feats[i].shape[-3:])
+ feats[i] = feats[i].permute(0, 2, 3, 1, 4)
+ feats[i] = feats[i].reshape(*feats[i].shape[:3], -1)
+ cameras_info = {
+ 'extrinsic': inputs.get('extrinsic', None),
+ 'intrinsic': inputs.get('intrinsic', None),
+ 'do_flip': inputs['extra_infos'].get('do_flip', None)
+ }
+ # src_feats: (N, H1 * W1, C) tgt_feats: # (M, N, H2 * W2, C)
+ _, _, bev_feats = self.bev_encoder(feats, cameras_info=cameras_info)
+
+ return {
+ "bev_enc_features": list(bev_feats),
+ }
diff --git a/mapmaster/models/bev_decoder/transformer.py b/mapmaster/models/bev_decoder/transformer.py
new file mode 100644
index 0000000..7e6bab7
--- /dev/null
+++ b/mapmaster/models/bev_decoder/transformer.py
@@ -0,0 +1,407 @@
+import copy
+import torch
+import numpy as np
+import torch.nn.functional as F
+from torch import nn, Tensor
+from typing import Optional
+from torch.utils.checkpoint import checkpoint
+from mapmaster.models.utils.position_encoding import PositionEmbeddingSine, PositionEmbeddingLearned
+from mapmaster.models.utils.position_encoding import PositionEmbeddingIPM, PositionEmbeddingTgt
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ src_shape=(32, 336),
+ query_shape=(32, 32),
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ src_pos_embed='sine',
+ tgt_pos_embed='sine',
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ use_checkpoint=True,
+ ipm_proj_conf=None,
+ ipmpe_with_sine=True,
+ enforce_no_aligner=False,
+ ):
+ super().__init__()
+
+ self.src_shape = src_shape
+ self.query_shape = query_shape
+ self.d_model = d_model
+ self.nhead = nhead
+ self.ipm_proj_conf = ipm_proj_conf
+
+ self.ipmpe_with_sine = ipmpe_with_sine
+ self.enforce_no_aligner = enforce_no_aligner
+
+ num_queries = np.prod(query_shape).item()
+ self.input_proj = nn.Conv2d(in_channels, d_model, kernel_size=1)
+ self.query_embed = nn.Embedding(num_queries, d_model)
+ src_pe, tgt_pe = self._get_pos_embed_layers(src_pos_embed, tgt_pos_embed)
+ self.src_pos_embed_layer, self.tgt_pos_embed_layer = src_pe, tgt_pe
+
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm, use_checkpoint)
+
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer, num_decoder_layers, decoder_norm, return_intermediate_dec, use_checkpoint
+ )
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def _get_pos_embed_layers(self, src_pos_embed, tgt_pos_embed):
+
+ pos_embed_encoder = None
+ if (src_pos_embed.startswith('ipm_learned')) and (tgt_pos_embed.startswith('ipm_learned')) \
+ and not self.enforce_no_aligner:
+ pos_embed_encoder = nn.Sequential(
+ nn.Conv2d(self.d_model, self.d_model * 4, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(),
+ nn.Conv2d(self.d_model * 4, self.d_model, kernel_size=1, stride=1, padding=0)
+ )
+
+ if src_pos_embed == 'sine':
+ src_pos_embed_layer = PositionEmbeddingSine(self.d_model // 2, normalize=True)
+ elif src_pos_embed == 'learned':
+ src_pos_embed_layer = PositionEmbeddingLearned(self.src_shape, self.d_model)
+ elif src_pos_embed == 'ipm_learned':
+ input_shape = self.ipm_proj_conf['input_shape']
+ src_pos_embed_layer = PositionEmbeddingIPM(
+ pos_embed_encoder, self.src_shape, input_shape, num_pos_feats=self.d_model,
+ sine_encoding=self.ipmpe_with_sine)
+ else:
+ raise NotImplementedError
+ self.src_pos_embed = src_pos_embed
+
+ if tgt_pos_embed == 'sine':
+ tgt_pos_embed_layer = PositionEmbeddingSine(self.d_model // 2, normalize=True)
+ elif tgt_pos_embed == 'learned':
+ tgt_pos_embed_layer = PositionEmbeddingLearned(self.query_shape, self.d_model)
+ elif tgt_pos_embed == 'ipm_learned':
+ map_size, map_res = self.ipm_proj_conf['map_size'], self.ipm_proj_conf['map_resolution']
+ tgt_pos_embed_layer = PositionEmbeddingTgt(
+ pos_embed_encoder, self.query_shape, map_size, map_res, num_pos_feats=self.d_model, sine_encoding=True)
+ else:
+ raise NotImplementedError
+ self.tgt_pos_embed = tgt_pos_embed
+
+ return src_pos_embed_layer, tgt_pos_embed_layer
+
+ def forward(self, src, mask=None, cameras_info=None):
+
+ bs, c, h, w = src.shape
+ if mask is None:
+ mask = torch.zeros((bs, h, w), dtype=torch.bool, device=src.device) # (B, H, W)
+
+ src = self.input_proj(src) # (B, C, H, W)
+ src = src.flatten(2).permute(2, 0, 1) # (H* W, B, C)
+
+ if self.src_pos_embed.startswith('ipm_learned'):
+ extrinsic = cameras_info['extrinsic'].float()
+ intrinsic = cameras_info['intrinsic'].float()
+ ida_mats = cameras_info['ida_mats'].float()
+ do_flip = cameras_info['do_flip']
+ src_pos_embed, src_mask = self.src_pos_embed_layer(extrinsic, intrinsic, ida_mats, do_flip)
+ mask = ~src_mask
+ else:
+ src_pos_embed = self.src_pos_embed_layer(mask)
+
+ src_pos_embed = src_pos_embed.flatten(2).permute(2, 0, 1) # (H* W, B, C)
+ mask = mask.flatten(1) # (B, H * W)
+
+ query_embed = self.query_embed.weight # (H* W, C)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (H* W, B, C)
+ tgt = query_embed # (H* W, B, C)
+
+ query_mask = torch.zeros((bs, *self.query_shape), dtype=torch.bool, device=src.device)
+ query_pos_embed = self.tgt_pos_embed_layer(query_mask) # (B, C, H, W)
+ query_pos_embed = query_pos_embed.flatten(2).permute(2, 0, 1) # (H* W, B, C)
+
+ memory = self.encoder.forward(src, None, mask, src_pos_embed) # (H* W, B, C)
+ hs = self.decoder.forward(tgt, memory, None, None, None, mask, src_pos_embed, query_pos_embed) # (M, H* W, B, C)
+ ys = hs.permute(0, 2, 3, 1) # (M, B, C, H* W)
+ ys = ys.reshape(*ys.shape[:-1], *self.query_shape) # (M, B, C, H, W)
+
+ return memory, hs, ys
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers, norm=None, use_checkpoint=True):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.use_checkpoint = use_checkpoint
+
+ def forward(
+ self,
+ src,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ output = src
+
+ for layer in self.layers:
+ if self.use_checkpoint and self.training:
+ output = checkpoint(layer, output, mask, src_key_padding_mask, pos)
+ else:
+ output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class TransformerDecoder(nn.Module):
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, use_checkpoint=True):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+ self.use_checkpoint = use_checkpoint
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ output = tgt
+
+ intermediate = []
+
+ for layer in self.layers:
+ if self.use_checkpoint and self.training:
+ output = checkpoint(
+ layer,
+ output,
+ memory,
+ tgt_mask,
+ memory_mask,
+ tgt_key_padding_mask,
+ memory_key_padding_mask,
+ pos,
+ query_pos,
+ )
+ else:
+ output = layer(
+ output, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
+ )
+
+ if self.return_intermediate:
+ intermediate.append(self.norm(output))
+
+ if self.norm is not None:
+ output = self.norm(output)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output.unsqueeze(0)
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward_pre(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ src2 = self.norm1(src)
+ q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src2 = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(src2)
+ return src
+
+ def forward(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ if self.normalize_before:
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+class TransformerDecoderLayer(nn.Module):
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward_pre(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ if self.normalize_before:
+ return self.forward_pre(
+ tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
+ )
+ return self.forward_post(
+ tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
+ )
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
diff --git a/mapmaster/models/ins_decoder/__init__.py b/mapmaster/models/ins_decoder/__init__.py
new file mode 100644
index 0000000..4a45803
--- /dev/null
+++ b/mapmaster/models/ins_decoder/__init__.py
@@ -0,0 +1 @@
+from .model import Mask2formerINSDecoder, PointMask2formerINSDecoder
diff --git a/mapmaster/models/ins_decoder/mask2former.py b/mapmaster/models/ins_decoder/mask2former.py
new file mode 100644
index 0000000..c220c67
--- /dev/null
+++ b/mapmaster/models/ins_decoder/mask2former.py
@@ -0,0 +1,315 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
+import torch
+import logging
+from typing import Optional
+from torch import nn, Tensor
+from torch.nn import functional as F
+from mapmaster.models.utils.misc import Conv2d, c2_xavier_fill, get_activation_fn
+from mapmaster.models.utils.position_encoding import PositionEmbeddingSine
+
+
+class SelfAttentionLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.norm = nn.LayerNorm(d_model)
+ self.dropout = nn.Dropout(dropout)
+ self.activation = get_activation_fn(activation)
+ self.normalize_before = normalize_before
+ self._reset_parameters()
+
+ def forward(self, tgt, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None):
+ if self.normalize_before:
+ return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos)
+ return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos)
+
+ def forward_pre(self, tgt, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None):
+ tgt2 = self.norm(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+ return tgt
+
+ def forward_post(self, tgt, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None):
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+ tgt = self.norm(tgt)
+ return tgt
+
+ @staticmethod
+ def with_pos_embed(tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+
+class CrossAttentionLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False):
+ super().__init__()
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.norm = nn.LayerNorm(d_model)
+ self.dropout = nn.Dropout(dropout)
+ self.activation = get_activation_fn(activation)
+ self.normalize_before = normalize_before
+ self._reset_parameters()
+
+ def forward(self, tgt, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None):
+ if self.normalize_before:
+ return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos)
+ return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos)
+
+ def forward_pre(self, tgt, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None):
+ tgt2 = self.norm(tgt)
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+ return tgt
+
+ def forward_post(self, tgt, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None):
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+ tgt = self.norm(tgt)
+ return tgt
+
+ @staticmethod
+ def with_pos_embed(tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+
+class FFNLayer(nn.Module):
+
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False):
+ super().__init__()
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+ self.norm = nn.LayerNorm(d_model)
+ self.activation = get_activation_fn(activation)
+ self.normalize_before = normalize_before
+ self._reset_parameters()
+
+ def forward(self, tgt):
+ if self.normalize_before:
+ return self.forward_pre(tgt)
+ return self.forward_post(tgt)
+
+ def forward_pre(self, tgt):
+ tgt2 = self.norm(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout(tgt2)
+ return tgt
+
+ def forward_post(self, tgt):
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout(tgt2)
+ tgt = self.norm(tgt)
+ return tgt
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+class MultiScaleMaskedTransformerDecoder(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ num_feature_levels,
+ mask_classification=True,
+ *,
+ num_classes: int,
+ hidden_dim: int,
+ num_queries: int,
+ nheads: int,
+ dim_feedforward: int,
+ dec_layers: int,
+ pre_norm: bool,
+ mask_dim: int,
+ enforce_input_project: bool,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ in_channels: channels of the input features
+ mask_classification: whether to add mask classifier or not
+ num_classes: number of classes
+ hidden_dim: Transformer feature dimension
+ num_queries: number of queries
+ nheads: number of heads
+ dim_feedforward: feature dimension in feedforward network
+ enc_layers: number of Transformer encoder layers
+ dec_layers: number of Transformer decoder layers
+ pre_norm: whether to use pre-LayerNorm or not
+ mask_dim: mask feature dimension
+ enforce_input_project: add input project 1x1 conv even if input
+ channels and hidden dim is identical
+ """
+ super().__init__()
+
+ assert mask_classification, "Only support mask classification model"
+ self.mask_classification = mask_classification
+
+ # positional encoding
+ self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
+
+ # define Transformer decoder here
+ self.num_heads = nheads
+ self.num_layers = dec_layers
+ self.transformer_self_attention_layers = nn.ModuleList()
+ self.transformer_cross_attention_layers = nn.ModuleList()
+ self.transformer_ffn_layers = nn.ModuleList()
+ # d_model, nhead, dropout = 0.0, activation = "relu", normalize_before = False
+ for _ in range(self.num_layers):
+ self.transformer_self_attention_layers.append(
+ SelfAttentionLayer(d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm)
+ )
+ self.transformer_cross_attention_layers.append(
+ CrossAttentionLayer(d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm)
+ )
+ self.transformer_ffn_layers.append(
+ FFNLayer(d_model=hidden_dim, dim_feedforward=dim_feedforward, dropout=0.0, normalize_before=pre_norm)
+ )
+
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
+
+ self.num_queries = num_queries
+ # learnable query features
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
+ # learnable query p.e.
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+
+ # level embedding (we always use 3 scales)
+ self.num_feature_levels = num_feature_levels
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
+ self.input_proj = nn.ModuleList()
+ for _ in range(self.num_feature_levels):
+ if in_channels != hidden_dim or enforce_input_project:
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
+ c2_xavier_fill(self.input_proj[-1])
+ else:
+ self.input_proj.append(nn.Sequential())
+
+ # output FFNs
+ if self.mask_classification:
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
+
+ def forward(self, x, mask_features, mask=None):
+ # x is a list of multi-scale feature
+ assert len(x) == self.num_feature_levels
+ src = []
+ pos = []
+ size_list = []
+
+ # disable mask, it does not affect performance
+ del mask
+
+ for i in range(self.num_feature_levels):
+ size_list.append(x[i].shape[-2:])
+ mask = torch.zeros((x[i].size(0), x[i].size(2), x[i].size(3)), device=x[i].device, dtype=torch.bool)
+ pos.append(self.pe_layer(mask).flatten(2))
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
+
+ # flatten NxCxHxW to HWxNxC
+ pos[-1] = pos[-1].permute(2, 0, 1)
+ src[-1] = src[-1].permute(2, 0, 1)
+
+ _, bs, _ = src[0].shape
+
+ # QxNxC
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
+
+ predictions_class = []
+ predictions_mask = []
+ decoder_outputs = []
+
+ # prediction heads on learnable query features
+ dec_out, outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(
+ output, mask_features, attn_mask_target_size=size_list[0]
+ )
+ predictions_class.append(outputs_class)
+ predictions_mask.append(outputs_mask)
+ decoder_outputs.append(dec_out)
+
+ for i in range(self.num_layers):
+ level_index = i % self.num_feature_levels
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
+ # attention: cross-attention first
+ output = self.transformer_cross_attention_layers[i](
+ output, src[level_index],
+ memory_mask=attn_mask,
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
+ pos=pos[level_index], query_pos=query_embed
+ )
+ output = self.transformer_self_attention_layers[i](
+ output, tgt_mask=None, tgt_key_padding_mask=None, query_pos=query_embed
+ )
+ output = self.transformer_ffn_layers[i](output)
+ dec_out, outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(
+ output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels]
+ )
+
+ predictions_class.append(outputs_class)
+ predictions_mask.append(outputs_mask)
+ decoder_outputs.append(dec_out)
+
+ assert len(predictions_class) == self.num_layers + 1
+
+ out = {
+ 'pred_logits': predictions_class,
+ 'pred_masks': predictions_mask,
+ 'decoder_outputs': decoder_outputs
+ }
+ return out
+
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size):
+ decoder_output = self.decoder_norm(output)
+ decoder_output = decoder_output.transpose(0, 1) # (b, q, c')
+ outputs_class = self.class_embed(decoder_output) # (b, q, c') -> (b, q, 2)
+ mask_embed = self.mask_embed(decoder_output) # (b, q, c') -> (b, q, c)
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
+
+ # NOTE: prediction is of higher-resolution
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
+ # must use bool type
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
+ attn_mask = attn_mask.detach()
+
+ return decoder_output, outputs_class, outputs_mask, attn_mask
diff --git a/mapmaster/models/ins_decoder/model.py b/mapmaster/models/ins_decoder/model.py
new file mode 100644
index 0000000..9ea5ad9
--- /dev/null
+++ b/mapmaster/models/ins_decoder/model.py
@@ -0,0 +1,39 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from mapmaster.models.ins_decoder.mask2former import MultiScaleMaskedTransformerDecoder
+from mapmaster.models.ins_decoder.pointmask2former import PointMask2TransformerDecoder
+
+
+class INSDecoderBase(nn.Module):
+ def __init__(self, decoder_ids=(5, ), tgt_shape=None):
+ super(INSDecoderBase, self).__init__()
+ self.decoder_ids = tuple(decoder_ids) # [0, 1, 2, 3, 4, 5]
+ self.tgt_shape = tgt_shape
+ self.bev_decoder = None
+
+ def forward(self, inputs):
+ assert "bev_enc_features" in inputs
+ bev_enc_features = inputs["bev_enc_features"]
+ if self.tgt_shape is not None:
+ bev_enc_features = [self.up_sample(x) for x in inputs["bev_enc_features"]]
+ out = self.bev_decoder(bev_enc_features[-1:], bev_enc_features[-1])
+ return {"mask_features": [out["pred_masks"][1:][i] for i in self.decoder_ids],
+ "obj_scores": [out["pred_logits"][1:][i] for i in self.decoder_ids],
+ "decoder_outputs": [out["decoder_outputs"][1:][i] for i in self.decoder_ids],
+ "bev_enc_features": bev_enc_features}
+
+ def up_sample(self, x, tgt_shape=None):
+ tgt_shape = self.tgt_shape if tgt_shape is None else tgt_shape
+ if tuple(x.shape[-2:]) == tuple(tgt_shape):
+ return x
+ return F.interpolate(x, size=tgt_shape, mode="bilinear", align_corners=True)
+
+class Mask2formerINSDecoder(INSDecoderBase):
+ def __init__(self, decoder_ids=(5, ), tgt_shape=None, **kwargs):
+ super(Mask2formerINSDecoder, self).__init__(decoder_ids, tgt_shape)
+ self.bev_decoder = MultiScaleMaskedTransformerDecoder(**kwargs)
+
+class PointMask2formerINSDecoder(INSDecoderBase):
+ def __init__(self, decoder_ids=(5, ), tgt_shape=None, **kwargs):
+ super(PointMask2formerINSDecoder, self).__init__(decoder_ids, tgt_shape)
+ self.bev_decoder = PointMask2TransformerDecoder(**kwargs)
diff --git a/mapmaster/models/ins_decoder/pointmask2former.py b/mapmaster/models/ins_decoder/pointmask2former.py
new file mode 100644
index 0000000..4e51026
--- /dev/null
+++ b/mapmaster/models/ins_decoder/pointmask2former.py
@@ -0,0 +1,346 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
+import torch
+from typing import Optional
+from torch import nn, Tensor
+from torch.nn import functional as F
+from mapmaster.models.utils.misc import Conv2d, c2_xavier_fill, get_activation_fn
+from mapmaster.models.utils.position_encoding import PositionEmbeddingSine
+
+
+class SelfAttentionLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.norm = nn.LayerNorm(d_model)
+ self.dropout = nn.Dropout(dropout)
+ self.activation = get_activation_fn(activation)
+ self.normalize_before = normalize_before
+ self._reset_parameters()
+
+ def forward(self, tgt, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None):
+ if self.normalize_before:
+ return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos)
+ return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos)
+
+ def forward_pre(self, tgt, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None):
+ tgt2 = self.norm(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+ return tgt
+
+ def forward_post(self, tgt, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None):
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+ tgt = self.norm(tgt)
+ return tgt
+
+ @staticmethod
+ def with_pos_embed(tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+
+class CrossAttentionLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False):
+ super().__init__()
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.norm = nn.LayerNorm(d_model)
+ self.dropout = nn.Dropout(dropout)
+ self.activation = get_activation_fn(activation)
+ self.normalize_before = normalize_before
+ self._reset_parameters()
+
+ def forward(self, tgt, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None):
+ if self.normalize_before:
+ return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos)
+ return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos)
+
+ def forward_pre(self, tgt, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None):
+ tgt2 = self.norm(tgt)
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+ return tgt
+
+ def forward_post(self, tgt, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None):
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+ tgt = self.norm(tgt)
+ return tgt
+
+ @staticmethod
+ def with_pos_embed(tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+
+class FFNLayer(nn.Module):
+
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False):
+ super().__init__()
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+ self.norm = nn.LayerNorm(d_model)
+ self.activation = get_activation_fn(activation)
+ self.normalize_before = normalize_before
+ self._reset_parameters()
+
+ def forward(self, tgt):
+ if self.normalize_before:
+ return self.forward_pre(tgt)
+ return self.forward_post(tgt)
+
+ def forward_pre(self, tgt):
+ tgt2 = self.norm(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout(tgt2)
+ return tgt
+
+ def forward_post(self, tgt):
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout(tgt2)
+ tgt = self.norm(tgt)
+ return tgt
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+class PointMask2TransformerDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ num_feature_levels,
+ mask_classification=True,
+ *,
+ num_classes: int,
+ hidden_dim: int,
+ nheads: int,
+ dim_feedforward: int,
+ dec_layers: int,
+ pre_norm: bool,
+ mask_dim: int,
+ enforce_input_project: bool,
+ query_split=(20, 25, 15),
+ max_pieces=(10, 2, 30),
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ in_channels: channels of the input features
+ mask_classification: whether to add mask classifier or not
+ num_classes: number of classes
+ hidden_dim: Transformer feature dimension
+ num_queries: number of queries
+ nheads: number of heads
+ dim_feedforward: feature dimension in feedforward network
+ enc_layers: number of Transformer encoder layers
+ dec_layers: number of Transformer decoder layers
+ pre_norm: whether to use pre-LayerNorm or not
+ mask_dim: mask feature dimension
+ enforce_input_project: add input project 1x1 conv even if input
+ channels and hidden dim is identical
+ """
+ super().__init__()
+
+ assert mask_classification, "Only support mask classification model"
+ self.mask_classification = mask_classification
+
+ # positional encoding
+ self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
+
+ # define Transformer decoder here
+ self.num_heads = nheads
+ self.num_layers = dec_layers
+ self.transformer_self_attention_layers = nn.ModuleList()
+ self.transformer_cross_attention_layers = nn.ModuleList()
+ self.transformer_ffn_layers = nn.ModuleList()
+ # d_model, nhead, dropout = 0.0, activation = "relu", normalize_before = False
+ for _ in range(self.num_layers):
+ self.transformer_self_attention_layers.append(
+ SelfAttentionLayer(d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm)
+ )
+ self.transformer_cross_attention_layers.append(
+ CrossAttentionLayer(d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm)
+ )
+ self.transformer_ffn_layers.append(
+ FFNLayer(d_model=hidden_dim, dim_feedforward=dim_feedforward, dropout=0.0, normalize_before=pre_norm)
+ )
+
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
+
+ self.max_pieces = max_pieces
+ self.query_split = query_split
+ self.num_queries = sum([self.max_pieces[i] * self.query_split[i] for i in range(len(self.max_pieces))])
+
+ # learnable pt features
+ self.query_feat = nn.Embedding(self.num_queries, hidden_dim).weight # [700, C]
+ # learnable pt p.e.
+ self.query_embed = nn.Embedding(self.num_queries, hidden_dim).weight # [20 * 10 + 25 * 2 + 15 * 30, C] which is [700, C], note the memory
+
+ # level embedding (we always use 3 scales)
+ self.num_feature_levels = num_feature_levels
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
+ self.input_proj = nn.ModuleList()
+ for _ in range(self.num_feature_levels):
+ if in_channels != hidden_dim or enforce_input_project:
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
+ c2_xavier_fill(self.input_proj[-1])
+ else:
+ self.input_proj.append(nn.Sequential())
+
+ # output FFNs
+ if self.mask_classification:
+ self.class_embed = nn.ModuleList(
+ nn.Linear(hidden_dim*self.max_pieces[i], num_classes + 1) for i in range(len(query_split))
+ )
+ self.mask_embed = nn.ModuleList(
+ MLP(hidden_dim*self.max_pieces[i], hidden_dim, mask_dim, 3) for i in range(len(query_split))
+ )
+
+ # split info
+ self.cls_split = [self.query_split[i] * self.max_pieces[i] for i in range(len(self.query_split))]
+ self.ins_split = torch.cat([torch.ones(self.query_split[i], dtype=torch.long)*self.max_pieces[i] for i in range(len(self.query_split))])
+
+ def forward(self, x, mask_features, mask=None):
+ # x is a list of multi-scale feature
+ assert len(x) == self.num_feature_levels
+ src = []
+ pos = []
+ size_list = []
+
+ # disable mask, it does not affect performance
+ del mask
+
+ for i in range(self.num_feature_levels):
+ size_list.append(x[i].shape[-2:])
+ mask = torch.zeros((x[i].size(0), x[i].size(2), x[i].size(3)), device=x[i].device, dtype=torch.bool)
+ pos.append(self.pe_layer(mask).flatten(2))
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
+
+ # flatten NxCxHxW to HWxNxC
+ pos[-1] = pos[-1].permute(2, 0, 1)
+ src[-1] = src[-1].permute(2, 0, 1)
+
+ _, bs, _ = src[0].shape
+
+ query_embed = self.query_embed.unsqueeze(1).repeat(1, bs, 1) # [n_q, bs, C]
+ output = self.query_feat.unsqueeze(1).repeat(1, bs, 1) # [n_q, bs, C]
+
+ predictions_class = []
+ predictions_mask = []
+ decoder_outputs = []
+
+ # prediction heads on learnable query features
+ dec_out, outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(
+ output, mask_features, attn_mask_target_size=size_list[0]
+ )
+ predictions_class.append(outputs_class)
+ predictions_mask.append(outputs_mask)
+ decoder_outputs.append(dec_out)
+
+ for i in range(self.num_layers):
+ level_index = i % self.num_feature_levels
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
+ # attention: cross-attention first
+ output = self.transformer_cross_attention_layers[i](
+ output, src[level_index],
+ memory_mask=attn_mask,
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
+ pos=pos[level_index], query_pos=query_embed
+ )
+ output = self.transformer_self_attention_layers[i](
+ output, tgt_mask=None, tgt_key_padding_mask=None, query_pos=query_embed
+ )
+ output = self.transformer_ffn_layers[i](output)
+ dec_out, outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(
+ output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels]
+ )
+
+ predictions_class.append(outputs_class)
+ predictions_mask.append(outputs_mask)
+ decoder_outputs.append(dec_out)
+
+ assert len(predictions_class) == self.num_layers + 1
+
+ out = {
+ 'pred_logits': predictions_class,
+ 'pred_masks': predictions_mask,
+ 'decoder_outputs': decoder_outputs
+ }
+ return out
+
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size):
+ decoder_output = self.decoder_norm(output)
+ decoder_output = decoder_output.transpose(0, 1) # (b, q, c') (b, 700, c')
+ decoder_output_split = decoder_output.split(self.cls_split, dim=1)
+
+ outputs_class, mask_embed = None, None
+ for i in range(len(self.query_split)):
+ x_cls_pt = decoder_output_split[i] # (bs, n_q*n_pt, c')
+ n_q, n_pt = self.query_split[i], self.max_pieces[i]
+ x_cls_pt = x_cls_pt.reshape(x_cls_pt.shape[0], n_q, -1) # (bs, n_q, n_pt * c')
+ x_cls = self.class_embed[i](x_cls_pt) # (bs, n_q, 2)
+ x_cls_ins = self.mask_embed[i](x_cls_pt) # (bs, n_q, c)
+ if outputs_class is None:
+ outputs_class = x_cls
+ mask_embed = x_cls_ins
+ else:
+ outputs_class = torch.cat([outputs_class, x_cls], dim=1)
+ mask_embed = torch.cat([mask_embed, x_cls_ins], dim=1)
+
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) # for loss cal
+
+ outputs_mask_split = outputs_mask.split(self.query_split, dim=1)
+
+ attn_mask = []
+ for i in range(len(self.query_split)):
+ # [bs, n_q, h, w] -> [bs, n_q, 1, h, w] -> [bs, n_q, n_pt, h, w] -> [bs, n_q*n_pt, h, w]
+ attn_mask.append(outputs_mask_split[i].unsqueeze(2).repeat(1, 1, self.max_pieces[i], 1, 1).flatten(1, 2))
+ attn_mask = torch.cat(attn_mask, dim=1)
+
+ # NOTE: prediction is of higher-resolution
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
+ attn_mask = F.interpolate(attn_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
+ # must use bool type
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
+ attn_mask = attn_mask.detach()
+
+ return decoder_output, outputs_class, outputs_mask, attn_mask
diff --git a/mapmaster/models/network.py b/mapmaster/models/network.py
new file mode 100644
index 0000000..85d6caa
--- /dev/null
+++ b/mapmaster/models/network.py
@@ -0,0 +1,68 @@
+import torch.nn as nn
+from mapmaster.models import backbone, bev_decoder, ins_decoder, output_head
+# os.environ['TORCH_DISTRIBUTED_DEBUG'] = "INFO"
+# warnings.filterwarnings('ignore')
+
+
+class MapMaster(nn.Module):
+ def __init__(self, model_config, *args, **kwargs):
+ super(MapMaster, self).__init__()
+ self.im_backbone = self.create_backbone(**model_config["im_backbone"])
+ self.bev_decoder = self.create_bev_decoder(**model_config["bev_decoder"])
+ self.ins_decoder = self.create_ins_decoder(**model_config["ins_decoder"])
+ self.output_head = self.create_output_head(**model_config["output_head"])
+ self.post_processor = self.create_post_processor(**model_config["post_processor"])
+
+ def forward(self, inputs):
+ outputs = {}
+ outputs.update({k: inputs[k] for k in ["images", "extra_infos"]})
+ outputs.update({k: inputs[k].float() for k in ["extrinsic", "intrinsic"]})
+ if "ida_mats" in inputs:
+ outputs.update({"ida_mats": inputs["ida_mats"].float()})
+ outputs.update(self.im_backbone(outputs))
+ outputs.update(self.bev_decoder(outputs))
+ outputs.update(self.ins_decoder(outputs))
+ outputs.update(self.output_head(outputs))
+ return outputs
+
+ @staticmethod
+ def create_backbone(arch_name, ret_layers, bkb_kwargs, fpn_kwargs, up_shape=None):
+ __factory_dict__ = {
+ "resnet": backbone.ResNetBackbone,
+ "efficient_net": backbone.EfficientNetBackbone,
+ "swin_transformer": backbone.SwinTRBackbone,
+ }
+ return __factory_dict__[arch_name](bkb_kwargs, fpn_kwargs, up_shape, ret_layers)
+
+ @staticmethod
+ def create_bev_decoder(arch_name, net_kwargs):
+ __factory_dict__ = {
+ "transformer": bev_decoder.TransformerBEVDecoder,
+ "ipm_deformable_transformer": bev_decoder.DeformTransformerBEVEncoder,
+ }
+ return __factory_dict__[arch_name](**net_kwargs)
+
+ @staticmethod
+ def create_ins_decoder(arch_name, net_kwargs):
+ __factory_dict__ = {
+ "mask2former": ins_decoder.Mask2formerINSDecoder,
+ "line_aware_decoder": ins_decoder.PointMask2formerINSDecoder,
+ }
+
+ return __factory_dict__[arch_name](**net_kwargs)
+
+ @staticmethod
+ def create_output_head(arch_name, net_kwargs):
+ __factory_dict__ = {
+ "bezier_output_head": output_head.PiecewiseBezierMapOutputHead,
+ "pivot_point_predictor": output_head.PivotMapOutputHead,
+ }
+ return __factory_dict__[arch_name](**net_kwargs)
+
+ @staticmethod
+ def create_post_processor(arch_name, net_kwargs):
+ __factory_dict__ = {
+ "bezier_post_processor": output_head.PiecewiseBezierMapPostProcessor,
+ "pivot_post_processor": output_head.PivotMapPostProcessor,
+ }
+ return __factory_dict__[arch_name](**net_kwargs)
diff --git a/mapmaster/models/output_head/__init__.py b/mapmaster/models/output_head/__init__.py
new file mode 100644
index 0000000..08f39e3
--- /dev/null
+++ b/mapmaster/models/output_head/__init__.py
@@ -0,0 +1,4 @@
+from .bezier_outputs import PiecewiseBezierMapOutputHead
+from .bezier_post_processor import PiecewiseBezierMapPostProcessor
+from .pivot_outputs import PivotMapOutputHead
+from .pivot_post_processor import PivotMapPostProcessor
\ No newline at end of file
diff --git a/mapmaster/models/output_head/bezier_outputs.py b/mapmaster/models/output_head/bezier_outputs.py
new file mode 100644
index 0000000..9fecd39
--- /dev/null
+++ b/mapmaster/models/output_head/bezier_outputs.py
@@ -0,0 +1,127 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class FFN(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, basic_type='linear'):
+ super().__init__()
+ self.basic_type = basic_type
+ if output_dim == 0:
+ self.basic_type = "identity"
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(self.basic_layer(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+ def basic_layer(self, n, k):
+ if self.basic_type == 'linear':
+ return nn.Linear(n, k)
+ elif self.basic_type == 'conv':
+ return nn.Conv2d(n, k, kernel_size=1, stride=1)
+ elif self.basic_type == 'identity':
+ return nn.Identity()
+ else:
+ raise NotImplementedError
+
+
+class PiecewiseBezierMapOutputHead(nn.Module):
+ def __init__(self, in_channel, num_queries, tgt_shape, num_degree, max_pieces, bev_channels=-1, ins_channel=64):
+ super(PiecewiseBezierMapOutputHead, self).__init__()
+ self.num_queries = num_queries
+ self.num_classes = len(num_queries)
+ self.tgt_shape = tgt_shape
+ self.bev_channels = bev_channels
+ self.semantic_heads = None
+ if self.bev_channels > 0:
+ self.semantic_heads = nn.ModuleList(
+ nn.Sequential(nn.Conv2d(bev_channels, 2, kernel_size=1, stride=1)) for _ in range(self.num_classes)
+ )
+ self.num_degree = num_degree
+ self.max_pieces = max_pieces
+ self.num_ctr_im = [(n + 1) for n in self.max_pieces]
+ self.num_ctr_ex = [n * (d - 1) for n, d in zip(self.max_pieces, self.num_degree)]
+ _N = self.num_classes
+
+ _C = ins_channel
+ self.im_ctr_heads = nn.ModuleList(FFN(in_channel, 256, (self.num_ctr_im[i] * 2) * _C, 3) for i in range(_N))
+ self.ex_ctr_heads = nn.ModuleList(FFN(in_channel, 256, (self.num_ctr_ex[i] * 2) * _C, 3) for i in range(_N))
+ self.npiece_heads = nn.ModuleList(FFN(in_channel, 256, self.max_pieces[i], 3) for i in range(_N))
+ self.gap_layer = nn.AdaptiveAvgPool2d((1, 1))
+ self.coords = self.compute_locations(device='cuda')
+ self.coords_head = FFN(2, 256, _C, 3, 'conv')
+
+ def forward(self, inputs):
+ num_decoders = len(inputs["mask_features"])
+ dt_obj_logit = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)]
+ dt_ins_masks = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)]
+ im_ctr_coord = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)]
+ ex_ctr_coord = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)]
+ dt_end_logit = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)]
+ coords_feats = self.coords_head.forward(self.coords.repeat((inputs["mask_features"][0].shape[0], 1, 1, 1)))
+ for i in range(num_decoders):
+ x_ins_cw = inputs["mask_features"][i].split(self.num_queries, dim=1)
+ x_obj_cw = inputs["obj_scores"][i].split(self.num_queries, dim=1)
+ x_qry_cw = inputs["decoder_outputs"][i].split(self.num_queries, dim=1)
+ batch_size = x_qry_cw[0].shape[0]
+ for j in range(self.num_classes):
+ num_qry = self.num_queries[j]
+ # if self.training:
+ dt_ins_masks[i][j] = self.up_sample(x_ins_cw[j])
+ dt_obj_logit[i][j] = x_obj_cw[j]
+ dt_end_logit[i][j] = self.npiece_heads[j](x_qry_cw[j])
+ # im
+ im_feats = self.im_ctr_heads[j](x_qry_cw[j])
+ im_feats = im_feats.reshape(batch_size, num_qry, self.num_ctr_im[j] * 2, -1).flatten(1, 2)
+ im_coords_map = torch.einsum("bqc,bchw->bqhw", im_feats, coords_feats)
+ im_coords = self.gap_layer(im_coords_map)
+ im_ctr_coord[i][j] = im_coords.reshape(batch_size, num_qry, self.max_pieces[j] + 1, 2)
+ # ex
+ if self.num_ctr_ex[j] == 0:
+ ex_ctr_coord[i][j] = torch.zeros(batch_size, num_qry, self.max_pieces[j], 0, 2).cuda()
+ else:
+ ex_feats = self.ex_ctr_heads[j](x_qry_cw[j])
+ ex_feats = ex_feats.reshape(batch_size, num_qry, self.num_ctr_ex[j] * 2, -1).flatten(1, 2)
+ ex_coords_map = torch.einsum("bqc,bchw->bqhw", ex_feats, coords_feats)
+ ex_coords = self.gap_layer(ex_coords_map)
+ ex_ctr_coord[i][j] = ex_coords.reshape(batch_size, num_qry, self.max_pieces[j], self.num_degree[j] - 1, 2)
+ ret = {"outputs": {"obj_logits": dt_obj_logit, "ins_masks": dt_ins_masks,
+ "ctr_im": im_ctr_coord, "ctr_ex": ex_ctr_coord, "end_logits": dt_end_logit}}
+ if self.semantic_heads is not None:
+ num_decoders = len(inputs["bev_enc_features"])
+ dt_sem_masks = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)]
+ for i in range(num_decoders):
+ x_sem = inputs["bev_enc_features"][i]
+ for j in range(self.num_classes):
+ dt_sem_masks[i][j] = self.up_sample(self.semantic_heads[j](x_sem))
+ ret["outputs"].update({"sem_masks": dt_sem_masks})
+ return ret
+
+ def up_sample(self, x, tgt_shape=None):
+ tgt_shape = self.tgt_shape if tgt_shape is None else tgt_shape
+ if tuple(x.shape[-2:]) == tuple(tgt_shape):
+ return x
+ return F.interpolate(x, size=tgt_shape, mode="bilinear", align_corners=True)
+
+ def compute_locations(self, stride=1, device='cpu'):
+
+ fh, fw = self.tgt_shape
+
+ shifts_x = torch.arange(0, fw * stride, step=stride, dtype=torch.float32, device=device)
+ shifts_y = torch.arange(0, fh * stride, step=stride, dtype=torch.float32, device=device)
+ shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
+ shift_x = shift_x.reshape(-1)
+ shift_y = shift_y.reshape(-1)
+ locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2
+
+ locations = locations.unsqueeze(0).permute(0, 2, 1).contiguous().float().view(1, 2, fh, fw)
+ locations[:, 0, :, :] /= fw
+ locations[:, 1, :, :] /= fh
+
+ return locations
diff --git a/mapmaster/models/output_head/bezier_post_processor.py b/mapmaster/models/output_head/bezier_post_processor.py
new file mode 100644
index 0000000..c6716ac
--- /dev/null
+++ b/mapmaster/models/output_head/bezier_post_processor.py
@@ -0,0 +1,393 @@
+import cv2
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+from scipy.special import comb as n_over_k
+from scipy.optimize import linear_sum_assignment
+from mapmaster.models.utils.mask_loss import SegmentationLoss
+from mapmaster.models.utils.recovery_loss import PointRecoveryLoss
+from mapmaster.utils.misc import nested_tensor_from_tensor_list
+from mapmaster.utils.misc import get_world_size, is_available, is_distributed
+
+
+class HungarianMatcher(nn.Module):
+
+ def __init__(self, cost_obj=1., cost_ctr=1., cost_end=1., cost_mask=1., cost_curve=1., cost_recovery=1.,
+ ins_mask_loss_conf=None, point_loss_conf=None, class_weight=None):
+ super().__init__()
+ self.cost_obj, self.cost_ctr, self.cost_end = cost_obj, cost_ctr, cost_end
+ self.cost_mask, self.cost_curve, self.cost_recovery = cost_mask, cost_curve, cost_recovery
+ self.ins_mask_loss = SegmentationLoss(**ins_mask_loss_conf)
+ self.recovery_loss = PointRecoveryLoss(**point_loss_conf)
+ self.class_weight = class_weight
+
+ @torch.no_grad()
+ def forward(self, outputs, targets):
+ num_decoders, num_classes = len(outputs["ins_masks"]), len(outputs["ins_masks"][0])
+ matching_indices = [[[] for _ in range(num_classes)] for _ in range(num_decoders)]
+ for dec_id in range(num_decoders):
+ for cid in range(num_classes):
+ if self.class_weight is not None and self.class_weight[cid] == 0:
+ continue
+
+ bs, num_queries = outputs["obj_logits"][dec_id][cid].shape[:2]
+
+ dt_probs = outputs["obj_logits"][dec_id][cid].flatten(0, 1).softmax(-1) # [n_dt, 2], n_dt in a batch
+ gt_idxes = torch.cat([tgt["obj_labels"][cid] for tgt in targets]) # [n_gt, ]
+ cost_mat_obj = -dt_probs[:, gt_idxes] # [n_dt, n_gt]
+
+ dt_curves = outputs["curve_points"][dec_id][cid].flatten(0, 1) # [n_dt, n, 2]
+ dt_masks = outputs["ins_masks"][dec_id][cid].flatten(0, 1) # [n_dt, h, w]
+ gt_masks = torch.cat([tgt["ins_masks"][cid] for tgt in targets]) # [n_gt, h, w]
+ cost_mat_mask, cost_mat_rec = 0, 0
+ if gt_masks.shape[0] > 0:
+ dt_num, gt_num = dt_masks.shape[0], gt_masks.shape[0]
+ dt_masks = dt_masks.unsqueeze(1).expand(dt_num, gt_num, *dt_masks.shape[1:]).flatten(0, 1)
+ gt_masks = gt_masks.unsqueeze(0).expand(dt_num, gt_num, *gt_masks.shape[1:]).flatten(0, 1)
+ cost_mat_mask = self.ins_mask_loss(dt_masks, gt_masks, "matcher").reshape(dt_num, gt_num)
+ dt_curves = dt_curves.unsqueeze(1).expand(dt_num, gt_num, *dt_curves.shape[1:]).flatten(0, 1)
+ cost_mat_rec = self.recovery_loss(dt_curves, gt_masks).reshape(dt_num, gt_num)
+
+ dt_ctrs = outputs["ctr_points"][dec_id][cid].flatten(0, 1).flatten(1) # [n_dt, n, 2]
+ gt_ctrs = torch.cat([tgt["ctr_points"][cid] for tgt in targets]).flatten(1) # [n_gt, h, w]
+ cost_mat_ctr = torch.cdist(dt_ctrs, gt_ctrs, p=1) / gt_ctrs.shape[1]
+
+ dt_end_probs = outputs["end_logits"][dec_id][cid].flatten(0, 1).softmax(-1)
+ gt_end_idxes = torch.cat([tgt["end_labels"][cid] for tgt in targets])
+ cost_mat_end = -dt_end_probs[:, gt_end_idxes] # [n_dt, n_gt]
+
+ dt_curves = outputs["curve_points"][dec_id][cid].flatten(0, 1).flatten(1) # [n_dt, n, 2]
+ gt_curves = torch.cat([tgt["curve_points"][cid] for tgt in targets]).flatten(1) # [n_gt, n, 2]
+ cost_mat_curve = torch.cdist(dt_curves, gt_curves, p=1) / gt_curves.shape[1]
+
+ sizes = [len(tgt["obj_labels"][cid]) for tgt in targets]
+ C = self.cost_obj * cost_mat_obj + self.cost_mask * cost_mat_mask + \
+ self.cost_ctr * cost_mat_ctr + self.cost_end * cost_mat_end + self.cost_curve * cost_mat_curve +\
+ self.cost_recovery * cost_mat_rec
+ C = C.view(bs, num_queries, -1).cpu()
+ indices = [linear_sum_assignment(c[i].detach().numpy()) for i, c in enumerate(C.split(sizes, -1))]
+ matching_indices[dec_id][cid] = [self.to_tensor(i, j) for i, j in indices]
+
+ return matching_indices
+
+ @staticmethod
+ def to_tensor(i, j):
+ return torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)
+
+
+class SetCriterion(nn.Module):
+
+ def __init__(self, criterion_conf, matcher, num_degree, no_object_coe=1.0):
+ super().__init__()
+ self.num_degree = num_degree
+ self.matcher = matcher
+ self.criterion_conf = criterion_conf
+ self.loss_weight_dict = self.criterion_conf['loss_weight']
+ self.sem_mask_loss = SegmentationLoss(**criterion_conf['bev_decoder']['sem_mask_loss'])
+ self.register_buffer("empty_weight", torch.tensor([1.0, no_object_coe]))
+
+ def forward(self, outputs, targets):
+ losses = {}
+ matching_indices = self.matcher(outputs, targets)
+ losses.update(self.criterion_instance(outputs, targets, matching_indices))
+ losses.update(self.criterion_instance_labels(outputs, targets, matching_indices))
+ losses.update(self.criterion_semantic_masks(outputs, targets))
+ losses = {key: self.criterion_conf['loss_weight'][key] * losses[key] for key in losses}
+ return sum(losses.values()), losses
+
+ def criterion_instance(self, outputs, targets, matching_indices):
+ loss_masks, loss_ctr, loss_end, loss_curve, loss_rec = 0, 0, 0, 0, 0
+ device = outputs["ins_masks"][0][0].device
+ num_decoders, num_classes = len(matching_indices), len(matching_indices[0])
+ for i in range(num_decoders):
+ w = self.criterion_conf['ins_decoder']['weight'][i]
+ for j in range(num_classes):
+ w2 = self.criterion_conf["class_weights"][j] if "class_weights" in self.criterion_conf else 1.0
+ num_instances = sum(len(t["obj_labels"][j]) for t in targets)
+ num_instances = torch.as_tensor([num_instances], dtype=torch.float, device=device)
+ if is_distributed() and is_available():
+ torch.distributed.all_reduce(num_instances)
+ num_instances = torch.clamp(num_instances / get_world_size(), min=1).item()
+
+ indices = matching_indices[i][j]
+ src_idx = self._get_src_permutation_idx(indices) # dt
+ tgt_idx = self._get_tgt_permutation_idx(indices) # gt
+
+ # instance masks
+ src_masks = outputs["ins_masks"][i][j][src_idx]
+ tgt_masks = [t["ins_masks"][j] for t in targets]
+ tgt_masks, _ = nested_tensor_from_tensor_list(tgt_masks).decompose()
+ tgt_masks = tgt_masks.to(src_masks)[tgt_idx]
+ loss_masks += w * self.matcher.ins_mask_loss(src_masks, tgt_masks, "loss").sum() / num_instances * w2
+
+ # eof indices classification
+ src_logits = outputs["end_logits"][i][j][src_idx] # [M, K]
+ tgt_labels = torch.cat([t["end_labels"][j][J] for t, (_, J) in zip(targets, indices)]) # (M, )
+ loss_end += w * F.cross_entropy(src_logits, tgt_labels, ignore_index=-1, reduction='sum') / num_instances * w2
+
+ # control points
+ src_ctrs = outputs["ctr_points"][i][j][src_idx] # [bs, num_queries, o, 2]
+ end_labels = torch.max(src_logits.softmax(dim=-1), dim=-1)[1] # [m, k] 0, 1, 2, 3...
+ end_labels_new = (end_labels + 1) * self.num_degree[j] + 1
+ valid_mask = torch.zeros(src_ctrs.shape[:2], device=device).long()
+ for a, b in enumerate(end_labels_new):
+ valid_mask[a][:b] = 1
+ src_ctrs_masked = (src_ctrs * valid_mask.unsqueeze(-1))
+ tgt_ctrs = torch.zeros((len(tgt_idx[0]), *src_ctrs.shape[-2:]), device=device).float()
+ valid_mask = torch.zeros((len(tgt_idx[0]), src_ctrs.shape[-2]), device=device).float()
+ for idx in range(len(tgt_idx[0])):
+ batch_id, gt_id = tgt_idx[0][idx], tgt_idx[1][idx]
+ tgt_ctrs[idx] = targets[batch_id]['ctr_points'][j][gt_id]
+ valid_mask[idx] = targets[batch_id]['valid_masks'][j][gt_id]
+ tgt_ctrs_masked = (tgt_ctrs * valid_mask.unsqueeze(-1))
+ num_pt = src_ctrs.shape[-2] * src_ctrs.shape[-1]
+ loss_ctr += w * F.l1_loss(src_ctrs_masked, tgt_ctrs_masked, reduction='sum') / num_instances / num_pt * w2
+
+ # curve loss
+ src_curves = outputs["curve_points"][i][j][src_idx]
+ tgt_curves = torch.zeros((len(tgt_idx[0]), *src_curves.shape[-2:]), device=device).float()
+ for idx in range(len(tgt_idx[0])):
+ batch_id, gt_id = tgt_idx[0][idx], tgt_idx[1][idx]
+ tgt_curves[idx] = targets[batch_id]['curve_points'][j][gt_id]
+ num_pt = src_curves.shape[-2] * src_curves.shape[-1]
+ loss_curve += w * F.l1_loss(src_curves, tgt_curves, reduction='sum') / num_instances / num_pt * w2
+
+ # recovery loss
+ loss_rec += w * self.matcher.recovery_loss(src_curves, tgt_masks).sum() / num_instances * w2
+
+ loss_masks /= (num_decoders * num_classes)
+ loss_ctr /= (num_decoders * num_classes)
+ loss_curve /= (num_decoders * num_classes)
+ loss_end /= (num_decoders * num_classes)
+ loss_rec /= (num_decoders * num_classes)
+
+ return {"ctr_loss": loss_ctr, "end_loss": loss_end, "msk_loss": loss_masks, "curve_loss": loss_curve,
+ "recovery_loss": loss_rec}
+
+ def criterion_instance_labels(self, outputs, targets, matching_indices):
+ loss_labels = 0
+ num_decoders, num_classes = len(matching_indices), len(matching_indices[0])
+ for i in range(num_decoders):
+ w = self.criterion_conf['ins_decoder']['weight'][i]
+ for j in range(num_classes):
+ w2 = self.criterion_conf["class_weights"][j] if "class_weights" in self.criterion_conf else 1.0
+ indices = matching_indices[i][j]
+ idx = self._get_src_permutation_idx(indices) # (batch_id, query_id)
+ logits = outputs["obj_logits"][i][j]
+ target_classes_o = torch.cat([t["obj_labels"][j][J] for t, (_, J) in zip(targets, indices)])
+ target_classes = torch.full(logits.shape[:2], 1, dtype=torch.int64, device=logits.device)
+ target_classes[idx] = target_classes_o
+ loss_labels += (w * F.cross_entropy(logits.transpose(1, 2), target_classes, self.empty_weight)) * w2
+ loss_labels /= (num_decoders * num_classes)
+ return {"obj_loss": loss_labels}
+
+ def criterion_semantic_masks(self, outputs, targets):
+ loss_masks = 0
+ num_decoders, num_classes = len(outputs["sem_masks"]), len(outputs["sem_masks"][0])
+ for i in range(num_decoders):
+ w = self.criterion_conf['bev_decoder']['weight'][i]
+ for j in range(num_classes):
+ w2 = self.criterion_conf["class_weights"][j] if "class_weights" in self.criterion_conf else 1.0
+ dt_masks = outputs["sem_masks"][i][j] # (B, 2, H, W)
+ gt_masks = torch.stack([t["sem_masks"][j] for t in targets], dim=0) # (B, H, W)
+ loss_masks += w * self.sem_mask_loss(dt_masks[:, 1, :, :], gt_masks, "loss").mean() * w2
+ loss_masks /= num_decoders
+ return {"sem_loss": loss_masks}
+
+ @staticmethod
+ def _get_src_permutation_idx(indices):
+ # permute predictions following indices
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+ src_idx = torch.cat([src for (src, _) in indices])
+ return batch_idx, src_idx
+
+ @staticmethod
+ def _get_tgt_permutation_idx(indices):
+ # permute targets following indices
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+ return batch_idx, tgt_idx
+
+
+class PiecewiseBezierMapPostProcessor(nn.Module):
+ def __init__(self, criterion_conf, matcher_conf, bezier_conf, map_conf, no_object_coe=1.0):
+ super(PiecewiseBezierMapPostProcessor, self).__init__()
+ # setting
+ self.num_classes = map_conf['num_classes']
+ self.ego_size = map_conf['ego_size']
+ self.map_size = map_conf['map_size']
+ self.line_width = map_conf['line_width']
+ self.num_degree = bezier_conf['num_degree']
+ self.num_pieces = bezier_conf['max_pieces']
+ self.num_points = bezier_conf['num_points']
+ self.curve_size = bezier_conf['piece_length']
+ self.class_indices = torch.tensor(list(range(self.num_classes)), dtype=torch.int).cuda()
+ self.bezier_coefficient_np = self._get_bezier_coefficients()
+ self.bezier_coefficient = [torch.from_numpy(x).float().cuda() for x in self.bezier_coefficient_np]
+ self.matcher = HungarianMatcher(**matcher_conf)
+ self.criterion = SetCriterion(criterion_conf, self.matcher, self.num_degree, no_object_coe)
+ self.save_thickness = map_conf['save_thickness'] if 'save_thickness' in map_conf else 1
+
+ def forward(self, outputs, targets=None):
+ outputs.update(self.bezier_curve_outputs(outputs))
+ if self.training:
+ targets = self.refactor_targets(targets)
+ return self.criterion.forward(outputs, targets)
+ else:
+ return self.post_processing(outputs)
+
+ def bezier_curve_outputs(self, outputs):
+ dt_ctr_im, dt_ctr_ex, dt_ends = outputs["ctr_im"], outputs["ctr_ex"], outputs["end_logits"]
+ num_decoders, num_classes = len(dt_ends), len(dt_ends[0])
+ ctr_points = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)]
+ curve_points = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)]
+ for i in range(num_decoders):
+ for j in range(num_classes):
+ batch_size, num_queries = dt_ctr_im[i][j].shape[:2]
+
+ im_coords = dt_ctr_im[i][j].sigmoid()
+ ex_offsets = dt_ctr_ex[i][j].sigmoid() - 0.5
+ im_center_coords = ((im_coords[:, :, :-1] + im_coords[:, :, 1:]) / 2).unsqueeze(-2)
+ ex_coords = torch.stack([im_center_coords[:, :, :, :, 0] + ex_offsets[:, :, :, :, 0],
+ im_center_coords[:, :, :, :, 1] + ex_offsets[:, :, :, :, 1]], dim=-1)
+ im_coords = im_coords.unsqueeze(-2)
+ ctr_coords = torch.cat([im_coords[:, :, :-1], ex_coords], dim=-2).flatten(2, 3)
+ ctr_coords = torch.cat([ctr_coords, im_coords[:, :, -1:, 0, :]], dim=-2)
+ ctr_points[i][j] = ctr_coords.clone()
+
+ end_inds = torch.max(torch.softmax(dt_ends[i][j].flatten(0, 1), dim=-1), dim=-1)[1]
+ curve_pts = self.curve_recovery_with_bezier(ctr_coords.flatten(0, 1), end_inds, j)
+ curve_points[i][j] = curve_pts.reshape(batch_size, num_queries, *curve_pts.shape[-2:])
+
+ return {"curve_points": curve_points, 'ctr_points': ctr_points}
+
+ def refactor_targets(self, targets):
+ targets_refactored = []
+ batch_size, num_classes = len(targets["masks"]), len(targets["masks"][0])
+ targets["masks"] = targets["masks"].cuda()
+ targets["points"] = targets["points"].cuda()
+ targets["labels"] = targets["labels"].cuda()
+ for batch_id in range(batch_size):
+
+ sem_masks, ins_masks, ins_objects = [], [], []
+ ctr_points, curve_points, end_labels, valid_masks = [], [], [], []
+
+ ins_classes = targets['labels'][batch_id][:, 0].int()
+ cls_ids, ins_ids = torch.where((ins_classes.unsqueeze(0) == self.class_indices.unsqueeze(1)).int())
+ for cid in range(num_classes):
+ indices = ins_ids[torch.where(cls_ids == cid)]
+ num_ins = indices.shape[0]
+
+ # object class: 0 or 1
+ ins_obj = torch.zeros((num_ins,), dtype=torch.long).cuda()
+ ins_objects.append(ins_obj)
+
+ # bezier control points coords
+ num_max = self.num_points[cid]
+ ctr_pts = targets['points'][batch_id][indices][:, :num_max].float()
+ ctr_pts[:, :, 0] = ctr_pts[:, :, 0] / self.ego_size[1]
+ ctr_pts[:, :, 1] = ctr_pts[:, :, 1] / self.ego_size[0]
+ ctr_points.append(ctr_pts)
+
+ # piecewise end indices
+ end_indices = targets['labels'][batch_id][indices][:, 1].long()
+ end_labels.append(end_indices)
+
+ # bezier valid masks
+ v_mask = torch.zeros((num_ins, num_max), dtype=torch.int8).cuda()
+ for ins_id in range(num_ins):
+ k = targets['labels'][batch_id][indices[ins_id]][2].long()
+ v_mask[ins_id][:k] = 1
+ valid_masks.append(v_mask)
+
+ # curve points
+ curve_pts = self.curve_recovery_with_bezier(ctr_pts, end_indices, cid)
+ curve_points.append(curve_pts)
+
+ # instance mask
+ mask_pc = targets["masks"][batch_id][cid] # mask supervision
+ unique_ids = torch.unique(mask_pc, sorted=True)[1:]
+ if num_ins == unique_ids.shape[0]:
+ ins_msk = (mask_pc.unsqueeze(0).repeat(num_ins, 1, 1) == unique_ids.view(-1, 1, 1)).float()
+ else:
+ ins_msk = np.zeros((num_ins, *self.map_size), dtype=np.uint8)
+ for i, ins_pts in enumerate(curve_pts):
+ ins_pts[:, 0] *= self.map_size[1]
+ ins_pts[:, 1] *= self.map_size[0]
+ ins_pts = ins_pts.cpu().data.numpy().astype(np.int32)
+ cv2.polylines(ins_msk[i], [ins_pts], False, color=1, thickness=self.line_width)
+ ins_msk = torch.from_numpy(ins_msk).float().cuda()
+ ins_masks.append(ins_msk)
+
+ # semantic mask
+ sem_msk = (ins_msk.sum(0) > 0).float()
+ sem_masks.append(sem_msk)
+
+ targets_refactored.append({
+ "sem_masks": sem_masks, "ins_masks": ins_masks, "obj_labels": ins_objects,
+ "ctr_points": ctr_points, "end_labels": end_labels, "curve_points": curve_points,
+ "valid_masks": valid_masks,
+ })
+
+ return targets_refactored
+
+ def curve_recovery_with_bezier(self, ctr_points, end_indices, cid):
+ device = ctr_points.device
+ curve_pts_ret = torch.zeros((0, self.curve_size, 2), dtype=torch.float, device=device)
+ num_instances, num_pieces = ctr_points.shape[0], ctr_points.shape[1]
+ pieces_ids = [[i+j for j in range(self.num_degree[cid]+1)] for i in range(0, num_pieces - 1, self.num_degree[cid])]
+ pieces_ids = torch.tensor(pieces_ids).long().to(device)
+ points_ids = torch.tensor(list(range(self.curve_size))).long().to(device)
+ points_ids = (end_indices + 1).unsqueeze(1) * points_ids.unsqueeze(0)
+ if num_instances > 0:
+ ctr_points_flatten = ctr_points[:, pieces_ids, :].flatten(0, 1)
+ curve_pts = torch.matmul(self.bezier_coefficient[cid], ctr_points_flatten)
+ curve_pts = curve_pts.reshape(num_instances, pieces_ids.shape[0], *curve_pts.shape[-2:])
+ curve_pts = curve_pts.flatten(1, 2)
+ curve_pts_ret = torch.stack([curve_pts[i][points_ids[i]] for i in range(points_ids.shape[0])])
+ return curve_pts_ret
+
+ def _get_bezier_coefficients(self):
+
+ def bernstein_func(n, t, k):
+ return (t ** k) * ((1 - t) ** (n - k)) * n_over_k(n, k)
+
+ ts = np.linspace(0, 1, self.curve_size)
+ bezier_coefficient_list = []
+ for nn in self.num_degree:
+ bezier_coefficient_list.append(np.array([[bernstein_func(nn, t, k) for k in range(nn + 1)] for t in ts]))
+ return bezier_coefficient_list
+
+ def post_processing(self, outputs):
+ batch_results, batch_masks, batch_masks5 = [], [], []
+ batch_size = outputs["obj_logits"][-1][0].shape[0]
+ for i in range(batch_size):
+ points, scores, labels = [None], [-1], [0]
+ masks = np.zeros((self.num_classes, *self.map_size)).astype(np.uint8)
+ masks5 = np.zeros((self.num_classes, *self.map_size)).astype(np.uint8)
+ instance_index = 1
+ for j in range(self.num_classes):
+ pred_scores, pred_labels = torch.max(F.softmax(outputs["obj_logits"][-1][j][i], dim=-1), dim=-1)
+ keep_ids = torch.where((pred_labels == 0).int())[0]
+ if keep_ids.shape[0] == 0:
+ continue
+ curve_pts = outputs['curve_points'][-1][j][i][keep_ids].cpu().data.numpy()
+ curve_pts[:, :, 0] *= self.map_size[1]
+ curve_pts[:, :, 1] *= self.map_size[0]
+ for dt_curve, dt_score in zip(curve_pts, pred_scores[keep_ids]):
+ cv2.polylines(masks[j], [dt_curve.astype(np.int32)], False, color=instance_index,
+ thickness=self.save_thickness)
+ cv2.polylines(masks5[j], [dt_curve.astype(np.int32)], False, color=instance_index, thickness=12)
+ instance_index += 1
+ points.append(curve_pts)
+ scores.append(self._to_np(dt_score).item())
+ labels.append(j + 1)
+ batch_results.append({'map': points, 'confidence_level': scores, 'pred_label': labels})
+ batch_masks.append(masks)
+ batch_masks5.append(masks5)
+ return batch_results, batch_masks, batch_masks5
+
+ @staticmethod
+ def _to_np(tensor):
+ return tensor.cpu().data.numpy()
diff --git a/mapmaster/models/output_head/line_matching.py b/mapmaster/models/output_head/line_matching.py
new file mode 100644
index 0000000..c690ffb
--- /dev/null
+++ b/mapmaster/models/output_head/line_matching.py
@@ -0,0 +1,65 @@
+import numpy as np
+
+def seq_matching_dist_parallel(cost, gt_lens, coe_endpts=0):
+ # Time complexity: O(m*n)
+ bs, m, n = cost.shape
+ assert m <= n
+ min_cost = np.ones((bs, m, n)) * np.inf
+ mem_sort_value = np.ones((bs, m, n)) * np.inf # v[i][j] = np.min(min_cost[i][:j+1])
+
+ # initialization
+ for j in range(0, n):
+ if j == 0:
+ min_cost[:, 0, j] = cost[:, 0, j]
+ mem_sort_value[:, 0, j] = min_cost[:, 0, 0]
+
+ for i in range(1, m):
+ for j in range(i, n):
+ min_cost[:, i, j] = mem_sort_value[:, i-1, j-1] + cost[:, i, j]
+ indexes = (min_cost[:, i, j] < mem_sort_value[:, i, j-1])
+ indexes_inv = np.array(1-indexes, dtype=np.bool)
+ mem_sort_value[indexes, i, j] = min_cost[indexes, i, j]
+ mem_sort_value[indexes_inv, i, j] = mem_sort_value[indexes_inv, i, j-1]
+
+ indexes = []
+ for i, ll in enumerate(gt_lens):
+ indexes.append([i, ll-1, n-1])
+ indexes = np.array(indexes)
+ xs, ys, zs = indexes[:, 0], indexes[:, 1], indexes[:, 2]
+ res_cost = min_cost[xs, ys, zs] + (cost[xs, 0, 0] + cost[xs, ys, zs]) * coe_endpts
+ return res_cost / (indexes[:, 1]+1+coe_endpts*2)
+
+def pivot_dynamic_matching(cost: np.array):
+ # Time complexity: O(m*n)
+ m, n = cost.shape
+ assert m <= n
+
+ min_cost = np.ones((m, n)) * np.inf
+ mem_sort_value = np.ones((m, n)) * np.inf
+ match_res1 = [[] for _ in range(n)]
+ match_res2 = [[] for _ in range(n)]
+
+ # initialization
+ for j in range(0, n-m+1):
+ match_res1[j] = [0]
+ mem_sort_value[0][j] = cost[0][0]
+ if j == 0:
+ min_cost[0][j] = cost[0][0]
+
+ for i in range(1, m):
+ for j in range(i, n-m + i+1):
+ min_cost[i][j] = mem_sort_value[i-1][j-1] + cost[i][j]
+ if min_cost[i][j] < mem_sort_value[i][j-1]:
+ mem_sort_value[i][j] = min_cost[i][j]
+ if i < m-1:
+ match_res2[j] = match_res1[j-1] + [j]
+ else:
+ mem_sort_value[i][j] = mem_sort_value[i][j-1]
+ if i < m -1:
+ match_res2[j] = match_res2[j-1]
+ if i < m-1:
+ match_res1, match_res2 = match_res2.copy(), [[] for _ in range(n)]
+
+ total_cost = min_cost[-1][-1]
+ final_match_res = match_res1[-2] + [n-1]
+ return total_cost, final_match_res
\ No newline at end of file
diff --git a/mapmaster/models/output_head/pivot_outputs.py b/mapmaster/models/output_head/pivot_outputs.py
new file mode 100644
index 0000000..2a63e6b
--- /dev/null
+++ b/mapmaster/models/output_head/pivot_outputs.py
@@ -0,0 +1,115 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class FFN(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, basic_type='linear'):
+ super().__init__()
+ self.basic_type = basic_type
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(self.basic_layer(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+ def basic_layer(self, n, k):
+ if self.basic_type == 'linear':
+ return nn.Linear(n, k)
+ elif self.basic_type == 'conv':
+ return nn.Conv2d(n, k, kernel_size=1, stride=1)
+ else:
+ raise NotImplementedError
+
+class PivotMapOutputHead(nn.Module):
+ def __init__(self, in_channel, num_queries, tgt_shape, max_pieces, bev_channels=-1, ins_channel=64):
+ super(PivotMapOutputHead, self).__init__()
+ self.num_queries = num_queries
+ self.num_classes = len(num_queries)
+ self.tgt_shape = tgt_shape
+ self.bev_channels = bev_channels
+ self.semantic_heads = None
+ if self.bev_channels > 0:
+ self.semantic_heads = nn.ModuleList(
+ nn.Sequential(nn.Conv2d(bev_channels, 2, kernel_size=1, stride=1)) for _ in range(self.num_classes)
+ )
+
+ self.max_pieces = max_pieces # [10, 2, 30]
+ self.pts_split = [num_queries[i]*max_pieces[i] for i in range(len(num_queries))]
+ _N = self.num_classes
+ _C = ins_channel
+ self.im_ctr_heads = nn.ModuleList(FFN(in_channel, 256, 2 * _C, 3) for _ in range(_N))
+ self.pts_cls_heads = nn.ModuleList(FFN((_C)*2, _C*2, 2, 3) for i in range(_N))
+ self.gap_layer = nn.AdaptiveAvgPool2d((1, 1))
+ self.coords = self.compute_locations(device='cuda') # (1, 2, h, w)
+ self.coords_head = FFN(2, 256, _C, 3, 'conv')
+
+ def forward(self, inputs):
+ num_decoders = len(inputs["mask_features"])
+ dt_obj_logit = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)]
+ dt_ins_masks = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)]
+ im_ctr_coord = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)]
+ dt_pivots_logits = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)]
+ coords_feats = self.coords_head.forward(self.coords.repeat((inputs["mask_features"][0].shape[0], 1, 1, 1)))
+
+ for i in range(num_decoders):
+ x_ins_cw = inputs["mask_features"][i].split(self.num_queries, dim=1)
+ x_obj_cw = inputs["obj_scores"][i].split(self.num_queries, dim=1)
+ x_qry_cw = inputs["decoder_outputs"][i].split(self.pts_split, dim=1) # [(b, 200, c), (b, 50, c), (b, 450, c)]
+ batch_size = x_qry_cw[0].shape[0]
+ for j in range(self.num_classes):
+ dt_ins_masks[i][j] = self.up_sample(x_ins_cw[j]) # (B, P, H, W)
+ dt_obj_logit[i][j] = x_obj_cw[j] # (B, P, 2)
+ # im
+ num_qry, n_pts = self.num_queries[j], self.max_pieces[j]
+ im_feats = self.im_ctr_heads[j](x_qry_cw[j]) # (bs, n_q * n_pts, 2*c)
+ im_feats_tmp = im_feats.reshape(batch_size, num_qry*n_pts*2, -1) # (bs, n_q*n_pts*2, c)
+ im_coords_map = torch.einsum("bqc,bchw->bqhw", im_feats_tmp, coords_feats) # [bs, n_q*n_pts*2, h, w]
+ im_coords = self.gap_layer(im_coords_map) # [bs, n_q * n_pts]
+ im_coords = im_coords.reshape(batch_size, num_qry, self.max_pieces[j], 2).sigmoid()
+ im_ctr_coord[i][j] = im_coords
+
+ pt_feats = im_feats.reshape(batch_size, num_qry, self.max_pieces[j], -1).flatten(1, 2) # [bs, n_q * n_pts, 2*C]
+ pt_logits = self.pts_cls_heads[j](pt_feats)
+ dt_pivots_logits[i][j] = pt_logits.reshape(batch_size, num_qry, self.max_pieces[j], 2)
+
+ ret = {"outputs": {"obj_logits": dt_obj_logit, "ins_masks": dt_ins_masks,
+ "ctr_im": im_ctr_coord, "pts_logits": dt_pivots_logits}}
+
+ if self.semantic_heads is not None:
+ num_decoders = len(inputs["bev_enc_features"])
+ dt_sem_masks = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)]
+ for i in range(num_decoders):
+ x_sem = inputs["bev_enc_features"][i]
+ for j in range(self.num_classes):
+ dt_sem_masks[i][j] = self.up_sample(self.semantic_heads[j](x_sem)) # (B, P, 2, H, W)
+ ret["outputs"].update({"sem_masks": dt_sem_masks})
+ return ret
+
+ def up_sample(self, x, tgt_shape=None):
+ tgt_shape = self.tgt_shape if tgt_shape is None else tgt_shape
+ if tuple(x.shape[-2:]) == tuple(tgt_shape):
+ return x
+ return F.interpolate(x, size=tgt_shape, mode="bilinear", align_corners=True)
+
+ def compute_locations(self, stride=1, device='cpu'):
+
+ fh, fw = self.tgt_shape
+
+ shifts_x = torch.arange(0, fw * stride, step=stride, dtype=torch.float32, device=device)
+ shifts_y = torch.arange(0, fh * stride, step=stride, dtype=torch.float32, device=device)
+ shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
+ shift_x = shift_x.reshape(-1)
+ shift_y = shift_y.reshape(-1)
+ locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2
+
+ locations = locations.unsqueeze(0).permute(0, 2, 1).contiguous().float().view(1, 2, fh, fw)
+ locations[:, 0, :, :] /= fw
+ locations[:, 1, :, :] /= fh
+
+ return locations
diff --git a/mapmaster/models/output_head/pivot_post_processor.py b/mapmaster/models/output_head/pivot_post_processor.py
new file mode 100644
index 0000000..cc49dcc
--- /dev/null
+++ b/mapmaster/models/output_head/pivot_post_processor.py
@@ -0,0 +1,340 @@
+import cv2
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+from scipy.optimize import linear_sum_assignment
+from mapmaster.models.utils.mask_loss import SegmentationLoss
+from mapmaster.utils.misc import nested_tensor_from_tensor_list
+from mapmaster.utils.misc import get_world_size, is_available, is_distributed
+
+from .line_matching import pivot_dynamic_matching, seq_matching_dist_parallel
+
+
+class HungarianMatcher(nn.Module):
+
+ def __init__(self, cost_obj=1., cost_mask=1., coe_endpts=1., cost_pts=2., mask_loss_conf=None):
+ super().__init__()
+ self.cost_obj, self.cost_mask = cost_obj, cost_mask
+ self.coe_endpts = coe_endpts # end points weight: 1 + coe_endpts
+ self.cost_pts = cost_pts
+ self.mask_loss = SegmentationLoss(**mask_loss_conf)
+
+ @torch.no_grad()
+ def forward(self, outputs, targets):
+ num_decoders, num_classes = len(outputs["ins_masks"]), len(outputs["ins_masks"][0])
+ matching_indices = [[[] for _ in range(num_classes)] for _ in range(num_decoders)]
+ for dec_id in range(num_decoders):
+ for cid in range(num_classes):
+ bs, num_queries = outputs["obj_logits"][dec_id][cid].shape[:2]
+
+ # 1. obj class cost mat
+ dt_probs = outputs["obj_logits"][dec_id][cid].flatten(0, 1).softmax(-1) # [n_dt, 2], n_dt in a batch
+ gt_idxes = torch.cat([tgt["obj_labels"][cid] for tgt in targets]) # [n_gt, ]
+ cost_mat_obj = -dt_probs[:, gt_idxes] # [n_dt, n_gt]
+
+ # 2. masks cost mat
+ dt_masks = outputs["ins_masks"][dec_id][cid].flatten(0, 1) # [n_dt, h, w]
+ gt_masks = torch.cat([tgt["ins_masks"][cid] for tgt in targets]) # [n_gt, h, w]
+ cost_mat_mask = 0
+ if gt_masks.shape[0] == 0:
+ matching_indices[dec_id][cid] = [(torch.tensor([], dtype=torch.int64), torch.tensor([], dtype=torch.int64))]
+ continue
+ dt_num, gt_num = dt_masks.shape[0], gt_masks.shape[0]
+ dt_masks = dt_masks.unsqueeze(1).expand(dt_num, gt_num, *dt_masks.shape[1:]).flatten(0, 1)
+ gt_masks = gt_masks.unsqueeze(0).expand(dt_num, gt_num, *gt_masks.shape[1:]).flatten(0, 1)
+
+ cost_mat_mask = self.mask_loss(dt_masks, gt_masks, "Matcher").reshape(dt_num, gt_num)
+
+ # 3. sequence matching costmat
+ dt_pts = outputs["ctr_im"][dec_id][cid].flatten(0, 1) # [n_dt, n_pts, 2]
+ n_pt = dt_pts.shape[1]
+ dt_pts = dt_pts.unsqueeze(0).repeat(gt_num, 1, 1, 1).flatten(0, 1) # [n_gt, n_dt, n_pts, 2] -> [n_gt*n_dt, n_pts, 2]
+ gt_pts = targets[0]["points"][cid].to(torch.float32)
+ gt_pts = gt_pts.unsqueeze(1).repeat(1, dt_num, 1, 1).flatten(0, 1) # [n_gt, n_dt, n_pts, 2] -> [n_gt*n_dt, n_pts, 2]
+
+ gt_pts_mask = torch.zeros(gt_num, n_pt, dtype=torch.double, device=gt_pts.device)
+ gt_lens = torch.tensor([ll for ll in targets[0]["valid_len"][cid]]) # n_gt
+ gt_lens = gt_lens.unsqueeze(-1).repeat(1, dt_num).flatten()
+ for i, ll in enumerate(targets[0]["valid_len"][cid]):
+ gt_pts_mask[i][:ll] = 1
+ gt_pts_mask = gt_pts_mask.unsqueeze(1).unsqueeze(-1).repeat(1, dt_num, 1, n_pt).flatten(0, 1)
+ cost_mat_seqmatching = torch.cdist(gt_pts, dt_pts, p=1) * gt_pts_mask # [n_gt*n_dt, n_pts, n_pts]
+ cost_mat_seqmatching = seq_matching_dist_parallel(
+ cost_mat_seqmatching.detach().cpu().numpy(),
+ gt_lens,
+ self.coe_endpts).reshape(gt_num, dt_num).transpose(1, 0) #[n_gt, n_dt]
+ cost_mat_seqmatching = torch.from_numpy(cost_mat_seqmatching).to(cost_mat_mask.device)
+
+ # 4. sum mat
+ sizes = [len(tgt["obj_labels"][cid]) for tgt in targets]
+ C = self.cost_obj * cost_mat_obj + self.cost_mask * cost_mat_mask + self.cost_pts * cost_mat_seqmatching
+ C = C.view(bs, num_queries, -1).cpu()
+ indices = [linear_sum_assignment(c[i].detach().numpy()) for i, c in enumerate(C.split(sizes, -1))]
+
+ matching_indices[dec_id][cid] = [self.to_tensor(i, j) for i, j in indices]
+
+ return matching_indices
+
+ @staticmethod
+ def to_tensor(i, j):
+ return torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)
+
+
+
+class SetCriterion(nn.Module):
+ def __init__(self, criterion_conf, matcher, sem_loss_conf=None, no_object_coe=1.0, collinear_pts_coe=1.0, coe_endpts=1.0):
+ super().__init__()
+ self.matcher = matcher
+ self.criterion_conf = criterion_conf
+ self.register_buffer("empty_weight", torch.tensor([1.0, no_object_coe]))
+ self.register_buffer("collinear_pt_weight", torch.tensor([collinear_pts_coe, 1.0]))
+ self.coe_endpts = coe_endpts
+
+ self.sem_loss_conf = sem_loss_conf
+ self.mask_loss = SegmentationLoss(**sem_loss_conf["mask_loss_conf"])
+
+ def forward(self, outputs, targets):
+ matching_indices = self.matcher(outputs, targets)
+ ins_msk_loss, pts_loss, collinear_pts_loss, pt_logits_loss = \
+ self.criterion_instance(outputs, targets, matching_indices)
+ ins_obj_loss = self.criterion_instance_labels(outputs, targets, matching_indices)
+ losses = {"ins_msk_loss": ins_msk_loss, "ins_obj_loss": ins_obj_loss,
+ "pts_loss": pts_loss, "collinear_pts_loss": collinear_pts_loss,
+ "pt_logits_loss": pt_logits_loss}
+ if self.sem_loss_conf is not None:
+ losses.update({"sem_msk_loss": self.criterion_semantice_masks(outputs, targets)})
+ losses = {key: self.criterion_conf['weight_dict'][key] * losses[key] for key in losses}
+ return sum(losses.values()), losses
+
+ def criterion_instance(self, outputs, targets, matching_indices):
+ loss_masks, loss_pts, loss_collinear_pts, loss_logits = 0, 0, 0, 0
+ device = outputs["ins_masks"][0][0].device
+ num_decoders, num_classes = len(matching_indices), len(matching_indices[0])
+ for i in range(num_decoders):
+ w = self.criterion_conf['decoder_weights'][i]
+ for j in range(num_classes):
+ num_instances = sum(len(t["obj_labels"][j]) for t in targets)
+ num_instances = torch.as_tensor([num_instances], dtype=torch.float, device=device)
+ if is_distributed() and is_available():
+ torch.distributed.all_reduce(num_instances)
+ num_instances = torch.clamp(num_instances / get_world_size(), min=1).item()
+ indices = matching_indices[i][j]
+ src_idx = self._get_src_permutation_idx(indices) # dt
+ tgt_idx = self._get_tgt_permutation_idx(indices) # gt
+
+ # instance masks
+ src_masks = outputs["ins_masks"][i][j][src_idx]
+ tgt_masks = [t["ins_masks"][j] for t in targets]
+ tgt_masks, _ = nested_tensor_from_tensor_list(tgt_masks).decompose()
+ tgt_masks = tgt_masks.to(src_masks)[tgt_idx]
+ loss_masks += w * self.mask_loss(src_masks, tgt_masks, "Loss").sum() / num_instances
+
+ # prepare tgt points
+ src_ctrs = outputs["ctr_im"][i][j][src_idx] # [num_queries, o, 2]
+ tgt_ctrs = [] # [num_queries, o, 2]
+ for info in targets: # B
+ for pts, valid_len in zip(info["points"][j][tgt_idx[1]], info["valid_len"][j][tgt_idx[1]]): # n_gt
+ tgt_ctrs.append(pts[:valid_len])
+
+ # pts match, valid pts loss, collinear pts loss
+ n_match_q, n_dt_pts = src_ctrs.shape[0], src_ctrs.shape[1]
+ logits_gt = torch.zeros((n_match_q, n_dt_pts), dtype=torch.long, device=src_ctrs.device)
+ if n_match_q == 0: # avoid unused parameters
+ loss_pts += w * F.l1_loss(src_ctrs, src_ctrs, reduction='sum')
+ loss_logits += w * F.l1_loss(outputs["pts_logits"][i][j][src_idx].flatten(0, 1), outputs["pts_logits"][i][j][src_idx].flatten(0, 1), reduction="sum")
+ continue
+
+ for ii, (src_pts, tgt_pts) in enumerate(zip(src_ctrs, tgt_ctrs)): # B=1, traverse matched query pairs
+ n_gt_pt = len(tgt_pts)
+ weight_pt = torch.ones((n_gt_pt), device=tgt_pts.device)
+ weight_pt[0] += self.coe_endpts
+ weight_pt[-1] += self.coe_endpts
+ cost_mat = torch.cdist(tgt_pts.to(torch.float32), src_pts, p=1)
+ _, matched_pt_idx = pivot_dynamic_matching(cost_mat.detach().cpu().numpy())
+ matched_pt_idx = torch.tensor(matched_pt_idx)
+ # match pts loss
+ loss_match = w * F.l1_loss(src_pts[matched_pt_idx], tgt_pts, reduction="none").sum(dim=-1) # [n_gt_pt, 2] -> [n_gt_dt]
+ loss_match = (loss_match * weight_pt).sum() / weight_pt.sum()
+ loss_pts += loss_match / num_instances
+ # interpolate pts loss
+ loss_collinear_pts += w * self.interpolate_loss(src_pts, tgt_pts, matched_pt_idx) / num_instances
+ # pt logits
+ logits_gt[ii][matched_pt_idx] = 1
+ loss_logits += w * F.cross_entropy(outputs["pts_logits"][i][j][src_idx].flatten(0, 1), logits_gt.flatten(), self.collinear_pt_weight) / num_instances
+
+ loss_masks /= (num_decoders * num_classes)
+ loss_pts /= (num_decoders * num_classes)
+ loss_logits /= (num_decoders * num_classes)
+ loss_collinear_pts /= (num_decoders * num_classes)
+
+ return loss_masks, loss_pts, loss_collinear_pts, loss_logits
+
+ def interpolate_loss(self, src_pts, tgt_pts, matched_pt_idx):
+ # 1. pick collinear pt idx
+ collinear_idx = torch.ones(src_pts.shape[0], dtype=torch.bool)
+ collinear_idx[matched_pt_idx] = 0
+ collinear_src_pts = src_pts[collinear_idx]
+ # 2. interpolate tgt_pts
+ inter_tgt = torch.zeros_like(collinear_src_pts)
+ cnt = 0
+ for i in range(len(matched_pt_idx)-1):
+ start_pt, end_pt = tgt_pts[i], tgt_pts[i+1]
+ inter_num = matched_pt_idx[i+1] - matched_pt_idx[i] - 1
+ inter_tgt[cnt:cnt+inter_num] = self.interpolate(start_pt, end_pt, inter_num)
+ cnt += inter_num
+ assert collinear_src_pts.shape[0] == cnt
+ # 3. cal loss
+ if cnt > 0:
+ inter_loss = F.l1_loss(collinear_src_pts, inter_tgt, reduction="sum") / cnt
+ else:
+ inter_loss = F.l1_loss(collinear_src_pts, inter_tgt, reduction="sum")
+ return inter_loss
+
+ @staticmethod
+ def interpolate(start_pt, end_pt, inter_num):
+ res = torch.zeros((inter_num, 2), dtype=start_pt.dtype, device=start_pt.device)
+ num_len = inter_num + 1 # segment num.
+ for i in range(1, num_len):
+ ratio = i / num_len
+ res[i-1] = (1 - ratio) * start_pt + ratio * end_pt
+ return res
+
+ def criterion_instance_labels(self, outputs, targets, matching_indices):
+ loss_labels = 0
+ num_decoders, num_classes = len(matching_indices), len(matching_indices[0])
+ for i in range(num_decoders):
+ w = self.criterion_conf['decoder_weights'][i]
+ for j in range(num_classes):
+ indices = matching_indices[i][j]
+ idx = self._get_src_permutation_idx(indices) # (batch_id, query_id)
+ logits = outputs["obj_logits"][i][j]
+ target_classes_o = torch.cat([t["obj_labels"][j][J] for t, (_, J) in zip(targets, indices)])
+ target_classes = torch.full(logits.shape[:2], 1, dtype=torch.int64, device=logits.device)
+ target_classes[idx] = target_classes_o
+ loss_labels += (w * F.cross_entropy(logits.transpose(1, 2), target_classes, self.empty_weight))
+ loss_labels /= (num_decoders * num_classes)
+ return loss_labels
+
+ def criterion_semantice_masks(self, outputs, targets):
+ loss_masks = 0
+ num_decoders, num_classes = len(outputs["sem_masks"]), len(outputs["sem_masks"][0])
+ for i in range(num_decoders):
+ w = self.sem_loss_conf['decoder_weights'][i]
+ for j in range(num_classes):
+ dt_masks = outputs["sem_masks"][i][j] # (B, 2, H, W)
+ gt_masks = torch.stack([t["sem_masks"][j] for t in targets], dim=0) # (B, H, W)
+ loss_masks += w * self.mask_loss(dt_masks[:, 1, :, :], gt_masks).mean()
+ loss_masks /= (num_decoders * num_classes)
+ return loss_masks
+
+ @staticmethod
+ def _get_src_permutation_idx(indices):
+ # permute predictions following indices
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+ src_idx = torch.cat([src for (src, _) in indices])
+ return batch_idx, src_idx
+
+ @staticmethod
+ def _get_tgt_permutation_idx(indices):
+ # permute targets following indices
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+ return batch_idx, tgt_idx
+
+
+class PivotMapPostProcessor(nn.Module):
+ def __init__(self, criterion_conf, matcher_conf, pivot_conf, map_conf,
+ sem_loss_conf=None, no_object_coe=1.0, collinear_pts_coe=1.0, coe_endpts=0.0):
+ super(PivotMapPostProcessor, self).__init__()
+ self.criterion = SetCriterion(criterion_conf, HungarianMatcher(**matcher_conf), sem_loss_conf, no_object_coe, collinear_pts_coe, coe_endpts)
+ self.ego_size = map_conf['ego_size']
+ self.map_size = map_conf['map_size']
+ self.line_width = map_conf['line_width']
+ self.num_pieces = pivot_conf['max_pieces'] # (10, 2, 30)
+ self.num_classes = len(self.num_pieces)
+ self.class_indices = torch.tensor(list(range(self.num_classes)), dtype=torch.int).cuda()
+
+ def forward(self, outputs, targets=None):
+ if self.training:
+ targets = self.refactor_targets(targets)
+ return self.criterion.forward(outputs, targets)
+ else:
+ return self.post_processing(outputs)
+
+
+ def refactor_targets(self, targets):
+ # only support bs == 1
+ targets_refactored = []
+ targets["masks"] = targets["masks"].cuda()
+ for key in [0, 1, 2]: # map type
+ targets["points"][key] = targets["points"][key].cuda()[0] # [0] remove batch dim
+ targets["valid_len"][key] = targets["valid_len"][key].cuda()[0] # [0] remove batch dim
+
+ for instance_mask in targets["masks"]: # bs, only support bs == 1
+ sem_masks, ins_masks, ins_objects = [], [], []
+ for i, mask_pc in enumerate(instance_mask): # class
+ sem_masks.append((mask_pc > 0).float())
+ unique_ids = torch.unique(mask_pc, sorted=True)[1:]
+ ins_num = unique_ids.shape[0]
+ pt_ins_num = len(targets["points"][i])
+ if pt_ins_num == ins_num:
+ ins_msk = (mask_pc.unsqueeze(0).repeat(ins_num, 1, 1) == unique_ids.view(-1, 1, 1)).float()
+ else:
+ ins_msk = np.zeros((pt_ins_num, *self.map_size), dtype=np.uint8)
+ for j, ins_pts in enumerate(targets["points"][i]):
+ ins_pts_tmp = ins_pts.clone()
+ ins_pts_tmp[:, 0] *= self.map_size[0]
+ ins_pts_tmp[:, 1] *= self.map_size[1]
+ ins_pts_tmp = ins_pts_tmp.cpu().data.numpy().astype(np.int32)
+ cv2.polylines(ins_msk[j], [ins_pts_tmp[:, ::-1]], False, color=1, thickness=self.line_width)
+ ins_msk = torch.from_numpy(ins_msk).float().cuda()
+ assert len(ins_msk) == len(targets["points"][i])
+ ins_obj = torch.zeros(pt_ins_num, dtype=torch.long, device=unique_ids.device)
+ ins_masks.append(ins_msk)
+ ins_objects.append(ins_obj)
+ targets_refactored.append({
+ "sem_masks": sem_masks,
+ "ins_masks": ins_masks,
+ "obj_labels": ins_objects,
+ "points": targets["points"],
+ "valid_len": targets["valid_len"],
+ })
+ return targets_refactored
+
+ def post_processing(self, outputs):
+ batch_results, batch_masks = [], []
+ batch_size = outputs["obj_logits"][-1][0].shape[0]
+ for i in range(batch_size):
+ points, scores, labels = [None], [-1], [0]
+ masks = np.zeros((self.num_classes, *self.map_size)).astype(np.uint8)
+ instance_index = 1
+ for j in range(self.num_classes):
+ pred_scores, pred_labels = torch.max(F.softmax(outputs["obj_logits"][-1][j][i], dim=-1), dim=-1)
+ keep_ids = torch.where((pred_labels == 0).int())[0] # fore-ground
+ if keep_ids.shape[0] == 0:
+ continue
+ keypts = outputs["ctr_im"][-1][j][i][keep_ids].cpu().data.numpy() # [P, N, 2]
+ keypts[:, :, 0] *= self.map_size[0]
+ keypts[:, :, 1] *= self.map_size[1]
+
+ valid_pt_idx = F.softmax(outputs["pts_logits"][-1][j][i][keep_ids], dim=-1)[:,:,1].cpu().data.numpy() > 0.5 # [P, N]
+ valid_pt_idx[:, 0] = 1
+ valid_pt_idx[:, -1] = 1
+
+ for k, (dt_pts, dt_score) in enumerate(zip(keypts, pred_scores[keep_ids])):
+ select_pt = dt_pts[valid_pt_idx[k]]
+ cv2.polylines(masks[j], [select_pt.astype(np.int32)[:, ::-1]], False, color=instance_index, thickness=1)
+ instance_index += 1
+ points.append(select_pt)
+ scores.append(self._to_np(dt_score).item())
+ labels.append(j + 1)
+ batch_results.append({'map': points, 'confidence_level': scores, 'pred_label': labels})
+ batch_masks.append(masks)
+ return batch_results, batch_masks
+
+ @staticmethod
+ def _to_np(tensor):
+ return tensor.cpu().data.numpy()
+
+
\ No newline at end of file
diff --git a/mapmaster/models/utils/mask_loss.py b/mapmaster/models/utils/mask_loss.py
new file mode 100644
index 0000000..8897cbf
--- /dev/null
+++ b/mapmaster/models/utils/mask_loss.py
@@ -0,0 +1,89 @@
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+from detectron2.projects.point_rend.point_features import point_sample
+from detectron2.projects.point_rend.point_features import get_uncertain_point_coords_with_randomness
+
+
+class SegmentationLoss(nn.Module):
+
+ def __init__(self, ce_weight, dice_weight, use_point_render=False, num_points=8000, oversample=3.0, importance=0.75):
+ super(SegmentationLoss, self).__init__()
+ self.ce_weight = ce_weight
+ self.dice_weight = dice_weight
+ self.use_point_render = use_point_render
+ self.num_points = num_points
+ self.oversample = oversample
+ self.importance = importance
+
+ def forward(self, dt_masks, gt_masks, stage="loss"):
+ loss = 0
+ if self.use_point_render:
+ dt_masks, gt_masks = self.points_render(dt_masks, gt_masks, stage)
+ if self.ce_weight > 0:
+ loss += self.ce_weight * self.forward_sigmoid_ce_loss(dt_masks, gt_masks)
+ if self.dice_weight > 0:
+ loss += self.dice_weight * self.forward_dice_loss(dt_masks, gt_masks)
+ return loss
+
+ @staticmethod
+ def forward_dice_loss(inputs, targets):
+ inputs = inputs.sigmoid()
+ inputs = inputs.flatten(1)
+ targets = targets.flatten(1)
+ numerator = 2 * (inputs * targets).sum(-1)
+ denominator = inputs.sum(-1) + targets.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss
+
+ @staticmethod
+ def forward_sigmoid_ce_loss(inputs, targets):
+ inputs = inputs.flatten(1)
+ targets = targets.flatten(1)
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ return loss.mean(1)
+
+ def points_render(self, src_masks, tgt_masks, stage):
+ assert stage in ["loss", "matcher"]
+ assert src_masks.shape == tgt_masks.shape
+
+ src_masks = src_masks[:, None]
+ tgt_masks = tgt_masks[:, None]
+
+ if stage == "matcher":
+ point_coords = torch.rand(1, self.num_points, 2, device=src_masks.device)
+ point_coords_src = point_coords.repeat(src_masks.shape[0], 1, 1)
+ point_coords_tgt = point_coords.repeat(tgt_masks.shape[0], 1, 1)
+ else:
+ point_coords = get_uncertain_point_coords_with_randomness(
+ src_masks,
+ lambda logits: self.calculate_uncertainty(logits),
+ self.num_points,
+ self.oversample,
+ self.importance,
+ )
+ point_coords_src = point_coords.clone()
+ point_coords_tgt = point_coords.clone()
+
+ src_masks = point_sample(src_masks, point_coords_src, align_corners=False).squeeze(1)
+ tgt_masks = point_sample(tgt_masks, point_coords_tgt, align_corners=False).squeeze(1)
+
+ return src_masks, tgt_masks
+
+ @staticmethod
+ def calculate_uncertainty(logits):
+ """
+ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
+ foreground class in `classes`.
+ Args:
+ logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
+ class-agnostic, where R is the total number of predicted masks in all images and C is
+ the number of foreground classes. The values are logits.
+ Returns:
+ scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
+ the most uncertain locations having the highest uncertainty score.
+ """
+ assert logits.shape[1] == 1
+ gt_class_logits = logits.clone()
+ return -(torch.abs(gt_class_logits))
diff --git a/mapmaster/models/utils/misc.py b/mapmaster/models/utils/misc.py
new file mode 100644
index 0000000..25e394b
--- /dev/null
+++ b/mapmaster/models/utils/misc.py
@@ -0,0 +1,78 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+import warnings
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+def c2_xavier_fill(module: nn.Module) -> None:
+ """
+ Initialize `module.weight` using the "XavierFill" implemented in Caffe2.
+ Also initializes `module.bias` to 0.
+ Args:
+ module (torch.nn.Module): module to initialize.
+ """
+ # Caffe2 implementation of XavierFill in fact
+ # corresponds to kaiming_uniform_ in PyTorch
+ # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Module, Tensor]`.
+ nn.init.kaiming_uniform_(module.weight, a=1)
+ if module.bias is not None:
+ # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module,
+ # torch.Tensor]`.
+ nn.init.constant_(module.bias, 0)
+
+
+class Conv2d(torch.nn.Conv2d):
+ """
+ A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
+ """
+
+ def __init__(self, *args, **kwargs):
+ """
+ Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
+ Args:
+ norm (nn.Module, optional): a normalization layer
+ activation (callable(Tensor) -> Tensor): a callable activation function
+ It assumes that norm layer is used before activation.
+ """
+ norm = kwargs.pop("norm", None)
+ activation = kwargs.pop("activation", None)
+ super().__init__(*args, **kwargs)
+
+ self.norm = norm
+ self.activation = activation
+
+ def forward(self, x):
+ # torchscript does not support SyncBatchNorm yet
+ # https://github.com/pytorch/pytorch/issues/40507
+ # and we skip these codes in torchscript since:
+ # 1. currently we only support torchscript in evaluation mode
+ # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
+ # later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
+ if not torch.jit.is_scripting():
+ with warnings.catch_warnings(record=True):
+ if x.numel() == 0 and self.training:
+ # https://github.com/pytorch/pytorch/issues/12013
+ assert not isinstance(
+ self.norm, torch.nn.SyncBatchNorm
+ ), "SyncBatchNorm does not support empty inputs!"
+
+ x = F.conv2d(
+ x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
+ )
+ if self.norm is not None:
+ x = self.norm(x)
+ if self.activation is not None:
+ x = self.activation(x)
+ return x
+
+
+def get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
diff --git a/mapmaster/models/utils/position_encoding.py b/mapmaster/models/utils/position_encoding.py
new file mode 100644
index 0000000..565d4ce
--- /dev/null
+++ b/mapmaster/models/utils/position_encoding.py
@@ -0,0 +1,217 @@
+"""
+Various positional encodings for the transformer.
+"""
+import math
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, mask):
+ assert mask is not None
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=mask.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+
+ def __init__(self, num_pos=(50, 50), num_pos_feats=256):
+ super().__init__()
+ self.num_pos = num_pos
+ self.pos_embed = nn.Embedding(num_pos[0] * num_pos[1], num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.normal_(self.pos_embed.weight)
+
+ def forward(self, mask):
+ h, w = mask.shape[-2:]
+ pos = self.pos_embed.weight.view(*self.num_pos, -1)[:h, :w]
+ pos = pos.permute(2, 0, 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1)
+ return pos
+
+
+class PositionEmbeddingIPM(nn.Module):
+
+ def __init__(self,
+ encoder=None,
+ num_pos=(16, 168),
+ input_shape=(512, 896),
+ num_pos_feats=64,
+ sine_encoding=False,
+ temperature=10000):
+ super().__init__()
+
+ h, w_expand = num_pos
+ self.current_shape = (h, w_expand // 6)
+ self.input_shape = input_shape
+
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.encoder = encoder
+ self.sine_encoding = sine_encoding
+
+ def get_embedding(self, extrinsic, intrinsic, ida_mats):
+ """
+ Get the BeV Coordinate for Image
+
+ Return
+ xy_world_coord (N, H, W, 2) Ego x, y coordinate
+ Valid (N, H, W, 1) -- Valid Points or Not 1 -- valid; 0 -- invalid
+ """
+ # extrinsic -> (B, M, 4, 4)
+ device, b, n = extrinsic.device, extrinsic.shape[0], extrinsic.shape[1]
+
+ x = torch.linspace(0, self.input_shape[1] - 1, self.current_shape[1], dtype=torch.float)
+ y = torch.linspace(0, self.input_shape[0] - 1, self.current_shape[0], dtype=torch.float)
+ y_grid, x_grid = torch.meshgrid(y, x)
+ z = torch.ones(self.current_shape)
+ feat_coords = torch.stack([x_grid, y_grid, z], dim=-1).to(device) # (H, W, 3)
+ feat_coords = feat_coords.unsqueeze(0).repeat(n, 1, 1, 1).unsqueeze(0).repeat(b, 1, 1, 1, 1) # (B, N, H, W, 3)
+
+ ida_mats = ida_mats.view(b, n, 1, 1, 3, 3)
+ image_coords = ida_mats.inverse().matmul(feat_coords.unsqueeze(-1)) # (B, N, H, W, 3, 1)
+
+ intrinsic = intrinsic.view(b, n, 1, 1, 3, 3) # (B, N, 1, 1, 3, 3)
+ normed_coords = torch.linalg.inv(intrinsic) @ image_coords # (B, N, H, W, 3, 1)
+
+ ext_rots = extrinsic[:, :, :3, :3] # (B, N, 3, 3)
+ ext_trans = extrinsic[:, :, :3, 3] # (B, N, 3)
+
+ ext_rots = ext_rots.view(b, n, 1, 1, 3, 3) # (B, N, 1, 1, 3, 3)
+ world_coords = (ext_rots @ normed_coords).squeeze(-1) # (B, N, H, W, 3)
+ world_coords = F.normalize(world_coords, p=2, dim=-1)
+ z_coord = world_coords[:, :, :, :, 2] # (B, N, H, W)
+
+ trans_z = ext_trans[:, :, 2].unsqueeze(-1).unsqueeze(-1) # (B, N, 1, 1)
+ depth = - trans_z / z_coord # (B, N, H, W)
+ valid = depth > 0 # (B, N, H, W)
+
+ xy_world_coords = world_coords[:, :, :, :, :2] # (B, N, H, W, 2)
+ xy_world_coords = xy_world_coords * depth.unsqueeze(-1)
+ valid = valid.unsqueeze(-1) # (B, N, H, W, 1)
+
+ return xy_world_coords, valid
+
+ def forward(self, extrinsic, intrinsic, ida_mats, do_flip):
+ """
+ extrinsic (N, 6, 4, 4) torch.Tensor
+ intrinsic (N, 6, 3, 3)
+ """
+ device = extrinsic.device
+ xy_pos_embed, valid = self.get_embedding(extrinsic, intrinsic, ida_mats)
+ if do_flip:
+ xy_pos_embed[:, :, :, :, 1] = -1 * xy_pos_embed[:, :, :, :, 1]
+ # along with w
+ xy_pos_embed = torch.cat(torch.unbind(xy_pos_embed, dim=1), dim=-2) # (B, H, N*W, 2)
+ valid = torch.cat(torch.unbind(valid, dim=1), dim=-2) # (B, H, N*W, 2)
+ if self.sine_encoding:
+ # Use Sine encoding to get 256 dim embeddings
+ dim_t = torch.arange(self.num_pos_feats // 2, dtype=torch.float32, device=device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / (self.num_pos_feats // 2))
+ pos_embed = xy_pos_embed[:, :, :, :, None] / dim_t
+ pos_x = torch.stack((pos_embed[:, :, :, 0, 0::2].sin(), pos_embed[:, :, :, 0, 1::2].cos()), dim=4)
+ pos_y = torch.stack((pos_embed[:, :, :, 1, 0::2].sin(), pos_embed[:, :, :, 1, 1::2].cos()), dim=4)
+ pos_full_embed = torch.cat((pos_y.flatten(3), pos_x.flatten(3)), dim=3)
+ pos_combined = torch.where(valid, pos_full_embed, torch.tensor(0., dtype=torch.float32, device=device))
+ pos_combined = pos_combined.permute(0, 3, 1, 2) # (B, 2, H, W')
+ else:
+ assert None
+ # pos_combined = torch.where(valid, xy_pos_embed, torch.tensor(0., dtype=torch.float32, device=device))
+ # pos_combined = pos_combined.permute(0, 3, 1, 2)
+
+ if self.encoder is None:
+ return pos_combined, valid.squeeze(-1)
+ else:
+ pos_embed_contiguous = pos_combined.contiguous()
+ return self.encoder(pos_embed_contiguous), valid.squeeze(-1)
+
+
+class PositionEmbeddingTgt(nn.Module):
+ def __init__(self,
+ encoder=None,
+ tgt_shape=(40, 20),
+ map_size=(400, 200),
+ map_resolution=0.15,
+ num_pos_feats=64,
+ sine_encoding=False,
+ temperature=10000):
+ super().__init__()
+ self.tgt_shape = tgt_shape
+ self.encoder = encoder
+ self.map_size = map_size
+ self.map_resolution = map_resolution
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.sine_encoding = sine_encoding
+
+ def forward(self, mask):
+ B = mask.shape[0]
+
+ map_forward_ratio = self.tgt_shape[0] / self.map_size[0]
+ map_lateral_ratio = self.tgt_shape[1] / self.map_size[1]
+
+ map_forward_res = self.map_resolution / map_forward_ratio
+ map_lateral_res = self.map_resolution / map_lateral_ratio
+
+ X = (torch.arange(self.tgt_shape[0] - 1, -1, -1, device=mask.device) + 0.5 - self.tgt_shape[
+ 0] / 2) * map_forward_res
+ Y = (torch.arange(self.tgt_shape[1] - 1, -1, -1, device=mask.device) + 0.5 - self.tgt_shape[
+ 1] / 2) * map_lateral_res
+
+ grid_X, grid_Y = torch.meshgrid(X, Y)
+ pos_embed = torch.stack([grid_X, grid_Y], dim=-1) # (H, W, 2)
+
+ if self.sine_encoding:
+ dim_t = torch.arange(self.num_pos_feats // 2, dtype=torch.float32, device=mask.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / (self.num_pos_feats // 2))
+
+ pos_embed = pos_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_embed[:, :, 0, 0::2].sin(), pos_embed[:, :, 0, 1::2].cos()), dim=3).flatten(2)
+ pos_y = torch.stack((pos_embed[:, :, 1, 0::2].sin(), pos_embed[:, :, 1, 1::2].cos()), dim=3).flatten(2)
+ pos_full_embed = torch.cat((pos_y, pos_x), dim=2)
+
+ pos_embed = pos_full_embed.unsqueeze(0).repeat(B, 1, 1, 1).permute(0, 3, 1, 2)
+ else:
+ pos_embed = pos_embed.unsqueeze(0).repeat(B, 1, 1, 1).permute(0, 3, 1, 2)
+
+ if self.encoder is None:
+ return pos_embed
+ else:
+ pos_embed_contiguous = pos_embed.contiguous()
+ return self.encoder(pos_embed_contiguous)
\ No newline at end of file
diff --git a/mapmaster/models/utils/recovery_loss.py b/mapmaster/models/utils/recovery_loss.py
new file mode 100644
index 0000000..b25ce3b
--- /dev/null
+++ b/mapmaster/models/utils/recovery_loss.py
@@ -0,0 +1,49 @@
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+from detectron2.projects.point_rend.point_features import point_sample
+
+
+class PointRecoveryLoss(nn.Module):
+
+ def __init__(self, ce_weight, dice_weight, curve_width, tgt_shape):
+ super(PointRecoveryLoss, self).__init__()
+ self.ce_weight = ce_weight
+ self.dice_weight = dice_weight
+ self.kernel = self.generate_kernel(curve_width, tgt_shape)
+
+ def forward(self, points, gt_masks):
+ points_expanded = points.unsqueeze(2) - self.kernel.repeat(points.shape[0], 1, 1, 1)
+ points_expanded = torch.clamp(points_expanded.flatten(1, 2), min=0, max=1) # (N, P*w*w, 2) [0, 1]
+ dt_points = point_sample(gt_masks[:, None], points_expanded, align_corners=False).squeeze(1).flatten(1)
+ gt_points = torch.ones_like(dt_points)
+ loss = 0
+ if self.ce_weight > 0:
+ loss += self.ce_weight * self.forward_ce_loss(dt_points, gt_points)
+ if self.dice_weight > 0:
+ loss += self.dice_weight * self.forward_dice_loss(dt_points, gt_points)
+ return loss
+
+ @staticmethod
+ def generate_kernel(curve_width, tgt_shape, device='cuda'):
+ width = torch.tensor(list(range(curve_width)))
+ kernel = torch.stack(torch.meshgrid(width, width), dim=-1).float()
+ kernel = kernel - curve_width // 2
+ kernel[..., 0] = kernel[..., 0] / tgt_shape[1]
+ kernel[..., 1] = kernel[..., 1] / tgt_shape[0]
+ kernel = kernel.flatten(0, 1).unsqueeze(0).unsqueeze(0) # (1, 1, w*w, 2)
+ kernel = kernel.cuda() if device == 'cuda' else kernel
+ return kernel
+
+ @staticmethod
+ def forward_dice_loss(inputs, targets):
+ numerator = 2 * (inputs * targets).sum(-1)
+ denominator = inputs.sum(-1) + targets.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss
+
+ @staticmethod
+ def forward_ce_loss(inputs, targets):
+ loss = F.binary_cross_entropy(inputs, targets, reduction="none")
+ return loss.mean(1)
diff --git a/mapmaster/utils/env.py b/mapmaster/utils/env.py
new file mode 100644
index 0000000..e7994a8
--- /dev/null
+++ b/mapmaster/utils/env.py
@@ -0,0 +1,129 @@
+import os
+import re
+import sys
+import PIL
+import importlib
+import warnings
+import subprocess
+import torch
+import torchvision
+import numpy as np
+from tabulate import tabulate
+from collections import defaultdict
+
+__all__ = ["collect_env_info"]
+
+
+def collect_torch_env():
+ import torch.__config__
+ return torch.__config__.show()
+
+
+def collect_git_info():
+ try:
+ import git
+ from git import InvalidGitRepositoryError
+ except ImportError:
+ warnings.warn("Please consider to install gitpython for git info collection by 'pip install gitpython'.")
+ return "Git status: unknown\n"
+
+ try:
+ repo = git.Repo(get_root_dir())
+ except InvalidGitRepositoryError:
+ warnings.warn("Current path is possibly not a valid git repository.")
+ return "Git status: unknown\n"
+
+ msg = "***Git status:***\n{}\nHEAD Commit-id: {}\n".format(repo.git.status().replace("<", "\<"), repo.head.commit)
+ msg = "{}\n{}".format(msg, "***Git Diff:***\n{}\n".format(repo.git.diff().replace("<", "\<")))
+ return msg
+
+
+def detect_compute_compatibility(CUDA_HOME, so_file):
+ try:
+ cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump")
+ if os.path.isfile(cuobjdump):
+ output = subprocess.check_output("'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True)
+ output = output.decode("utf-8").strip().split("\n")
+ sm = []
+ for line in output:
+ line = re.findall(r"\.sm_[0-9]*\.", line)[0]
+ sm.append(line.strip("."))
+ sm = sorted(set(sm))
+ return ", ".join(sm)
+ else:
+ return so_file + "; cannot find cuobjdump"
+ except Exception:
+ # unhandled failure
+ return so_file
+
+
+def collect_env_info():
+ data = []
+ data.append(("sys.platform", sys.platform))
+ data.append(("Python", sys.version.replace("\n", "")))
+ data.append(("numpy", np.__version__))
+ data.append(("Pillow", PIL.__version__))
+
+ data.append(("PyTorch", torch.__version__ + " @" + os.path.dirname(torch.__file__)))
+ data.append(("PyTorch debug build", torch.version.debug))
+
+ has_cuda = torch.cuda.is_available()
+
+ data.append(("CUDA available", has_cuda))
+ if has_cuda:
+ devices = defaultdict(list)
+ for k in range(torch.cuda.device_count()):
+ devices[torch.cuda.get_device_name(k)].append(str(k))
+ for name, devids in devices.items():
+ data.append(("GPU " + ",".join(devids), name))
+
+ from torch.utils.cpp_extension import CUDA_HOME
+
+ data.append(("CUDA_HOME", str(CUDA_HOME)))
+
+ if CUDA_HOME is not None and os.path.isdir(CUDA_HOME):
+ try:
+ nvcc = os.path.join(CUDA_HOME, "bin", "nvcc")
+ nvcc = subprocess.check_output("'{}' -V | tail -n1".format(nvcc), shell=True)
+ nvcc = nvcc.decode("utf-8").strip()
+ except subprocess.SubprocessError:
+ nvcc = "Not Available"
+ data.append(("NVCC", nvcc))
+
+ cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
+ if cuda_arch_list:
+ data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list))
+
+ try:
+ data.append(
+ (
+ "torchvision",
+ str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__),
+ )
+ )
+ if has_cuda:
+ try:
+ torchvision_C = importlib.util.find_spec("torchvision._C").origin
+ msg = detect_compute_compatibility(CUDA_HOME, torchvision_C)
+ data.append(("torchvision arch flags", msg))
+ except ImportError:
+ data.append(("torchvision._C", "failed to find"))
+ except AttributeError:
+ data.append(("torchvision", "unknown"))
+
+ try:
+ import cv2
+
+ data.append(("cv2", cv2.__version__))
+ except ImportError:
+ pass
+
+ env_str = tabulate(data) + "\n"
+ env_str += collect_git_info()
+ env_str += "-" * 100 + "\n"
+ env_str += collect_torch_env()
+ return env_str
+
+
+def get_root_dir():
+ return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
diff --git a/mapmaster/utils/misc.py b/mapmaster/utils/misc.py
new file mode 100644
index 0000000..78e13ea
--- /dev/null
+++ b/mapmaster/utils/misc.py
@@ -0,0 +1,411 @@
+import os
+import re
+import torch
+import torchvision
+import unicodedata
+from sys import stderr
+from torch import Tensor
+from loguru import logger
+from argparse import Action
+from collections import deque
+from typing import Optional, List
+from torch import distributed as dist
+
+
+__all__ = [
+ "PyDecorator", "NestedTensor", "AvgMeter", "DictAction", "sanitize_filename", "parse_devices",
+ "_max_by_axis", "nested_tensor_from_tensor_list", "_onnx_nested_tensor_from_tensor_list",
+ "get_param_groups", "setup_logger", "get_rank", "get_world_size", "synchronize", "reduce_sum",
+ "reduce_mean", "all_gather_object", "is_distributed", "is_available"
+]
+
+
+class PyDecorator:
+ @staticmethod
+ def overrides(interface_class):
+ def overrider(method):
+ assert method.__name__ in dir(interface_class), "{} function not in {}".format(
+ method.__name__, interface_class
+ )
+ return method
+ return overrider
+
+
+class NestedTensor(object):
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+
+ def to(self, device):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+
+class AvgMeter(object):
+ def __init__(self, window_size=50):
+ self.window_size = window_size
+ self._value_deque = deque(maxlen=window_size)
+ self._total_value = 0.0
+ self._wdsum_value = 0.0
+ self._count_deque = deque(maxlen=window_size)
+ self._total_count = 0.0
+ self._wdsum_count = 0.0
+
+ def reset(self):
+ self._value_deque.clear()
+ self._total_value = 0.0
+ self._wdsum_value = 0.0
+ self._count_deque.clear()
+ self._total_count = 0.0
+ self._wdsum_count = 0.0
+
+ def update(self, value, n=1):
+ if len(self._value_deque) >= self.window_size:
+ self._wdsum_value -= self._value_deque.popleft()
+ self._wdsum_count -= self._count_deque.popleft()
+ self._value_deque.append(value * n)
+ self._total_value += value * n
+ self._wdsum_value += value * n
+ self._count_deque.append(n)
+ self._total_count += n
+ self._wdsum_count += n
+
+ @property
+ def avg(self):
+ return self.global_avg
+
+ @property
+ def global_avg(self):
+ return self._total_value / max(self._total_count, 1e-5)
+
+ @property
+ def window_avg(self):
+ return self._wdsum_value / max(self._wdsum_count, 1e-5)
+
+
+class DictAction(Action):
+ """
+ argparse action to split an argument into KEY=VALUE form
+ on the first = and append to a dictionary. List options can
+ be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
+ brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
+ list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
+ """
+
+ @staticmethod
+ def _parse_int_float_bool(val):
+ try:
+ return int(val)
+ except ValueError:
+ pass
+ try:
+ return float(val)
+ except ValueError:
+ pass
+ if val.lower() in ["true", "false"]:
+ return True if val.lower() == "true" else False
+ return val
+
+ @staticmethod
+ def _parse_iterable(val):
+ """Parse iterable values in the string.
+ All elements inside '()' or '[]' are treated as iterable values.
+ Args:
+ val (str): Value string.
+ Returns:
+ list | tuple: The expanded list or tuple from the string.
+ Examples:
+ >>> DictAction._parse_iterable('1,2,3')
+ [1, 2, 3]
+ >>> DictAction._parse_iterable('[a, b, c]')
+ ['a', 'b', 'c']
+ >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
+ [(1, 2, 3), ['a', 'b], 'c']
+ """
+
+ def find_next_comma(string):
+ """Find the position of next comma in the string.
+ If no ',' is found in the string, return the string length. All
+ chars inside '()' and '[]' are treated as one element and thus ','
+ inside these brackets are ignored.
+ """
+ assert (string.count("(") == string.count(")")) and (
+ string.count("[") == string.count("]")
+ ), f"Imbalanced brackets exist in {string}"
+ end = len(string)
+ for idx, char in enumerate(string):
+ pre = string[:idx]
+ # The string before this ',' is balanced
+ if (char == ",") and (pre.count("(") == pre.count(")")) and (pre.count("[") == pre.count("]")):
+ end = idx
+ break
+ return end
+
+ # Strip ' and " characters and replace whitespace.
+ val = val.strip("'\"").replace(" ", "")
+ is_tuple = False
+ if val.startswith("(") and val.endswith(")"):
+ is_tuple = True
+ val = val[1:-1]
+ elif val.startswith("[") and val.endswith("]"):
+ val = val[1:-1]
+ elif "," not in val:
+ # val is a single value
+ return DictAction._parse_int_float_bool(val)
+
+ values = []
+ while len(val) > 0:
+ comma_idx = find_next_comma(val)
+ element = DictAction._parse_iterable(val[:comma_idx])
+ values.append(element)
+ val = val[comma_idx + 1 :]
+ if is_tuple:
+ values = tuple(values)
+ return values
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ options = {}
+ for kv in values:
+ key, val = kv.split("=", maxsplit=1)
+ options[key] = self._parse_iterable(val)
+ setattr(namespace, self.dest, options)
+
+
+def sanitize_filename(value, allow_unicode=False):
+ value = str(value)
+ if allow_unicode:
+ value = unicodedata.normalize("NFKC", value)
+ else:
+ value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
+ value = re.sub(r"[^\w\s-]", "", value.lower())
+ return re.sub(r"[-\s]+", "-", value).strip("-_")
+
+
+def parse_devices(gpu_ids):
+ if "-" in gpu_ids:
+ gpus = gpu_ids.split("-")
+ gpus[0] = int(gpus[0])
+ gpus[1] = int(gpus[1]) + 1
+ parsed_ids = ",".join(map(lambda x: str(x), list(range(*gpus))))
+ return parsed_ids
+ else:
+ return gpu_ids
+
+
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+ # TODO make this more general
+ if tensor_list[0].ndim == 3:
+ if torchvision._is_tracing():
+ # nested_tensor_from_tensor_list() does not export well to ONNX
+ # call _onnx_nested_tensor_from_tensor_list() instead
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
+
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ m[: img.shape[1], : img.shape[2]] = False
+ else:
+ raise ValueError("not supported")
+ return NestedTensor(tensor, mask)
+
+
+# _onnx_nested_tensor_from_tensor_list() is an implementation of
+# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
+@torch.jit.unused
+def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
+ max_size = []
+ for i in range(tensor_list[0].dim()):
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
+ max_size.append(max_size_i)
+ max_size = tuple(max_size)
+
+ # work around for
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ # m[: img.shape[1], :img.shape[2]] = False
+ # which is not yet supported in onnx
+ padded_imgs = []
+ padded_masks = []
+ for img in tensor_list:
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+ padded_imgs.append(padded_img)
+
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
+ padded_masks.append(padded_mask.to(torch.bool))
+
+ tensor = torch.stack(padded_imgs)
+ mask = torch.stack(padded_masks)
+
+ return NestedTensor(tensor, mask=mask)
+
+
+def get_param_groups(model, optimizer_setup):
+ def match_name_keywords(n, name_keywords):
+ out = False
+ for b in name_keywords:
+ if b in n:
+ out = True
+ break
+ return out
+
+ for n, p in model.named_parameters():
+ if match_name_keywords(n, optimizer_setup["freeze_names"]):
+ p.requires_grad = False
+
+ param_groups = [
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if not match_name_keywords(n, optimizer_setup["backb_names"])
+ and not match_name_keywords(n, optimizer_setup["extra_names"])
+ and p.requires_grad
+ ],
+ "lr": optimizer_setup["base_lr"],
+ "wd": optimizer_setup["wd"],
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if match_name_keywords(n, optimizer_setup["backb_names"]) and p.requires_grad
+ ],
+ "lr": optimizer_setup["backb_lr"],
+ "wd": optimizer_setup["wd"],
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if match_name_keywords(n, optimizer_setup["extra_names"]) and p.requires_grad
+ ],
+ "lr": optimizer_setup["extra_lr"],
+ "wd": optimizer_setup["wd"],
+ },
+ ]
+
+ return param_groups
+
+
+def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"):
+ """setup logger for training and testing.
+ Args:
+ save_dir(str): loaction to save log file
+ distributed_rank(int): device rank when multi-gpu environment
+ mode(str): log file write mode, `append` or `override`. default is `a`.
+ Return:
+ logger instance.
+ """
+ save_file = os.path.join(save_dir, filename)
+ if mode == "o" and os.path.exists(save_file):
+ os.remove(save_file)
+ format = f"[Rank #{distributed_rank}] | " + "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}"
+ if distributed_rank > 0:
+ logger.remove()
+ logger.add(
+ stderr,
+ format=format,
+ level="WARNING",
+ )
+ logger.add(
+ save_file,
+ format=format,
+ filter="",
+ level="INFO" if distributed_rank == 0 else "WARNING",
+ enqueue=True,
+ )
+
+ return logger
+
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def synchronize():
+ """Helper function to synchronize (barrier) among all processes when using distributed training"""
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ current_world_size = dist.get_world_size()
+ if current_world_size == 1:
+ return
+ dist.barrier()
+
+
+def reduce_sum(tensor):
+ world_size = get_world_size()
+ if world_size < 2:
+ return tensor
+ tensor = tensor.clone()
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
+ return tensor
+
+
+def reduce_mean(tensor):
+ return reduce_sum(tensor) / float(get_world_size())
+
+
+def all_gather_object(obj):
+ world_size = get_world_size()
+ if world_size < 2:
+ return [obj]
+ output = [None for _ in range(world_size)]
+ dist.all_gather_object(output, obj)
+ return output
+
+
+def is_distributed() -> bool:
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def is_available() -> bool:
+ return dist.is_available()