diff --git a/mapmaster/__init__.py b/mapmaster/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mapmaster/dataset/nuscenes_bemapnet.py b/mapmaster/dataset/nuscenes_bemapnet.py new file mode 100644 index 0000000..2498965 --- /dev/null +++ b/mapmaster/dataset/nuscenes_bemapnet.py @@ -0,0 +1,141 @@ +import os +import torch +import numpy as np +from PIL import Image +from copy import deepcopy +from skimage import io as skimage_io +from torch.utils.data import Dataset + + +class NuScenesMapDataset(Dataset): + def __init__(self, img_key_list, map_conf, ida_conf, bezier_conf, transforms, data_split="training"): + super().__init__() + self.img_key_list = img_key_list + self.map_conf = map_conf + self.ida_conf = ida_conf + self.bez_conf = bezier_conf + self.ego_size = map_conf["ego_size"] + self.mask_key = map_conf["mask_key"] + self.nusc_root = map_conf["nusc_root"] + self.anno_root = map_conf["anno_root"] + self.split_dir = map_conf["split_dir"] + self.num_degree = bezier_conf["num_degree"] + self.max_pieces = bezier_conf["max_pieces"] + self.max_instances = bezier_conf["max_instances"] + self.split_mode = 'train' if data_split == "training" else 'val' + split_path = os.path.join(self.split_dir, f'{self.split_mode}.txt') + self.tokens = [token.strip() for token in open(split_path).readlines()] + self.transforms = transforms + + def __getitem__(self, idx: int): + token = self.tokens[idx] + sample = np.load(os.path.join(self.anno_root, f'{token}.npz'), allow_pickle=True) + resize_dims, crop, flip, rotate = self.sample_ida_augmentation() + images, ida_mats = [], [] + for im_view in self.img_key_list: + for im_path in sample['image_paths']: + if im_path.startswith(f'samples/{im_view}/'): + im_path = os.path.join(self.nusc_root, im_path) + img = skimage_io.imread(im_path) + img, ida_mat = self.img_transform(img, resize_dims, crop, flip, rotate) + images.append(img) + ida_mats.append(ida_mat) + extrinsic = np.stack([np.eye(4) for _ in range(sample["trans"].shape[0])], axis=0) + extrinsic[:, :3, :3] = sample["rots"] + extrinsic[:, :3, 3] = sample["trans"] + intrinsic = sample['intrins'] + ctr_points = np.zeros((self.max_instances, max(self.max_pieces) * max(self.num_degree) + 1, 2), dtype=np.float) + ins_labels = np.zeros((self.max_instances, 3), dtype=np.int16) - 1 + for ins_id, ctr_info in enumerate(sample['ctr_points']): + cls_id = int(ctr_info['type']) + ctr_pts_raw = np.array(ctr_info['pts']) + max_points = self.max_pieces[cls_id] * self.num_degree[cls_id] + 1 + num_points = max_points if max_points <= ctr_pts_raw.shape[0] else ctr_pts_raw.shape[0] + assert num_points >= self.num_degree[cls_id] + 1 + ctr_points[ins_id][:num_points] = np.array(ctr_pts_raw[:num_points]) + ins_labels[ins_id] = [cls_id, (num_points - 1) // self.num_degree[cls_id] - 1, num_points] + masks = sample[self.mask_key] + if flip: + new_order = [2, 1, 0, 5, 4, 3] + img_key_list = [self.img_key_list[i] for i in new_order] + images = [images[i] for i in new_order] + ida_mats = [ida_mats[i] for i in new_order] + extrinsic = [extrinsic[i] for i in new_order] + intrinsic = [intrinsic[i] for i in new_order] + masks = [np.flip(mask, axis=1) for mask in masks] + ctr_points = self.point_flip(ctr_points, ins_labels, self.ego_size) + item = dict( + images=images, targets=dict(masks=masks, points=ctr_points, labels=ins_labels), + extrinsic=np.stack(extrinsic), intrinsic=np.stack(intrinsic), ida_mats=np.stack(ida_mats), + extra_infos=dict(token=token, img_key_list=self.img_key_list, map_size=self.ego_size, do_flip=flip) + ) + if self.transforms is not None: + item = self.transforms(item) + return item + + def __len__(self): + return len(self.tokens) + + def sample_ida_augmentation(self): + """Generate ida augmentation values based on ida_config.""" + resize_dims = w, h = self.ida_conf["resize_dims"] + crop = (0, 0, w, h) + if self.ida_conf["up_crop_ratio"] > 0: + crop = (0, int(self.ida_conf["up_crop_ratio"] * h), w, h) + flip, color, rotate_ida = False, False, 0 + if self.split_mode == "train": + if self.ida_conf["rand_flip"] and np.random.choice([0, 1]): + flip = True + if self.ida_conf["rot_lim"]: + assert isinstance(self.ida_conf["rot_lim"], (tuple, list)) + rotate_ida = np.random.uniform(*self.ida_conf["rot_lim"]) + return resize_dims, crop, flip, rotate_ida + + def img_transform(self, img, resize_dims, crop, flip, rotate): + img = Image.fromarray(img) + ida_rot = torch.eye(2) + ida_tran = torch.zeros(2) + W, H = img.size + img = img.resize(resize_dims) + img = img.crop(crop) + if flip: + img = img.transpose(method=Image.FLIP_LEFT_RIGHT) + img = img.rotate(rotate) + + # post-homography transformation + scales = torch.tensor([resize_dims[0] / W, resize_dims[1] / H]) + ida_rot *= torch.Tensor(scales) + ida_tran -= torch.Tensor(crop[:2]) + if flip: + A = torch.Tensor([[-1, 0], [0, 1]]) + b = torch.Tensor([crop[2] - crop[0], 0]) + ida_rot = A.matmul(ida_rot) + ida_tran = A.matmul(ida_tran) + b + A = self.get_rot(rotate / 180 * np.pi) + b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2 + b = A.matmul(-b) + b + ida_rot = A.matmul(ida_rot) + ida_tran = A.matmul(ida_tran) + b + ida_mat = ida_rot.new_zeros(3, 3) + ida_mat[2, 2] = 1 + ida_mat[:2, :2] = ida_rot + ida_mat[:2, 2] = ida_tran + return np.asarray(img), ida_mat + + @staticmethod + def point_flip(points, labels, map_shape): + + def _flip(pts): + pts[:, 0] = map_shape[1] - pts[:, 0] + return pts.copy() + + points_ret = deepcopy(points) + for ins_id in range(points.shape[0]): + end = labels[ins_id, 2] + points_ret[ins_id][:end] = _flip(points[ins_id][:end]) + + return points_ret + + @staticmethod + def get_rot(h): + return torch.Tensor([[np.cos(h), np.sin(h)], [-np.sin(h), np.cos(h)]]) diff --git a/mapmaster/dataset/nuscenes_pivotnet.py b/mapmaster/dataset/nuscenes_pivotnet.py new file mode 100644 index 0000000..270db9b --- /dev/null +++ b/mapmaster/dataset/nuscenes_pivotnet.py @@ -0,0 +1,56 @@ +import os +import numpy as np +import pickle as pkl +from PIL import Image +from torch.utils.data import Dataset + +class NuScenesMapDataset(Dataset): + def __init__(self, img_key_list, map_conf, transforms, data_split="training"): + super().__init__() + self.img_key_list = img_key_list + self.map_conf = map_conf + + self.ego_size = map_conf["ego_size"] + self.mask_key = map_conf["mask_key"] + self.nusc_root = map_conf["nusc_root"] + self.anno_root = map_conf["anno_root"] + self.split_dir = map_conf["split_dir"] # instance_mask/instance_mask8 + + self.split_mode = 'train' if data_split == "training" else 'val' + split_path = os.path.join(self.split_dir, f'{self.split_mode}.txt') + self.tokens = [token.strip() for token in open(split_path).readlines()] + self.transforms = transforms + + def __getitem__(self, idx: int): + token = self.tokens[idx] + sample = np.load(os.path.join(self.anno_root, f'{token}.npz'), allow_pickle=True) + # images + images = [] + for im_view in self.img_key_list: + for im_path in sample['image_paths']: + if im_path.startswith(f'samples/{im_view}/'): + im_path = os.path.join(self.nusc_root, im_path) + img = np.asarray(Image.open(im_path)) + images.append(img) + # pivot pts + pivot_pts = sample["pivot_pts"].item() + valid_length = sample["pivot_length"].item() + # targets + masks=sample[self.mask_key] + targets = dict(masks=masks, points=pivot_pts, valid_len=valid_length) + # pose + extrinsic = np.stack([np.eye(4) for _ in range(sample["trans"].shape[0])], axis=0) + extrinsic[:, :3, :3] = sample["rots"] + extrinsic[:, :3, 3] = sample["trans"] + intrinsic = sample['intrins'] + # transform + item = dict(images=images, targets=targets, + extra_infos=dict(token=token, map_size=self.ego_size), + extrinsic=np.stack(extrinsic, axis=0), intrinsic=np.stack(intrinsic, axis=0)) + if self.transforms is not None: + item = self.transforms(item) + + return item + + def __len__(self): + return len(self.tokens) diff --git a/mapmaster/dataset/sampler.py b/mapmaster/dataset/sampler.py new file mode 100644 index 0000000..23cd2cd --- /dev/null +++ b/mapmaster/dataset/sampler.py @@ -0,0 +1,61 @@ +import torch +import itertools +import torch.distributed as dist +from typing import Optional +from torch.utils.data.sampler import Sampler + + +class InfiniteSampler(Sampler): + """ + In training, we only care about the "infinite stream" of training data. + So this sampler produces an infinite stream of indices and + all workers cooperate to correctly shuffle the indices and sample different indices. + The samplers in each worker effectively produces `indices[worker_id::num_workers]` + where `indices` is an infinite stream of indices consisting of + `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) + or `range(size) + range(size) + ...` (if shuffle is False) + """ + + def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = 0, rank=0, world_size=1, drop_last=False): + """ + Args: + size (int): the total number of data of the underlying dataset to sample from + shuffle (bool): whether to shuffle the indices or not + seed (int): the initial seed of the shuffle. Must be the same + across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + """ + self._size = size + assert size > 0 + self._shuffle = shuffle + self._seed = int(seed) + self.drop_last = drop_last + + if dist.is_available() and dist.is_initialized(): + self._rank = dist.get_rank() + self._world_size = dist.get_world_size() + else: + self._rank = rank + self._world_size = world_size + + def set_epoch(self, epoch): + pass + + def __iter__(self): + start = self._rank + yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) + + def _infinite_indices(self): + g = torch.Generator() + g.manual_seed(self._seed) + while True: + if self._shuffle: + yield from torch.randperm(self._size, generator=g).tolist() + else: + yield from list(range(self._size)) + + def __len__(self): + if self.drop_last: + return self._size // self._world_size + else: + return (self._size + self._world_size - 1) // self._world_size diff --git a/mapmaster/dataset/transform.py b/mapmaster/dataset/transform.py new file mode 100644 index 0000000..0d8d2f0 --- /dev/null +++ b/mapmaster/dataset/transform.py @@ -0,0 +1,274 @@ +import cv2 +import mmcv +import torch +import numpy as np +from PIL import Image +from collections.abc import Sequence + +class Resize(object): + def __init__(self, img_scale=None, backend="cv2", interpolation="bilinear"): + self.size = img_scale + self.backend = backend + self.interpolation = interpolation + self.cv2_interp_codes = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4, + } + self.pillow_interp_codes = { + "nearest": Image.NEAREST, + "bilinear": Image.BILINEAR, + "bicubic": Image.BICUBIC, + "box": Image.BOX, + "lanczos": Image.LANCZOS, + "hamming": Image.HAMMING, + } + + def __call__(self, data_dict): + """Call function to resize images. + + Args: + data_dict (dict): Result dict from loading pipeline. + + Returns: + dict: Resized data_dict, 'scale_factor' keys are added into result dict. + """ + + imgs = [] + for img in data_dict["images"]: + img = self.im_resize(img, self.size, backend=self.backend) + imgs.append(img) + data_dict["images"] = imgs + + new_h, new_w = imgs[0].shape[:2] + h, w = data_dict["images"][0].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32) + data_dict["extra_infos"].update({"scale_factor": scale_factor}) + + return data_dict + + def im_resize(self, img, size, return_scale=False, interpolation="bilinear", out=None, backend="cv2"): + """Resize image to a given size. + Args: + img (ndarray): The input image. + size (tuple[int]): Target size (w, h). + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + out (ndarray): The output destination. + backend (str | None): The image resize backend type. Options are `cv2`, + `pillow`, `None`. + Returns: + tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = img.shape[:2] + if backend not in ["cv2", "pillow"]: + raise ValueError( + f"backend: {backend} is not supported for resize." f"Supported backends are 'cv2', 'pillow'" + ) + + if backend == "pillow": + assert img.dtype == np.uint8, "Pillow backend only support uint8 type" + pil_image = Image.fromarray(img) + pil_image = pil_image.resize(size, self.pillow_interp_codes[interpolation]) + resized_img = np.array(pil_image) + else: + resized_img = cv2.resize(img, size, dst=out, interpolation=self.cv2_interp_codes[interpolation]) + if not return_scale: + return resized_img + else: + w_scale = size[0] / w + h_scale = size[1] / h + return resized_img, w_scale, h_scale + +class Normalize(object): + """Normalize the image. + + Added key is "img_norm_cfg". + + Args: + mean (sequence): Mean values of 3 channels. + std (sequence): Std values of 3 channels. + to_rgb (bool): Whether to convert the image from BGR to RGB, + default is true. + """ + + def __init__(self, mean, std, to_rgb=True): + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.to_rgb = to_rgb + + def __call__(self, data_dict): + imgs = [] + for img in data_dict["images"]: + if self.to_rgb: + img = img.astype(np.float32) / 255.0 + img = self.im_normalize(img, self.mean, self.std, self.to_rgb) + imgs.append(img) + data_dict["images"] = imgs + data_dict["extra_infos"]["img_norm_cfg"] = dict(mean=self.mean, std=self.std, to_rgb=self.to_rgb) + return data_dict + + @staticmethod + def im_normalize(img, mean, std, to_rgb=True): + img = img.copy().astype(np.float32) + assert img.dtype != np.uint8 # cv2 inplace normalization does not accept uint8 + mean = np.float64(mean.reshape(1, -1)) + stdinv = 1 / np.float64(std.reshape(1, -1)) + if to_rgb: + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace + cv2.subtract(img, mean, img) # inplace + cv2.multiply(img, stdinv, img) # inplace + return img + + +class ToTensor(object): + """Default formatting bundle.""" + + def __call__(self, data_dict): + """Call function to transform and format common fields in data_dict. + + Args: + data_dict (dict): Data dict contains the data to convert. + + Returns: + dict: The result dict contains the data that is formatted with default bundle. + """ + + for k in ["images", "extrinsic", "intrinsic", "ida_mats"]: + if k == "images": + data_dict[k] = np.stack([img.transpose(2, 0, 1) for img in data_dict[k]], axis=0) + data_dict[k] = self.to_tensor(np.ascontiguousarray(data_dict[k])) + + for k in ["masks", "points", "labels"]: + data_dict["targets"][k] = self.to_tensor(np.ascontiguousarray(data_dict["targets"][k])) + + return data_dict + + @staticmethod + def to_tensor(data): + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not mmcv.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f"type {type(data)} cannot be converted to tensor.") + +class ToTensor_Pivot(object): + """Default formatting bundle.""" + + def __call__(self, data_dict): + """Call function to transform and format common fields in data_dict. + + Args: + data_dict (dict): Data dict contains the data to convert. + + Returns: + dict: The result dict contains the data that is formatted with default bundle. + """ + if "images" in data_dict: + if isinstance(data_dict["images"], list): + # process multiple imgs in single frame + imgs = [img.transpose(2, 0, 1) for img in data_dict["images"]] + imgs = np.ascontiguousarray(np.stack(imgs, axis=0)) + data_dict["images"] = self.to_tensor(imgs) + else: + img = np.ascontiguousarray(data_dict["img"].transpose(2, 0, 1)) + data_dict["images"] = self.to_tensor(img) + + for k in ["masks"]: + data_dict["targets"][k] = self.to_tensor(np.ascontiguousarray(data_dict["targets"][k])) + + return data_dict + + @staticmethod + def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not mmcv.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f"type {type(data)} cannot be converted to tensor.") + + + +class Pad(object): + """Pad the image & mask. + + There are two padding modes: (1) pad to a fixed size and (2) pad to the + minimum size that is divisible by some number. + Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor", + + Args: + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (float, optional): Padding value, 0 by default. + """ + + def __init__(self, size_divisor=None, pad_val=0): + self.size_divisor = size_divisor + self.pad_val = pad_val + # only one of size and size_divisor should be valid + assert size_divisor is not None + + def __call__(self, data_dict): + """Call function to pad images, masks, semantic segmentation maps. + + Args: + data_dict (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + padded_img = None + padded_imgs = [] + for img in data_dict["images"]: + padded_img = self.im_pad_to_multiple(img, self.size_divisor, pad_val=self.pad_val) + padded_imgs.append(padded_img) + data_dict["images"] = padded_imgs + data_dict["extra_infos"].update( + { + "pad_shape": padded_img.shape, + "pad_size_divisor": self.size_divisor if self.size_divisor is not None else "None", + } + ) + return data_dict + + def im_pad_to_multiple(self, img, divisor, pad_val=0): + """Pad an image to ensure each edge to be multiple to some number. + Args: + img (ndarray): Image to be padded. + divisor (int): Padded image edges will be multiple to divisor. + pad_val (Number | Sequence[Number]): Same as :func:`impad`. + Returns: + ndarray: The padded image. + """ + pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor + pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor + return self.im_pad(img, shape=(pad_h, pad_w), pad_val=pad_val) diff --git a/mapmaster/engine/callbacks.py b/mapmaster/engine/callbacks.py new file mode 100644 index 0000000..f410930 --- /dev/null +++ b/mapmaster/engine/callbacks.py @@ -0,0 +1,299 @@ +import os +import glob +import tqdm +import torch +import pickle +from clearml import Task +from loguru import logger +from typing import Callable, Optional +from tensorboardX import SummaryWriter +from torch.nn.utils import clip_grad_norm_ +from mapmaster.engine.executor import Callback, BaseExecutor, Trainer +from mapmaster.utils.misc import AvgMeter + + +__all__ = ["Callback", "MasterOnlyCallback", "CheckPointSaver", "CheckPointLoader", "CheckPointC2Loader", + "ClearMLCallback", "EvalResultsSaver", "LamdaCallback", "ClipGrad", "ProgressBar", + "LearningRateMonitor", "TextMonitor", "TensorBoardMonitor"] + + +class MasterOnlyCallback(Callback): + enabled_rank = [0] + + +class CheckPointSaver(MasterOnlyCallback): + def __init__( + self, + local_path, + filename=r"checkpoint_epoch_{epoch}.pth", + remote_path=None, + save_interval: int = 1, + num_keep_latest=None, + ): + self.local_path = local_path + self.filename = filename + self.remote_path = remote_path + self.save_interval = save_interval + self.num_keep_latest = num_keep_latest + os.makedirs(local_path, exist_ok=True) + + def _make_checkpoint(self, trainer: Trainer): + model_state = None + if hasattr(trainer, "ema_model"): + model = trainer.ema_model.ema + else: + model = trainer.model + if model: + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model_state = model.module.state_dict() + model_state_cpu = type(model_state)() + for key, val in model_state.items(): + model_state_cpu[key] = val.cpu() + model_state = model_state_cpu + else: + model_state = model.state_dict() + + optim_state = trainer.optimizer.state_dict() if trainer.optimizer else None + + callback_states = {} + for cb in trainer.callbacks: + if hasattr(cb, "state_dict"): + cls_name = cb.__class__.__name__ + callback_states[cls_name] = cb.state_dict() + + ckpt = { + "epoch": trainer.epoch, + "it": trainer.global_step, + "global_step": trainer.global_step, + "model_state": model_state, + "optimizer_state": optim_state, + "callback": callback_states, + } + + # save grad_scaler + if hasattr(trainer, "grad_scaler"): + ckpt["grad_scaler_state"] = trainer.grad_scaler.state_dict() + + return ckpt + + def after_epoch(self, trainer: Trainer, epoch: int, update_best_ckpt: bool = False): + if (epoch + 1) % self.save_interval != 0: + return + filename = self.filename.format(epoch=epoch) + save_path = os.path.join(self.local_path, filename) + torch.save(self._make_checkpoint(trainer), save_path) + if update_best_ckpt: + torch.save(self._make_checkpoint(trainer), os.path.join(self.local_path, f"checkpoint_best.pth")) + self._remove_out_of_date_ckpt() + + def _remove_out_of_date_ckpt(self): + if not self.num_keep_latest: + return + + ckpt_list = glob.glob(os.path.join(self.local_path, self.filename.format(epoch="*"))) + ckpt_list.sort(key=os.path.getmtime) + if len(ckpt_list) > self.num_keep_latest: + for cur_file_idx in range(0, len(ckpt_list) - self.num_keep_latest): + os.remove(ckpt_list[cur_file_idx]) + + +class CheckPointLoader(Callback): + def __init__( + self, + path, + weight_only=False, + ): + self.path = path + self.weight_only = weight_only + + def load_checkpoint(self, trainer: Trainer): + logger.info(f"Loading parameters from checkpoint {self.path}") + with open(self.path, "rb") as f: + checkpoint = torch.load(f, map_location=torch.device("cpu")) + + # TODO bulid model finetune callback + model_state_dict = trainer.model.state_dict() + checkpoint_state_dict = checkpoint["model_state"] + for k in list(checkpoint_state_dict.keys()): + if k in model_state_dict: + shape_model = tuple(model_state_dict[k].shape) + shape_checkpoint = tuple(checkpoint_state_dict[k].shape) + if shape_model != shape_checkpoint: + logger.info( + "'{}' has shape {} in the checkpoint but {} in the " + "model! Skipped.".format(k, shape_checkpoint, shape_model) + ) + checkpoint_state_dict.pop(k) + trainer.model.load_state_dict(checkpoint_state_dict, strict=False) + + if self.weight_only: + return + + trainer.epoch = checkpoint.get("epoch", -1) + 1 + trainer.global_step = checkpoint.get("global_step", -1) + 1 + if "optimizer_state" in checkpoint: + trainer.optimizer.load_state_dict(checkpoint["optimizer_state"]) + # resume callback + for cb in trainer.callbacks: + if hasattr(cb, "state_dict"): + cls_name = cb.__class__.__name__ + if cls_name in checkpoint["callback"]: + cb.load_state_dict(checkpoint["callback"][cls_name]) + # resume grad_scaler + if hasattr(trainer, "grad_scaler") and "grad_scaler_state" in checkpoint: + trainer.grad_scaler.load_state_dict(checkpoint["grad_scaler_state"]) + + +class ClearMLCallback(MasterOnlyCallback): + def __init__(self): + super().__init__() + self.task_id = None + + def after_init(self, executor: BaseExecutor): + if self.task_id is None: + self.task = Task.init( + project_name="det3d", + task_name=executor.exp.exp_name, + auto_connect_frameworks={"pytorch": False}, + reuse_last_task_id=False, + continue_last_task=False, + ) + else: + self.task = Task.get_task(task_id=self.task_id) + self.task.add_tags(["resume"]) + logger.info(f"continue from clearml task {self.task_id}") + self.task.connect(executor.exp) + if hasattr(executor.exp, "get_pcdet_cfg"): + self.task.connect(executor.exp.get_pcdet_cfg(), "pcdet_config") + + def state_dict(self): + return {"task_id": self.task.task_id} + + def load_state_dict(self, state_dict): + self.task_id = state_dict["task_id"] + + +class EvalResultsSaver(MasterOnlyCallback): + def __init__(self, out_dir: str): + self.out_dir = out_dir + + def after_eval(self, executor, det_annos: list): + out_file = os.path.join(self.out_dir, "result.pkl") + pickle.dump(det_annos, open(out_file, "wb")) + + +class LamdaCallback: + def __init__( + self, + setup: Optional[Callable] = None, + load_checkpoint: Optional[Callable] = None, + after_init: Optional[Callable] = None, + before_train: Optional[Callable] = None, + before_epoch: Optional[Callable] = None, + before_step: Optional[Callable] = None, + before_backward: Optional[Callable] = None, + before_optimize: Optional[Callable] = None, + after_step: Optional[Callable] = None, + after_epoch: Optional[Callable] = None, + after_train: Optional[Callable] = None, + ) -> None: + for k, v in locals().items(): + if k == "self": + continue + if v is not None: + setattr(self, k, v) + + +class ClipGrad(Callback): + def __init__(self, max_norm: float): + self.max_norm = max_norm + + def before_optimize(self, trainer): + clip_grad_norm_(trainer.model.parameters(), self.max_norm) + + +class ProgressBar(MasterOnlyCallback): + def __init__(self, logger=None) -> None: + self.epoch_bar = None + self.step_bar = None + self.logger = logger + + def setup(self, trainer: Trainer): + self.epoch_bar = tqdm.tqdm(initial=0, total=trainer.exp.max_epoch, desc="[Epoch]", dynamic_ncols=True) + self.step_bar = tqdm.tqdm(initial=0, desc="[Step]", dynamic_ncols=True, leave=False) + if self.logger: + self.logger.remove(0) + self.logger.add(lambda msg: self.step_bar.write(msg, end="")) + + def before_epoch(self, trainer: Trainer, epoch: int): + self.epoch_bar.update(epoch - self.epoch_bar.n) + self.step_bar.reset(len(trainer.train_dataloader)) + + def after_step(self, trainer: Trainer, step, data_dict, *args, **kwargs): + self.step_bar.update() + + def after_train(self, trainer: Trainer): + if self.step_bar: + self.step_bar.close() + if self.epoch_bar: + self.epoch_bar.close() + + +class LearningRateMonitor: + def _get_learning_rate(self, optimizer): + if hasattr(optimizer, "lr"): + lr = float(optimizer.lr) + else: + lr = optimizer.param_groups[0]["lr"] + return lr + + +class TextMonitor(MasterOnlyCallback, LearningRateMonitor): + def __init__(self, interval=10): + self.interval = interval + self.avg_loss = AvgMeter() + self.ext_dict = None + + def after_step(self, trainer: Trainer, step, data_dict, *args, **kwargs): + self.avg_loss.update(kwargs["loss"]) + + lr = self._get_learning_rate(trainer.optimizer) + + ext_info = "" + if kwargs["extra"] is not None: + if self.ext_dict is None: + self.ext_dict = {k: AvgMeter() for k in kwargs["extra"]} + for key, val in kwargs["extra"].items(): + self.ext_dict[key].update(val) + ext_info = "".join([f" {k}={v.window_avg :.4f}" for k, v in self.ext_dict.items()]) + + if step % self.interval != 0: + return + + trainer.logger.info( + f"e:{trainer.epoch}[{step}/{self.total_step}] lr={lr :.6f} loss={self.avg_loss.window_avg :.4f}{ext_info}" + ) + + def before_epoch(self, trainer: Trainer, epoch: int): + lr = trainer.optimizer.param_groups[0]["lr"] + trainer.logger.info(f"e:{epoch} learning rate = {lr :.6f}") + self.total_step = len(trainer.train_dataloader) + + +class TensorBoardMonitor(MasterOnlyCallback, LearningRateMonitor): + def __init__(self, log_dir, interval=10): + os.makedirs(log_dir, exist_ok=True) + self.tb_log = SummaryWriter(log_dir=log_dir) + self.interval = interval + + def after_step(self, trainer: Trainer, step, data_dict, *args, **kwargs): + accumulated_iter = trainer.global_step + if accumulated_iter % self.interval != 0: + return + lr = self._get_learning_rate(trainer.optimizer) + self.tb_log.add_scalar("epoch", trainer.epoch, accumulated_iter) + self.tb_log.add_scalar("train/loss", kwargs["loss"], accumulated_iter) + self.tb_log.add_scalar("meta_data/learning_rate", lr, accumulated_iter) + if kwargs["extra"] is not None: + for key, val in kwargs["extra"].items(): + self.tb_log.add_scalar(f"train/{key}", val, accumulated_iter) diff --git a/mapmaster/engine/core.py b/mapmaster/engine/core.py new file mode 100644 index 0000000..4ecc69f --- /dev/null +++ b/mapmaster/engine/core.py @@ -0,0 +1,191 @@ +import os +import sys +import argparse +import datetime +import warnings +import subprocess +from mapmaster.engine.executor import Trainer, BeMapNetEvaluator +from mapmaster.engine.environ import ShareFSUUIDNameServer, RlaunchReplicaEnv +from mapmaster.engine.callbacks import CheckPointLoader, CheckPointSaver, ClearMLCallback, ProgressBar +from mapmaster.engine.callbacks import TensorBoardMonitor, TextMonitor, ClipGrad +from mapmaster.utils.env import collect_env_info, get_root_dir +from mapmaster.utils.misc import setup_logger, sanitize_filename, PyDecorator, all_gather_object + + +__all__ = ["BaseCli", "BeMapNetCli"] + + +class BaseCli: + """Command line tools for any exp.""" + + def __init__(self, Exp): + """Make sure the order of initialization is: build_args --> build_env --> build_exp, + since experiments depend on the environment and the environment depends on args. + + Args: + Exp : experiment description class + """ + self.ExpCls = Exp + self.args = self._get_parser(Exp).parse_args() + self.env = RlaunchReplicaEnv(self.args.sync_bn, self.args.devices, self.args.find_unused_parameters) + + @property + def exp(self): + if not hasattr(self, "_exp"): + exp = self.ExpCls( + **{x if y is not None else "none": y for (x, y) in vars(self.args).items()}, + total_devices=self.env.world_size(), + ) + self.exp_updated_cfg_msg = exp.update_attr(self.args.exp_options) + self._exp = exp + return self._exp + + def _get_parser(self, Exp): + parser = argparse.ArgumentParser() + parser = Exp.add_argparse_args(parser) + parser = self.add_argparse_args(parser) + return parser + + @staticmethod + def add_argparse_args(parser: argparse.ArgumentParser): + parser.add_argument("--eval", dest="eval", action="store_true", help="conduct evaluation only") + parser.add_argument("-te", "--train_and_eval", dest="train_and_eval", action="store_true", help="train+eval") + parser.add_argument("--find_unused_parameters", dest="find_unused_parameters", action="store_true") + parser.add_argument("-d", "--devices", default="0-7", type=str, help="device for training") + parser.add_argument("--ckpt", type=str, default=None, help="checkpoint to start from or be evaluated") + parser.add_argument("--pretrained_model", type=str, default=None, help="pretrained_model used by training") + parser.add_argument("--sync_bn", type=int, default=0, help="0-> disable sync_bn, 1-> whole world") + clearml_parser = parser.add_mutually_exclusive_group(required=False) + clearml_parser.add_argument("--clearml", dest="clearml", action="store_true", help="enabel clearml for train") + clearml_parser.add_argument("--no-clearml", dest="clearml", action="store_false", help="disable clearml") + parser.set_defaults(clearml=True) + return parser + + def _get_exp_output_dir(self): + exp_dir = os.path.join(os.path.join(get_root_dir(), "outputs"), sanitize_filename(self.exp.exp_name)) + os.makedirs(exp_dir, exist_ok=True) + output_dir = None + if self.args.ckpt: + output_dir = os.path.dirname(os.path.dirname(os.path.abspath(self.args.ckpt))) + elif self.env.global_rank() == 0: + output_dir = os.path.join(exp_dir, datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")) + os.makedirs(output_dir, exist_ok=True) + # make a symlink "latest" + symlink, symlink_tmp = os.path.join(exp_dir, "latest"), os.path.join(exp_dir, "latest_tmp") + if os.path.exists(symlink_tmp): + os.remove(symlink_tmp) + os.symlink(os.path.relpath(output_dir, exp_dir), symlink_tmp) + os.rename(symlink_tmp, symlink) + output_dir = all_gather_object(output_dir)[0] + return output_dir + + def get_evaluator(self, callbacks=None): + exp = self.exp + if self.args.ckpt is None: + warnings.warn("No checkpoint is specified for evaluation") + if exp.eval_executor_class is None: + sys.exit("No evaluator is specified for evaluation") + + output_dir = self._get_exp_output_dir() + logger = setup_logger(output_dir, distributed_rank=self.env.global_rank(), filename="eval.log") + self._set_basic_log_message(logger) + if callbacks is None: + callbacks = [self.env, CheckPointLoader(self.args.ckpt)] + evaluator = exp.eval_executor_class(exp=exp, callbacks=callbacks, logger=logger) + return evaluator + + def _set_basic_log_message(self, logger): + logger.opt(ansi=True).info("Cli arguments:\n{}".format(self.args)) + logger.info(f"exp_name: {self.exp.exp_name}") + logger.opt(ansi=True).info( + "Used experiment configs:\n{}".format(self.exp.get_cfg_as_str()) + ) + if self.exp_updated_cfg_msg: + logger.opt(ansi=True).info( + "List of override configs:\n{}".format(self.exp_updated_cfg_msg) + ) + logger.opt(ansi=True).info("Environment info:\n{}".format(collect_env_info())) + + def get_trainer(self, callbacks=None, evaluator=None): + args = self.args + exp = self.exp + if evaluator is not None: + output_dir = self.exp.output_dir + else: + output_dir = self._get_exp_output_dir() + + logger = setup_logger(output_dir, distributed_rank=self.env.global_rank(), filename="train.log") + self._set_basic_log_message(logger) + + if callbacks is None: + callbacks = [ + self.env, + ProgressBar(logger=logger), + TextMonitor(interval=exp.print_interval), + TensorBoardMonitor(os.path.join(output_dir, "tensorboard"), interval=exp.print_interval), + CheckPointSaver( + local_path=os.path.join(output_dir, "dump_model"), + remote_path=exp.ckpt_oss_save_dir, + save_interval=exp.dump_interval, + num_keep_latest=exp.num_keep_latest_ckpt, + ), + ] + if "grad_clip_value" in exp.__dict__: + callbacks.append(ClipGrad(exp.grad_clip_value)) + if args.clearml: + callbacks.append(ClearMLCallback()) + if args.ckpt: + callbacks.append(CheckPointLoader(args.ckpt)) + if args.pretrained_model: + callbacks.append(CheckPointLoader(args.pretrained_model, weight_only=True)) + callbacks.extend(exp.callbacks) + + trainer = Trainer(exp=exp, callbacks=callbacks, logger=logger, evaluator=evaluator) + return trainer + + def executor(self): + if self.args.eval: + self.get_evaluator().eval() + elif self.args.train_and_eval: + evaluator = self.get_evaluator(callbacks=[]) + self.get_trainer(evaluator=evaluator).train() + else: + self.get_trainer().train() + + def dispatch(self, executor_func): + is_master = self.env.global_rank() == 0 + with ShareFSUUIDNameServer(is_master) as ns: + self.env.set_master_uri(ns) + self.env.setup_nccl() + if self.env.local_rank() == 0: + command = sys.argv.copy() + command[0] = os.path.abspath(command[0]) + command = [sys.executable] + command + for local_rank in range(1, self.env.nr_gpus): + env_copy = os.environ.copy() + env_copy["LOCAL_RANK"] = f"{local_rank}" + subprocess.Popen(command, env=env_copy) + self.env.init_dist() + executor_func() + + def run(self): + self.dispatch(self.executor) + + +class MapMasterCli(BaseCli): + @PyDecorator.overrides(BaseCli) + def get_evaluator(self, callbacks=None): + exp = self.exp + + output_dir = self._get_exp_output_dir() + self.exp.output_dir = output_dir + logger = setup_logger(output_dir, distributed_rank=self.env.global_rank(), filename="eval.log") + self._set_basic_log_message(logger) + if callbacks is None: + callbacks = [ + self.env, + CheckPointLoader(self.args.ckpt), + ] + + evaluator = BeMapNetEvaluator(exp=exp, callbacks=callbacks, logger=logger) + return evaluator diff --git a/mapmaster/engine/environ.py b/mapmaster/engine/environ.py new file mode 100644 index 0000000..d6a2757 --- /dev/null +++ b/mapmaster/engine/environ.py @@ -0,0 +1,151 @@ +import os +import time +import uuid +import torch +import subprocess +import numpy as np +from torch import nn +from loguru import logger +import torch.distributed as dist +from mapmaster.utils.env import get_root_dir +from mapmaster.utils.misc import parse_devices +from mapmaster.engine.callbacks import Callback + + +__all__ = ["ShareFSUUIDNameServer", "RlaunchReplicaEnv"] +output_root_dir = os.path.join(get_root_dir(), "outputs") + + +class ShareFSUUIDNameServer: + def __init__(self, is_master): + self.exp_id = self._get_exp_id() + self.is_master = is_master + os.makedirs(os.path.dirname(self.filepath), exist_ok=True) + + def _get_exp_id(self): + if "DET3D_EXPID" not in os.environ: + if int(os.environ.get("RLAUNCH_REPLICA_TOTAL", 1)) == 1: + return str(uuid.uuid4()) + msg = """cannot find DET3D_EXPID in environ please use following + command DET3D_EXPID=$(cat /proc/sys/kernel/random/uuid) rlaunch ... + """ + logger.error(msg) + raise RuntimeError + return str(os.environ["DET3D_EXPID"]) + + @property + def filepath(self): + return os.path.join(output_root_dir, f"master_ip_{self.exp_id}.txt") + + def __enter__(self): + if self.is_master: + self.set_master() + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + if self.is_master: + os.remove(self.filepath) + + def set_master(self): + assert not os.path.exists(self.filepath) + hostname = "Host" + with open(self.filepath, "w") as f: + f.write(hostname) + + def get_master(self): + while True: + if os.path.exists(self.filepath): + with open(self.filepath, "r") as f: + return f.read() + else: + time.sleep(5) + + +class _DDPEnv(Callback): + def __init__(self, sync_bn=0, devices=None, find_unused_parameters=False): + if devices: + devices = parse_devices(devices) + os.environ["CUDA_VISIBLE_DEVICES"] = devices + self.nr_gpus = torch.cuda.device_count() + self.sync_bn = sync_bn + self.find_unused_parameters = find_unused_parameters + + @staticmethod + def setup_nccl(): + ifname = filter(lambda x: x not in ("lo",), os.listdir("/sys/class/net/")) + os.environ["NCCL_SOCKET_IFNAME"] = ",".join(ifname) + os.environ["NCCL_IB_DISABLE"] = "1" + + # os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL" + os.environ["NCCL_IB_HCA"] = subprocess.getoutput( + "cd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; " + "do cat $i/ports/1/gid_attrs/types/* 2>/dev/null " + "| grep v >/dev/null && echo $i ; done; > /dev/null" + ) + os.environ["NCCL_IB_GID_INDEX"] = "3" + os.environ["NCCL_IB_TC"] = "106" + + def after_init(self, trainer): + trainer.model.cuda() + if int(self.sync_bn) > 1: + ranks = np.arange(self.world_size()).reshape(-1, self.sync_bn) + process_groups = [torch.distributed.new_group(list(pids)) for pids in ranks] + trainer.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( + trainer.model, process_groups[self.global_rank() // self.sync_bn] + ) + elif int(self.sync_bn) == 1: + trainer.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(trainer.model) + trainer.model = nn.parallel.DistributedDataParallel( + trainer.model, device_ids=[self.local_rank()], find_unused_parameters=self.find_unused_parameters + ) + + def cleanup(self): + dist.destroy_process_group() + + def init_dist(self): + torch.cuda.set_device(self.local_rank()) + dist.init_process_group( + backend="nccl", + init_method=self._master_uri, + rank=self.global_rank(), + world_size=self.world_size(), + ) + dist.barrier() + + +class RlaunchReplicaEnv(_DDPEnv): + def __init__(self, sync_bn=0, devices=None, find_unused_parameters=False): + super().__init__(sync_bn, devices, find_unused_parameters) + + def set_master_uri(self, ns): + self._master_uri = f"tcp://{self.master_address(ns)}:{self.master_port()}" + logger.info(self._master_uri) + + @staticmethod + def is_brainpp_mm_env(): + return int(os.environ.get("RLAUNCH_REPLICA_TOTAL", 1)) > 1 + + def master_address(self, ns) -> str: + if self.node_rank() == 0: + root_node = "localhost" + else: + root_node = ns.get_master() + os.environ["MASTER_ADDR"] = root_node + return root_node + + def master_port(self) -> int: + port = os.environ.get("MASTER_PORT", 12345) + os.environ["MASTER_PORT"] = str(port) + return int(port) + + def world_size(self) -> int: + return int(os.environ.get("RLAUNCH_REPLICA_TOTAL", 1)) * int(self.nr_gpus) + + def global_rank(self) -> int: + return int(self.nr_gpus) * self.node_rank() + self.local_rank() + + def local_rank(self) -> int: + return int(os.environ.get("LOCAL_RANK", 0)) + + def node_rank(self) -> int: + return int(os.environ.get("RLAUNCH_REPLICA", 0)) diff --git a/mapmaster/engine/executor.py b/mapmaster/engine/executor.py new file mode 100644 index 0000000..a7bcf4c --- /dev/null +++ b/mapmaster/engine/executor.py @@ -0,0 +1,227 @@ +import os +import torch +from tqdm import tqdm +from typing import Sequence +from mapmaster.engine.experiment import BaseExp +from mapmaster.utils.misc import get_rank, synchronize + + +__all__ = ["Callback", "BaseExecutor", "Trainer", "BeMapNetEvaluator"] + + +class Callback: + + # callback enabled rank list + # None means callback is always enabled + enabled_rank = None + + def setup(self, executor): + pass + + def load_checkpoint(self, executor): + pass + + def after_init(self, executor): + pass + + def before_train(self, executor): + pass + + def before_epoch(self, executor, epoch: int): + pass + + def before_step(self, executor, step, data_dict): + pass + + def before_backward(self, executor): + pass + + def before_optimize(self, executor): + pass + + def after_step(self, executor, step, data_dict, *args, **kwargs): + pass + + def after_epoch(self, executor, epoch: int, update_best_ckpt: bool = False): + pass + + def after_train(self, executor): + pass + + +class BaseExecutor: + def __init__(self, exp: BaseExp, callbacks: Sequence["Callback"], logger=None) -> None: + self.exp = exp + self.logger = logger + self.callbacks = callbacks + self._invoke_callback("setup") + + self.epoch = 0 + self.global_step = 0 + self._invoke_callback("load_checkpoint") + self._invoke_callback("after_init") + + @property + def train_dataloader(self): + return self.exp.train_dataloader + + @property + def val_dataloader(self): + return self.exp.val_dataloader + + @property + def model(self): + return self.exp.model + + @model.setter + def model(self, value): + self.exp.model = value + + @property + def optimizer(self): + return self.exp.optimizer + + @property + def lr_scheduler(self): + return self.exp.lr_scheduler + + def _invoke_callback(self, callback_name, *args, **kwargs): + for cb in self.callbacks: + if cb.enabled_rank is None or self.global_rank in cb.enabled_rank: + func = getattr(cb, callback_name, None) + if func: + func(self, *args, **kwargs) + + @property + def global_rank(self): + return get_rank() + + +class Trainer(BaseExecutor): + def __init__( + self, exp: BaseExp, callbacks: Sequence["Callback"], logger=None, use_amp=False, evaluator=None + ) -> None: + super(Trainer, self).__init__(exp, callbacks, logger) + self.use_amp = use_amp + self.evaluator = evaluator + if self.use_amp: + self.grad_scaler = torch.cuda.amp.GradScaler() + + def train(self): + self.train_iter = iter(self.train_dataloader) + self._invoke_callback("before_train") + self.model.cuda() + self.model.train() + self.optimizer_to(self.optimizer, next(self.model.parameters()).device) + start_epoch = self.epoch + for epoch in range(start_epoch, self.exp.max_epoch): + self.epoch = epoch + self.model.train() + self.train_epoch(epoch) + self._invoke_callback("after_train") + + def train_epoch(self, epoch): + self._invoke_callback("before_epoch", epoch) + sampler = self.train_dataloader.sampler + if hasattr(sampler, "set_epoch"): + sampler.set_epoch(epoch) + for step in range(len(self.train_dataloader)): + try: + data = next(self.train_iter) + except StopIteration: + self.train_iter = iter(self.train_dataloader) + data = next(self.train_iter) + self.train_step(data, step) + if self.evaluator is not None: + self.evaluator.eval() + self._invoke_callback("after_epoch", epoch, update_best_ckpt=False) + + def train_step(self, data, step): + self._invoke_callback("before_step", step, data) + self.lr_scheduler.step(self.global_step) + self.model.train() + self.optimizer.zero_grad() + if not self.use_amp: + ret = self.exp.training_step(data) + else: + with torch.cuda.amp.autocast(): + ret = self.exp.training_step(data) + if isinstance(ret, torch.Tensor): + loss = ret + ext_dict = None + elif isinstance(ret, tuple): + loss, ext_dict = ret + ext_dict = {k: v.detach() if isinstance(v, torch.Tensor) else v for k, v in ext_dict.items()} + else: + raise TypeError + self._invoke_callback("before_backward") + if not self.use_amp: + loss.backward() + self._invoke_callback("before_optimize") + self.optimizer.step() + else: + self.grad_scaler.scale(loss).backward() + self.grad_scaler.unscale_(self.optimizer) # NOTE: grads are unscaled before "before_optimize" callbacks + self._invoke_callback("before_optimize") + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + self._invoke_callback("after_step", step, data, loss=loss.detach(), extra=ext_dict) + self.global_step += 1 + + # refer to: https://github.com/pytorch/pytorch/issues/8741 + @staticmethod + def optimizer_to(optim, device): + for param in optim.state.values(): + # Not sure there are any global tensors in the state dict + if isinstance(param, torch.Tensor): + param.data = param.data.to(device) + if param._grad is not None: + param._grad.data = param._grad.data.to(device) + elif isinstance(param, dict): + for subparam in param.values(): + if isinstance(subparam, torch.Tensor): + subparam.data = subparam.data.to(device) + if subparam._grad is not None: + subparam._grad.data = subparam._grad.data.to(device) + + +class BeMapNetEvaluator(BaseExecutor): + def __init__(self, exp: BaseExp, callbacks: Sequence["Callback"], logger=None) -> None: + super(BeMapNetEvaluator, self).__init__(exp, callbacks, logger) + + def eval(self, ckpt_name=None): + + exp = self.exp + val_iter = iter(self.val_dataloader) + + self._invoke_callback("before_eval") + + if ckpt_name is not None: + if get_rank() == 0: + self.logger.info("Eval with best checkpoint!") + path = os.path.join(exp.output_dir, 'dump_model', ckpt_name) + checkpoint = torch.load(open(path, "rb"), map_location=torch.device("cpu")) + self.model.load_state_dict(checkpoint["model_state"], strict=False) + + self.model.cuda() + self.model.eval() + + for step in tqdm(range(len(self.val_dataloader))): + batch_data = next(val_iter) + with torch.no_grad(): + exp.test_step(batch_data) + self._invoke_callback("after_step", step, {}) + + synchronize() + + if get_rank() == 0: + self.logger.info("Done with inference, start evaluation later!") + gt_dir = exp.exp_config.map_conf['anno_root'] + dt_dir = exp.evaluation_save_dir + val_txts = exp.exp_config.VAL_TXT + + for val_txt in val_txts: + ap_table = "".join(os.popen(f"python3 tools/evaluation/eval.py {gt_dir} {dt_dir} {val_txt}").readlines()) + self.logger.info(" AP-Performance with HDMapNetAPI: \n" + val_txt + "\n" + ap_table) + + self._invoke_callback("after_eval") diff --git a/mapmaster/engine/experiment.py b/mapmaster/engine/experiment.py new file mode 100644 index 0000000..af5aee2 --- /dev/null +++ b/mapmaster/engine/experiment.py @@ -0,0 +1,187 @@ +import os +import sys +import torch +import functools +import numpy as np +from torch.nn import Module +from tabulate import tabulate +from abc import ABCMeta, abstractmethod +from mapmaster.utils.misc import DictAction + + +class BaseExp(metaclass=ABCMeta): + """Basic class for any experiment in Perceptron. + + Args: + batch_size_per_device (int): + batch_size of each device + + total_devices (int): + number of devices to use + + max_epoch (int): + total training epochs, the reason why we need to give max_epoch + is that lr_scheduler may need to be adapted according to max_epoch + """ + + def __init__(self, batch_size_per_device, total_devices, max_epoch): + self._batch_size_per_device = batch_size_per_device + self._max_epoch = max_epoch + self._total_devices = total_devices + # ----------------------------------------------- extra configure ------------------------- # + self.seed = None + self.exp_name = os.path.splitext(os.path.basename(sys.argv.copy()[0]))[0] # entrypoint filename as exp_name + self.print_interval = 100 + self.dump_interval = 10 + self.eval_interval = 10 + self.num_keep_latest_ckpt = 10 + self.ckpt_oss_save_dir = None + self.enable_tensorboard = False + self.eval_executor_class = None + + @property + def train_dataloader(self): + if "_train_dataloader" not in self.__dict__: + self._train_dataloader = self._configure_train_dataloader() + return self._train_dataloader + + @property + def val_dataloader(self): + if "_val_dataloader" not in self.__dict__: + self._val_dataloader = self._configure_val_dataloader() + return self._val_dataloader + + @property + def test_dataloader(self): + if "_test_dataloader" not in self.__dict__: + self._test_dataloader = self._configure_test_dataloader() + return self._test_dataloader + + @property + def model(self): + if "_model" not in self.__dict__: + self._model = self._configure_model() + return self._model + + @model.setter + def model(self, value): + self._model = value + + @property + def callbacks(self): + if not hasattr(self, "_callbacks"): + self._callbacks = self._configure_callbacks() + return self._callbacks + + @property + def optimizer(self): + if "_optimizer" not in self.__dict__: + self._optimizer = self._configure_optimizer() + return self._optimizer + + @property + def lr_scheduler(self): + if "_lr_scheduler" not in self.__dict__: + self._lr_scheduler = self._configure_lr_scheduler() + return self._lr_scheduler + + @property + def batch_size_per_device(self): + return self._batch_size_per_device + + @property + def max_epoch(self): + return self._max_epoch + + @property + def total_devices(self): + return self._total_devices + + @abstractmethod + def _configure_model(self) -> Module: + pass + + @abstractmethod + def _configure_train_dataloader(self): + """""" + + def _configure_callbacks(self): + return [] + + @abstractmethod + def _configure_val_dataloader(self): + """""" + + @abstractmethod + def _configure_test_dataloader(self): + """""" + + def training_step(self, *args, **kwargs): + pass + + @abstractmethod + def _configure_optimizer(self) -> torch.optim.Optimizer: + pass + + @abstractmethod + def _configure_lr_scheduler(self, **kwargs): + pass + + def update_attr(self, options: dict) -> str: + if options is None: + return "" + assert isinstance(options, dict) + msg = "" + for k, v in options.items(): + if k in self.__dict__: + old_v = self.__getattribute__(k) + if not v == old_v: + self.__setattr__(k, v) + msg = "{}\n'{}' is overriden from '{}' to '{}'".format(msg, k, old_v, v) + else: + self.__setattr__(k, v) + msg = "{}\n'{}' is set to '{}'".format(msg, k, v) + + # update exp_name + exp_name_suffix = "-".join(sorted([f"{k}-{v}" for k, v in options.items()])) + self.exp_name = f"{self.exp_name}--{exp_name_suffix}" + return msg + + def get_cfg_as_str(self) -> str: + config_table = [] + for c, v in self.__dict__.items(): + if not isinstance(v, (int, float, str, list, tuple, dict, np.ndarray)): + if hasattr(v, "__name__"): + v = v.__name__ + elif hasattr(v, "__class__"): + v = v.__class__ + elif type(v) == functools.partial: + v = v.func.__name__ + if c[0] == "_": + c = c[1:] + config_table.append((str(c), str(v))) + + headers = ["config key", "value"] + config_table = tabulate(config_table, headers, tablefmt="plain") + return config_table + + def __str__(self): + return self.get_cfg_as_str() + + def to_onnx(self): + pass + + @classmethod + def add_argparse_args(cls, parser): # pragma: no-cover + parser.add_argument( + "--exp_options", + nargs="+", + action=DictAction, + help="override some settings in the exp, the key-value pair in xxx=yyy format will be merged into exp. " + 'If the value to be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + "Note that the quotation marks are necessary and that no white space is allowed.", + ) + parser.add_argument("-b", "--batch-size-per-device", type=int, default=None) + parser.add_argument("-e", "--max-epoch", type=int, default=None) + return parser diff --git a/mapmaster/models/__init__.py b/mapmaster/models/__init__.py new file mode 100644 index 0000000..f8f8380 --- /dev/null +++ b/mapmaster/models/__init__.py @@ -0,0 +1 @@ +from .network import MapMaster diff --git a/mapmaster/models/backbone/__init__.py b/mapmaster/models/backbone/__init__.py new file mode 100644 index 0000000..e9a30cc --- /dev/null +++ b/mapmaster/models/backbone/__init__.py @@ -0,0 +1 @@ +from .model import ResNetBackbone, EfficientNetBackbone, SwinTRBackbone diff --git a/mapmaster/models/backbone/bifpn/__init__.py b/mapmaster/models/backbone/bifpn/__init__.py new file mode 100644 index 0000000..0ed49b5 --- /dev/null +++ b/mapmaster/models/backbone/bifpn/__init__.py @@ -0,0 +1 @@ +from .model import BiFPN diff --git a/mapmaster/models/backbone/bifpn/model.py b/mapmaster/models/backbone/bifpn/model.py new file mode 100644 index 0000000..0615998 --- /dev/null +++ b/mapmaster/models/backbone/bifpn/model.py @@ -0,0 +1,372 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from .utils import Swish, Conv2dStaticSamePadding, MaxPool2dStaticSamePadding + + +class SeparableConvBlock(nn.Module): + """ + created by Zylo117 + """ + + def __init__(self, in_channels, out_channels=None, norm=True, activation=False, norm_layer=nn.BatchNorm2d): + super(SeparableConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + + # Q: whether separate conv + # share bias between depthwise_conv and pointwise_conv + # or just pointwise_conv apply bias. + # A: Confirmed, just pointwise_conv applies bias, depthwise_conv has no bias. + + self.depthwise_conv = Conv2dStaticSamePadding( + in_channels, in_channels, kernel_size=3, stride=1, groups=in_channels, bias=False + ) + self.pointwise_conv = Conv2dStaticSamePadding(in_channels, out_channels, kernel_size=1, stride=1) + + self.norm = norm + if self.norm: + # Warning: pytorch momentum is different from tensorflow's, momentum_pytorch = 1 - momentum_tensorflow + self.bn = norm_layer(num_features=out_channels, momentum=0.01, eps=1e-3) + + self.activation = activation + if self.activation: + self.swish = Swish() + + def forward(self, x): + x = self.depthwise_conv(x) + x = self.pointwise_conv(x) + + if self.norm: + x = self.bn(x) + + if self.activation: + x = self.swish(x) + + return x + + +class BiFPNLayer(nn.Module): + """ + modified by Zylo117 + """ + + def __init__( + self, + num_channels, + conv_channels, + first_time=False, + epsilon=1e-4, + attention=True, + use_p8=False, + norm_layer=nn.BatchNorm2d, + ): + """ + Args: + num_channels: + conv_channels: + first_time: whether the input comes directly from the efficientnet, + if True, downchannel it first, and downsample P5 to generate P6 then P7 + epsilon: epsilon of fast weighted attention sum of BiFPN, not the BN's epsilon + onnx_export: if True, use Swish instead of MemoryEfficientSwish + """ + super(BiFPNLayer, self).__init__() + self.epsilon = epsilon + self.use_p8 = use_p8 + + # Conv layers + self.conv6_up = SeparableConvBlock(num_channels, norm_layer=norm_layer) + self.conv5_up = SeparableConvBlock(num_channels, norm_layer=norm_layer) + self.conv4_up = SeparableConvBlock(num_channels, norm_layer=norm_layer) + self.conv3_up = SeparableConvBlock(num_channels, norm_layer=norm_layer) + self.conv4_down = SeparableConvBlock(num_channels, norm_layer=norm_layer) + self.conv5_down = SeparableConvBlock(num_channels, norm_layer=norm_layer) + self.conv6_down = SeparableConvBlock(num_channels, norm_layer=norm_layer) + self.conv7_down = SeparableConvBlock(num_channels, norm_layer=norm_layer) + if use_p8: + self.conv7_up = SeparableConvBlock(num_channels, norm_layer=norm_layer) + self.conv8_down = SeparableConvBlock(num_channels, norm_layer=norm_layer) + + # Feature scaling layers + self.p6_upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.p5_upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.p4_upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.p3_upsample = nn.Upsample(scale_factor=2, mode="nearest") + + self.p4_downsample = MaxPool2dStaticSamePadding(3, 2) + self.p5_downsample = MaxPool2dStaticSamePadding(3, 2) + self.p6_downsample = MaxPool2dStaticSamePadding(3, 2) + self.p7_downsample = MaxPool2dStaticSamePadding(3, 2) + if use_p8: + self.p7_upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.p8_downsample = MaxPool2dStaticSamePadding(3, 2) + + self.swish = Swish() + + self.first_time = first_time + if self.first_time: + self.p5_down_channel = nn.Sequential( + Conv2dStaticSamePadding(conv_channels[2], num_channels, 1), + norm_layer(num_channels, momentum=0.01, eps=1e-3), + ) + self.p4_down_channel = nn.Sequential( + Conv2dStaticSamePadding(conv_channels[1], num_channels, 1), + norm_layer(num_channels, momentum=0.01, eps=1e-3), + ) + self.p3_down_channel = nn.Sequential( + Conv2dStaticSamePadding(conv_channels[0], num_channels, 1), + norm_layer(num_channels, momentum=0.01, eps=1e-3), + ) + + self.p5_to_p6 = nn.Sequential( + Conv2dStaticSamePadding(conv_channels[2], num_channels, 1), + norm_layer(num_channels, momentum=0.01, eps=1e-3), + MaxPool2dStaticSamePadding(3, 2), + ) + self.p6_to_p7 = nn.Sequential(MaxPool2dStaticSamePadding(3, 2)) + if use_p8: + self.p7_to_p8 = nn.Sequential(MaxPool2dStaticSamePadding(3, 2)) + + self.p4_down_channel_2 = nn.Sequential( + Conv2dStaticSamePadding(conv_channels[1], num_channels, 1), + norm_layer(num_channels, momentum=0.01, eps=1e-3), + ) + self.p5_down_channel_2 = nn.Sequential( + Conv2dStaticSamePadding(conv_channels[2], num_channels, 1), + norm_layer(num_channels, momentum=0.01, eps=1e-3), + ) + + # Weight + self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p6_w1_relu = nn.ReLU() + self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p5_w1_relu = nn.ReLU() + self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p4_w1_relu = nn.ReLU() + self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p3_w1_relu = nn.ReLU() + + self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) + self.p4_w2_relu = nn.ReLU() + self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) + self.p5_w2_relu = nn.ReLU() + self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) + self.p6_w2_relu = nn.ReLU() + self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p7_w2_relu = nn.ReLU() + + self.attention = attention + + def forward(self, inputs): + """ + illustration of a minimal bifpn unit + P7_0 -------------------------> P7_2 --------> + |-------------| ↑ + ↓ | + P6_0 ---------> P6_1 ---------> P6_2 --------> + |-------------|--------------↑ ↑ + ↓ | + P5_0 ---------> P5_1 ---------> P5_2 --------> + |-------------|--------------↑ ↑ + ↓ | + P4_0 ---------> P4_1 ---------> P4_2 --------> + |-------------|--------------↑ ↑ + |--------------↓ | + P3_0 -------------------------> P3_2 --------> + """ + + # downsample channels using same-padding conv2d to target phase's if not the same + # judge: same phase as target, + # if same, pass; + # elif earlier phase, downsample to target phase's by pooling + # elif later phase, upsample to target phase's by nearest interpolation + + if self.attention: + outs = self._forward_fast_attention(inputs) + else: + outs = self._forward(inputs) + + return outs + + def _forward_fast_attention(self, inputs): + if self.first_time: + p3, p4, p5 = inputs + + p6_in = self.p5_to_p6(p5) + p7_in = self.p6_to_p7(p6_in) + + p3_in = self.p3_down_channel(p3) + p4_in = self.p4_down_channel(p4) + p5_in = self.p5_down_channel(p5) + + else: + # P3_0, P4_0, P5_0, P6_0 and P7_0 + p3_in, p4_in, p5_in, p6_in, p7_in = inputs + + # P7_0 to P7_2 + + # Weights for P6_0 and P7_0 to P6_1 + p6_w1 = self.p6_w1_relu(self.p6_w1) + weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon) + # Connections for P6_0 and P7_0 to P6_1 respectively + p6_up = self.conv6_up.forward(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in))) + + # Weights for P5_0 and P6_1 to P5_1 + p5_w1 = self.p5_w1_relu(self.p5_w1) + weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon) + # Connections for P5_0 and P6_1 to P5_1 respectively + p5_up = self.conv5_up.forward(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up))) + + # Weights for P4_0 and P5_1 to P4_1 + p4_w1 = self.p4_w1_relu(self.p4_w1) + weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon) + # Connections for P4_0 and P5_1 to P4_1 respectively + p4_up = self.conv4_up.forward(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up))) + + # Weights for P3_0 and P4_1 to P3_2 + p3_w1 = self.p3_w1_relu(self.p3_w1) + weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon) + # Connections for P3_0 and P4_1 to P3_2 respectively + p3_out = self.conv3_up.forward(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up))) + + if self.first_time: + p4_in = self.p4_down_channel_2(p4) + p5_in = self.p5_down_channel_2(p5) + + # Weights for P4_0, P4_1 and P3_2 to P4_2 + p4_w2 = self.p4_w2_relu(self.p4_w2) + weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon) + # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively + p4_out = self.conv4_down.forward( + self.swish(weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out)) + ) + + # Weights for P5_0, P5_1 and P4_2 to P5_2 + p5_w2 = self.p5_w2_relu(self.p5_w2) + weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon) + # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively + p5_out = self.conv5_down.forward( + self.swish(weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out)) + ) + + # Weights for P6_0, P6_1 and P5_2 to P6_2 + p6_w2 = self.p6_w2_relu(self.p6_w2) + weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon) + # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively + p6_out = self.conv6_down.forward( + self.swish(weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out)) + ) + + # Weights for P7_0 and P6_2 to P7_2 + p7_w2 = self.p7_w2_relu(self.p7_w2) + weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon) + # Connections for P7_0 and P6_2 to P7_2 + p7_out = self.conv7_down.forward(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out))) + + return p3_out, p4_out, p5_out, p6_out, p7_out + + def _forward(self, inputs): + if self.first_time: + p3, p4, p5 = inputs + + p6_in = self.p5_to_p6(p5) + p7_in = self.p6_to_p7(p6_in) + if self.use_p8: + p8_in = self.p7_to_p8(p7_in) + + p3_in = self.p3_down_channel(p3) + p4_in = self.p4_down_channel(p4) + p5_in = self.p5_down_channel(p5) + + else: + if self.use_p8: + # P3_0, P4_0, P5_0, P6_0, P7_0 and P8_0 + p3_in, p4_in, p5_in, p6_in, p7_in, p8_in = inputs + else: + # P3_0, P4_0, P5_0, P6_0 and P7_0 + p3_in, p4_in, p5_in, p6_in, p7_in = inputs + + if self.use_p8: + # P8_0 to P8_2 + + # Connections for P7_0 and P8_0 to P7_1 respectively + p7_up = self.conv7_up.forward(self.swish(p7_in + self.p7_upsample(p8_in))) + + # Connections for P6_0 and P7_0 to P6_1 respectively + p6_up = self.conv6_up.forward(self.swish(p6_in + self.p6_upsample(p7_up))) + else: + # P7_0 to P7_2 + + # Connections for P6_0 and P7_0 to P6_1 respectively + p6_up = self.conv6_up.forward(self.swish(p6_in + self.p6_upsample(p7_in))) + + # Connections for P5_0 and P6_1 to P5_1 respectively + p5_up = self.conv5_up.forward(self.swish(p5_in + self.p5_upsample(p6_up))) + + # Connections for P4_0 and P5_1 to P4_1 respectively + p4_up = self.conv4_up.forward(self.swish(p4_in + self.p4_upsample(p5_up))) + + # Connections for P3_0 and P4_1 to P3_2 respectively + p3_out = self.conv3_up.forward(self.swish(p3_in + self.p3_upsample(p4_up))) + + if self.first_time: + p4_in = self.p4_down_channel_2(p4) + p5_in = self.p5_down_channel_2(p5) + + # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively + p4_out = self.conv4_down.forward(self.swish(p4_in + p4_up + self.p4_downsample(p3_out))) + + # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively + p5_out = self.conv5_down.forward(self.swish(p5_in + p5_up + self.p5_downsample(p4_out))) + + # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively + p6_out = self.conv6_down.forward(self.swish(p6_in + p6_up + self.p6_downsample(p5_out))) + + if self.use_p8: + # Connections for P7_0, P7_1 and P6_2 to P7_2 respectively + p7_out = self.conv7_down.forward(self.swish(p7_in + p7_up + self.p7_downsample(p6_out))) + + # Connections for P8_0 and P7_2 to P8_2 + p8_out = self.conv8_down.forward(self.swish(p8_in + self.p8_downsample(p7_out))) + + return p3_out, p4_out, p5_out, p6_out, p7_out, p8_out + else: + # Connections for P7_0 and P6_2 to P7_2 + p7_out = self.conv7_down.forward(self.swish(p7_in + self.p7_downsample(p6_out))) + + return p3_out, p4_out, p5_out, p6_out, p7_out + + +class BiFPN(nn.Module): + def __init__( + self, conv_channels, fpn_cell_repeat=2, fpn_num_filters=64, norm_layer=nn.BatchNorm2d, use_checkpoint=False, + tgt_shape=(12, 28*6), + ): + super(BiFPN, self).__init__() + self.model = nn.Sequential( + *[ + BiFPNLayer(fpn_num_filters, conv_channels, True if i == 0 else False, norm_layer=norm_layer) + for i in range(fpn_cell_repeat) + ] + ) + self.tgt_shape = tgt_shape + self.use_checkpoint = use_checkpoint + + def forward(self, im_bkb_features): + if self.use_checkpoint and self.training: + im_nek_features = checkpoint.checkpoint(self._forward, *im_bkb_features) + else: + im_nek_features = self._forward(*im_bkb_features) + im_nek_features = [torch.cat([self.up_sample(x, tgt_shape=self.tgt_shape) for x in im_nek_features], dim=1)] + return im_nek_features + + def _forward(self, *inputs): + outputs = self.model(inputs[-3:]) + return outputs + + def up_sample(self, x, tgt_shape=None): + tgt_shape = self.tgt_shape if tgt_shape is None else tgt_shape + if tuple(x.shape[-2:]) == tuple(tgt_shape): + return x + return F.interpolate(x, size=tgt_shape, mode="bilinear", align_corners=True) diff --git a/mapmaster/models/backbone/bifpn/utils.py b/mapmaster/models/backbone/bifpn/utils.py new file mode 100644 index 0000000..0f7bb4b --- /dev/null +++ b/mapmaster/models/backbone/bifpn/utils.py @@ -0,0 +1,90 @@ +# Author: Zylo117 + +import math +import torch +from torch import nn +import torch.nn.functional as F + + +class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class Conv2dStaticSamePadding(nn.Module): + """ + created by Zylo117 + The real keras/tensorflow conv2d with same padding + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, bias=bias, groups=groups) + self.stride = self.conv.stride + self.kernel_size = self.conv.kernel_size + self.dilation = self.conv.dilation + + if isinstance(self.stride, int): + self.stride = [self.stride] * 2 + elif len(self.stride) == 1: + self.stride = [self.stride[0]] * 2 + + if isinstance(self.kernel_size, int): + self.kernel_size = [self.kernel_size] * 2 + elif len(self.kernel_size) == 1: + self.kernel_size = [self.kernel_size[0]] * 2 + + def forward(self, x): + h, w = x.shape[-2:] + + extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1] + extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0] + + left = extra_h // 2 + right = extra_h - left + top = extra_v // 2 + bottom = extra_v - top + + x = F.pad(x, [left, right, top, bottom]) + + x = self.conv(x) + return x + + +class MaxPool2dStaticSamePadding(nn.Module): + """ + created by Zylo117 + The real keras/tensorflow MaxPool2d with same padding + """ + + def __init__(self, *args, **kwargs): + super().__init__() + self.pool = nn.MaxPool2d(*args, **kwargs) + self.stride = self.pool.stride + self.kernel_size = self.pool.kernel_size + + if isinstance(self.stride, int): + self.stride = [self.stride] * 2 + elif len(self.stride) == 1: + self.stride = [self.stride[0]] * 2 + + if isinstance(self.kernel_size, int): + self.kernel_size = [self.kernel_size] * 2 + elif len(self.kernel_size) == 1: + self.kernel_size = [self.kernel_size[0]] * 2 + + def forward(self, x): + h, w = x.shape[-2:] + + extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1] + extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0] + + left = extra_h // 2 + right = extra_h - left + top = extra_v // 2 + bottom = extra_v - top + + x = F.pad(x, [left, right, top, bottom]) + + x = self.pool(x) + return x diff --git a/mapmaster/models/backbone/efficientnet/__init__.py b/mapmaster/models/backbone/efficientnet/__init__.py new file mode 100644 index 0000000..464fc1b --- /dev/null +++ b/mapmaster/models/backbone/efficientnet/__init__.py @@ -0,0 +1 @@ +from .model import EfficientNet diff --git a/mapmaster/models/backbone/efficientnet/model.py b/mapmaster/models/backbone/efficientnet/model.py new file mode 100644 index 0000000..63d869d --- /dev/null +++ b/mapmaster/models/backbone/efficientnet/model.py @@ -0,0 +1,468 @@ +"""model.py - Model and module class for EfficientNet. + They are built to mirror those in the official TensorFlow implementation. +""" + +# Author: lukemelas (github username) +# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch +# With adjustments and added comments by workingcoder (github username). + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import checkpoint as cp +from mapmaster.models.backbone.efficientnet.utils import ( + round_filters, + round_repeats, + drop_connect, + get_same_padding_conv2d, + get_model_params, + efficientnet_params, + load_pretrained_weights, + Swish, + MemoryEfficientSwish, + calculate_output_image_size, +) + + +VALID_MODELS = ( + "efficientnet-b0", + "efficientnet-b1", + "efficientnet-b2", + "efficientnet-b3", + "efficientnet-b4", + "efficientnet-b5", + "efficientnet-b6", + "efficientnet-b7", + "efficientnet-b8", + # Support the construction of 'efficientnet-l2' without pretrained weights + "efficientnet-l2", +) + + +class MBConvBlock(nn.Module): + """Mobile Inverted Residual Bottleneck Block. + + Args: + block_args (namedtuple): BlockArgs, defined in utils.py. + global_params (namedtuple): GlobalParam, defined in utils.py. + image_size (tuple or list): [image_height, image_width]. + + References: + [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) + [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) + [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) + """ + + def __init__(self, block_args, global_params, image_size=None, norm_layer=nn.BatchNorm2d): + super().__init__() + self._block_args = block_args + self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow + self._bn_eps = global_params.batch_norm_epsilon + self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) + self.id_skip = block_args.id_skip # whether to use skip connection and drop connect + + # Expansion phase (Inverted Bottleneck) + inp = self._block_args.input_filters # number of input channels + oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels + if self._block_args.expand_ratio != 1: + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) + self._bn0 = norm_layer(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size + + # Depthwise convolution phase + k = self._block_args.kernel_size + s = self._block_args.stride + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._depthwise_conv = Conv2d( + in_channels=oup, + out_channels=oup, + groups=oup, # groups makes it depthwise + kernel_size=k, + stride=s, + bias=False, + ) + self._bn1 = norm_layer(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + image_size = calculate_output_image_size(image_size, s) + + # Squeeze and Excitation layer, if desired + if self.has_se: + Conv2d = get_same_padding_conv2d(image_size=(1, 1)) + num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) + self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) + self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) + + # Pointwise convolution phase + final_oup = self._block_args.output_filters + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) + self._bn2 = norm_layer(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) + self._swish = MemoryEfficientSwish() + + def forward(self, inputs, drop_connect_rate=None): + """MBConvBlock's forward function. + + Args: + inputs (tensor): Input tensor. + drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). + + Returns: + Output of this block after processing. + """ + + # Expansion and Depthwise Convolution + x = inputs + if self._block_args.expand_ratio != 1: + x = self._expand_conv(inputs) + x = self._bn0(x) + x = self._swish(x) + + x = self._depthwise_conv(x) + x = self._bn1(x) + x = self._swish(x) + + # Squeeze and Excitation + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._se_reduce(x_squeezed) + x_squeezed = self._swish(x_squeezed) + x_squeezed = self._se_expand(x_squeezed) + x = torch.sigmoid(x_squeezed) * x + + # Pointwise Convolution + x = self._project_conv(x) + x = self._bn2(x) + + # Skip connection and drop connect + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: + # The combination of skip connection and drop connect brings about stochastic depth. + if drop_connect_rate: + x = drop_connect(x, p=drop_connect_rate, training=self.training) + x = x + inputs # skip connection + return x + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + + +class EfficientNet(nn.Module): + """EfficientNet model. + Most easily loaded with the .from_name or .from_pretrained methods. + + Args: + blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. + global_params (namedtuple): A set of GlobalParams shared between blocks. + + References: + [1] https://arxiv.org/abs/1905.11946 (EfficientNet) + + Example: + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> model.eval() + >>> outputs = model(inputs) + """ + + def __init__(self, blocks_args=None, global_params=None, with_head=True, with_cp=False, norm_layer=nn.BatchNorm2d): + super().__init__() + assert isinstance(blocks_args, list), "blocks_args should be a list" + assert len(blocks_args) > 0, "block args must be greater than 0" + self._global_params = global_params + self._blocks_args = blocks_args + self.with_head = with_head + self.with_cp = with_cp + + # Batch norm parameters + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + + # Get stem static or dynamic convolution depending on image size + image_size = global_params.image_size + Conv2d = get_same_padding_conv2d(image_size=image_size) + + # Stem + in_channels = 3 # rgb + out_channels = round_filters(32, self._global_params) # number of output channels + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._bn0 = norm_layer(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + image_size = calculate_output_image_size(image_size, 2) + + # Build blocks + self._blocks = nn.ModuleList([]) + + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, self._global_params), + output_filters=round_filters(block_args.output_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params), + ) + + # The first block needs to take care of stride and filter size increase. + self._blocks.append( + MBConvBlock(block_args, self._global_params, image_size=image_size, norm_layer=norm_layer) + ) + image_size = calculate_output_image_size(image_size, block_args.stride) + if block_args.num_repeat > 1: # modify block_args to keep same output size + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks.append( + MBConvBlock(block_args, self._global_params, image_size=image_size, norm_layer=norm_layer) + ) + # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 + + # Head + in_channels = block_args.output_filters # output of final block + out_channels = round_filters(1280, self._global_params) + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self._bn1 = norm_layer(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + # Final linear layer + self._avg_pooling = nn.AdaptiveAvgPool2d(1) + if self._global_params.include_top: + self._dropout = nn.Dropout(self._global_params.dropout_rate) + self._fc = nn.Linear(out_channels, self._global_params.num_classes) + + # set activation to memory efficient swish by default + self._swish = MemoryEfficientSwish() + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + def extract_endpoints(self, inputs): + """Use convolution layer to extract features + from reduction levels i in [1, 2, 3, 4, 5]. + + Args: + inputs (tensor): Input tensor. + + Returns: + Dictionary of last intermediate features + with reduction levels i in [1, 2, 3, 4, 5]. + Example: + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> endpoints = model.extract_endpoints(inputs) + >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) + >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) + >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) + >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) + >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7]) + >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7]) + """ + endpoints = dict() + + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + prev_x = x + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate + + # x = block(x, drop_connect_rate=drop_connect_rate) + if self.with_cp and x.requires_grad: + x = cp.checkpoint(block, x, drop_connect_rate) + # x = block(x, drop_connect_rate=drop_connect_rate) + else: + x = block(x, drop_connect_rate=drop_connect_rate) + + if prev_x.size(2) > x.size(2): + endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x + elif idx == len(self._blocks) - 1: + endpoints["reduction_{}".format(len(endpoints) + 1)] = x + prev_x = x + + if self.with_head: + # Head + x = self._swish(self._bn1(self._conv_head(x))) + endpoints["reduction_{}".format(len(endpoints) + 1)] = x + + return endpoints + + def extract_features(self, inputs): + """use convolution layer to extract feature . + + Args: + inputs (tensor): Input tensor. + + Returns: + Output of the final convolution + layer in the efficientnet model. + """ + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + + # Head + x = self._swish(self._bn1(self._conv_head(x))) + + return x + + def forward(self, inputs): + """EfficientNet's forward function. + Calls extract_features to extract features, applies final linear layer, and returns logits. + + Args: + inputs (tensor): Input tensor. + + Returns: + Output of this model after processing. + """ + # Convolution layers + x = self.extract_features(inputs) + # Pooling and final linear layer + x = self._avg_pooling(x) + if self._global_params.include_top: + x = x.flatten(start_dim=1) + x = self._dropout(x) + x = self._fc(x) + return x + + @classmethod + def from_name( + cls, model_name, in_channels=3, out_stride=32, with_head=True, with_cp=False, norm_layer=nn.BatchNorm2d, + **override_params + ): + """Create an efficientnet model according to name. + + Args: + model_name (str): Name for efficientnet. + in_channels (int): Input data's channel number. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'num_classes', 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + + Returns: + An efficientnet model. + """ + cls._check_model_name_is_valid(model_name) + blocks_args, global_params = get_model_params(model_name, out_stride, override_params) + model = cls(blocks_args, global_params, with_head=with_head, with_cp=with_cp, norm_layer=norm_layer) + model._change_in_channels(in_channels) + return model + + @classmethod + def from_pretrained( + cls, + model_name, + weights_path=None, + advprop=False, + in_channels=3, + num_classes=100, + out_stride=32, + with_head=True, + with_cp=False, + norm_layer=nn.BatchNorm2d, + **override_params + ): + """Create an efficientnet model according to name. + + Args: + model_name (str): Name for efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + advprop (bool): + Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + in_channels (int): Input data's channel number. + num_classes (int): + Number of categories for classification. + It controls the output size for final linear layer. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + + Returns: + A pretrained efficientnet model. + """ + model = cls.from_name( + model_name, + num_classes=num_classes, + out_stride=out_stride, + with_head=with_head, + with_cp=with_cp, + norm_layer=norm_layer, + **override_params + ) + load_pretrained_weights( + model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop + ) + model._change_in_channels(in_channels) + return model + + @classmethod + def get_image_size(cls, model_name): + """Get the input image size for a given efficientnet model. + + Args: + model_name (str): Name for efficientnet. + + Returns: + Input image size (resolution). + """ + cls._check_model_name_is_valid(model_name) + _, _, res, _ = efficientnet_params(model_name) + return res + + @classmethod + def _check_model_name_is_valid(cls, model_name): + """Validates model name. + + Args: + model_name (str): Name for efficientnet. + + Returns: + bool: Is a valid name or not. + """ + if model_name not in VALID_MODELS: + raise ValueError("model_name should be one of: " + ", ".join(VALID_MODELS)) + + def _change_in_channels(self, in_channels): + """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. + + Args: + in_channels (int): Input data's channel number. + """ + if in_channels != 3: + Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) + out_channels = round_filters(32, self._global_params) + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) diff --git a/mapmaster/models/backbone/efficientnet/utils.py b/mapmaster/models/backbone/efficientnet/utils.py new file mode 100644 index 0000000..c6b715f --- /dev/null +++ b/mapmaster/models/backbone/efficientnet/utils.py @@ -0,0 +1,656 @@ +"""utils.py - Helper functions for building the model and for loading model parameters. + These helper functions are built to mirror those in the official TensorFlow implementation. +""" + +# Author: lukemelas (github username) +# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch +# With adjustments and added comments by workingcoder (github username). + +import re +import math +import collections +from functools import partial +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import model_zoo + + +################################################################################ +# Help functions for model architecture +################################################################################ + +# GlobalParams and BlockArgs: Two namedtuples +# Swish and MemoryEfficientSwish: Two implementations of the method +# round_filters and round_repeats: +# Functions to calculate params for scaling model width and depth ! ! ! +# get_width_and_height_from_size and calculate_output_image_size +# drop_connect: A structural design +# get_same_padding_conv2d: +# Conv2dDynamicSamePadding +# Conv2dStaticSamePadding +# get_same_padding_maxPool2d: +# MaxPool2dDynamicSamePadding +# MaxPool2dStaticSamePadding +# It's an additional function, not used in EfficientNet, +# but can be used in other model (such as EfficientDet). + +# Parameters for the entire model (stem, all blocks, and head) +GlobalParams = collections.namedtuple( + "GlobalParams", + [ + "width_coefficient", + "depth_coefficient", + "image_size", + "dropout_rate", + "num_classes", + "batch_norm_momentum", + "batch_norm_epsilon", + "drop_connect_rate", + "depth_divisor", + "min_depth", + "include_top", + ], +) + +# Parameters for an individual model block +BlockArgs = collections.namedtuple( + "BlockArgs", + ["num_repeat", "kernel_size", "stride", "expand_ratio", "input_filters", "output_filters", "se_ratio", "id_skip"], +) + +# Set GlobalParams and BlockArgs's defaults +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) + +# Swish activation function +if hasattr(nn, "SiLU"): + Swish = nn.SiLU +else: + # For compatibility with old PyTorch versions + class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +# A memory-efficient implementation of Swish function +class SwishImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_tensors[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class MemoryEfficientSwish(nn.Module): + def forward(self, x): + return SwishImplementation.apply(x) + + +def round_filters(filters, global_params): + """Calculate and round number of filters based on width multiplier. + Use width_coefficient, depth_divisor and min_depth of global_params. + + Args: + filters (int): Filters number to be calculated. + global_params (namedtuple): Global params of the model. + + Returns: + new_filters: New filters number after calculating. + """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + # TODO: modify the params names. + # maybe the names (width_divisor,min_width) + # are more suitable than (depth_divisor,min_depth). + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor # pay attention to this line when using min_depth + # follow the formula transferred from official TensorFlow implementation + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """Calculate module's repeat number of a block based on depth multiplier. + Use depth_coefficient of global_params. + + Args: + repeats (int): num_repeat to be calculated. + global_params (namedtuple): Global params of the model. + + Returns: + new repeat: New repeat number after calculating. + """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + # follow the formula transferred from official TensorFlow implementation + return int(math.ceil(multiplier * repeats)) + + +def drop_connect(inputs, p, training): + """Drop connect. + + Args: + input (tensor: BCWH): Input of this structure. + p (float: 0.0~1.0): Probability of drop connection. + training (bool): The running mode. + + Returns: + output: Output after drop connection. + """ + assert 0 <= p <= 1, "p must be in range of [0,1]" + + if not training: + return inputs + + batch_size = inputs.shape[0] + keep_prob = 1 - p + + # generate binary_tensor mask according to probability (p for 0, 1-p for 1) + random_tensor = keep_prob + random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) + binary_tensor = torch.floor(random_tensor) + + output = inputs / keep_prob * binary_tensor + return output + + +def get_width_and_height_from_size(x): + """Obtain height and width from x. + + Args: + x (int, tuple or list): Data size. + + Returns: + size: A tuple or list (H,W). + """ + if isinstance(x, int): + return x, x + if isinstance(x, list) or isinstance(x, tuple): + return x + else: + raise TypeError() + + +def calculate_output_image_size(input_image_size, stride): + """Calculates the output image size when using Conv2dSamePadding with a stride. + Necessary for static padding. Thanks to mannatsingh for pointing this out. + + Args: + input_image_size (int, tuple or list): Size of input image. + stride (int, tuple or list): Conv2d operation's stride. + + Returns: + output_image_size: A list [H,W]. + """ + if input_image_size is None: + return None + image_height, image_width = get_width_and_height_from_size(input_image_size) + stride = stride if isinstance(stride, int) else stride[0] + image_height = int(math.ceil(image_height / stride)) + image_width = int(math.ceil(image_width / stride)) + return [image_height, image_width] + + +# Note: +# The following 'SamePadding' functions make output size equal ceil(input size/stride). +# Only when stride equals 1, can the output size be the same as input size. +# Don't be confused by their function names ! ! ! + + +def get_same_padding_conv2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + + Args: + image_size (int or tuple): Size of the image. + + Returns: + Conv2dDynamicSamePadding or Conv2dStaticSamePadding. + """ + if image_size is None: + return Conv2dDynamicSamePadding + else: + return partial(Conv2dStaticSamePadding, image_size=image_size) + + +class Conv2dDynamicSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow, for a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + # Tips for 'SAME' mode padding. + # Given the following: + # i: width or height + # s: stride + # k: kernel size + # d: dilation + # p: padding + # Output after Conv2d: + # o = floor((i+p-((k-1)*d+1))/s+1) + # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1), + # => p = (i-1)*s+((k-1)*d+1)-i + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): + super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! ! + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class Conv2dStaticSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + # With the same calculation as Conv2dDynamicSamePadding + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs): + super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +def get_same_padding_maxPool2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + + Args: + image_size (int or tuple): Size of the image. + + Returns: + MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. + """ + if image_size is None: + return MaxPool2dDynamicSamePadding + else: + return partial(MaxPool2dStaticSamePadding, image_size=image_size) + + +class MaxPool2dDynamicSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False): + super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) + self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.max_pool2d( + x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode, self.return_indices + ) + + +class MaxPool2dStaticSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + def __init__(self, kernel_size, stride, image_size=None, **kwargs): + super().__init__(kernel_size, stride, **kwargs) + self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.max_pool2d( + x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode, self.return_indices + ) + return x + + +################################################################################ +# Helper functions for loading model params +################################################################################ + +# BlockDecoder: A Class for encoding and decoding BlockArgs +# efficientnet_params: A function to query compound coefficient +# get_model_params and efficientnet: +# Functions to get BlockArgs and GlobalParams for efficientnet +# url_map and url_map_advprop: Dicts of url_map for pretrained weights +# load_pretrained_weights: A function to load pretrained weights + + +class BlockDecoder(object): + """Block Decoder for readability, + straight from the official TensorFlow repository. + """ + + @staticmethod + def _decode_block_string(block_string): + """Get a block through a string notation of arguments. + + Args: + block_string (str): A string notation of arguments. + Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. + + Returns: + BlockArgs: The namedtuple defined at the top of this file. + """ + assert isinstance(block_string, str) + + ops = block_string.split("_") + options = {} + for op in ops: + splits = re.split(r"(\d.*)", op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert ("s" in options and len(options["s"]) == 1) or ( + len(options["s"]) == 2 and options["s"][0] == options["s"][1] + ) + + return BlockArgs( + num_repeat=int(options["r"]), + kernel_size=int(options["k"]), + stride=[int(options["s"][0])], + expand_ratio=int(options["e"]), + input_filters=int(options["i"]), + output_filters=int(options["o"]), + se_ratio=float(options["se"]) if "se" in options else None, + id_skip=("noskip" not in block_string), + ) + + @staticmethod + def _encode_block_string(block): + """Encode a block to a string. + + Args: + block (namedtuple): A BlockArgs type argument. + + Returns: + block_string: A String form of BlockArgs. + """ + args = [ + "r%d" % block.num_repeat, + "k%d" % block.kernel_size, + "s%d%d" % (block.strides[0], block.strides[1]), + "e%s" % block.expand_ratio, + "i%d" % block.input_filters, + "o%d" % block.output_filters, + ] + if 0 < block.se_ratio <= 1: + args.append("se%s" % block.se_ratio) + if block.id_skip is False: + args.append("noskip") + return "_".join(args) + + @staticmethod + def decode(string_list): + """Decode a list of string notations to specify blocks inside the network. + + Args: + string_list (list[str]): A list of strings, each string is a notation of block. + + Returns: + blocks_args: A list of BlockArgs namedtuples of block args. + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """Encode a list of BlockArgs to a list of strings. + + Args: + blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. + + Returns: + block_strings: A list of strings, each string is a notation of block. + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def efficientnet_params(model_name): + """Map EfficientNet model name to parameter coefficients. + + Args: + model_name (str): Model name to be queried. + + Returns: + params_dict[model_name]: A (width,depth,res,dropout) tuple. + """ + params_dict = { + # Coefficients: width,depth,res,dropout + "efficientnet-b0": (1.0, 1.0, 224, 0.2), + "efficientnet-b1": (1.0, 1.1, 240, 0.2), + "efficientnet-b2": (1.1, 1.2, 260, 0.3), + "efficientnet-b3": (1.2, 1.4, 300, 0.3), + "efficientnet-b4": (1.4, 1.8, 380, 0.4), + "efficientnet-b5": (1.6, 2.2, 456, 0.4), + "efficientnet-b6": (1.8, 2.6, 528, 0.5), + "efficientnet-b7": (2.0, 3.1, 600, 0.5), + "efficientnet-b8": (2.2, 3.6, 672, 0.5), + "efficientnet-l2": (4.3, 5.3, 800, 0.5), + } + return params_dict[model_name] + + +def efficientnet( + width_coefficient=None, + depth_coefficient=None, + image_size=None, + dropout_rate=0.2, + drop_connect_rate=0.2, + num_classes=1000, + include_top=True, + out_stride=32, +): + """Create BlockArgs and GlobalParams for efficientnet model. + + Args: + width_coefficient (float) + depth_coefficient (float) + image_size (int) + dropout_rate (float) + drop_connect_rate (float) + num_classes (int) + + Meaning as the name suggests. + + Returns: + blocks_args, global_params. + """ + + # Blocks args for the whole model(efficientnet-b0 by default) + # It will be modified in the construction of EfficientNet Class according to model + if out_stride == 32: + blocks_args = [ + "r1_k3_s11_e1_i32_o16_se0.25", + "r2_k3_s22_e6_i16_o24_se0.25", + "r2_k5_s22_e6_i24_o40_se0.25", + "r3_k3_s22_e6_i40_o80_se0.25", + "r3_k5_s11_e6_i80_o112_se0.25", + "r4_k5_s22_e6_i112_o192_se0.25", + "r1_k3_s11_e6_i192_o320_se0.25", + ] + elif out_stride == 16: + blocks_args = [ + "r1_k3_s11_e1_i32_o16_se0.25", + "r2_k3_s22_e6_i16_o24_se0.25", + "r2_k5_s22_e6_i24_o40_se0.25", + "r3_k3_s22_e6_i40_o80_se0.25", + "r3_k5_s11_e6_i80_o112_se0.25", + "r4_k5_s11_e6_i112_o192_se0.25", + "r1_k3_s11_e6_i192_o320_se0.25", + ] + else: + raise NotImplementedError + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + image_size=image_size, + dropout_rate=dropout_rate, + num_classes=num_classes, + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + drop_connect_rate=drop_connect_rate, + depth_divisor=8, + min_depth=None, + include_top=include_top, + ) + + return blocks_args, global_params + + +def get_model_params(model_name, out_stride, override_params): + """Get the block args and global params for a given model name. + + Args: + model_name (str): Model's name. + override_params (dict): A dict to modify global_params. + + Returns: + blocks_args, global_params + """ + if model_name.startswith("efficientnet"): + w, d, s, p = efficientnet_params(model_name) + # note: all models have drop connect rate = 0.2 + blocks_args, global_params = efficientnet( + width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s, out_stride=out_stride + ) + else: + raise NotImplementedError("model name is not pre-defined: {}".format(model_name)) + if override_params: + # ValueError will be raised here if override_params has fields not included in global_params. + global_params = global_params._replace(**override_params) + return blocks_args, global_params + + +# train with Standard methods +# check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks) +url_map = { + "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth", + "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth", + "efficientnet-b2": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth", + "efficientnet-b3": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth", + "efficientnet-b4": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth", + "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth", + "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth", + "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth", +} + +# train with Adversarial Examples(AdvProp) +# check more details in paper(Adversarial Examples Improve Image Recognition) +url_map_advprop = { + "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth", + "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth", + "efficientnet-b2": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth", + "efficientnet-b3": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth", + "efficientnet-b4": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth", + "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth", + "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth", + "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth", + "efficientnet-b8": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth", +} + +# TODO: add the petrained weights url map of 'efficientnet-l2' + + +def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True): + """Loads pretrained weights from weights path or download using url. + + Args: + model (Module): The whole model of efficientnet. + model_name (str): Model name of efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. + advprop (bool): Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + """ + if isinstance(weights_path, str): + state_dict = torch.load(weights_path) + else: + # AutoAugment or Advprop (different preprocessing) + url_map_ = url_map_advprop if advprop else url_map + state_dict = model_zoo.load_url(url_map_[model_name]) + + if load_fc: + ret = model.load_state_dict(state_dict, strict=False) + assert not ret.missing_keys, "Missing keys when loading pretrained weights: {}".format(ret.missing_keys) + else: + state_dict.pop("_fc.weight") + state_dict.pop("_fc.bias") + ret = model.load_state_dict(state_dict, strict=False) + assert set(ret.missing_keys) == set( + ["_fc.weight", "_fc.bias"] + ), "Missing keys when loading pretrained weights: {}".format(ret.missing_keys) + assert not ret.unexpected_keys, "Missing keys when loading pretrained weights: {}".format(ret.unexpected_keys) + + if verbose: + print("Loaded pretrained weights for {}".format(model_name)) diff --git a/mapmaster/models/backbone/model.py b/mapmaster/models/backbone/model.py new file mode 100644 index 0000000..003124b --- /dev/null +++ b/mapmaster/models/backbone/model.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mapmaster.models.backbone.resnet import ResNet +from mapmaster.models.backbone.efficientnet import EfficientNet +from mapmaster.models.backbone.swin_transformer import SwinTransformer +from mapmaster.models.backbone.bifpn import BiFPN + + +class ResNetBackbone(nn.Module): + def __init__(self, bkb_kwargs, fpn_kwarg=None, up_shape=None, ret_layers=1): + super(ResNetBackbone, self).__init__() + assert 0 < ret_layers < 4 + self.ret_layers = ret_layers + self.bkb = ResNet(**bkb_kwargs) + self.fpn = None if fpn_kwarg is None else BiFPN(**fpn_kwarg) + self.up_shape = None if up_shape is None else up_shape + self.bkb.init_weights() + + def forward(self, inputs): + images = inputs["images"] + images = images.view(-1, *images.shape[-3:]) + bkb_features = list(self.bkb(images)[-self.ret_layers:]) + nek_features = self.fpn(bkb_features) if self.fpn is not None else None + return {"im_bkb_features": bkb_features, "im_nek_features": nek_features} + + +class EfficientNetBackbone(nn.Module): + def __init__(self, bkb_kwargs, fpn_kwarg=None, up_shape=None, ret_layers=1): + super(EfficientNetBackbone, self).__init__() + assert 0 < ret_layers < 4 + self.ret_layers = ret_layers + self.bkb = EfficientNet.from_pretrained(**bkb_kwargs) + self.fpn = None if fpn_kwarg is None else BiFPN(**fpn_kwarg) + self.up_shape = None if up_shape is None else up_shape + del self.bkb._conv_head + del self.bkb._bn1 + del self.bkb._avg_pooling + del self.bkb._dropout + del self.bkb._fc + + def forward(self, inputs): + images = inputs["images"] + images = images.view(-1, *images.shape[-3:]) + endpoints = self.bkb.extract_endpoints(images) + bkb_features = [] + for i, (key, value) in enumerate(endpoints.items()): + if i > 0: + bkb_features.append(value) + bkb_features = list(bkb_features[-self.ret_layers:]) + nek_features = self.fpn(bkb_features) if self.fpn is not None else None + return {"im_bkb_features": bkb_features, "im_nek_features": nek_features} + + +class SwinTRBackbone(nn.Module): + def __init__(self, bkb_kwargs, fpn_kwarg=None, up_shape=None, ret_layers=1): + super(SwinTRBackbone, self).__init__() + assert 0 < ret_layers < 4 + self.ret_layers = ret_layers + self.bkb = SwinTransformer(**bkb_kwargs) + self.fpn = None if fpn_kwarg is None else BiFPN(**fpn_kwarg) + self.up_shape = None if up_shape is None else up_shape + + def forward(self, inputs): + images = inputs["images"] + images = images.view(-1, *images.shape[-3:]) + bkb_features = list(self.bkb(images)[-self.ret_layers:]) + nek_features = None + if self.fpn is not None: + nek_features = self.fpn(bkb_features) + else: + if self.up_shape is not None: + nek_features = [torch.cat([self.up_sample(x, self.up_shape) for x in bkb_features], dim=1)] + + return {"im_bkb_features": bkb_features, "im_nek_features": nek_features} + + def up_sample(self, x, tgt_shape=None): + tgt_shape = self.tgt_shape if tgt_shape is None else tgt_shape + if tuple(x.shape[-2:]) == tuple(tgt_shape): + return x + return F.interpolate(x, size=tgt_shape, mode="bilinear", align_corners=True) diff --git a/mapmaster/models/backbone/resnet/__init__.py b/mapmaster/models/backbone/resnet/__init__.py new file mode 100644 index 0000000..4f73423 --- /dev/null +++ b/mapmaster/models/backbone/resnet/__init__.py @@ -0,0 +1 @@ +from .resnet import ResNet diff --git a/mapmaster/models/backbone/resnet/resnet.py b/mapmaster/models/backbone/resnet/resnet.py new file mode 100644 index 0000000..3762e25 --- /dev/null +++ b/mapmaster/models/backbone/resnet/resnet.py @@ -0,0 +1,596 @@ +import warnings + +import torch.nn as nn +import torch.utils.checkpoint as cp +from torch.nn.modules.batchnorm import _BatchNorm +from mmcv.runner import BaseModule +from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer +from .utils import ResLayer + + +class BasicBlock(BaseModule): + expansion = 1 + + def __init__( + self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style="pytorch", + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type="BN"), + dcn=None, + plugins=None, + init_cfg=None, + ): + super(BasicBlock, self).__init__(init_cfg) + assert dcn is None, "Not implemented yet." + assert plugins is None, "Not implemented yet." + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, inplanes, planes, 3, stride=stride, padding=dilation, dilation=dilation, bias=False + ) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer(conv_cfg, planes, planes, 3, padding=1, bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(BaseModule): + expansion = 4 + + def __init__( + self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style="pytorch", + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type="BN"), + dcn=None, + plugins=None, + init_cfg=None, + ): + """Bottleneck block for ResNet. + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if + it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super(Bottleneck, self).__init__(init_cfg) + assert style in ["pytorch", "caffe"] + assert dcn is None or isinstance(dcn, dict) + assert plugins is None or isinstance(plugins, list) + if plugins is not None: + allowed_position = ["after_conv1", "after_conv2", "after_conv3"] + assert all(p["position"] in allowed_position for p in plugins) + + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dcn = dcn + self.with_dcn = dcn is not None + self.plugins = plugins + self.with_plugins = plugins is not None + + if self.with_plugins: + # collect plugins for conv1/conv2/conv3 + self.after_conv1_plugins = [plugin["cfg"] for plugin in plugins if plugin["position"] == "after_conv1"] + self.after_conv2_plugins = [plugin["cfg"] for plugin in plugins if plugin["position"] == "after_conv2"] + self.after_conv3_plugins = [plugin["cfg"] for plugin in plugins if plugin["position"] == "after_conv3"] + + if self.style == "pytorch": + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer(norm_cfg, planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer(conv_cfg, inplanes, planes, kernel_size=1, stride=self.conv1_stride, bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = dcn.pop("fallback_on_stride", False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + conv_cfg, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False, + ) + else: + assert self.conv_cfg is None, "conv_cfg must be None for DCN" + self.conv2 = build_conv_layer( + dcn, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False, + ) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer(conv_cfg, planes, planes * self.expansion, kernel_size=1, bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + if self.with_plugins: + self.after_conv1_plugin_names = self.make_block_plugins(planes, self.after_conv1_plugins) + self.after_conv2_plugin_names = self.make_block_plugins(planes, self.after_conv2_plugins) + self.after_conv3_plugin_names = self.make_block_plugins(planes * self.expansion, self.after_conv3_plugins) + + def make_block_plugins(self, in_channels, plugins): + """make plugins for block. + Args: + in_channels (int): Input channels of plugin. + plugins (list[dict]): List of plugins cfg to build. + Returns: + list[str]: List of the names of plugin. + """ + assert isinstance(plugins, list) + plugin_names = [] + for plugin in plugins: + plugin = plugin.copy() + name, layer = build_plugin_layer(plugin, in_channels=in_channels, postfix=plugin.pop("postfix", "")) + assert not hasattr(self, name), f"duplicate plugin {name}" + self.add_module(name, layer) + plugin_names.append(name) + return plugin_names + + def forward_plugin(self, x, plugin_names): + out = x + for name in plugin_names: + out = getattr(self, name)(x) + return out + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: normalization layer after the third convolution layer""" + return getattr(self, self.norm3_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class ResNet(BaseModule): + """ResNet backbone. + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + stem_channels (int | None): Number of stem channels. If not specified, + it will be the same as `base_channels`. Default: None. + base_channels (int): Number of base channels of res layer. Default: 64. + in_channels (int): Number of input image channels. Default: 3. + num_stages (int): Resnet stages. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + plugins (list[dict]): List of plugins for stages, each dict contains: + - cfg (dict, required): Cfg dict to build plugin. + - position (str, required): Position inside block to insert + plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'. + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + Example: + >>> from mmdet.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + } + + def __init__( + self, + depth, + in_channels=3, + stem_channels=None, + base_channels=64, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style="pytorch", + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type="BN", requires_grad=True), + norm_eval=True, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + with_cp=False, + zero_init_residual=True, + pretrained=None, + init_cfg=None, + ): + super(ResNet, self).__init__(init_cfg) + self.zero_init_residual = zero_init_residual + if depth not in self.arch_settings: + raise KeyError(f"invalid depth {depth} for resnet") + + block_init_cfg = None + assert not (init_cfg and pretrained), "init_cfg and pretrained cannot be specified at the same time" + if isinstance(pretrained, str): + warnings.warn("DeprecationWarning: pretrained is deprecated, " 'please use "init_cfg" instead') + self.init_cfg = dict(type="Pretrained", checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type="Kaiming", layer="Conv2d"), + dict(type="Constant", val=1, layer=["_BatchNorm", "GroupNorm"]), + ] + block = self.arch_settings[depth][0] + if self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict(type="Constant", val=0, override=dict(name="norm2")) + elif block is Bottleneck: + block_init_cfg = dict(type="Constant", val=0, override=dict(name="norm3")) + else: + raise TypeError("pretrained must be a str or None") + + self.depth = depth + if stem_channels is None: + stem_channels = base_channels + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + if dcn is not None: + assert len(stage_with_dcn) == num_stages + self.plugins = plugins + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + if plugins is not None: + stage_plugins = self.make_stage_plugins(plugins, i) + else: + stage_plugins = None + planes = base_channels * 2 ** i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + dcn=dcn, + plugins=stage_plugins, + init_cfg=block_init_cfg, + ) + self.inplanes = planes * self.block.expansion + layer_name = f"layer{i + 1}" + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2 ** (len(self.stage_blocks) - 1) + + def make_stage_plugins(self, plugins, stage_idx): + """Make plugins for ResNet ``stage_idx`` th stage. + Currently we support to insert ``context_block``, + ``empirical_attention_block``, ``nonlocal_block`` into the backbone + like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of + Bottleneck. + An example of plugins format could be: + Examples: + >>> plugins=[ + ... dict(cfg=dict(type='xxx', arg1='xxx'), + ... stages=(False, True, True, True), + ... position='after_conv2'), + ... dict(cfg=dict(type='yyy'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='1'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='2'), + ... stages=(True, True, True, True), + ... position='after_conv3') + ... ] + >>> self = ResNet(depth=18) + >>> stage_plugins = self.make_stage_plugins(plugins, 0) + >>> assert len(stage_plugins) == 3 + Suppose ``stage_idx=0``, the structure of blocks in the stage would be: + .. code-block:: none + conv1-> conv2->conv3->yyy->zzz1->zzz2 + Suppose 'stage_idx=1', the structure of blocks in the stage would be: + .. code-block:: none + conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 + If stages is missing, the plugin would be applied to all stages. + Args: + plugins (list[dict]): List of plugins cfg to build. The postfix is + required if multiple same type plugins are inserted. + stage_idx (int): Index of stage to build + Returns: + list[dict]: Plugins for current stage + """ + stage_plugins = [] + for plugin in plugins: + plugin = plugin.copy() + stages = plugin.pop("stages", None) + assert stages is None or len(stages) == self.num_stages + # whether to insert plugin into current stage + if stages is None or stages[stage_idx]: + stage_plugins.append(plugin) + + return stage_plugins + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, in_channels, stem_channels // 2, kernel_size=3, stride=2, padding=1, bias=False + ), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, stem_channels // 2, stem_channels, kernel_size=3, stride=1, padding=1, bias=False + ), + build_norm_layer(self.norm_cfg, stem_channels)[1], + nn.ReLU(inplace=True), + ) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, in_channels, stem_channels, kernel_size=7, stride=2, padding=3, bias=False + ) + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f"layer{i}") + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super(ResNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +class ResNetV1d(ResNet): + r"""ResNetV1d variant described in `Bag of Tricks + `_. + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super(ResNetV1d, self).__init__(deep_stem=True, avg_down=True, **kwargs) diff --git a/mapmaster/models/backbone/resnet/utils.py b/mapmaster/models/backbone/resnet/utils.py new file mode 100644 index 0000000..affe659 --- /dev/null +++ b/mapmaster/models/backbone/resnet/utils.py @@ -0,0 +1,93 @@ +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.runner import Sequential +from torch import nn as nn + + +class ResLayer(Sequential): + """ResLayer to build ResNet style backbone. + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + downsample_first (bool): Downsample at the first block or last block. + False for Hourglass, True for ResNet. Default: True + """ + + def __init__( + self, + block, + inplanes, + planes, + num_blocks, + stride=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type="BN"), + downsample_first=True, + **kwargs + ): + self.block = block + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down: + conv_stride = 1 + downsample.append( + nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False) + ) + downsample.extend( + [ + build_conv_layer( + conv_cfg, inplanes, planes * block.expansion, kernel_size=1, stride=conv_stride, bias=False + ), + build_norm_layer(norm_cfg, planes * block.expansion)[1], + ] + ) + downsample = nn.Sequential(*downsample) + + layers = [] + if downsample_first: + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs + ) + ) + inplanes = planes * block.expansion + for _ in range(1, num_blocks): + layers.append( + block(inplanes=inplanes, planes=planes, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs) + ) + + else: # downsample_first=False is for HourglassModule + for _ in range(num_blocks - 1): + layers.append( + block(inplanes=inplanes, planes=inplanes, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs) + ) + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs + ) + ) + super(ResLayer, self).__init__(*layers) diff --git a/mapmaster/models/backbone/swin_transformer/__init__.py b/mapmaster/models/backbone/swin_transformer/__init__.py new file mode 100644 index 0000000..55af267 --- /dev/null +++ b/mapmaster/models/backbone/swin_transformer/__init__.py @@ -0,0 +1,85 @@ +import os +import torch +from .model import SwinTransformer as _SwinTransformer +from torch.utils import model_zoo + +model_urls = { + "tiny": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_tiny_patch4_window7_512x512.pth", + "base": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_base_patch4_window7_512x512.pth", +} + + +class SwinTransformer(_SwinTransformer): + def __init__( + self, + arch="tiny", + pretrained=False, + window_size=7, + shift_mode=1, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.3, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + use_checkpoint=False, + **kwargs + ): + if arch == "tiny": + embed_dim = 96 + depths = (2, 2, 6, 2) + num_heads = (3, 6, 12, 24) + elif arch == "small": + embed_dim = 96 + depths = (2, 2, 18, 2) + num_heads = (3, 6, 12, 24) + elif arch == "base": + embed_dim = 128 + depths = (2, 2, 18, 2) + num_heads = (4, 8, 16, 32) + else: + raise NotImplementedError + + super(SwinTransformer, self).__init__( + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + shift_mode=shift_mode, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + ape=ape, + patch_norm=patch_norm, + out_indices=out_indices, + use_checkpoint=use_checkpoint, + **kwargs + ) + if isinstance(pretrained, bool): + assert pretrained is True + print(model_urls[arch]) + state_dict = model_zoo.load_url(model_urls[arch])["state_dict"] + elif isinstance(pretrained, str): + assert os.path.exists(pretrained) + print(pretrained) + state_dict = torch.load(pretrained)["state_dict"] + else: + raise NotImplementedError + + self.arch = arch + self.init_weights(state_dict=state_dict) + + def init_weights(self, state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if "backbone" in key: + new_state_dict[key.replace("backbone.", "")] = value + ret = self.load_state_dict(new_state_dict, strict=False) + print("Backbone missing_keys: {}".format(ret.missing_keys)) + print("Backbone unexpected_keys: {}".format(ret.unexpected_keys)) diff --git a/mapmaster/models/backbone/swin_transformer/model.py b/mapmaster/models/backbone/swin_transformer/model.py new file mode 100644 index 0000000..f2e0da2 --- /dev/null +++ b/mapmaster/models/backbone/swin_transformer/model.py @@ -0,0 +1,670 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import numpy as np +from .utils import DropPath, to_2tuple, trunc_normal_, get_root_logger, load_checkpoint + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """Forward function. + + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + num_heads, + window_size=7, + shift_mode=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.shift_mode = shift_mode + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for i, h in enumerate(h_slices): + for j, w in enumerate(w_slices): + img_mask[:, h, w, :] = cnt + if self.shift_mode == 1 and j == 1: + continue + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint and self.training: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + shift_mode=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + shift_mode=shift_mode, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError("pretrained must be a str or None") + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic") + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f"norm{i}") + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() diff --git a/mapmaster/models/backbone/swin_transformer/utils.py b/mapmaster/models/backbone/swin_transformer/utils.py new file mode 100644 index 0000000..4d239b9 --- /dev/null +++ b/mapmaster/models/backbone/swin_transformer/utils.py @@ -0,0 +1,695 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import io +import os +import os.path as osp +import pkgutil +import time +import math +import logging +import warnings +from collections import OrderedDict +from importlib import import_module +from tempfile import TemporaryDirectory +from itertools import repeat + +import torch +import torch.nn as nn +import torch.distributed as dist +import torchvision +from torch import Tensor +from torch.optim import Optimizer +from torch.utils import model_zoo +from torch.nn import functional as F + +# from torch._six import container_abcs +import collections.abc as container_abcs + +import mmcv +from mmcv.fileio import FileClient +from mmcv.fileio import load as load_file +from mmcv.parallel import is_module_wrapper +from mmcv.utils import mkdir_or_exist +from mmcv.runner import get_dist_info + +ENV_MMCV_HOME = "MMCV_HOME" +ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" +DEFAULT_CACHE_DIR = "~/.cache" + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, container_abcs.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +def _get_mmcv_home(): + mmcv_home = os.path.expanduser( + os.getenv(ENV_MMCV_HOME, os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "mmcv")) + ) + + mkdir_or_exist(mmcv_home) + return mmcv_home + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +logger_initialized = {} + + +def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="w"): + """Initialize and get a logger by name. + + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified and the process rank is 0, a FileHandler + will also be added. + + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. Note that only the process of + rank 0 is affected, and other processes will set the level to + "Error" thus be silent most of the time. + file_mode (str): The file mode used in opening log file. + Defaults to 'w'. + + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + # handle hierarchical names + # e.g., logger "a" is initialized, then logger "a.b" will skip the + # initialization since it is a child of "a". + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + # handle duplicate logs to the console + # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) + # to the root logger. As logger.propagate is True by default, this root + # level handler causes logging messages from rank>0 processes to + # unexpectedly show up on the console, creating much unwanted clutter. + # To fix this issue, we set the root logger's StreamHandler, if any, to log + # at the ERROR level. + for handler in logger.root.handlers: + if type(handler) is logging.StreamHandler: + handler.setLevel(logging.ERROR) + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + # only rank 0 will add a FileHandler + if rank == 0 and log_file is not None: + # Here, the default behaviour of the official logger is 'a'. Thus, we + # provide an interface to change the file mode to the default + # behaviour. + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + if rank == 0: + logger.setLevel(log_level) + else: + logger.setLevel(logging.ERROR) + + logger_initialized[name] = True + + return logger + + +def get_root_logger(log_file=None, log_level=logging.INFO): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name, + e.g., "mmseg". + + Args: + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + + logger = get_logger(name="mmseg", log_file=log_file, log_level=log_level) + + return logger + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + # use _load_from_state_dict to enable checkpoint version control + def load(module, prefix=""): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, all_missing_keys, unexpected_keys, err_msg + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(module) + load = None # break load->load reference cycle + + # ignore "num_batches_tracked" of BN layers + missing_keys = [key for key in all_missing_keys if "num_batches_tracked" not in key] + + if unexpected_keys: + err_msg.append("unexpected key in source " f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append(f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert(0, "The model and loaded state dict do not match exactly\n") + err_msg = "\n".join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + + +def load_url_dist(url, model_dir=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get("LOCAL_RANK", rank)) + if rank == 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + return checkpoint + + +def load_pavimodel_dist(model_path, map_location=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + try: + from pavi import modelcloud + except ImportError: + raise ImportError("Please install pavi to load checkpoint from modelcloud.") + rank, world_size = get_dist_info() + rank = int(os.environ.get("LOCAL_RANK", rank)) + if rank == 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + return checkpoint + + +def load_fileclient_dist(filename, backend, map_location): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get("LOCAL_RANK", rank)) + allowed_backends = ["ceph"] + if backend not in allowed_backends: + raise ValueError(f"Load from Backend {backend} is not supported.") + if rank == 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +def get_torchvision_models(): + model_urls = dict() + for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f"torchvision.models.{name}") + if hasattr(_zoo, "model_urls"): + _urls = getattr(_zoo, "model_urls") + model_urls.update(_urls) + return model_urls + + +def get_external_models(): + mmcv_home = _get_mmcv_home() + default_json_path = osp.join(mmcv.__path__[0], "model_zoo/open_mmlab.json") + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(mmcv_home, "open_mmlab.json") + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(mmcv.__path__[0], "model_zoo/mmcls.json") + mmcls_urls = load_file(mmcls_json_path) + + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmcv.__path__[0], "model_zoo/deprecated.json") + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + + return deprecate_urls + + +def _process_mmcls_checkpoint(checkpoint): + state_dict = checkpoint["state_dict"] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith("backbone."): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + + return new_checkpoint + + +def _load_checkpoint(filename, map_location=None): + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + + Returns: + dict | OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + if filename.startswith("modelzoo://"): + warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' 'use "torchvision://" instead') + model_urls = get_torchvision_models() + model_name = filename[11:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith("torchvision://"): + model_urls = get_torchvision_models() + model_name = filename[14:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith("open-mmlab://"): + model_urls = get_external_models() + model_name = filename[13:] + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn( + f"open-mmlab://{model_name} is deprecated in favor " f"of open-mmlab://{deprecated_urls[model_name]}" + ) + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(("http://", "https://")): + checkpoint = load_url_dist(model_url) + else: + filename = osp.join(_get_mmcv_home(), model_url) + if not osp.isfile(filename): + raise IOError(f"{filename} is not a checkpoint file") + checkpoint = torch.load(filename, map_location=map_location) + elif filename.startswith("mmcls://"): + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_url_dist(model_urls[model_name]) + checkpoint = _process_mmcls_checkpoint(checkpoint) + elif filename.startswith(("http://", "https://")): + checkpoint = load_url_dist(filename) + elif filename.startswith("pavi://"): + model_path = filename[7:] + checkpoint = load_pavimodel_dist(model_path, map_location=map_location) + elif filename.startswith("s3://"): + checkpoint = load_fileclient_dist(filename, backend="ceph", map_location=map_location) + else: + if not osp.isfile(filename): + raise IOError(f"{filename} is not a checkpoint file") + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +def load_checkpoint(model, filename, map_location="cpu", strict=False, logger=None): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError(f"No state_dict found in checkpoint file {filename}") + # get state_dict from checkpoint + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif "model" in checkpoint: + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith("module."): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # for MoBY, load model of online branch + if sorted(list(state_dict.keys()))[0].startswith("encoder"): + state_dict = {k.replace("encoder.", ""): v for k, v in state_dict.items() if k.startswith("encoder.")} + + # reshape absolute position embedding + if state_dict.get("absolute_pos_embed") is not None: + absolute_pos_embed = state_dict["absolute_pos_embed"] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = model.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H * W: + logger.warning("Error in loading absolute_pos_embed, pass") + else: + state_dict["absolute_pos_embed"] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) + + # interpolate position bias table if needed + relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = model.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f"Error in loading {table_key}, pass") + else: + if L1 != L2: + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode="bicubic" + ) + state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + return state_dict_cpu + + +def _save_to_state_dict(module, destination, prefix, keep_vars): + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d + if buf is not None: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module, destination=None, prefix="", keep_vars=False): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Default: False. + + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=module._version) + _save_to_state_dict(module, destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict(child, destination, prefix + name + ".", keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def save_checkpoint(model, filename, optimizer=None, meta=None): + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f"meta must be a dict or None, but got {type(meta)}") + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, "CLASSES") and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))} + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint["optimizer"] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint["optimizer"] = {} + for name, optim in optimizer.items(): + checkpoint["optimizer"][name] = optim.state_dict() + + if filename.startswith("pavi://"): + try: + from pavi import modelcloud + from pavi.exception import NodeNotFoundError + except ImportError: + raise ImportError("Please install pavi to load checkpoint from modelcloud.") + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, "wb") as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + mmcv.mkdir_or_exist(osp.dirname(filename)) + # immediately flush buffer + with open(filename, "wb") as f: + torch.save(checkpoint, f) + f.flush() diff --git a/mapmaster/models/bev_decoder/__init__.py b/mapmaster/models/bev_decoder/__init__.py new file mode 100644 index 0000000..cd39500 --- /dev/null +++ b/mapmaster/models/bev_decoder/__init__.py @@ -0,0 +1 @@ +from .model import TransformerBEVDecoder, DeformTransformerBEVEncoder diff --git a/mapmaster/models/bev_decoder/deform_transformer/__init__.py b/mapmaster/models/bev_decoder/deform_transformer/__init__.py new file mode 100644 index 0000000..810662e --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/__init__.py @@ -0,0 +1 @@ +from .deform_transformer import DeformTransformer diff --git a/mapmaster/models/bev_decoder/deform_transformer/deform_transformer.py b/mapmaster/models/bev_decoder/deform_transformer/deform_transformer.py new file mode 100644 index 0000000..28f9ea1 --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/deform_transformer.py @@ -0,0 +1,672 @@ +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from .ops import MSDeformAttn +from .position_encoding import PositionEmbeddingSine +from .position_encoding import PositionEmbeddingLearned + +class DeformTransformer(nn.Module): + def __init__( + self, + in_channels, + src_shape=(16, 168), + tgt_shape=(32, 32), + d_model=256, + n_heads=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0.1, + activation="relu", + return_intermediate_dec=False, + dec_n_points=4, + enc_n_points=4, + src_pos_encode="sine", + tgt_pos_encode="learned", + norm_layer=nn.BatchNorm2d, + use_checkpoint=False, + use_projection=False, + map_size=(400, 200), + image_shape=(900, 1600), + map_resolution=0.15, + image_order=(2, 1, 0, 5, 4, 3) + ): + super().__init__() + + if isinstance(in_channels, int): + in_channels = [in_channels] + if isinstance(src_shape[0], int): + src_shape = [src_shape] + assert len(src_shape) == len(in_channels) + n_levels = len(in_channels) + + self.input_proj = nn.ModuleList() + for i in range(len(in_channels)): + self.input_proj.append( + nn.Sequential( + nn.Conv2d(in_channels[i], d_model, kernel_size=1, bias=False), + norm_layer(d_model), + ) + ) + + encoder_layer = DeformTransformerEncoderLayer( + d_model, dim_feedforward, dropout, activation, n_levels, n_heads, enc_n_points, use_checkpoint + ) + self.encoder = DeformTransformerEncoder(encoder_layer, num_encoder_layers) + + decoder_layer = DeformTransformerDecoderLayer( + d_model, dim_feedforward, dropout, activation, n_levels, n_heads, dec_n_points, use_checkpoint + ) + self.decoder = DeformTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec) + + self.dropout = nn.Dropout(dropout) + + self.t2s_reference_points = nn.Linear(d_model, 2) + + self._reset_parameters() + + if src_pos_encode == "sine": + self.src_pos_embed = PositionEmbeddingSine(d_model, normalize=True) + self.src_lvl_embed = nn.Embedding(n_levels, d_model) + elif src_pos_encode == "learned": + self.src_pos_embed = nn.ModuleList( + [PositionEmbeddingLearned(shape, d_model) for shape in src_shape], + ) + else: + raise NotImplementedError + + if tgt_pos_encode == "sine": + self.tgt_pos_embed = PositionEmbeddingSine(d_model, normalize=True) + elif tgt_pos_encode == "learned": + self.tgt_pos_embed = PositionEmbeddingLearned(tgt_shape, d_model) + else: + raise NotImplementedError + + self.tgt_embed = PositionEmbeddingLearned(tgt_shape, d_model) + + self.src_shape = src_shape + self.tgt_shape = tgt_shape + self.src_pos_encode = src_pos_encode + self.tgt_pos_encode = tgt_pos_encode + + """ + use_projection: bool / whether to use IPM as the reference points + map_size: (x_width, y_width) shape of the original Map (400, 200) + image_shape: (Height, Width) + map_resolution: map resolution (m / pixel) + """ + self.use_projection = use_projection # Use IPM Projection to get reference points + self.map_size = map_size + self.map_resolution = map_resolution + self.image_shape = image_shape + image_order = torch.tensor(image_order, dtype=torch.long) + self.register_buffer("image_order", image_order) + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + nn.init.xavier_uniform_(self.t2s_reference_points.weight, gain=1.0) + nn.init.constant_(self.t2s_reference_points.bias, 0.0) + + @staticmethod + def get_valid_ratio(mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def get_projection_points(self, extrinsic, intrinsic, flip=False): + """ + extrinsic: + torch.Tensor (6, 4, 4) + intrinsic: + torch.Tensor (6, 3, 3) + flip: + flip or not + + Return + reference points (N, L, 2) + mask (N, ) + """ + map_forward_ratio = self.tgt_shape[0] / self.map_size[0] + map_lateral_ratio = self.tgt_shape[1] / self.map_size[1] + + map_forward_res = self.map_resolution / map_forward_ratio + map_lateral_res = self.map_resolution / map_lateral_ratio + + X = (torch.arange(self.tgt_shape[0] - 1, -1, -1, device=extrinsic.device) + 0.5 - self.tgt_shape[0] / 2) * map_forward_res + Y = (torch.arange(self.tgt_shape[1] - 1, -1, -1, device=extrinsic.device) + 0.5 - self.tgt_shape[1] / 2) * map_lateral_res + if flip: + Y = -1 * Y # Flip the Y axis + + Z = torch.zeros(self.tgt_shape, device=extrinsic.device) + grid_X, grid_Y = torch.meshgrid(X, Y) + coords = torch.stack([grid_X, grid_Y, Z, torch.ones(self.tgt_shape, device=extrinsic.device)], dim=-1) # (H, W, 4) homogeneous coordinates + coords_flatten = coords.reshape(-1, 4) # (N, 4) + + cams = [] + for cam in extrinsic: + cam_coords = torch.linalg.inv(cam) @ coords_flatten.T # (4, N) + cam_coords = cam_coords[:3, :] # (3, N) -- x, y, z + cams.append(cam_coords) + cams = torch.stack(cams, dim=0) # (6, 3, N) Coordinates in Camera Frame + normed_coors = F.normalize(cams, p=1, dim=0) # (6, 3, N) Normalized Coordinates in Camera Frame + + cams_z = normed_coors[:, 2, :] # (6, N) z coord + cam_id = torch.argmax(cams_z, dim=0) # (N, ) -- bev to img idx, Choose the camera with the smallest angle of view + + max_z = cams_z[cam_id, torch.arange(cams.shape[-1])] + valid_mask = max_z > 0 + + intrinsic_percam = intrinsic[cam_id] # (N, 3, 3) + + coords_percam = cams[cam_id, :, torch.arange(cams.shape[2])] # (N, 3) + pixel_coord = (intrinsic_percam @ coords_percam[:, :, None]).squeeze() # (N, 3) + pixel_coord = pixel_coord[:, :2] / pixel_coord[:, [2]] # divided by Z / (N, 2) + + if not isinstance(self.image_shape, list): + image_shape = torch.tensor([self.image_shape for _ in range(len(extrinsic))], device=extrinsic.device)[cam_id] + else: + image_shape = torch.tensor(self.image_shape, device=extrinsic.device)[cam_id] + + valid_pixelx = torch.bitwise_and(pixel_coord[:, 0] < image_shape[:,1], pixel_coord[:, 0] >= 0) + valid_pixely = torch.bitwise_and(pixel_coord[:, 1] < image_shape[:,0], pixel_coord[:, 1] >= 0) + valid_mask = valid_mask * valid_pixelx * valid_pixely + + # cast to levels + reference_points = [] + for level_shape in self.src_shape: + level_h, level_w = level_shape + level_w /= 6 + image_h, image_w = image_shape.T + + ratio_h = image_h / level_h + ratio_w = image_w / level_w + + if flip: + img_x = level_w - pixel_coord[:, 0] / ratio_w + cam_id_ = self.image_order[cam_id] + x = cam_id_ * level_w + img_x + + else: + x = cam_id * level_w + pixel_coord[:, 0] / ratio_w + + y = pixel_coord[:, 1] / ratio_h + + x /= (level_w * 6) # Normalize to [0 ~ 1] + y /= level_h + + level_point = torch.stack([x, y], dim=-1) + reference_points.append(level_point) + + reference_points = torch.stack(reference_points, dim=-2) + return reference_points, valid_mask + + def forward(self, srcs, src_masks=None, cameras_info=None): + + if not isinstance(srcs, (list, tuple)): + srcs = [srcs] + if not isinstance(src_masks, (list, tuple)): + src_masks = [src_masks] + + if src_masks[0] is None: + src_masks = [] + + src_flatten = [] + src_mask_flatten = [] + src_pos_embed_flatten = [] + src_spatial_shapes = [] + for i, src in enumerate(srcs): + bs, c, h, w = src.shape + spatial_shape = (h, w) + src_spatial_shapes.append(spatial_shape) + + if len(src_masks) < i + 1: + src_mask = torch.zeros((bs, h, w), dtype=torch.bool, device=src.device) # (N, H, W) + src_masks.append(src_mask) + else: + src_mask = src_masks[i] + + if self.src_pos_encode == "sine": + src_pos_embed = self.src_pos_embed(src_mask) # (N, C, H, W) + src_pos_embed = src_pos_embed + self.src_lvl_embed.weight[i].view(-1, 1, 1) # (N, C, H, W) + else: + src_pos_embed = self.src_pos_embed[i](src_mask) # (N, C, H, W) + + src = self.input_proj[i](src) # (N, C, H, W) + src = src + src_pos_embed # (N, C, H, W) + + src = src.flatten(2).transpose(1, 2) # (N, H * W, C) + src_mask = src_mask.flatten(1) # (N, H * W) + src_pos_embed = src_pos_embed.flatten(2).transpose(1, 2) # (N, H * W, C) + + src_flatten.append(src) + src_mask_flatten.append(src_mask) + src_pos_embed_flatten.append(src_pos_embed) + + src = torch.cat(src_flatten, 1) # (N, L * H * W, C) + src_mask = torch.cat(src_mask_flatten, 1) # (N, L * H * W) + src_pos_embed = torch.cat(src_pos_embed_flatten, 1) # (N, L * H * W, C) + src_spatial_shapes = torch.as_tensor(src_spatial_shapes, dtype=torch.long, device=src.device) # (L, 2) + src_level_start_index = torch.cat( + (src_spatial_shapes.new_zeros((1,)), src_spatial_shapes.prod(1).cumsum(0)[:-1]) + ) # (L,) + src_valid_ratios = torch.stack([self.get_valid_ratio(m) for m in src_masks], 1) # (N, L, 2) + + tgt_mask = torch.zeros((srcs[0].size(0), *self.tgt_shape), dtype=torch.bool, device=srcs[0].device) # (N, H, W) + tgt_pos_embed = self.tgt_pos_embed(tgt_mask) # (N, C, H, W) + tgt_pos_embed = tgt_pos_embed.flatten(2).transpose(1, 2) # (N, H * W, C) + # tgt = tgt_pos_embed # (N, H * W, C) + tgt = self.tgt_embed(tgt_mask).flatten(2).transpose(1, 2) + + tgt_spatial_shapes = torch.as_tensor(self.tgt_shape, dtype=torch.long, device=tgt.device).unsqueeze(0) # (1, 2) + tgt_valid_ratios = self.get_valid_ratio(tgt_mask).unsqueeze(1) # (N, 1, 2) + tgt_level_start_index = tgt_spatial_shapes.new_zeros((1,)) # (1,) + tgt_mask = tgt_mask.flatten(1) # (N, 1 * H * W) + + t2s_reference_points = self.t2s_reference_points(tgt_pos_embed).sigmoid() # (N, H * W, 2) + + if self.use_projection: + t2s_reference_points = t2s_reference_points.unsqueeze(-2).repeat(1, 1, len(self.src_shape), 1) # (N, H * W, L, 2) + bs = srcs[0].shape[0] + + do_flip = cameras_info['do_flip'] + if do_flip is None: + do_flip = torch.zeros((bs,), dtype=torch.bool) + + for i in range(bs): + flip = do_flip[i].item() + extrinsic = cameras_info['extrinsic'][i].float() + intrinsic = cameras_info['intrinsic'][i].float() + + # Use IPM to generate reference points, Original Size (900, 1600) + ipm_reference_points, valid_mask = self.get_projection_points(extrinsic, intrinsic, flip) # (N, L, 2) + loc = torch.where(valid_mask > 0)[0] + + # Change the embeddings to reference point coordinate + t2s_reference_points[i, loc, :, :] = ipm_reference_points[loc, :, :] + else: + t2s_reference_points = t2s_reference_points[:, :, None] + + # encoder + memory = self.encoder( + src, src_spatial_shapes, src_level_start_index, src_valid_ratios, src_pos_embed, src_mask + ) # (N, H * W, C) + # decoder + hs = self.decoder( + tgt, + memory, + tgt_pos_embed, + t2s_reference_points, + tgt_spatial_shapes, + src_spatial_shapes, + tgt_level_start_index, + src_level_start_index, + tgt_valid_ratios, + src_valid_ratios, + tgt_mask, + src_mask, + ) # (M, N, H * W, C) + ys = hs.transpose(2, 3) # (M, N, C, H * W) + ys = ys.reshape(*ys.shape[:-1], *self.tgt_shape).contiguous() # (M, N, C, H, W) + return [memory, hs, ys] + + +class DeformTransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model=256, + d_ffn=1024, + dropout=0.1, + activation="relu", + n_levels=4, + n_heads=8, + n_points=4, + use_checkpoint=False, + ): + super().__init__() + + # self attention + self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + self.use_checkpoint = use_checkpoint + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def _forward( + self, src, src_pos_embed, src_reference_points, src_spatial_shapes, src_level_start_index, src_key_padding_mask + ): + # self attention + src2 = self.self_attn( + self.with_pos_embed(src, src_pos_embed), + src_reference_points, + src, + src_spatial_shapes, + src_level_start_index, + src_key_padding_mask, + ) + src = src + self.dropout1(src2) + src = self.norm1(src) + + # ffn + src = self.forward_ffn(src) + return src + + def forward( + self, src, src_pos_embed, src_reference_points, src_spatial_shapes, src_level_start_index, src_key_padding_mask + ): + if self.use_checkpoint and self.training: + src = checkpoint.checkpoint( + self._forward, + src, + src_pos_embed, + src_reference_points, + src_spatial_shapes, + src_level_start_index, + src_key_padding_mask, + ) + else: + src = self._forward( + src, + src_pos_embed, + src_reference_points, + src_spatial_shapes, + src_level_start_index, + src_key_padding_mask, + ) + return src + + +class DeformTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward( + self, + src, + src_spatial_shapes, + src_level_start_index, + src_valid_ratios, + src_pos_embed=None, + src_key_padding_mask=None, + ): + + src_reference_points = self.get_reference_points(src_spatial_shapes, src_valid_ratios, device=src.device) + + output = src + for _, layer in enumerate(self.layers): + output = layer( + output, + src_pos_embed, + src_reference_points, + src_spatial_shapes, + src_level_start_index, + src_key_padding_mask, + ) + return output + + +class DeformTransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model=256, + d_ffn=1024, + dropout=0.1, + activation="relu", + n_levels=4, + n_heads=8, + n_points=4, + use_checkpoint=False, + ): + super().__init__() + + # cross attention + self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # self attention + # self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + + self.self_attn = MSDeformAttn(d_model, 1, n_heads, n_points) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + self.use_checkpoint = use_checkpoint + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def _forward( + self, + tgt, + src, + tgt_pos_embed, + tgt_reference_points, + t2s_reference_points, + tgt_spatial_shapes, + src_spatial_shapes, + tgt_level_start_index, + src_level_start_index, + tgt_key_padding_mask, + src_key_padding_mask, + ): + # self attention + # q = k = self.with_pos_embed(tgt, tgt_pos_embed) + # tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1) + + tgt2 = self.self_attn( + self.with_pos_embed(tgt, tgt_pos_embed), + tgt_reference_points, + tgt, + tgt_spatial_shapes, + tgt_level_start_index, + tgt_key_padding_mask, + ) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # cross attention + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, tgt_pos_embed), + t2s_reference_points, + src, + src_spatial_shapes, + src_level_start_index, + src_key_padding_mask, + ) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt = self.forward_ffn(tgt) + return tgt + + def forward( + self, + tgt, + src, + tgt_pos_embed, + tgt_reference_points, + t2s_reference_points, + tgt_spatial_shapes, + src_spatial_shapes, + tgt_level_start_index, + src_level_start_index, + tgt_key_padding_mask, + src_key_padding_mask, + ): + if self.use_checkpoint and self.training: + tgt = checkpoint.checkpoint( + self._forward, + tgt, + src, + tgt_pos_embed, + tgt_reference_points, + t2s_reference_points, + tgt_spatial_shapes, + src_spatial_shapes, + tgt_level_start_index, + src_level_start_index, + tgt_key_padding_mask, + src_key_padding_mask, + ) + else: + tgt = self._forward( + tgt, + src, + tgt_pos_embed, + tgt_reference_points, + t2s_reference_points, + tgt_spatial_shapes, + src_spatial_shapes, + tgt_level_start_index, + src_level_start_index, + tgt_key_padding_mask, + src_key_padding_mask, + ) + return tgt + + +class DeformTransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward( + self, + tgt, + src, + tgt_pos_embed, + t2s_reference_points, + tgt_spatial_shapes, + src_spatial_shapes, + tgt_level_start_index, + src_level_start_index, + tgt_valid_ratios, + src_valid_ratios, + tgt_key_padding_mask=None, + src_key_padding_mask=None, + ): + + tgt_reference_points = self.get_reference_points(tgt_spatial_shapes, tgt_valid_ratios, device=tgt.device) + t2s_reference_points = t2s_reference_points * src_valid_ratios[:, None] + # t2s_reference_points = t2s_reference_points[:, :, None] * src_valid_ratios[:, None] + + intermediate = [] + output = tgt + for _, layer in enumerate(self.layers): + output = layer( + output, + src, + tgt_pos_embed, + tgt_reference_points, + t2s_reference_points, + tgt_spatial_shapes, + src_spatial_shapes, + tgt_level_start_index, + src_level_start_index, + tgt_key_padding_mask, + src_key_padding_mask, + ) + + if self.return_intermediate: + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/__init__.py b/mapmaster/models/bev_decoder/deform_transformer/ops/__init__.py new file mode 100644 index 0000000..d8f3e3c --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/__init__.py @@ -0,0 +1 @@ +from .modules import MSDeformAttn diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/functions/__init__.py b/mapmaster/models/bev_decoder/deform_transformer/ops/functions/__init__.py new file mode 100644 index 0000000..06ebc91 --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/functions/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn_func import MSDeformAttnFunction diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/functions/ms_deform_attn_func.py b/mapmaster/models/bev_decoder/deform_transformer/ops/functions/ms_deform_attn_func.py new file mode 100644 index 0000000..7cd8d12 --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/functions/ms_deform_attn_func.py @@ -0,0 +1,73 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +import MultiScaleDeformableAttention as MSDA + + +class MSDeformAttnFunction(Function): + @staticmethod + def forward( + ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step + ): + ctx.im2col_step = im2col_step + output = MSDA.ms_deform_attn_forward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step + ) + ctx.save_for_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = MSDA.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output, + ctx.im2col_step, + ) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_) + return output.transpose(1, 2).contiguous() diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/make.sh b/mapmaster/models/bev_decoder/deform_transformer/ops/make.sh new file mode 100644 index 0000000..106b685 --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/make.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +python setup.py build install diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/modules/__init__.py b/mapmaster/models/bev_decoder/deform_transformer/ops/modules/__init__.py new file mode 100644 index 0000000..ff5ce83 --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/modules/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn import MSDeformAttn diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/modules/ms_deform_attn.py b/mapmaster/models/bev_decoder/deform_transformer/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000..c091ed6 --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/modules/ms_deform_attn.py @@ -0,0 +1,140 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import math + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ + +from ..functions import MSDeformAttnFunction + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError("d_model must be divisible by n_heads, but got {} and {}".format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn( + "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation." + ) + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + def forward( + self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None, + ): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1]) + ) + output = MSDeformAttnFunction.apply( + value, + input_spatial_shapes, + input_level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) + return output diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/setup.py b/mapmaster/models/bev_decoder/deform_transformer/ops/setup.py new file mode 100644 index 0000000..148dc05 --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/setup.py @@ -0,0 +1,84 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +import os +import glob + +import torch + +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +from setuptools import find_packages +from setuptools import setup + +requirements = ["torch", "torchvision"] + + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": []} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + "-arch=sm_60", + "-gencode=arch=compute_60,code=sm_60", + "-gencode=arch=compute_61,code=sm_61", + "-gencode=arch=compute_70,code=sm_70", + "-gencode=arch=compute_75,code=sm_75", + # "-gencode=arch=compute_80,code=sm_80", + ] + else: + raise NotImplementedError("Cuda is not availabel") + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "MultiScaleDeformableAttention", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + + +setup( + name="MultiScaleDeformableAttention", + version="1.0", + author="Weijie Su", + url="https://github.com/fundamentalvision/Deformable-DETR", + description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", + packages=find_packages( + exclude=( + "configs", + "tests", + ) + ), + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/cpu/ms_deform_attn_cpu.cpp b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cpu/ms_deform_attn_cpu.cpp new file mode 100644 index 0000000..dc1c0a1 --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cpu/ms_deform_attn_cpu.cpp @@ -0,0 +1,40 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/cpu/ms_deform_attn_cpu.h b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cpu/ms_deform_attn_cpu.h new file mode 100644 index 0000000..67422ce --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cpu/ms_deform_attn_cpu.h @@ -0,0 +1,31 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_attn_cuda.cu b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_attn_cuda.cu new file mode 100644 index 0000000..aa9ee57 --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_attn_cuda.cu @@ -0,0 +1,153 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include "cuda/ms_deform_im2col_cuda.cuh" + +#include +#include +#include +#include + + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_attn_cuda.h b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_attn_cuda.h new file mode 100644 index 0000000..2d03704 --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_attn_cuda.h @@ -0,0 +1,29 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_im2col_cuda.cuh b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_im2col_cuda.cuh new file mode 100644 index 0000000..3d86a01 --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/cuda/ms_deform_im2col_cuda.cuh @@ -0,0 +1,1327 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/ms_deform_attn.h b/mapmaster/models/bev_decoder/deform_transformer/ops/src/ms_deform_attn.h new file mode 100644 index 0000000..c5c6166 --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/ms_deform_attn.h @@ -0,0 +1,61 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once + +#include "cpu/ms_deform_attn_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/ms_deform_attn_cuda.h" +#endif + + +at::Tensor +ms_deform_attn_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +ms_deform_attn_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/src/vision.cpp b/mapmaster/models/bev_decoder/deform_transformer/ops/src/vision.cpp new file mode 100644 index 0000000..d37eafb --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/src/vision.cpp @@ -0,0 +1,16 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "ms_deform_attn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); +} diff --git a/mapmaster/models/bev_decoder/deform_transformer/ops/test.py b/mapmaster/models/bev_decoder/deform_transformer/ops/test.py new file mode 100644 index 0000000..50bb67e --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/ops/test.py @@ -0,0 +1,117 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +from torch.autograd import gradcheck + +from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch + + +N, M, D = 1, 2, 2 +Lq, L, P = 2, 2, 2 +shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() +level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1])) +S = sum([(H * W).item() for H, W in shapes]) + + +torch.manual_seed(3) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_double(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ( + ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()) + .detach() + .cpu() + ) + output_cuda = ( + MSDeformAttnFunction.apply( + value.double(), + shapes, + level_start_index, + sampling_locations.double(), + attention_weights.double(), + im2col_step, + ) + .detach() + .cpu() + ) + fwdok = torch.allclose(output_cuda, output_pytorch) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print( + f"* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}" + ) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_float(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() + output_cuda = ( + MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step) + .detach() + .cpu() + ) + fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print( + f"* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}" + ) + + +def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): + + value = torch.rand(N, S, M, channels).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + func = MSDeformAttnFunction.apply + + value.requires_grad = grad_value + sampling_locations.requires_grad = grad_sampling_loc + attention_weights.requires_grad = grad_attn_weight + + gradok = gradcheck( + func, + ( + value.double(), + shapes, + level_start_index, + sampling_locations.double(), + attention_weights.double(), + im2col_step, + ), + ) + + print(f"* {gradok} check_gradient_numerical(D={channels})") + + +if __name__ == "__main__": + check_forward_equal_with_pytorch_double() + check_forward_equal_with_pytorch_float() + + for channels in [30, 32, 64, 71, 1025, 2048, 3096]: + check_gradient_numerical(channels, True, True, True) diff --git a/mapmaster/models/bev_decoder/deform_transformer/position_encoding.py b/mapmaster/models/bev_decoder/deform_transformer/position_encoding.py new file mode 100644 index 0000000..7ecdcaf --- /dev/null +++ b/mapmaster/models/bev_decoder/deform_transformer/position_encoding.py @@ -0,0 +1,63 @@ +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, mask): + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats // 2, dtype=torch.float32, device=mask.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / (self.num_pos_feats // 2)) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + + def __init__(self, num_pos=(50, 50), num_pos_feats=256): + super().__init__() + self.num_pos = num_pos + self.pos_embed = nn.Embedding(num_pos[0] * num_pos[1], num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.pos_embed.weight) + + def forward(self, mask): + h, w = mask.shape[-2:] + pos = self.pos_embed.weight.view(*self.num_pos, -1)[:h, :w] + pos = pos.permute(2, 0, 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1) + return pos diff --git a/mapmaster/models/bev_decoder/model.py b/mapmaster/models/bev_decoder/model.py new file mode 100644 index 0000000..daddb64 --- /dev/null +++ b/mapmaster/models/bev_decoder/model.py @@ -0,0 +1,53 @@ +import torch +import numpy as np +import torch.nn as nn +from mapmaster.models.bev_decoder.transformer import Transformer +from mapmaster.models.bev_decoder.deform_transformer import DeformTransformer + +class TransformerBEVDecoder(nn.Module): + def __init__(self, key='im_bkb_features', **kwargs): + super(TransformerBEVDecoder, self).__init__() + self.bev_encoder = Transformer(**kwargs) + self.key = key + + def forward(self, inputs): + assert self.key in inputs + feats = inputs[self.key] + fuse_feats = feats[-1] + fuse_feats = fuse_feats.reshape(*inputs['images'].shape[:2], *fuse_feats.shape[-3:]) + fuse_feats = torch.cat(torch.unbind(fuse_feats, dim=1), dim=-1) + + cameras_info = { + 'extrinsic': inputs.get('extrinsic', None), + 'intrinsic': inputs.get('intrinsic', None), + 'ida_mats': inputs.get('ida_mats', None), + 'do_flip': inputs['extra_infos'].get('do_flip', None) + } + + _, _, bev_feats = self.bev_encoder(fuse_feats, cameras_info=cameras_info) + + return {"bev_enc_features": list(bev_feats)} + +class DeformTransformerBEVEncoder(nn.Module): + def __init__(self, **kwargs): + super(DeformTransformerBEVEncoder, self).__init__() + self.bev_encoder = DeformTransformer(**kwargs) + + def forward(self, inputs): + assert "im_bkb_features" in inputs + feats = inputs["im_bkb_features"] + for i in range(len(feats)): + feats[i] = feats[i].reshape(*inputs["images"].shape[:2], *feats[i].shape[-3:]) + feats[i] = feats[i].permute(0, 2, 3, 1, 4) + feats[i] = feats[i].reshape(*feats[i].shape[:3], -1) + cameras_info = { + 'extrinsic': inputs.get('extrinsic', None), + 'intrinsic': inputs.get('intrinsic', None), + 'do_flip': inputs['extra_infos'].get('do_flip', None) + } + # src_feats: (N, H1 * W1, C) tgt_feats: # (M, N, H2 * W2, C) + _, _, bev_feats = self.bev_encoder(feats, cameras_info=cameras_info) + + return { + "bev_enc_features": list(bev_feats), + } diff --git a/mapmaster/models/bev_decoder/transformer.py b/mapmaster/models/bev_decoder/transformer.py new file mode 100644 index 0000000..7e6bab7 --- /dev/null +++ b/mapmaster/models/bev_decoder/transformer.py @@ -0,0 +1,407 @@ +import copy +import torch +import numpy as np +import torch.nn.functional as F +from torch import nn, Tensor +from typing import Optional +from torch.utils.checkpoint import checkpoint +from mapmaster.models.utils.position_encoding import PositionEmbeddingSine, PositionEmbeddingLearned +from mapmaster.models.utils.position_encoding import PositionEmbeddingIPM, PositionEmbeddingTgt + + +class Transformer(nn.Module): + def __init__( + self, + in_channels, + src_shape=(32, 336), + query_shape=(32, 32), + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + src_pos_embed='sine', + tgt_pos_embed='sine', + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + use_checkpoint=True, + ipm_proj_conf=None, + ipmpe_with_sine=True, + enforce_no_aligner=False, + ): + super().__init__() + + self.src_shape = src_shape + self.query_shape = query_shape + self.d_model = d_model + self.nhead = nhead + self.ipm_proj_conf = ipm_proj_conf + + self.ipmpe_with_sine = ipmpe_with_sine + self.enforce_no_aligner = enforce_no_aligner + + num_queries = np.prod(query_shape).item() + self.input_proj = nn.Conv2d(in_channels, d_model, kernel_size=1) + self.query_embed = nn.Embedding(num_queries, d_model) + src_pe, tgt_pe = self._get_pos_embed_layers(src_pos_embed, tgt_pos_embed) + self.src_pos_embed_layer, self.tgt_pos_embed_layer = src_pe, tgt_pe + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm, use_checkpoint) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, num_decoder_layers, decoder_norm, return_intermediate_dec, use_checkpoint + ) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def _get_pos_embed_layers(self, src_pos_embed, tgt_pos_embed): + + pos_embed_encoder = None + if (src_pos_embed.startswith('ipm_learned')) and (tgt_pos_embed.startswith('ipm_learned')) \ + and not self.enforce_no_aligner: + pos_embed_encoder = nn.Sequential( + nn.Conv2d(self.d_model, self.d_model * 4, kernel_size=1, stride=1, padding=0), + nn.ReLU(), + nn.Conv2d(self.d_model * 4, self.d_model, kernel_size=1, stride=1, padding=0) + ) + + if src_pos_embed == 'sine': + src_pos_embed_layer = PositionEmbeddingSine(self.d_model // 2, normalize=True) + elif src_pos_embed == 'learned': + src_pos_embed_layer = PositionEmbeddingLearned(self.src_shape, self.d_model) + elif src_pos_embed == 'ipm_learned': + input_shape = self.ipm_proj_conf['input_shape'] + src_pos_embed_layer = PositionEmbeddingIPM( + pos_embed_encoder, self.src_shape, input_shape, num_pos_feats=self.d_model, + sine_encoding=self.ipmpe_with_sine) + else: + raise NotImplementedError + self.src_pos_embed = src_pos_embed + + if tgt_pos_embed == 'sine': + tgt_pos_embed_layer = PositionEmbeddingSine(self.d_model // 2, normalize=True) + elif tgt_pos_embed == 'learned': + tgt_pos_embed_layer = PositionEmbeddingLearned(self.query_shape, self.d_model) + elif tgt_pos_embed == 'ipm_learned': + map_size, map_res = self.ipm_proj_conf['map_size'], self.ipm_proj_conf['map_resolution'] + tgt_pos_embed_layer = PositionEmbeddingTgt( + pos_embed_encoder, self.query_shape, map_size, map_res, num_pos_feats=self.d_model, sine_encoding=True) + else: + raise NotImplementedError + self.tgt_pos_embed = tgt_pos_embed + + return src_pos_embed_layer, tgt_pos_embed_layer + + def forward(self, src, mask=None, cameras_info=None): + + bs, c, h, w = src.shape + if mask is None: + mask = torch.zeros((bs, h, w), dtype=torch.bool, device=src.device) # (B, H, W) + + src = self.input_proj(src) # (B, C, H, W) + src = src.flatten(2).permute(2, 0, 1) # (H* W, B, C) + + if self.src_pos_embed.startswith('ipm_learned'): + extrinsic = cameras_info['extrinsic'].float() + intrinsic = cameras_info['intrinsic'].float() + ida_mats = cameras_info['ida_mats'].float() + do_flip = cameras_info['do_flip'] + src_pos_embed, src_mask = self.src_pos_embed_layer(extrinsic, intrinsic, ida_mats, do_flip) + mask = ~src_mask + else: + src_pos_embed = self.src_pos_embed_layer(mask) + + src_pos_embed = src_pos_embed.flatten(2).permute(2, 0, 1) # (H* W, B, C) + mask = mask.flatten(1) # (B, H * W) + + query_embed = self.query_embed.weight # (H* W, C) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (H* W, B, C) + tgt = query_embed # (H* W, B, C) + + query_mask = torch.zeros((bs, *self.query_shape), dtype=torch.bool, device=src.device) + query_pos_embed = self.tgt_pos_embed_layer(query_mask) # (B, C, H, W) + query_pos_embed = query_pos_embed.flatten(2).permute(2, 0, 1) # (H* W, B, C) + + memory = self.encoder.forward(src, None, mask, src_pos_embed) # (H* W, B, C) + hs = self.decoder.forward(tgt, memory, None, None, None, mask, src_pos_embed, query_pos_embed) # (M, H* W, B, C) + ys = hs.permute(0, 2, 3, 1) # (M, B, C, H* W) + ys = ys.reshape(*ys.shape[:-1], *self.query_shape) # (M, B, C, H, W) + + return memory, hs, ys + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None, use_checkpoint=True): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.use_checkpoint = use_checkpoint + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + if self.use_checkpoint and self.training: + output = checkpoint(layer, output, mask, src_key_padding_mask, pos) + else: + output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, use_checkpoint=True): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + self.use_checkpoint = use_checkpoint + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output = tgt + + intermediate = [] + + for layer in self.layers: + if self.use_checkpoint and self.training: + output = checkpoint( + layer, + output, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + else: + output = layer( + output, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos + ) + + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos + ) + return self.forward_post( + tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos + ) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") diff --git a/mapmaster/models/ins_decoder/__init__.py b/mapmaster/models/ins_decoder/__init__.py new file mode 100644 index 0000000..4a45803 --- /dev/null +++ b/mapmaster/models/ins_decoder/__init__.py @@ -0,0 +1 @@ +from .model import Mask2formerINSDecoder, PointMask2formerINSDecoder diff --git a/mapmaster/models/ins_decoder/mask2former.py b/mapmaster/models/ins_decoder/mask2former.py new file mode 100644 index 0000000..c220c67 --- /dev/null +++ b/mapmaster/models/ins_decoder/mask2former.py @@ -0,0 +1,315 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py +import torch +import logging +from typing import Optional +from torch import nn, Tensor +from torch.nn import functional as F +from mapmaster.models.utils.misc import Conv2d, c2_xavier_fill, get_activation_fn +from mapmaster.models.utils.position_encoding import PositionEmbeddingSine + + +class SelfAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = get_activation_fn(activation) + self.normalize_before = normalize_before + self._reset_parameters() + + def forward(self, tgt, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None): + if self.normalize_before: + return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos) + return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos) + + def forward_pre(self, tgt, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None): + tgt2 = self.norm(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward_post(self, tgt, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + @staticmethod + def with_pos_embed(tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + +class CrossAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): + super().__init__() + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = get_activation_fn(activation) + self.normalize_before = normalize_before + self._reset_parameters() + + def forward(self, tgt, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None): + if self.normalize_before: + return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) + + def forward_pre(self, tgt, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None): + tgt2 = self.norm(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward_post(self, tgt, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None): + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + @staticmethod + def with_pos_embed(tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + +class FFNLayer(nn.Module): + + def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False): + super().__init__() + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.norm = nn.LayerNorm(d_model) + self.activation = get_activation_fn(activation) + self.normalize_before = normalize_before + self._reset_parameters() + + def forward(self, tgt): + if self.normalize_before: + return self.forward_pre(tgt) + return self.forward_post(tgt) + + def forward_pre(self, tgt): + tgt2 = self.norm(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward_post(self, tgt): + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class MultiScaleMaskedTransformerDecoder(nn.Module): + + def __init__( + self, + in_channels, + num_feature_levels, + mask_classification=True, + *, + num_classes: int, + hidden_dim: int, + num_queries: int, + nheads: int, + dim_feedforward: int, + dec_layers: int, + pre_norm: bool, + mask_dim: int, + enforce_input_project: bool, + ): + """ + NOTE: this interface is experimental. + Args: + in_channels: channels of the input features + mask_classification: whether to add mask classifier or not + num_classes: number of classes + hidden_dim: Transformer feature dimension + num_queries: number of queries + nheads: number of heads + dim_feedforward: feature dimension in feedforward network + enc_layers: number of Transformer encoder layers + dec_layers: number of Transformer decoder layers + pre_norm: whether to use pre-LayerNorm or not + mask_dim: mask feature dimension + enforce_input_project: add input project 1x1 conv even if input + channels and hidden dim is identical + """ + super().__init__() + + assert mask_classification, "Only support mask classification model" + self.mask_classification = mask_classification + + # positional encoding + self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True) + + # define Transformer decoder here + self.num_heads = nheads + self.num_layers = dec_layers + self.transformer_self_attention_layers = nn.ModuleList() + self.transformer_cross_attention_layers = nn.ModuleList() + self.transformer_ffn_layers = nn.ModuleList() + # d_model, nhead, dropout = 0.0, activation = "relu", normalize_before = False + for _ in range(self.num_layers): + self.transformer_self_attention_layers.append( + SelfAttentionLayer(d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm) + ) + self.transformer_cross_attention_layers.append( + CrossAttentionLayer(d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm) + ) + self.transformer_ffn_layers.append( + FFNLayer(d_model=hidden_dim, dim_feedforward=dim_feedforward, dropout=0.0, normalize_before=pre_norm) + ) + + self.decoder_norm = nn.LayerNorm(hidden_dim) + + self.num_queries = num_queries + # learnable query features + self.query_feat = nn.Embedding(num_queries, hidden_dim) + # learnable query p.e. + self.query_embed = nn.Embedding(num_queries, hidden_dim) + + # level embedding (we always use 3 scales) + self.num_feature_levels = num_feature_levels + self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) + self.input_proj = nn.ModuleList() + for _ in range(self.num_feature_levels): + if in_channels != hidden_dim or enforce_input_project: + self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) + c2_xavier_fill(self.input_proj[-1]) + else: + self.input_proj.append(nn.Sequential()) + + # output FFNs + if self.mask_classification: + self.class_embed = nn.Linear(hidden_dim, num_classes + 1) + self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) + + def forward(self, x, mask_features, mask=None): + # x is a list of multi-scale feature + assert len(x) == self.num_feature_levels + src = [] + pos = [] + size_list = [] + + # disable mask, it does not affect performance + del mask + + for i in range(self.num_feature_levels): + size_list.append(x[i].shape[-2:]) + mask = torch.zeros((x[i].size(0), x[i].size(2), x[i].size(3)), device=x[i].device, dtype=torch.bool) + pos.append(self.pe_layer(mask).flatten(2)) + src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) + + # flatten NxCxHxW to HWxNxC + pos[-1] = pos[-1].permute(2, 0, 1) + src[-1] = src[-1].permute(2, 0, 1) + + _, bs, _ = src[0].shape + + # QxNxC + query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) + output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) + + predictions_class = [] + predictions_mask = [] + decoder_outputs = [] + + # prediction heads on learnable query features + dec_out, outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads( + output, mask_features, attn_mask_target_size=size_list[0] + ) + predictions_class.append(outputs_class) + predictions_mask.append(outputs_mask) + decoder_outputs.append(dec_out) + + for i in range(self.num_layers): + level_index = i % self.num_feature_levels + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + # attention: cross-attention first + output = self.transformer_cross_attention_layers[i]( + output, src[level_index], + memory_mask=attn_mask, + memory_key_padding_mask=None, # here we do not apply masking on padded region + pos=pos[level_index], query_pos=query_embed + ) + output = self.transformer_self_attention_layers[i]( + output, tgt_mask=None, tgt_key_padding_mask=None, query_pos=query_embed + ) + output = self.transformer_ffn_layers[i](output) + dec_out, outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads( + output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels] + ) + + predictions_class.append(outputs_class) + predictions_mask.append(outputs_mask) + decoder_outputs.append(dec_out) + + assert len(predictions_class) == self.num_layers + 1 + + out = { + 'pred_logits': predictions_class, + 'pred_masks': predictions_mask, + 'decoder_outputs': decoder_outputs + } + return out + + def forward_prediction_heads(self, output, mask_features, attn_mask_target_size): + decoder_output = self.decoder_norm(output) + decoder_output = decoder_output.transpose(0, 1) # (b, q, c') + outputs_class = self.class_embed(decoder_output) # (b, q, c') -> (b, q, 2) + mask_embed = self.mask_embed(decoder_output) # (b, q, c') -> (b, q, c) + outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) + + # NOTE: prediction is of higher-resolution + # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW] + attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) + # must use bool type + # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. + attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() + attn_mask = attn_mask.detach() + + return decoder_output, outputs_class, outputs_mask, attn_mask diff --git a/mapmaster/models/ins_decoder/model.py b/mapmaster/models/ins_decoder/model.py new file mode 100644 index 0000000..9ea5ad9 --- /dev/null +++ b/mapmaster/models/ins_decoder/model.py @@ -0,0 +1,39 @@ +import torch.nn as nn +import torch.nn.functional as F +from mapmaster.models.ins_decoder.mask2former import MultiScaleMaskedTransformerDecoder +from mapmaster.models.ins_decoder.pointmask2former import PointMask2TransformerDecoder + + +class INSDecoderBase(nn.Module): + def __init__(self, decoder_ids=(5, ), tgt_shape=None): + super(INSDecoderBase, self).__init__() + self.decoder_ids = tuple(decoder_ids) # [0, 1, 2, 3, 4, 5] + self.tgt_shape = tgt_shape + self.bev_decoder = None + + def forward(self, inputs): + assert "bev_enc_features" in inputs + bev_enc_features = inputs["bev_enc_features"] + if self.tgt_shape is not None: + bev_enc_features = [self.up_sample(x) for x in inputs["bev_enc_features"]] + out = self.bev_decoder(bev_enc_features[-1:], bev_enc_features[-1]) + return {"mask_features": [out["pred_masks"][1:][i] for i in self.decoder_ids], + "obj_scores": [out["pred_logits"][1:][i] for i in self.decoder_ids], + "decoder_outputs": [out["decoder_outputs"][1:][i] for i in self.decoder_ids], + "bev_enc_features": bev_enc_features} + + def up_sample(self, x, tgt_shape=None): + tgt_shape = self.tgt_shape if tgt_shape is None else tgt_shape + if tuple(x.shape[-2:]) == tuple(tgt_shape): + return x + return F.interpolate(x, size=tgt_shape, mode="bilinear", align_corners=True) + +class Mask2formerINSDecoder(INSDecoderBase): + def __init__(self, decoder_ids=(5, ), tgt_shape=None, **kwargs): + super(Mask2formerINSDecoder, self).__init__(decoder_ids, tgt_shape) + self.bev_decoder = MultiScaleMaskedTransformerDecoder(**kwargs) + +class PointMask2formerINSDecoder(INSDecoderBase): + def __init__(self, decoder_ids=(5, ), tgt_shape=None, **kwargs): + super(PointMask2formerINSDecoder, self).__init__(decoder_ids, tgt_shape) + self.bev_decoder = PointMask2TransformerDecoder(**kwargs) diff --git a/mapmaster/models/ins_decoder/pointmask2former.py b/mapmaster/models/ins_decoder/pointmask2former.py new file mode 100644 index 0000000..4e51026 --- /dev/null +++ b/mapmaster/models/ins_decoder/pointmask2former.py @@ -0,0 +1,346 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py +import torch +from typing import Optional +from torch import nn, Tensor +from torch.nn import functional as F +from mapmaster.models.utils.misc import Conv2d, c2_xavier_fill, get_activation_fn +from mapmaster.models.utils.position_encoding import PositionEmbeddingSine + + +class SelfAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = get_activation_fn(activation) + self.normalize_before = normalize_before + self._reset_parameters() + + def forward(self, tgt, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None): + if self.normalize_before: + return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos) + return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos) + + def forward_pre(self, tgt, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None): + tgt2 = self.norm(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward_post(self, tgt, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + @staticmethod + def with_pos_embed(tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + +class CrossAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): + super().__init__() + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = get_activation_fn(activation) + self.normalize_before = normalize_before + self._reset_parameters() + + def forward(self, tgt, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None): + if self.normalize_before: + return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) + + def forward_pre(self, tgt, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None): + tgt2 = self.norm(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward_post(self, tgt, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None): + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + @staticmethod + def with_pos_embed(tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + +class FFNLayer(nn.Module): + + def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False): + super().__init__() + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.norm = nn.LayerNorm(d_model) + self.activation = get_activation_fn(activation) + self.normalize_before = normalize_before + self._reset_parameters() + + def forward(self, tgt): + if self.normalize_before: + return self.forward_pre(tgt) + return self.forward_post(tgt) + + def forward_pre(self, tgt): + tgt2 = self.norm(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward_post(self, tgt): + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class PointMask2TransformerDecoder(nn.Module): + def __init__( + self, + in_channels, + num_feature_levels, + mask_classification=True, + *, + num_classes: int, + hidden_dim: int, + nheads: int, + dim_feedforward: int, + dec_layers: int, + pre_norm: bool, + mask_dim: int, + enforce_input_project: bool, + query_split=(20, 25, 15), + max_pieces=(10, 2, 30), + ): + """ + NOTE: this interface is experimental. + Args: + in_channels: channels of the input features + mask_classification: whether to add mask classifier or not + num_classes: number of classes + hidden_dim: Transformer feature dimension + num_queries: number of queries + nheads: number of heads + dim_feedforward: feature dimension in feedforward network + enc_layers: number of Transformer encoder layers + dec_layers: number of Transformer decoder layers + pre_norm: whether to use pre-LayerNorm or not + mask_dim: mask feature dimension + enforce_input_project: add input project 1x1 conv even if input + channels and hidden dim is identical + """ + super().__init__() + + assert mask_classification, "Only support mask classification model" + self.mask_classification = mask_classification + + # positional encoding + self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True) + + # define Transformer decoder here + self.num_heads = nheads + self.num_layers = dec_layers + self.transformer_self_attention_layers = nn.ModuleList() + self.transformer_cross_attention_layers = nn.ModuleList() + self.transformer_ffn_layers = nn.ModuleList() + # d_model, nhead, dropout = 0.0, activation = "relu", normalize_before = False + for _ in range(self.num_layers): + self.transformer_self_attention_layers.append( + SelfAttentionLayer(d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm) + ) + self.transformer_cross_attention_layers.append( + CrossAttentionLayer(d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm) + ) + self.transformer_ffn_layers.append( + FFNLayer(d_model=hidden_dim, dim_feedforward=dim_feedforward, dropout=0.0, normalize_before=pre_norm) + ) + + self.decoder_norm = nn.LayerNorm(hidden_dim) + + self.max_pieces = max_pieces + self.query_split = query_split + self.num_queries = sum([self.max_pieces[i] * self.query_split[i] for i in range(len(self.max_pieces))]) + + # learnable pt features + self.query_feat = nn.Embedding(self.num_queries, hidden_dim).weight # [700, C] + # learnable pt p.e. + self.query_embed = nn.Embedding(self.num_queries, hidden_dim).weight # [20 * 10 + 25 * 2 + 15 * 30, C] which is [700, C], note the memory + + # level embedding (we always use 3 scales) + self.num_feature_levels = num_feature_levels + self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) + self.input_proj = nn.ModuleList() + for _ in range(self.num_feature_levels): + if in_channels != hidden_dim or enforce_input_project: + self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) + c2_xavier_fill(self.input_proj[-1]) + else: + self.input_proj.append(nn.Sequential()) + + # output FFNs + if self.mask_classification: + self.class_embed = nn.ModuleList( + nn.Linear(hidden_dim*self.max_pieces[i], num_classes + 1) for i in range(len(query_split)) + ) + self.mask_embed = nn.ModuleList( + MLP(hidden_dim*self.max_pieces[i], hidden_dim, mask_dim, 3) for i in range(len(query_split)) + ) + + # split info + self.cls_split = [self.query_split[i] * self.max_pieces[i] for i in range(len(self.query_split))] + self.ins_split = torch.cat([torch.ones(self.query_split[i], dtype=torch.long)*self.max_pieces[i] for i in range(len(self.query_split))]) + + def forward(self, x, mask_features, mask=None): + # x is a list of multi-scale feature + assert len(x) == self.num_feature_levels + src = [] + pos = [] + size_list = [] + + # disable mask, it does not affect performance + del mask + + for i in range(self.num_feature_levels): + size_list.append(x[i].shape[-2:]) + mask = torch.zeros((x[i].size(0), x[i].size(2), x[i].size(3)), device=x[i].device, dtype=torch.bool) + pos.append(self.pe_layer(mask).flatten(2)) + src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) + + # flatten NxCxHxW to HWxNxC + pos[-1] = pos[-1].permute(2, 0, 1) + src[-1] = src[-1].permute(2, 0, 1) + + _, bs, _ = src[0].shape + + query_embed = self.query_embed.unsqueeze(1).repeat(1, bs, 1) # [n_q, bs, C] + output = self.query_feat.unsqueeze(1).repeat(1, bs, 1) # [n_q, bs, C] + + predictions_class = [] + predictions_mask = [] + decoder_outputs = [] + + # prediction heads on learnable query features + dec_out, outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads( + output, mask_features, attn_mask_target_size=size_list[0] + ) + predictions_class.append(outputs_class) + predictions_mask.append(outputs_mask) + decoder_outputs.append(dec_out) + + for i in range(self.num_layers): + level_index = i % self.num_feature_levels + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + # attention: cross-attention first + output = self.transformer_cross_attention_layers[i]( + output, src[level_index], + memory_mask=attn_mask, + memory_key_padding_mask=None, # here we do not apply masking on padded region + pos=pos[level_index], query_pos=query_embed + ) + output = self.transformer_self_attention_layers[i]( + output, tgt_mask=None, tgt_key_padding_mask=None, query_pos=query_embed + ) + output = self.transformer_ffn_layers[i](output) + dec_out, outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads( + output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels] + ) + + predictions_class.append(outputs_class) + predictions_mask.append(outputs_mask) + decoder_outputs.append(dec_out) + + assert len(predictions_class) == self.num_layers + 1 + + out = { + 'pred_logits': predictions_class, + 'pred_masks': predictions_mask, + 'decoder_outputs': decoder_outputs + } + return out + + def forward_prediction_heads(self, output, mask_features, attn_mask_target_size): + decoder_output = self.decoder_norm(output) + decoder_output = decoder_output.transpose(0, 1) # (b, q, c') (b, 700, c') + decoder_output_split = decoder_output.split(self.cls_split, dim=1) + + outputs_class, mask_embed = None, None + for i in range(len(self.query_split)): + x_cls_pt = decoder_output_split[i] # (bs, n_q*n_pt, c') + n_q, n_pt = self.query_split[i], self.max_pieces[i] + x_cls_pt = x_cls_pt.reshape(x_cls_pt.shape[0], n_q, -1) # (bs, n_q, n_pt * c') + x_cls = self.class_embed[i](x_cls_pt) # (bs, n_q, 2) + x_cls_ins = self.mask_embed[i](x_cls_pt) # (bs, n_q, c) + if outputs_class is None: + outputs_class = x_cls + mask_embed = x_cls_ins + else: + outputs_class = torch.cat([outputs_class, x_cls], dim=1) + mask_embed = torch.cat([mask_embed, x_cls_ins], dim=1) + + outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) # for loss cal + + outputs_mask_split = outputs_mask.split(self.query_split, dim=1) + + attn_mask = [] + for i in range(len(self.query_split)): + # [bs, n_q, h, w] -> [bs, n_q, 1, h, w] -> [bs, n_q, n_pt, h, w] -> [bs, n_q*n_pt, h, w] + attn_mask.append(outputs_mask_split[i].unsqueeze(2).repeat(1, 1, self.max_pieces[i], 1, 1).flatten(1, 2)) + attn_mask = torch.cat(attn_mask, dim=1) + + # NOTE: prediction is of higher-resolution + # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW] + attn_mask = F.interpolate(attn_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) + # must use bool type + # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. + attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() + attn_mask = attn_mask.detach() + + return decoder_output, outputs_class, outputs_mask, attn_mask diff --git a/mapmaster/models/network.py b/mapmaster/models/network.py new file mode 100644 index 0000000..85d6caa --- /dev/null +++ b/mapmaster/models/network.py @@ -0,0 +1,68 @@ +import torch.nn as nn +from mapmaster.models import backbone, bev_decoder, ins_decoder, output_head +# os.environ['TORCH_DISTRIBUTED_DEBUG'] = "INFO" +# warnings.filterwarnings('ignore') + + +class MapMaster(nn.Module): + def __init__(self, model_config, *args, **kwargs): + super(MapMaster, self).__init__() + self.im_backbone = self.create_backbone(**model_config["im_backbone"]) + self.bev_decoder = self.create_bev_decoder(**model_config["bev_decoder"]) + self.ins_decoder = self.create_ins_decoder(**model_config["ins_decoder"]) + self.output_head = self.create_output_head(**model_config["output_head"]) + self.post_processor = self.create_post_processor(**model_config["post_processor"]) + + def forward(self, inputs): + outputs = {} + outputs.update({k: inputs[k] for k in ["images", "extra_infos"]}) + outputs.update({k: inputs[k].float() for k in ["extrinsic", "intrinsic"]}) + if "ida_mats" in inputs: + outputs.update({"ida_mats": inputs["ida_mats"].float()}) + outputs.update(self.im_backbone(outputs)) + outputs.update(self.bev_decoder(outputs)) + outputs.update(self.ins_decoder(outputs)) + outputs.update(self.output_head(outputs)) + return outputs + + @staticmethod + def create_backbone(arch_name, ret_layers, bkb_kwargs, fpn_kwargs, up_shape=None): + __factory_dict__ = { + "resnet": backbone.ResNetBackbone, + "efficient_net": backbone.EfficientNetBackbone, + "swin_transformer": backbone.SwinTRBackbone, + } + return __factory_dict__[arch_name](bkb_kwargs, fpn_kwargs, up_shape, ret_layers) + + @staticmethod + def create_bev_decoder(arch_name, net_kwargs): + __factory_dict__ = { + "transformer": bev_decoder.TransformerBEVDecoder, + "ipm_deformable_transformer": bev_decoder.DeformTransformerBEVEncoder, + } + return __factory_dict__[arch_name](**net_kwargs) + + @staticmethod + def create_ins_decoder(arch_name, net_kwargs): + __factory_dict__ = { + "mask2former": ins_decoder.Mask2formerINSDecoder, + "line_aware_decoder": ins_decoder.PointMask2formerINSDecoder, + } + + return __factory_dict__[arch_name](**net_kwargs) + + @staticmethod + def create_output_head(arch_name, net_kwargs): + __factory_dict__ = { + "bezier_output_head": output_head.PiecewiseBezierMapOutputHead, + "pivot_point_predictor": output_head.PivotMapOutputHead, + } + return __factory_dict__[arch_name](**net_kwargs) + + @staticmethod + def create_post_processor(arch_name, net_kwargs): + __factory_dict__ = { + "bezier_post_processor": output_head.PiecewiseBezierMapPostProcessor, + "pivot_post_processor": output_head.PivotMapPostProcessor, + } + return __factory_dict__[arch_name](**net_kwargs) diff --git a/mapmaster/models/output_head/__init__.py b/mapmaster/models/output_head/__init__.py new file mode 100644 index 0000000..08f39e3 --- /dev/null +++ b/mapmaster/models/output_head/__init__.py @@ -0,0 +1,4 @@ +from .bezier_outputs import PiecewiseBezierMapOutputHead +from .bezier_post_processor import PiecewiseBezierMapPostProcessor +from .pivot_outputs import PivotMapOutputHead +from .pivot_post_processor import PivotMapPostProcessor \ No newline at end of file diff --git a/mapmaster/models/output_head/bezier_outputs.py b/mapmaster/models/output_head/bezier_outputs.py new file mode 100644 index 0000000..9fecd39 --- /dev/null +++ b/mapmaster/models/output_head/bezier_outputs.py @@ -0,0 +1,127 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FFN(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, basic_type='linear'): + super().__init__() + self.basic_type = basic_type + if output_dim == 0: + self.basic_type = "identity" + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(self.basic_layer(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + def basic_layer(self, n, k): + if self.basic_type == 'linear': + return nn.Linear(n, k) + elif self.basic_type == 'conv': + return nn.Conv2d(n, k, kernel_size=1, stride=1) + elif self.basic_type == 'identity': + return nn.Identity() + else: + raise NotImplementedError + + +class PiecewiseBezierMapOutputHead(nn.Module): + def __init__(self, in_channel, num_queries, tgt_shape, num_degree, max_pieces, bev_channels=-1, ins_channel=64): + super(PiecewiseBezierMapOutputHead, self).__init__() + self.num_queries = num_queries + self.num_classes = len(num_queries) + self.tgt_shape = tgt_shape + self.bev_channels = bev_channels + self.semantic_heads = None + if self.bev_channels > 0: + self.semantic_heads = nn.ModuleList( + nn.Sequential(nn.Conv2d(bev_channels, 2, kernel_size=1, stride=1)) for _ in range(self.num_classes) + ) + self.num_degree = num_degree + self.max_pieces = max_pieces + self.num_ctr_im = [(n + 1) for n in self.max_pieces] + self.num_ctr_ex = [n * (d - 1) for n, d in zip(self.max_pieces, self.num_degree)] + _N = self.num_classes + + _C = ins_channel + self.im_ctr_heads = nn.ModuleList(FFN(in_channel, 256, (self.num_ctr_im[i] * 2) * _C, 3) for i in range(_N)) + self.ex_ctr_heads = nn.ModuleList(FFN(in_channel, 256, (self.num_ctr_ex[i] * 2) * _C, 3) for i in range(_N)) + self.npiece_heads = nn.ModuleList(FFN(in_channel, 256, self.max_pieces[i], 3) for i in range(_N)) + self.gap_layer = nn.AdaptiveAvgPool2d((1, 1)) + self.coords = self.compute_locations(device='cuda') + self.coords_head = FFN(2, 256, _C, 3, 'conv') + + def forward(self, inputs): + num_decoders = len(inputs["mask_features"]) + dt_obj_logit = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)] + dt_ins_masks = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)] + im_ctr_coord = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)] + ex_ctr_coord = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)] + dt_end_logit = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)] + coords_feats = self.coords_head.forward(self.coords.repeat((inputs["mask_features"][0].shape[0], 1, 1, 1))) + for i in range(num_decoders): + x_ins_cw = inputs["mask_features"][i].split(self.num_queries, dim=1) + x_obj_cw = inputs["obj_scores"][i].split(self.num_queries, dim=1) + x_qry_cw = inputs["decoder_outputs"][i].split(self.num_queries, dim=1) + batch_size = x_qry_cw[0].shape[0] + for j in range(self.num_classes): + num_qry = self.num_queries[j] + # if self.training: + dt_ins_masks[i][j] = self.up_sample(x_ins_cw[j]) + dt_obj_logit[i][j] = x_obj_cw[j] + dt_end_logit[i][j] = self.npiece_heads[j](x_qry_cw[j]) + # im + im_feats = self.im_ctr_heads[j](x_qry_cw[j]) + im_feats = im_feats.reshape(batch_size, num_qry, self.num_ctr_im[j] * 2, -1).flatten(1, 2) + im_coords_map = torch.einsum("bqc,bchw->bqhw", im_feats, coords_feats) + im_coords = self.gap_layer(im_coords_map) + im_ctr_coord[i][j] = im_coords.reshape(batch_size, num_qry, self.max_pieces[j] + 1, 2) + # ex + if self.num_ctr_ex[j] == 0: + ex_ctr_coord[i][j] = torch.zeros(batch_size, num_qry, self.max_pieces[j], 0, 2).cuda() + else: + ex_feats = self.ex_ctr_heads[j](x_qry_cw[j]) + ex_feats = ex_feats.reshape(batch_size, num_qry, self.num_ctr_ex[j] * 2, -1).flatten(1, 2) + ex_coords_map = torch.einsum("bqc,bchw->bqhw", ex_feats, coords_feats) + ex_coords = self.gap_layer(ex_coords_map) + ex_ctr_coord[i][j] = ex_coords.reshape(batch_size, num_qry, self.max_pieces[j], self.num_degree[j] - 1, 2) + ret = {"outputs": {"obj_logits": dt_obj_logit, "ins_masks": dt_ins_masks, + "ctr_im": im_ctr_coord, "ctr_ex": ex_ctr_coord, "end_logits": dt_end_logit}} + if self.semantic_heads is not None: + num_decoders = len(inputs["bev_enc_features"]) + dt_sem_masks = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)] + for i in range(num_decoders): + x_sem = inputs["bev_enc_features"][i] + for j in range(self.num_classes): + dt_sem_masks[i][j] = self.up_sample(self.semantic_heads[j](x_sem)) + ret["outputs"].update({"sem_masks": dt_sem_masks}) + return ret + + def up_sample(self, x, tgt_shape=None): + tgt_shape = self.tgt_shape if tgt_shape is None else tgt_shape + if tuple(x.shape[-2:]) == tuple(tgt_shape): + return x + return F.interpolate(x, size=tgt_shape, mode="bilinear", align_corners=True) + + def compute_locations(self, stride=1, device='cpu'): + + fh, fw = self.tgt_shape + + shifts_x = torch.arange(0, fw * stride, step=stride, dtype=torch.float32, device=device) + shifts_y = torch.arange(0, fh * stride, step=stride, dtype=torch.float32, device=device) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 + + locations = locations.unsqueeze(0).permute(0, 2, 1).contiguous().float().view(1, 2, fh, fw) + locations[:, 0, :, :] /= fw + locations[:, 1, :, :] /= fh + + return locations diff --git a/mapmaster/models/output_head/bezier_post_processor.py b/mapmaster/models/output_head/bezier_post_processor.py new file mode 100644 index 0000000..c6716ac --- /dev/null +++ b/mapmaster/models/output_head/bezier_post_processor.py @@ -0,0 +1,393 @@ +import cv2 +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from scipy.special import comb as n_over_k +from scipy.optimize import linear_sum_assignment +from mapmaster.models.utils.mask_loss import SegmentationLoss +from mapmaster.models.utils.recovery_loss import PointRecoveryLoss +from mapmaster.utils.misc import nested_tensor_from_tensor_list +from mapmaster.utils.misc import get_world_size, is_available, is_distributed + + +class HungarianMatcher(nn.Module): + + def __init__(self, cost_obj=1., cost_ctr=1., cost_end=1., cost_mask=1., cost_curve=1., cost_recovery=1., + ins_mask_loss_conf=None, point_loss_conf=None, class_weight=None): + super().__init__() + self.cost_obj, self.cost_ctr, self.cost_end = cost_obj, cost_ctr, cost_end + self.cost_mask, self.cost_curve, self.cost_recovery = cost_mask, cost_curve, cost_recovery + self.ins_mask_loss = SegmentationLoss(**ins_mask_loss_conf) + self.recovery_loss = PointRecoveryLoss(**point_loss_conf) + self.class_weight = class_weight + + @torch.no_grad() + def forward(self, outputs, targets): + num_decoders, num_classes = len(outputs["ins_masks"]), len(outputs["ins_masks"][0]) + matching_indices = [[[] for _ in range(num_classes)] for _ in range(num_decoders)] + for dec_id in range(num_decoders): + for cid in range(num_classes): + if self.class_weight is not None and self.class_weight[cid] == 0: + continue + + bs, num_queries = outputs["obj_logits"][dec_id][cid].shape[:2] + + dt_probs = outputs["obj_logits"][dec_id][cid].flatten(0, 1).softmax(-1) # [n_dt, 2], n_dt in a batch + gt_idxes = torch.cat([tgt["obj_labels"][cid] for tgt in targets]) # [n_gt, ] + cost_mat_obj = -dt_probs[:, gt_idxes] # [n_dt, n_gt] + + dt_curves = outputs["curve_points"][dec_id][cid].flatten(0, 1) # [n_dt, n, 2] + dt_masks = outputs["ins_masks"][dec_id][cid].flatten(0, 1) # [n_dt, h, w] + gt_masks = torch.cat([tgt["ins_masks"][cid] for tgt in targets]) # [n_gt, h, w] + cost_mat_mask, cost_mat_rec = 0, 0 + if gt_masks.shape[0] > 0: + dt_num, gt_num = dt_masks.shape[0], gt_masks.shape[0] + dt_masks = dt_masks.unsqueeze(1).expand(dt_num, gt_num, *dt_masks.shape[1:]).flatten(0, 1) + gt_masks = gt_masks.unsqueeze(0).expand(dt_num, gt_num, *gt_masks.shape[1:]).flatten(0, 1) + cost_mat_mask = self.ins_mask_loss(dt_masks, gt_masks, "matcher").reshape(dt_num, gt_num) + dt_curves = dt_curves.unsqueeze(1).expand(dt_num, gt_num, *dt_curves.shape[1:]).flatten(0, 1) + cost_mat_rec = self.recovery_loss(dt_curves, gt_masks).reshape(dt_num, gt_num) + + dt_ctrs = outputs["ctr_points"][dec_id][cid].flatten(0, 1).flatten(1) # [n_dt, n, 2] + gt_ctrs = torch.cat([tgt["ctr_points"][cid] for tgt in targets]).flatten(1) # [n_gt, h, w] + cost_mat_ctr = torch.cdist(dt_ctrs, gt_ctrs, p=1) / gt_ctrs.shape[1] + + dt_end_probs = outputs["end_logits"][dec_id][cid].flatten(0, 1).softmax(-1) + gt_end_idxes = torch.cat([tgt["end_labels"][cid] for tgt in targets]) + cost_mat_end = -dt_end_probs[:, gt_end_idxes] # [n_dt, n_gt] + + dt_curves = outputs["curve_points"][dec_id][cid].flatten(0, 1).flatten(1) # [n_dt, n, 2] + gt_curves = torch.cat([tgt["curve_points"][cid] for tgt in targets]).flatten(1) # [n_gt, n, 2] + cost_mat_curve = torch.cdist(dt_curves, gt_curves, p=1) / gt_curves.shape[1] + + sizes = [len(tgt["obj_labels"][cid]) for tgt in targets] + C = self.cost_obj * cost_mat_obj + self.cost_mask * cost_mat_mask + \ + self.cost_ctr * cost_mat_ctr + self.cost_end * cost_mat_end + self.cost_curve * cost_mat_curve +\ + self.cost_recovery * cost_mat_rec + C = C.view(bs, num_queries, -1).cpu() + indices = [linear_sum_assignment(c[i].detach().numpy()) for i, c in enumerate(C.split(sizes, -1))] + matching_indices[dec_id][cid] = [self.to_tensor(i, j) for i, j in indices] + + return matching_indices + + @staticmethod + def to_tensor(i, j): + return torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64) + + +class SetCriterion(nn.Module): + + def __init__(self, criterion_conf, matcher, num_degree, no_object_coe=1.0): + super().__init__() + self.num_degree = num_degree + self.matcher = matcher + self.criterion_conf = criterion_conf + self.loss_weight_dict = self.criterion_conf['loss_weight'] + self.sem_mask_loss = SegmentationLoss(**criterion_conf['bev_decoder']['sem_mask_loss']) + self.register_buffer("empty_weight", torch.tensor([1.0, no_object_coe])) + + def forward(self, outputs, targets): + losses = {} + matching_indices = self.matcher(outputs, targets) + losses.update(self.criterion_instance(outputs, targets, matching_indices)) + losses.update(self.criterion_instance_labels(outputs, targets, matching_indices)) + losses.update(self.criterion_semantic_masks(outputs, targets)) + losses = {key: self.criterion_conf['loss_weight'][key] * losses[key] for key in losses} + return sum(losses.values()), losses + + def criterion_instance(self, outputs, targets, matching_indices): + loss_masks, loss_ctr, loss_end, loss_curve, loss_rec = 0, 0, 0, 0, 0 + device = outputs["ins_masks"][0][0].device + num_decoders, num_classes = len(matching_indices), len(matching_indices[0]) + for i in range(num_decoders): + w = self.criterion_conf['ins_decoder']['weight'][i] + for j in range(num_classes): + w2 = self.criterion_conf["class_weights"][j] if "class_weights" in self.criterion_conf else 1.0 + num_instances = sum(len(t["obj_labels"][j]) for t in targets) + num_instances = torch.as_tensor([num_instances], dtype=torch.float, device=device) + if is_distributed() and is_available(): + torch.distributed.all_reduce(num_instances) + num_instances = torch.clamp(num_instances / get_world_size(), min=1).item() + + indices = matching_indices[i][j] + src_idx = self._get_src_permutation_idx(indices) # dt + tgt_idx = self._get_tgt_permutation_idx(indices) # gt + + # instance masks + src_masks = outputs["ins_masks"][i][j][src_idx] + tgt_masks = [t["ins_masks"][j] for t in targets] + tgt_masks, _ = nested_tensor_from_tensor_list(tgt_masks).decompose() + tgt_masks = tgt_masks.to(src_masks)[tgt_idx] + loss_masks += w * self.matcher.ins_mask_loss(src_masks, tgt_masks, "loss").sum() / num_instances * w2 + + # eof indices classification + src_logits = outputs["end_logits"][i][j][src_idx] # [M, K] + tgt_labels = torch.cat([t["end_labels"][j][J] for t, (_, J) in zip(targets, indices)]) # (M, ) + loss_end += w * F.cross_entropy(src_logits, tgt_labels, ignore_index=-1, reduction='sum') / num_instances * w2 + + # control points + src_ctrs = outputs["ctr_points"][i][j][src_idx] # [bs, num_queries, o, 2] + end_labels = torch.max(src_logits.softmax(dim=-1), dim=-1)[1] # [m, k] 0, 1, 2, 3... + end_labels_new = (end_labels + 1) * self.num_degree[j] + 1 + valid_mask = torch.zeros(src_ctrs.shape[:2], device=device).long() + for a, b in enumerate(end_labels_new): + valid_mask[a][:b] = 1 + src_ctrs_masked = (src_ctrs * valid_mask.unsqueeze(-1)) + tgt_ctrs = torch.zeros((len(tgt_idx[0]), *src_ctrs.shape[-2:]), device=device).float() + valid_mask = torch.zeros((len(tgt_idx[0]), src_ctrs.shape[-2]), device=device).float() + for idx in range(len(tgt_idx[0])): + batch_id, gt_id = tgt_idx[0][idx], tgt_idx[1][idx] + tgt_ctrs[idx] = targets[batch_id]['ctr_points'][j][gt_id] + valid_mask[idx] = targets[batch_id]['valid_masks'][j][gt_id] + tgt_ctrs_masked = (tgt_ctrs * valid_mask.unsqueeze(-1)) + num_pt = src_ctrs.shape[-2] * src_ctrs.shape[-1] + loss_ctr += w * F.l1_loss(src_ctrs_masked, tgt_ctrs_masked, reduction='sum') / num_instances / num_pt * w2 + + # curve loss + src_curves = outputs["curve_points"][i][j][src_idx] + tgt_curves = torch.zeros((len(tgt_idx[0]), *src_curves.shape[-2:]), device=device).float() + for idx in range(len(tgt_idx[0])): + batch_id, gt_id = tgt_idx[0][idx], tgt_idx[1][idx] + tgt_curves[idx] = targets[batch_id]['curve_points'][j][gt_id] + num_pt = src_curves.shape[-2] * src_curves.shape[-1] + loss_curve += w * F.l1_loss(src_curves, tgt_curves, reduction='sum') / num_instances / num_pt * w2 + + # recovery loss + loss_rec += w * self.matcher.recovery_loss(src_curves, tgt_masks).sum() / num_instances * w2 + + loss_masks /= (num_decoders * num_classes) + loss_ctr /= (num_decoders * num_classes) + loss_curve /= (num_decoders * num_classes) + loss_end /= (num_decoders * num_classes) + loss_rec /= (num_decoders * num_classes) + + return {"ctr_loss": loss_ctr, "end_loss": loss_end, "msk_loss": loss_masks, "curve_loss": loss_curve, + "recovery_loss": loss_rec} + + def criterion_instance_labels(self, outputs, targets, matching_indices): + loss_labels = 0 + num_decoders, num_classes = len(matching_indices), len(matching_indices[0]) + for i in range(num_decoders): + w = self.criterion_conf['ins_decoder']['weight'][i] + for j in range(num_classes): + w2 = self.criterion_conf["class_weights"][j] if "class_weights" in self.criterion_conf else 1.0 + indices = matching_indices[i][j] + idx = self._get_src_permutation_idx(indices) # (batch_id, query_id) + logits = outputs["obj_logits"][i][j] + target_classes_o = torch.cat([t["obj_labels"][j][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(logits.shape[:2], 1, dtype=torch.int64, device=logits.device) + target_classes[idx] = target_classes_o + loss_labels += (w * F.cross_entropy(logits.transpose(1, 2), target_classes, self.empty_weight)) * w2 + loss_labels /= (num_decoders * num_classes) + return {"obj_loss": loss_labels} + + def criterion_semantic_masks(self, outputs, targets): + loss_masks = 0 + num_decoders, num_classes = len(outputs["sem_masks"]), len(outputs["sem_masks"][0]) + for i in range(num_decoders): + w = self.criterion_conf['bev_decoder']['weight'][i] + for j in range(num_classes): + w2 = self.criterion_conf["class_weights"][j] if "class_weights" in self.criterion_conf else 1.0 + dt_masks = outputs["sem_masks"][i][j] # (B, 2, H, W) + gt_masks = torch.stack([t["sem_masks"][j] for t in targets], dim=0) # (B, H, W) + loss_masks += w * self.sem_mask_loss(dt_masks[:, 1, :, :], gt_masks, "loss").mean() * w2 + loss_masks /= num_decoders + return {"sem_loss": loss_masks} + + @staticmethod + def _get_src_permutation_idx(indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + @staticmethod + def _get_tgt_permutation_idx(indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + +class PiecewiseBezierMapPostProcessor(nn.Module): + def __init__(self, criterion_conf, matcher_conf, bezier_conf, map_conf, no_object_coe=1.0): + super(PiecewiseBezierMapPostProcessor, self).__init__() + # setting + self.num_classes = map_conf['num_classes'] + self.ego_size = map_conf['ego_size'] + self.map_size = map_conf['map_size'] + self.line_width = map_conf['line_width'] + self.num_degree = bezier_conf['num_degree'] + self.num_pieces = bezier_conf['max_pieces'] + self.num_points = bezier_conf['num_points'] + self.curve_size = bezier_conf['piece_length'] + self.class_indices = torch.tensor(list(range(self.num_classes)), dtype=torch.int).cuda() + self.bezier_coefficient_np = self._get_bezier_coefficients() + self.bezier_coefficient = [torch.from_numpy(x).float().cuda() for x in self.bezier_coefficient_np] + self.matcher = HungarianMatcher(**matcher_conf) + self.criterion = SetCriterion(criterion_conf, self.matcher, self.num_degree, no_object_coe) + self.save_thickness = map_conf['save_thickness'] if 'save_thickness' in map_conf else 1 + + def forward(self, outputs, targets=None): + outputs.update(self.bezier_curve_outputs(outputs)) + if self.training: + targets = self.refactor_targets(targets) + return self.criterion.forward(outputs, targets) + else: + return self.post_processing(outputs) + + def bezier_curve_outputs(self, outputs): + dt_ctr_im, dt_ctr_ex, dt_ends = outputs["ctr_im"], outputs["ctr_ex"], outputs["end_logits"] + num_decoders, num_classes = len(dt_ends), len(dt_ends[0]) + ctr_points = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)] + curve_points = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)] + for i in range(num_decoders): + for j in range(num_classes): + batch_size, num_queries = dt_ctr_im[i][j].shape[:2] + + im_coords = dt_ctr_im[i][j].sigmoid() + ex_offsets = dt_ctr_ex[i][j].sigmoid() - 0.5 + im_center_coords = ((im_coords[:, :, :-1] + im_coords[:, :, 1:]) / 2).unsqueeze(-2) + ex_coords = torch.stack([im_center_coords[:, :, :, :, 0] + ex_offsets[:, :, :, :, 0], + im_center_coords[:, :, :, :, 1] + ex_offsets[:, :, :, :, 1]], dim=-1) + im_coords = im_coords.unsqueeze(-2) + ctr_coords = torch.cat([im_coords[:, :, :-1], ex_coords], dim=-2).flatten(2, 3) + ctr_coords = torch.cat([ctr_coords, im_coords[:, :, -1:, 0, :]], dim=-2) + ctr_points[i][j] = ctr_coords.clone() + + end_inds = torch.max(torch.softmax(dt_ends[i][j].flatten(0, 1), dim=-1), dim=-1)[1] + curve_pts = self.curve_recovery_with_bezier(ctr_coords.flatten(0, 1), end_inds, j) + curve_points[i][j] = curve_pts.reshape(batch_size, num_queries, *curve_pts.shape[-2:]) + + return {"curve_points": curve_points, 'ctr_points': ctr_points} + + def refactor_targets(self, targets): + targets_refactored = [] + batch_size, num_classes = len(targets["masks"]), len(targets["masks"][0]) + targets["masks"] = targets["masks"].cuda() + targets["points"] = targets["points"].cuda() + targets["labels"] = targets["labels"].cuda() + for batch_id in range(batch_size): + + sem_masks, ins_masks, ins_objects = [], [], [] + ctr_points, curve_points, end_labels, valid_masks = [], [], [], [] + + ins_classes = targets['labels'][batch_id][:, 0].int() + cls_ids, ins_ids = torch.where((ins_classes.unsqueeze(0) == self.class_indices.unsqueeze(1)).int()) + for cid in range(num_classes): + indices = ins_ids[torch.where(cls_ids == cid)] + num_ins = indices.shape[0] + + # object class: 0 or 1 + ins_obj = torch.zeros((num_ins,), dtype=torch.long).cuda() + ins_objects.append(ins_obj) + + # bezier control points coords + num_max = self.num_points[cid] + ctr_pts = targets['points'][batch_id][indices][:, :num_max].float() + ctr_pts[:, :, 0] = ctr_pts[:, :, 0] / self.ego_size[1] + ctr_pts[:, :, 1] = ctr_pts[:, :, 1] / self.ego_size[0] + ctr_points.append(ctr_pts) + + # piecewise end indices + end_indices = targets['labels'][batch_id][indices][:, 1].long() + end_labels.append(end_indices) + + # bezier valid masks + v_mask = torch.zeros((num_ins, num_max), dtype=torch.int8).cuda() + for ins_id in range(num_ins): + k = targets['labels'][batch_id][indices[ins_id]][2].long() + v_mask[ins_id][:k] = 1 + valid_masks.append(v_mask) + + # curve points + curve_pts = self.curve_recovery_with_bezier(ctr_pts, end_indices, cid) + curve_points.append(curve_pts) + + # instance mask + mask_pc = targets["masks"][batch_id][cid] # mask supervision + unique_ids = torch.unique(mask_pc, sorted=True)[1:] + if num_ins == unique_ids.shape[0]: + ins_msk = (mask_pc.unsqueeze(0).repeat(num_ins, 1, 1) == unique_ids.view(-1, 1, 1)).float() + else: + ins_msk = np.zeros((num_ins, *self.map_size), dtype=np.uint8) + for i, ins_pts in enumerate(curve_pts): + ins_pts[:, 0] *= self.map_size[1] + ins_pts[:, 1] *= self.map_size[0] + ins_pts = ins_pts.cpu().data.numpy().astype(np.int32) + cv2.polylines(ins_msk[i], [ins_pts], False, color=1, thickness=self.line_width) + ins_msk = torch.from_numpy(ins_msk).float().cuda() + ins_masks.append(ins_msk) + + # semantic mask + sem_msk = (ins_msk.sum(0) > 0).float() + sem_masks.append(sem_msk) + + targets_refactored.append({ + "sem_masks": sem_masks, "ins_masks": ins_masks, "obj_labels": ins_objects, + "ctr_points": ctr_points, "end_labels": end_labels, "curve_points": curve_points, + "valid_masks": valid_masks, + }) + + return targets_refactored + + def curve_recovery_with_bezier(self, ctr_points, end_indices, cid): + device = ctr_points.device + curve_pts_ret = torch.zeros((0, self.curve_size, 2), dtype=torch.float, device=device) + num_instances, num_pieces = ctr_points.shape[0], ctr_points.shape[1] + pieces_ids = [[i+j for j in range(self.num_degree[cid]+1)] for i in range(0, num_pieces - 1, self.num_degree[cid])] + pieces_ids = torch.tensor(pieces_ids).long().to(device) + points_ids = torch.tensor(list(range(self.curve_size))).long().to(device) + points_ids = (end_indices + 1).unsqueeze(1) * points_ids.unsqueeze(0) + if num_instances > 0: + ctr_points_flatten = ctr_points[:, pieces_ids, :].flatten(0, 1) + curve_pts = torch.matmul(self.bezier_coefficient[cid], ctr_points_flatten) + curve_pts = curve_pts.reshape(num_instances, pieces_ids.shape[0], *curve_pts.shape[-2:]) + curve_pts = curve_pts.flatten(1, 2) + curve_pts_ret = torch.stack([curve_pts[i][points_ids[i]] for i in range(points_ids.shape[0])]) + return curve_pts_ret + + def _get_bezier_coefficients(self): + + def bernstein_func(n, t, k): + return (t ** k) * ((1 - t) ** (n - k)) * n_over_k(n, k) + + ts = np.linspace(0, 1, self.curve_size) + bezier_coefficient_list = [] + for nn in self.num_degree: + bezier_coefficient_list.append(np.array([[bernstein_func(nn, t, k) for k in range(nn + 1)] for t in ts])) + return bezier_coefficient_list + + def post_processing(self, outputs): + batch_results, batch_masks, batch_masks5 = [], [], [] + batch_size = outputs["obj_logits"][-1][0].shape[0] + for i in range(batch_size): + points, scores, labels = [None], [-1], [0] + masks = np.zeros((self.num_classes, *self.map_size)).astype(np.uint8) + masks5 = np.zeros((self.num_classes, *self.map_size)).astype(np.uint8) + instance_index = 1 + for j in range(self.num_classes): + pred_scores, pred_labels = torch.max(F.softmax(outputs["obj_logits"][-1][j][i], dim=-1), dim=-1) + keep_ids = torch.where((pred_labels == 0).int())[0] + if keep_ids.shape[0] == 0: + continue + curve_pts = outputs['curve_points'][-1][j][i][keep_ids].cpu().data.numpy() + curve_pts[:, :, 0] *= self.map_size[1] + curve_pts[:, :, 1] *= self.map_size[0] + for dt_curve, dt_score in zip(curve_pts, pred_scores[keep_ids]): + cv2.polylines(masks[j], [dt_curve.astype(np.int32)], False, color=instance_index, + thickness=self.save_thickness) + cv2.polylines(masks5[j], [dt_curve.astype(np.int32)], False, color=instance_index, thickness=12) + instance_index += 1 + points.append(curve_pts) + scores.append(self._to_np(dt_score).item()) + labels.append(j + 1) + batch_results.append({'map': points, 'confidence_level': scores, 'pred_label': labels}) + batch_masks.append(masks) + batch_masks5.append(masks5) + return batch_results, batch_masks, batch_masks5 + + @staticmethod + def _to_np(tensor): + return tensor.cpu().data.numpy() diff --git a/mapmaster/models/output_head/line_matching.py b/mapmaster/models/output_head/line_matching.py new file mode 100644 index 0000000..c690ffb --- /dev/null +++ b/mapmaster/models/output_head/line_matching.py @@ -0,0 +1,65 @@ +import numpy as np + +def seq_matching_dist_parallel(cost, gt_lens, coe_endpts=0): + # Time complexity: O(m*n) + bs, m, n = cost.shape + assert m <= n + min_cost = np.ones((bs, m, n)) * np.inf + mem_sort_value = np.ones((bs, m, n)) * np.inf # v[i][j] = np.min(min_cost[i][:j+1]) + + # initialization + for j in range(0, n): + if j == 0: + min_cost[:, 0, j] = cost[:, 0, j] + mem_sort_value[:, 0, j] = min_cost[:, 0, 0] + + for i in range(1, m): + for j in range(i, n): + min_cost[:, i, j] = mem_sort_value[:, i-1, j-1] + cost[:, i, j] + indexes = (min_cost[:, i, j] < mem_sort_value[:, i, j-1]) + indexes_inv = np.array(1-indexes, dtype=np.bool) + mem_sort_value[indexes, i, j] = min_cost[indexes, i, j] + mem_sort_value[indexes_inv, i, j] = mem_sort_value[indexes_inv, i, j-1] + + indexes = [] + for i, ll in enumerate(gt_lens): + indexes.append([i, ll-1, n-1]) + indexes = np.array(indexes) + xs, ys, zs = indexes[:, 0], indexes[:, 1], indexes[:, 2] + res_cost = min_cost[xs, ys, zs] + (cost[xs, 0, 0] + cost[xs, ys, zs]) * coe_endpts + return res_cost / (indexes[:, 1]+1+coe_endpts*2) + +def pivot_dynamic_matching(cost: np.array): + # Time complexity: O(m*n) + m, n = cost.shape + assert m <= n + + min_cost = np.ones((m, n)) * np.inf + mem_sort_value = np.ones((m, n)) * np.inf + match_res1 = [[] for _ in range(n)] + match_res2 = [[] for _ in range(n)] + + # initialization + for j in range(0, n-m+1): + match_res1[j] = [0] + mem_sort_value[0][j] = cost[0][0] + if j == 0: + min_cost[0][j] = cost[0][0] + + for i in range(1, m): + for j in range(i, n-m + i+1): + min_cost[i][j] = mem_sort_value[i-1][j-1] + cost[i][j] + if min_cost[i][j] < mem_sort_value[i][j-1]: + mem_sort_value[i][j] = min_cost[i][j] + if i < m-1: + match_res2[j] = match_res1[j-1] + [j] + else: + mem_sort_value[i][j] = mem_sort_value[i][j-1] + if i < m -1: + match_res2[j] = match_res2[j-1] + if i < m-1: + match_res1, match_res2 = match_res2.copy(), [[] for _ in range(n)] + + total_cost = min_cost[-1][-1] + final_match_res = match_res1[-2] + [n-1] + return total_cost, final_match_res \ No newline at end of file diff --git a/mapmaster/models/output_head/pivot_outputs.py b/mapmaster/models/output_head/pivot_outputs.py new file mode 100644 index 0000000..2a63e6b --- /dev/null +++ b/mapmaster/models/output_head/pivot_outputs.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FFN(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, basic_type='linear'): + super().__init__() + self.basic_type = basic_type + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(self.basic_layer(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + def basic_layer(self, n, k): + if self.basic_type == 'linear': + return nn.Linear(n, k) + elif self.basic_type == 'conv': + return nn.Conv2d(n, k, kernel_size=1, stride=1) + else: + raise NotImplementedError + +class PivotMapOutputHead(nn.Module): + def __init__(self, in_channel, num_queries, tgt_shape, max_pieces, bev_channels=-1, ins_channel=64): + super(PivotMapOutputHead, self).__init__() + self.num_queries = num_queries + self.num_classes = len(num_queries) + self.tgt_shape = tgt_shape + self.bev_channels = bev_channels + self.semantic_heads = None + if self.bev_channels > 0: + self.semantic_heads = nn.ModuleList( + nn.Sequential(nn.Conv2d(bev_channels, 2, kernel_size=1, stride=1)) for _ in range(self.num_classes) + ) + + self.max_pieces = max_pieces # [10, 2, 30] + self.pts_split = [num_queries[i]*max_pieces[i] for i in range(len(num_queries))] + _N = self.num_classes + _C = ins_channel + self.im_ctr_heads = nn.ModuleList(FFN(in_channel, 256, 2 * _C, 3) for _ in range(_N)) + self.pts_cls_heads = nn.ModuleList(FFN((_C)*2, _C*2, 2, 3) for i in range(_N)) + self.gap_layer = nn.AdaptiveAvgPool2d((1, 1)) + self.coords = self.compute_locations(device='cuda') # (1, 2, h, w) + self.coords_head = FFN(2, 256, _C, 3, 'conv') + + def forward(self, inputs): + num_decoders = len(inputs["mask_features"]) + dt_obj_logit = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)] + dt_ins_masks = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)] + im_ctr_coord = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)] + dt_pivots_logits = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)] + coords_feats = self.coords_head.forward(self.coords.repeat((inputs["mask_features"][0].shape[0], 1, 1, 1))) + + for i in range(num_decoders): + x_ins_cw = inputs["mask_features"][i].split(self.num_queries, dim=1) + x_obj_cw = inputs["obj_scores"][i].split(self.num_queries, dim=1) + x_qry_cw = inputs["decoder_outputs"][i].split(self.pts_split, dim=1) # [(b, 200, c), (b, 50, c), (b, 450, c)] + batch_size = x_qry_cw[0].shape[0] + for j in range(self.num_classes): + dt_ins_masks[i][j] = self.up_sample(x_ins_cw[j]) # (B, P, H, W) + dt_obj_logit[i][j] = x_obj_cw[j] # (B, P, 2) + # im + num_qry, n_pts = self.num_queries[j], self.max_pieces[j] + im_feats = self.im_ctr_heads[j](x_qry_cw[j]) # (bs, n_q * n_pts, 2*c) + im_feats_tmp = im_feats.reshape(batch_size, num_qry*n_pts*2, -1) # (bs, n_q*n_pts*2, c) + im_coords_map = torch.einsum("bqc,bchw->bqhw", im_feats_tmp, coords_feats) # [bs, n_q*n_pts*2, h, w] + im_coords = self.gap_layer(im_coords_map) # [bs, n_q * n_pts] + im_coords = im_coords.reshape(batch_size, num_qry, self.max_pieces[j], 2).sigmoid() + im_ctr_coord[i][j] = im_coords + + pt_feats = im_feats.reshape(batch_size, num_qry, self.max_pieces[j], -1).flatten(1, 2) # [bs, n_q * n_pts, 2*C] + pt_logits = self.pts_cls_heads[j](pt_feats) + dt_pivots_logits[i][j] = pt_logits.reshape(batch_size, num_qry, self.max_pieces[j], 2) + + ret = {"outputs": {"obj_logits": dt_obj_logit, "ins_masks": dt_ins_masks, + "ctr_im": im_ctr_coord, "pts_logits": dt_pivots_logits}} + + if self.semantic_heads is not None: + num_decoders = len(inputs["bev_enc_features"]) + dt_sem_masks = [[[] for _ in range(self.num_classes)] for _ in range(num_decoders)] + for i in range(num_decoders): + x_sem = inputs["bev_enc_features"][i] + for j in range(self.num_classes): + dt_sem_masks[i][j] = self.up_sample(self.semantic_heads[j](x_sem)) # (B, P, 2, H, W) + ret["outputs"].update({"sem_masks": dt_sem_masks}) + return ret + + def up_sample(self, x, tgt_shape=None): + tgt_shape = self.tgt_shape if tgt_shape is None else tgt_shape + if tuple(x.shape[-2:]) == tuple(tgt_shape): + return x + return F.interpolate(x, size=tgt_shape, mode="bilinear", align_corners=True) + + def compute_locations(self, stride=1, device='cpu'): + + fh, fw = self.tgt_shape + + shifts_x = torch.arange(0, fw * stride, step=stride, dtype=torch.float32, device=device) + shifts_y = torch.arange(0, fh * stride, step=stride, dtype=torch.float32, device=device) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 + + locations = locations.unsqueeze(0).permute(0, 2, 1).contiguous().float().view(1, 2, fh, fw) + locations[:, 0, :, :] /= fw + locations[:, 1, :, :] /= fh + + return locations diff --git a/mapmaster/models/output_head/pivot_post_processor.py b/mapmaster/models/output_head/pivot_post_processor.py new file mode 100644 index 0000000..cc49dcc --- /dev/null +++ b/mapmaster/models/output_head/pivot_post_processor.py @@ -0,0 +1,340 @@ +import cv2 +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment +from mapmaster.models.utils.mask_loss import SegmentationLoss +from mapmaster.utils.misc import nested_tensor_from_tensor_list +from mapmaster.utils.misc import get_world_size, is_available, is_distributed + +from .line_matching import pivot_dynamic_matching, seq_matching_dist_parallel + + +class HungarianMatcher(nn.Module): + + def __init__(self, cost_obj=1., cost_mask=1., coe_endpts=1., cost_pts=2., mask_loss_conf=None): + super().__init__() + self.cost_obj, self.cost_mask = cost_obj, cost_mask + self.coe_endpts = coe_endpts # end points weight: 1 + coe_endpts + self.cost_pts = cost_pts + self.mask_loss = SegmentationLoss(**mask_loss_conf) + + @torch.no_grad() + def forward(self, outputs, targets): + num_decoders, num_classes = len(outputs["ins_masks"]), len(outputs["ins_masks"][0]) + matching_indices = [[[] for _ in range(num_classes)] for _ in range(num_decoders)] + for dec_id in range(num_decoders): + for cid in range(num_classes): + bs, num_queries = outputs["obj_logits"][dec_id][cid].shape[:2] + + # 1. obj class cost mat + dt_probs = outputs["obj_logits"][dec_id][cid].flatten(0, 1).softmax(-1) # [n_dt, 2], n_dt in a batch + gt_idxes = torch.cat([tgt["obj_labels"][cid] for tgt in targets]) # [n_gt, ] + cost_mat_obj = -dt_probs[:, gt_idxes] # [n_dt, n_gt] + + # 2. masks cost mat + dt_masks = outputs["ins_masks"][dec_id][cid].flatten(0, 1) # [n_dt, h, w] + gt_masks = torch.cat([tgt["ins_masks"][cid] for tgt in targets]) # [n_gt, h, w] + cost_mat_mask = 0 + if gt_masks.shape[0] == 0: + matching_indices[dec_id][cid] = [(torch.tensor([], dtype=torch.int64), torch.tensor([], dtype=torch.int64))] + continue + dt_num, gt_num = dt_masks.shape[0], gt_masks.shape[0] + dt_masks = dt_masks.unsqueeze(1).expand(dt_num, gt_num, *dt_masks.shape[1:]).flatten(0, 1) + gt_masks = gt_masks.unsqueeze(0).expand(dt_num, gt_num, *gt_masks.shape[1:]).flatten(0, 1) + + cost_mat_mask = self.mask_loss(dt_masks, gt_masks, "Matcher").reshape(dt_num, gt_num) + + # 3. sequence matching costmat + dt_pts = outputs["ctr_im"][dec_id][cid].flatten(0, 1) # [n_dt, n_pts, 2] + n_pt = dt_pts.shape[1] + dt_pts = dt_pts.unsqueeze(0).repeat(gt_num, 1, 1, 1).flatten(0, 1) # [n_gt, n_dt, n_pts, 2] -> [n_gt*n_dt, n_pts, 2] + gt_pts = targets[0]["points"][cid].to(torch.float32) + gt_pts = gt_pts.unsqueeze(1).repeat(1, dt_num, 1, 1).flatten(0, 1) # [n_gt, n_dt, n_pts, 2] -> [n_gt*n_dt, n_pts, 2] + + gt_pts_mask = torch.zeros(gt_num, n_pt, dtype=torch.double, device=gt_pts.device) + gt_lens = torch.tensor([ll for ll in targets[0]["valid_len"][cid]]) # n_gt + gt_lens = gt_lens.unsqueeze(-1).repeat(1, dt_num).flatten() + for i, ll in enumerate(targets[0]["valid_len"][cid]): + gt_pts_mask[i][:ll] = 1 + gt_pts_mask = gt_pts_mask.unsqueeze(1).unsqueeze(-1).repeat(1, dt_num, 1, n_pt).flatten(0, 1) + cost_mat_seqmatching = torch.cdist(gt_pts, dt_pts, p=1) * gt_pts_mask # [n_gt*n_dt, n_pts, n_pts] + cost_mat_seqmatching = seq_matching_dist_parallel( + cost_mat_seqmatching.detach().cpu().numpy(), + gt_lens, + self.coe_endpts).reshape(gt_num, dt_num).transpose(1, 0) #[n_gt, n_dt] + cost_mat_seqmatching = torch.from_numpy(cost_mat_seqmatching).to(cost_mat_mask.device) + + # 4. sum mat + sizes = [len(tgt["obj_labels"][cid]) for tgt in targets] + C = self.cost_obj * cost_mat_obj + self.cost_mask * cost_mat_mask + self.cost_pts * cost_mat_seqmatching + C = C.view(bs, num_queries, -1).cpu() + indices = [linear_sum_assignment(c[i].detach().numpy()) for i, c in enumerate(C.split(sizes, -1))] + + matching_indices[dec_id][cid] = [self.to_tensor(i, j) for i, j in indices] + + return matching_indices + + @staticmethod + def to_tensor(i, j): + return torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64) + + + +class SetCriterion(nn.Module): + def __init__(self, criterion_conf, matcher, sem_loss_conf=None, no_object_coe=1.0, collinear_pts_coe=1.0, coe_endpts=1.0): + super().__init__() + self.matcher = matcher + self.criterion_conf = criterion_conf + self.register_buffer("empty_weight", torch.tensor([1.0, no_object_coe])) + self.register_buffer("collinear_pt_weight", torch.tensor([collinear_pts_coe, 1.0])) + self.coe_endpts = coe_endpts + + self.sem_loss_conf = sem_loss_conf + self.mask_loss = SegmentationLoss(**sem_loss_conf["mask_loss_conf"]) + + def forward(self, outputs, targets): + matching_indices = self.matcher(outputs, targets) + ins_msk_loss, pts_loss, collinear_pts_loss, pt_logits_loss = \ + self.criterion_instance(outputs, targets, matching_indices) + ins_obj_loss = self.criterion_instance_labels(outputs, targets, matching_indices) + losses = {"ins_msk_loss": ins_msk_loss, "ins_obj_loss": ins_obj_loss, + "pts_loss": pts_loss, "collinear_pts_loss": collinear_pts_loss, + "pt_logits_loss": pt_logits_loss} + if self.sem_loss_conf is not None: + losses.update({"sem_msk_loss": self.criterion_semantice_masks(outputs, targets)}) + losses = {key: self.criterion_conf['weight_dict'][key] * losses[key] for key in losses} + return sum(losses.values()), losses + + def criterion_instance(self, outputs, targets, matching_indices): + loss_masks, loss_pts, loss_collinear_pts, loss_logits = 0, 0, 0, 0 + device = outputs["ins_masks"][0][0].device + num_decoders, num_classes = len(matching_indices), len(matching_indices[0]) + for i in range(num_decoders): + w = self.criterion_conf['decoder_weights'][i] + for j in range(num_classes): + num_instances = sum(len(t["obj_labels"][j]) for t in targets) + num_instances = torch.as_tensor([num_instances], dtype=torch.float, device=device) + if is_distributed() and is_available(): + torch.distributed.all_reduce(num_instances) + num_instances = torch.clamp(num_instances / get_world_size(), min=1).item() + indices = matching_indices[i][j] + src_idx = self._get_src_permutation_idx(indices) # dt + tgt_idx = self._get_tgt_permutation_idx(indices) # gt + + # instance masks + src_masks = outputs["ins_masks"][i][j][src_idx] + tgt_masks = [t["ins_masks"][j] for t in targets] + tgt_masks, _ = nested_tensor_from_tensor_list(tgt_masks).decompose() + tgt_masks = tgt_masks.to(src_masks)[tgt_idx] + loss_masks += w * self.mask_loss(src_masks, tgt_masks, "Loss").sum() / num_instances + + # prepare tgt points + src_ctrs = outputs["ctr_im"][i][j][src_idx] # [num_queries, o, 2] + tgt_ctrs = [] # [num_queries, o, 2] + for info in targets: # B + for pts, valid_len in zip(info["points"][j][tgt_idx[1]], info["valid_len"][j][tgt_idx[1]]): # n_gt + tgt_ctrs.append(pts[:valid_len]) + + # pts match, valid pts loss, collinear pts loss + n_match_q, n_dt_pts = src_ctrs.shape[0], src_ctrs.shape[1] + logits_gt = torch.zeros((n_match_q, n_dt_pts), dtype=torch.long, device=src_ctrs.device) + if n_match_q == 0: # avoid unused parameters + loss_pts += w * F.l1_loss(src_ctrs, src_ctrs, reduction='sum') + loss_logits += w * F.l1_loss(outputs["pts_logits"][i][j][src_idx].flatten(0, 1), outputs["pts_logits"][i][j][src_idx].flatten(0, 1), reduction="sum") + continue + + for ii, (src_pts, tgt_pts) in enumerate(zip(src_ctrs, tgt_ctrs)): # B=1, traverse matched query pairs + n_gt_pt = len(tgt_pts) + weight_pt = torch.ones((n_gt_pt), device=tgt_pts.device) + weight_pt[0] += self.coe_endpts + weight_pt[-1] += self.coe_endpts + cost_mat = torch.cdist(tgt_pts.to(torch.float32), src_pts, p=1) + _, matched_pt_idx = pivot_dynamic_matching(cost_mat.detach().cpu().numpy()) + matched_pt_idx = torch.tensor(matched_pt_idx) + # match pts loss + loss_match = w * F.l1_loss(src_pts[matched_pt_idx], tgt_pts, reduction="none").sum(dim=-1) # [n_gt_pt, 2] -> [n_gt_dt] + loss_match = (loss_match * weight_pt).sum() / weight_pt.sum() + loss_pts += loss_match / num_instances + # interpolate pts loss + loss_collinear_pts += w * self.interpolate_loss(src_pts, tgt_pts, matched_pt_idx) / num_instances + # pt logits + logits_gt[ii][matched_pt_idx] = 1 + loss_logits += w * F.cross_entropy(outputs["pts_logits"][i][j][src_idx].flatten(0, 1), logits_gt.flatten(), self.collinear_pt_weight) / num_instances + + loss_masks /= (num_decoders * num_classes) + loss_pts /= (num_decoders * num_classes) + loss_logits /= (num_decoders * num_classes) + loss_collinear_pts /= (num_decoders * num_classes) + + return loss_masks, loss_pts, loss_collinear_pts, loss_logits + + def interpolate_loss(self, src_pts, tgt_pts, matched_pt_idx): + # 1. pick collinear pt idx + collinear_idx = torch.ones(src_pts.shape[0], dtype=torch.bool) + collinear_idx[matched_pt_idx] = 0 + collinear_src_pts = src_pts[collinear_idx] + # 2. interpolate tgt_pts + inter_tgt = torch.zeros_like(collinear_src_pts) + cnt = 0 + for i in range(len(matched_pt_idx)-1): + start_pt, end_pt = tgt_pts[i], tgt_pts[i+1] + inter_num = matched_pt_idx[i+1] - matched_pt_idx[i] - 1 + inter_tgt[cnt:cnt+inter_num] = self.interpolate(start_pt, end_pt, inter_num) + cnt += inter_num + assert collinear_src_pts.shape[0] == cnt + # 3. cal loss + if cnt > 0: + inter_loss = F.l1_loss(collinear_src_pts, inter_tgt, reduction="sum") / cnt + else: + inter_loss = F.l1_loss(collinear_src_pts, inter_tgt, reduction="sum") + return inter_loss + + @staticmethod + def interpolate(start_pt, end_pt, inter_num): + res = torch.zeros((inter_num, 2), dtype=start_pt.dtype, device=start_pt.device) + num_len = inter_num + 1 # segment num. + for i in range(1, num_len): + ratio = i / num_len + res[i-1] = (1 - ratio) * start_pt + ratio * end_pt + return res + + def criterion_instance_labels(self, outputs, targets, matching_indices): + loss_labels = 0 + num_decoders, num_classes = len(matching_indices), len(matching_indices[0]) + for i in range(num_decoders): + w = self.criterion_conf['decoder_weights'][i] + for j in range(num_classes): + indices = matching_indices[i][j] + idx = self._get_src_permutation_idx(indices) # (batch_id, query_id) + logits = outputs["obj_logits"][i][j] + target_classes_o = torch.cat([t["obj_labels"][j][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(logits.shape[:2], 1, dtype=torch.int64, device=logits.device) + target_classes[idx] = target_classes_o + loss_labels += (w * F.cross_entropy(logits.transpose(1, 2), target_classes, self.empty_weight)) + loss_labels /= (num_decoders * num_classes) + return loss_labels + + def criterion_semantice_masks(self, outputs, targets): + loss_masks = 0 + num_decoders, num_classes = len(outputs["sem_masks"]), len(outputs["sem_masks"][0]) + for i in range(num_decoders): + w = self.sem_loss_conf['decoder_weights'][i] + for j in range(num_classes): + dt_masks = outputs["sem_masks"][i][j] # (B, 2, H, W) + gt_masks = torch.stack([t["sem_masks"][j] for t in targets], dim=0) # (B, H, W) + loss_masks += w * self.mask_loss(dt_masks[:, 1, :, :], gt_masks).mean() + loss_masks /= (num_decoders * num_classes) + return loss_masks + + @staticmethod + def _get_src_permutation_idx(indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + @staticmethod + def _get_tgt_permutation_idx(indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + +class PivotMapPostProcessor(nn.Module): + def __init__(self, criterion_conf, matcher_conf, pivot_conf, map_conf, + sem_loss_conf=None, no_object_coe=1.0, collinear_pts_coe=1.0, coe_endpts=0.0): + super(PivotMapPostProcessor, self).__init__() + self.criterion = SetCriterion(criterion_conf, HungarianMatcher(**matcher_conf), sem_loss_conf, no_object_coe, collinear_pts_coe, coe_endpts) + self.ego_size = map_conf['ego_size'] + self.map_size = map_conf['map_size'] + self.line_width = map_conf['line_width'] + self.num_pieces = pivot_conf['max_pieces'] # (10, 2, 30) + self.num_classes = len(self.num_pieces) + self.class_indices = torch.tensor(list(range(self.num_classes)), dtype=torch.int).cuda() + + def forward(self, outputs, targets=None): + if self.training: + targets = self.refactor_targets(targets) + return self.criterion.forward(outputs, targets) + else: + return self.post_processing(outputs) + + + def refactor_targets(self, targets): + # only support bs == 1 + targets_refactored = [] + targets["masks"] = targets["masks"].cuda() + for key in [0, 1, 2]: # map type + targets["points"][key] = targets["points"][key].cuda()[0] # [0] remove batch dim + targets["valid_len"][key] = targets["valid_len"][key].cuda()[0] # [0] remove batch dim + + for instance_mask in targets["masks"]: # bs, only support bs == 1 + sem_masks, ins_masks, ins_objects = [], [], [] + for i, mask_pc in enumerate(instance_mask): # class + sem_masks.append((mask_pc > 0).float()) + unique_ids = torch.unique(mask_pc, sorted=True)[1:] + ins_num = unique_ids.shape[0] + pt_ins_num = len(targets["points"][i]) + if pt_ins_num == ins_num: + ins_msk = (mask_pc.unsqueeze(0).repeat(ins_num, 1, 1) == unique_ids.view(-1, 1, 1)).float() + else: + ins_msk = np.zeros((pt_ins_num, *self.map_size), dtype=np.uint8) + for j, ins_pts in enumerate(targets["points"][i]): + ins_pts_tmp = ins_pts.clone() + ins_pts_tmp[:, 0] *= self.map_size[0] + ins_pts_tmp[:, 1] *= self.map_size[1] + ins_pts_tmp = ins_pts_tmp.cpu().data.numpy().astype(np.int32) + cv2.polylines(ins_msk[j], [ins_pts_tmp[:, ::-1]], False, color=1, thickness=self.line_width) + ins_msk = torch.from_numpy(ins_msk).float().cuda() + assert len(ins_msk) == len(targets["points"][i]) + ins_obj = torch.zeros(pt_ins_num, dtype=torch.long, device=unique_ids.device) + ins_masks.append(ins_msk) + ins_objects.append(ins_obj) + targets_refactored.append({ + "sem_masks": sem_masks, + "ins_masks": ins_masks, + "obj_labels": ins_objects, + "points": targets["points"], + "valid_len": targets["valid_len"], + }) + return targets_refactored + + def post_processing(self, outputs): + batch_results, batch_masks = [], [] + batch_size = outputs["obj_logits"][-1][0].shape[0] + for i in range(batch_size): + points, scores, labels = [None], [-1], [0] + masks = np.zeros((self.num_classes, *self.map_size)).astype(np.uint8) + instance_index = 1 + for j in range(self.num_classes): + pred_scores, pred_labels = torch.max(F.softmax(outputs["obj_logits"][-1][j][i], dim=-1), dim=-1) + keep_ids = torch.where((pred_labels == 0).int())[0] # fore-ground + if keep_ids.shape[0] == 0: + continue + keypts = outputs["ctr_im"][-1][j][i][keep_ids].cpu().data.numpy() # [P, N, 2] + keypts[:, :, 0] *= self.map_size[0] + keypts[:, :, 1] *= self.map_size[1] + + valid_pt_idx = F.softmax(outputs["pts_logits"][-1][j][i][keep_ids], dim=-1)[:,:,1].cpu().data.numpy() > 0.5 # [P, N] + valid_pt_idx[:, 0] = 1 + valid_pt_idx[:, -1] = 1 + + for k, (dt_pts, dt_score) in enumerate(zip(keypts, pred_scores[keep_ids])): + select_pt = dt_pts[valid_pt_idx[k]] + cv2.polylines(masks[j], [select_pt.astype(np.int32)[:, ::-1]], False, color=instance_index, thickness=1) + instance_index += 1 + points.append(select_pt) + scores.append(self._to_np(dt_score).item()) + labels.append(j + 1) + batch_results.append({'map': points, 'confidence_level': scores, 'pred_label': labels}) + batch_masks.append(masks) + return batch_results, batch_masks + + @staticmethod + def _to_np(tensor): + return tensor.cpu().data.numpy() + + \ No newline at end of file diff --git a/mapmaster/models/utils/mask_loss.py b/mapmaster/models/utils/mask_loss.py new file mode 100644 index 0000000..8897cbf --- /dev/null +++ b/mapmaster/models/utils/mask_loss.py @@ -0,0 +1,89 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from detectron2.projects.point_rend.point_features import point_sample +from detectron2.projects.point_rend.point_features import get_uncertain_point_coords_with_randomness + + +class SegmentationLoss(nn.Module): + + def __init__(self, ce_weight, dice_weight, use_point_render=False, num_points=8000, oversample=3.0, importance=0.75): + super(SegmentationLoss, self).__init__() + self.ce_weight = ce_weight + self.dice_weight = dice_weight + self.use_point_render = use_point_render + self.num_points = num_points + self.oversample = oversample + self.importance = importance + + def forward(self, dt_masks, gt_masks, stage="loss"): + loss = 0 + if self.use_point_render: + dt_masks, gt_masks = self.points_render(dt_masks, gt_masks, stage) + if self.ce_weight > 0: + loss += self.ce_weight * self.forward_sigmoid_ce_loss(dt_masks, gt_masks) + if self.dice_weight > 0: + loss += self.dice_weight * self.forward_dice_loss(dt_masks, gt_masks) + return loss + + @staticmethod + def forward_dice_loss(inputs, targets): + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + targets = targets.flatten(1) + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + @staticmethod + def forward_sigmoid_ce_loss(inputs, targets): + inputs = inputs.flatten(1) + targets = targets.flatten(1) + loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + return loss.mean(1) + + def points_render(self, src_masks, tgt_masks, stage): + assert stage in ["loss", "matcher"] + assert src_masks.shape == tgt_masks.shape + + src_masks = src_masks[:, None] + tgt_masks = tgt_masks[:, None] + + if stage == "matcher": + point_coords = torch.rand(1, self.num_points, 2, device=src_masks.device) + point_coords_src = point_coords.repeat(src_masks.shape[0], 1, 1) + point_coords_tgt = point_coords.repeat(tgt_masks.shape[0], 1, 1) + else: + point_coords = get_uncertain_point_coords_with_randomness( + src_masks, + lambda logits: self.calculate_uncertainty(logits), + self.num_points, + self.oversample, + self.importance, + ) + point_coords_src = point_coords.clone() + point_coords_tgt = point_coords.clone() + + src_masks = point_sample(src_masks, point_coords_src, align_corners=False).squeeze(1) + tgt_masks = point_sample(tgt_masks, point_coords_tgt, align_corners=False).squeeze(1) + + return src_masks, tgt_masks + + @staticmethod + def calculate_uncertainty(logits): + """ + We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the + foreground class in `classes`. + Args: + logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or + class-agnostic, where R is the total number of predicted masks in all images and C is + the number of foreground classes. The values are logits. + Returns: + scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with + the most uncertain locations having the highest uncertainty score. + """ + assert logits.shape[1] == 1 + gt_class_logits = logits.clone() + return -(torch.abs(gt_class_logits)) diff --git a/mapmaster/models/utils/misc.py b/mapmaster/models/utils/misc.py new file mode 100644 index 0000000..25e394b --- /dev/null +++ b/mapmaster/models/utils/misc.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +import warnings +import torch.nn as nn +from torch.nn import functional as F + + +def c2_xavier_fill(module: nn.Module) -> None: + """ + Initialize `module.weight` using the "XavierFill" implemented in Caffe2. + Also initializes `module.bias` to 0. + Args: + module (torch.nn.Module): module to initialize. + """ + # Caffe2 implementation of XavierFill in fact + # corresponds to kaiming_uniform_ in PyTorch + # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Module, Tensor]`. + nn.init.kaiming_uniform_(module.weight, a=1) + if module.bias is not None: + # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, + # torch.Tensor]`. + nn.init.constant_(module.bias, 0) + + +class Conv2d(torch.nn.Conv2d): + """ + A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. + """ + + def __init__(self, *args, **kwargs): + """ + Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: + Args: + norm (nn.Module, optional): a normalization layer + activation (callable(Tensor) -> Tensor): a callable activation function + It assumes that norm layer is used before activation. + """ + norm = kwargs.pop("norm", None) + activation = kwargs.pop("activation", None) + super().__init__(*args, **kwargs) + + self.norm = norm + self.activation = activation + + def forward(self, x): + # torchscript does not support SyncBatchNorm yet + # https://github.com/pytorch/pytorch/issues/40507 + # and we skip these codes in torchscript since: + # 1. currently we only support torchscript in evaluation mode + # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or + # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. + if not torch.jit.is_scripting(): + with warnings.catch_warnings(record=True): + if x.numel() == 0 and self.training: + # https://github.com/pytorch/pytorch/issues/12013 + assert not isinstance( + self.norm, torch.nn.SyncBatchNorm + ), "SyncBatchNorm does not support empty inputs!" + + x = F.conv2d( + x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/mapmaster/models/utils/position_encoding.py b/mapmaster/models/utils/position_encoding.py new file mode 100644 index 0000000..565d4ce --- /dev/null +++ b/mapmaster/models/utils/position_encoding.py @@ -0,0 +1,217 @@ +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn +import torch.nn.functional as F + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, mask): + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=mask.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + + def __init__(self, num_pos=(50, 50), num_pos_feats=256): + super().__init__() + self.num_pos = num_pos + self.pos_embed = nn.Embedding(num_pos[0] * num_pos[1], num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.pos_embed.weight) + + def forward(self, mask): + h, w = mask.shape[-2:] + pos = self.pos_embed.weight.view(*self.num_pos, -1)[:h, :w] + pos = pos.permute(2, 0, 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1) + return pos + + +class PositionEmbeddingIPM(nn.Module): + + def __init__(self, + encoder=None, + num_pos=(16, 168), + input_shape=(512, 896), + num_pos_feats=64, + sine_encoding=False, + temperature=10000): + super().__init__() + + h, w_expand = num_pos + self.current_shape = (h, w_expand // 6) + self.input_shape = input_shape + + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.encoder = encoder + self.sine_encoding = sine_encoding + + def get_embedding(self, extrinsic, intrinsic, ida_mats): + """ + Get the BeV Coordinate for Image + + Return + xy_world_coord (N, H, W, 2) Ego x, y coordinate + Valid (N, H, W, 1) -- Valid Points or Not 1 -- valid; 0 -- invalid + """ + # extrinsic -> (B, M, 4, 4) + device, b, n = extrinsic.device, extrinsic.shape[0], extrinsic.shape[1] + + x = torch.linspace(0, self.input_shape[1] - 1, self.current_shape[1], dtype=torch.float) + y = torch.linspace(0, self.input_shape[0] - 1, self.current_shape[0], dtype=torch.float) + y_grid, x_grid = torch.meshgrid(y, x) + z = torch.ones(self.current_shape) + feat_coords = torch.stack([x_grid, y_grid, z], dim=-1).to(device) # (H, W, 3) + feat_coords = feat_coords.unsqueeze(0).repeat(n, 1, 1, 1).unsqueeze(0).repeat(b, 1, 1, 1, 1) # (B, N, H, W, 3) + + ida_mats = ida_mats.view(b, n, 1, 1, 3, 3) + image_coords = ida_mats.inverse().matmul(feat_coords.unsqueeze(-1)) # (B, N, H, W, 3, 1) + + intrinsic = intrinsic.view(b, n, 1, 1, 3, 3) # (B, N, 1, 1, 3, 3) + normed_coords = torch.linalg.inv(intrinsic) @ image_coords # (B, N, H, W, 3, 1) + + ext_rots = extrinsic[:, :, :3, :3] # (B, N, 3, 3) + ext_trans = extrinsic[:, :, :3, 3] # (B, N, 3) + + ext_rots = ext_rots.view(b, n, 1, 1, 3, 3) # (B, N, 1, 1, 3, 3) + world_coords = (ext_rots @ normed_coords).squeeze(-1) # (B, N, H, W, 3) + world_coords = F.normalize(world_coords, p=2, dim=-1) + z_coord = world_coords[:, :, :, :, 2] # (B, N, H, W) + + trans_z = ext_trans[:, :, 2].unsqueeze(-1).unsqueeze(-1) # (B, N, 1, 1) + depth = - trans_z / z_coord # (B, N, H, W) + valid = depth > 0 # (B, N, H, W) + + xy_world_coords = world_coords[:, :, :, :, :2] # (B, N, H, W, 2) + xy_world_coords = xy_world_coords * depth.unsqueeze(-1) + valid = valid.unsqueeze(-1) # (B, N, H, W, 1) + + return xy_world_coords, valid + + def forward(self, extrinsic, intrinsic, ida_mats, do_flip): + """ + extrinsic (N, 6, 4, 4) torch.Tensor + intrinsic (N, 6, 3, 3) + """ + device = extrinsic.device + xy_pos_embed, valid = self.get_embedding(extrinsic, intrinsic, ida_mats) + if do_flip: + xy_pos_embed[:, :, :, :, 1] = -1 * xy_pos_embed[:, :, :, :, 1] + # along with w + xy_pos_embed = torch.cat(torch.unbind(xy_pos_embed, dim=1), dim=-2) # (B, H, N*W, 2) + valid = torch.cat(torch.unbind(valid, dim=1), dim=-2) # (B, H, N*W, 2) + if self.sine_encoding: + # Use Sine encoding to get 256 dim embeddings + dim_t = torch.arange(self.num_pos_feats // 2, dtype=torch.float32, device=device) + dim_t = self.temperature ** (2 * (dim_t // 2) / (self.num_pos_feats // 2)) + pos_embed = xy_pos_embed[:, :, :, :, None] / dim_t + pos_x = torch.stack((pos_embed[:, :, :, 0, 0::2].sin(), pos_embed[:, :, :, 0, 1::2].cos()), dim=4) + pos_y = torch.stack((pos_embed[:, :, :, 1, 0::2].sin(), pos_embed[:, :, :, 1, 1::2].cos()), dim=4) + pos_full_embed = torch.cat((pos_y.flatten(3), pos_x.flatten(3)), dim=3) + pos_combined = torch.where(valid, pos_full_embed, torch.tensor(0., dtype=torch.float32, device=device)) + pos_combined = pos_combined.permute(0, 3, 1, 2) # (B, 2, H, W') + else: + assert None + # pos_combined = torch.where(valid, xy_pos_embed, torch.tensor(0., dtype=torch.float32, device=device)) + # pos_combined = pos_combined.permute(0, 3, 1, 2) + + if self.encoder is None: + return pos_combined, valid.squeeze(-1) + else: + pos_embed_contiguous = pos_combined.contiguous() + return self.encoder(pos_embed_contiguous), valid.squeeze(-1) + + +class PositionEmbeddingTgt(nn.Module): + def __init__(self, + encoder=None, + tgt_shape=(40, 20), + map_size=(400, 200), + map_resolution=0.15, + num_pos_feats=64, + sine_encoding=False, + temperature=10000): + super().__init__() + self.tgt_shape = tgt_shape + self.encoder = encoder + self.map_size = map_size + self.map_resolution = map_resolution + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.sine_encoding = sine_encoding + + def forward(self, mask): + B = mask.shape[0] + + map_forward_ratio = self.tgt_shape[0] / self.map_size[0] + map_lateral_ratio = self.tgt_shape[1] / self.map_size[1] + + map_forward_res = self.map_resolution / map_forward_ratio + map_lateral_res = self.map_resolution / map_lateral_ratio + + X = (torch.arange(self.tgt_shape[0] - 1, -1, -1, device=mask.device) + 0.5 - self.tgt_shape[ + 0] / 2) * map_forward_res + Y = (torch.arange(self.tgt_shape[1] - 1, -1, -1, device=mask.device) + 0.5 - self.tgt_shape[ + 1] / 2) * map_lateral_res + + grid_X, grid_Y = torch.meshgrid(X, Y) + pos_embed = torch.stack([grid_X, grid_Y], dim=-1) # (H, W, 2) + + if self.sine_encoding: + dim_t = torch.arange(self.num_pos_feats // 2, dtype=torch.float32, device=mask.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / (self.num_pos_feats // 2)) + + pos_embed = pos_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_embed[:, :, 0, 0::2].sin(), pos_embed[:, :, 0, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_embed[:, :, 1, 0::2].sin(), pos_embed[:, :, 1, 1::2].cos()), dim=3).flatten(2) + pos_full_embed = torch.cat((pos_y, pos_x), dim=2) + + pos_embed = pos_full_embed.unsqueeze(0).repeat(B, 1, 1, 1).permute(0, 3, 1, 2) + else: + pos_embed = pos_embed.unsqueeze(0).repeat(B, 1, 1, 1).permute(0, 3, 1, 2) + + if self.encoder is None: + return pos_embed + else: + pos_embed_contiguous = pos_embed.contiguous() + return self.encoder(pos_embed_contiguous) \ No newline at end of file diff --git a/mapmaster/models/utils/recovery_loss.py b/mapmaster/models/utils/recovery_loss.py new file mode 100644 index 0000000..b25ce3b --- /dev/null +++ b/mapmaster/models/utils/recovery_loss.py @@ -0,0 +1,49 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from detectron2.projects.point_rend.point_features import point_sample + + +class PointRecoveryLoss(nn.Module): + + def __init__(self, ce_weight, dice_weight, curve_width, tgt_shape): + super(PointRecoveryLoss, self).__init__() + self.ce_weight = ce_weight + self.dice_weight = dice_weight + self.kernel = self.generate_kernel(curve_width, tgt_shape) + + def forward(self, points, gt_masks): + points_expanded = points.unsqueeze(2) - self.kernel.repeat(points.shape[0], 1, 1, 1) + points_expanded = torch.clamp(points_expanded.flatten(1, 2), min=0, max=1) # (N, P*w*w, 2) [0, 1] + dt_points = point_sample(gt_masks[:, None], points_expanded, align_corners=False).squeeze(1).flatten(1) + gt_points = torch.ones_like(dt_points) + loss = 0 + if self.ce_weight > 0: + loss += self.ce_weight * self.forward_ce_loss(dt_points, gt_points) + if self.dice_weight > 0: + loss += self.dice_weight * self.forward_dice_loss(dt_points, gt_points) + return loss + + @staticmethod + def generate_kernel(curve_width, tgt_shape, device='cuda'): + width = torch.tensor(list(range(curve_width))) + kernel = torch.stack(torch.meshgrid(width, width), dim=-1).float() + kernel = kernel - curve_width // 2 + kernel[..., 0] = kernel[..., 0] / tgt_shape[1] + kernel[..., 1] = kernel[..., 1] / tgt_shape[0] + kernel = kernel.flatten(0, 1).unsqueeze(0).unsqueeze(0) # (1, 1, w*w, 2) + kernel = kernel.cuda() if device == 'cuda' else kernel + return kernel + + @staticmethod + def forward_dice_loss(inputs, targets): + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + @staticmethod + def forward_ce_loss(inputs, targets): + loss = F.binary_cross_entropy(inputs, targets, reduction="none") + return loss.mean(1) diff --git a/mapmaster/utils/env.py b/mapmaster/utils/env.py new file mode 100644 index 0000000..e7994a8 --- /dev/null +++ b/mapmaster/utils/env.py @@ -0,0 +1,129 @@ +import os +import re +import sys +import PIL +import importlib +import warnings +import subprocess +import torch +import torchvision +import numpy as np +from tabulate import tabulate +from collections import defaultdict + +__all__ = ["collect_env_info"] + + +def collect_torch_env(): + import torch.__config__ + return torch.__config__.show() + + +def collect_git_info(): + try: + import git + from git import InvalidGitRepositoryError + except ImportError: + warnings.warn("Please consider to install gitpython for git info collection by 'pip install gitpython'.") + return "Git status: unknown\n" + + try: + repo = git.Repo(get_root_dir()) + except InvalidGitRepositoryError: + warnings.warn("Current path is possibly not a valid git repository.") + return "Git status: unknown\n" + + msg = "***Git status:***\n{}\nHEAD Commit-id: {}\n".format(repo.git.status().replace("<", "\<"), repo.head.commit) + msg = "{}\n{}".format(msg, "***Git Diff:***\n{}\n".format(repo.git.diff().replace("<", "\<"))) + return msg + + +def detect_compute_compatibility(CUDA_HOME, so_file): + try: + cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump") + if os.path.isfile(cuobjdump): + output = subprocess.check_output("'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True) + output = output.decode("utf-8").strip().split("\n") + sm = [] + for line in output: + line = re.findall(r"\.sm_[0-9]*\.", line)[0] + sm.append(line.strip(".")) + sm = sorted(set(sm)) + return ", ".join(sm) + else: + return so_file + "; cannot find cuobjdump" + except Exception: + # unhandled failure + return so_file + + +def collect_env_info(): + data = [] + data.append(("sys.platform", sys.platform)) + data.append(("Python", sys.version.replace("\n", ""))) + data.append(("numpy", np.__version__)) + data.append(("Pillow", PIL.__version__)) + + data.append(("PyTorch", torch.__version__ + " @" + os.path.dirname(torch.__file__))) + data.append(("PyTorch debug build", torch.version.debug)) + + has_cuda = torch.cuda.is_available() + + data.append(("CUDA available", has_cuda)) + if has_cuda: + devices = defaultdict(list) + for k in range(torch.cuda.device_count()): + devices[torch.cuda.get_device_name(k)].append(str(k)) + for name, devids in devices.items(): + data.append(("GPU " + ",".join(devids), name)) + + from torch.utils.cpp_extension import CUDA_HOME + + data.append(("CUDA_HOME", str(CUDA_HOME))) + + if CUDA_HOME is not None and os.path.isdir(CUDA_HOME): + try: + nvcc = os.path.join(CUDA_HOME, "bin", "nvcc") + nvcc = subprocess.check_output("'{}' -V | tail -n1".format(nvcc), shell=True) + nvcc = nvcc.decode("utf-8").strip() + except subprocess.SubprocessError: + nvcc = "Not Available" + data.append(("NVCC", nvcc)) + + cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) + if cuda_arch_list: + data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list)) + + try: + data.append( + ( + "torchvision", + str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__), + ) + ) + if has_cuda: + try: + torchvision_C = importlib.util.find_spec("torchvision._C").origin + msg = detect_compute_compatibility(CUDA_HOME, torchvision_C) + data.append(("torchvision arch flags", msg)) + except ImportError: + data.append(("torchvision._C", "failed to find")) + except AttributeError: + data.append(("torchvision", "unknown")) + + try: + import cv2 + + data.append(("cv2", cv2.__version__)) + except ImportError: + pass + + env_str = tabulate(data) + "\n" + env_str += collect_git_info() + env_str += "-" * 100 + "\n" + env_str += collect_torch_env() + return env_str + + +def get_root_dir(): + return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) diff --git a/mapmaster/utils/misc.py b/mapmaster/utils/misc.py new file mode 100644 index 0000000..78e13ea --- /dev/null +++ b/mapmaster/utils/misc.py @@ -0,0 +1,411 @@ +import os +import re +import torch +import torchvision +import unicodedata +from sys import stderr +from torch import Tensor +from loguru import logger +from argparse import Action +from collections import deque +from typing import Optional, List +from torch import distributed as dist + + +__all__ = [ + "PyDecorator", "NestedTensor", "AvgMeter", "DictAction", "sanitize_filename", "parse_devices", + "_max_by_axis", "nested_tensor_from_tensor_list", "_onnx_nested_tensor_from_tensor_list", + "get_param_groups", "setup_logger", "get_rank", "get_world_size", "synchronize", "reduce_sum", + "reduce_mean", "all_gather_object", "is_distributed", "is_available" +] + + +class PyDecorator: + @staticmethod + def overrides(interface_class): + def overrider(method): + assert method.__name__ in dir(interface_class), "{} function not in {}".format( + method.__name__, interface_class + ) + return method + return overrider + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +class AvgMeter(object): + def __init__(self, window_size=50): + self.window_size = window_size + self._value_deque = deque(maxlen=window_size) + self._total_value = 0.0 + self._wdsum_value = 0.0 + self._count_deque = deque(maxlen=window_size) + self._total_count = 0.0 + self._wdsum_count = 0.0 + + def reset(self): + self._value_deque.clear() + self._total_value = 0.0 + self._wdsum_value = 0.0 + self._count_deque.clear() + self._total_count = 0.0 + self._wdsum_count = 0.0 + + def update(self, value, n=1): + if len(self._value_deque) >= self.window_size: + self._wdsum_value -= self._value_deque.popleft() + self._wdsum_count -= self._count_deque.popleft() + self._value_deque.append(value * n) + self._total_value += value * n + self._wdsum_value += value * n + self._count_deque.append(n) + self._total_count += n + self._wdsum_count += n + + @property + def avg(self): + return self.global_avg + + @property + def global_avg(self): + return self._total_value / max(self._total_count, 1e-5) + + @property + def window_avg(self): + return self._wdsum_value / max(self._wdsum_count, 1e-5) + + +class DictAction(Action): + """ + argparse action to split an argument into KEY=VALUE form + on the first = and append to a dictionary. List options can + be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit + brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build + list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' + """ + + @staticmethod + def _parse_int_float_bool(val): + try: + return int(val) + except ValueError: + pass + try: + return float(val) + except ValueError: + pass + if val.lower() in ["true", "false"]: + return True if val.lower() == "true" else False + return val + + @staticmethod + def _parse_iterable(val): + """Parse iterable values in the string. + All elements inside '()' or '[]' are treated as iterable values. + Args: + val (str): Value string. + Returns: + list | tuple: The expanded list or tuple from the string. + Examples: + >>> DictAction._parse_iterable('1,2,3') + [1, 2, 3] + >>> DictAction._parse_iterable('[a, b, c]') + ['a', 'b', 'c'] + >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]') + [(1, 2, 3), ['a', 'b], 'c'] + """ + + def find_next_comma(string): + """Find the position of next comma in the string. + If no ',' is found in the string, return the string length. All + chars inside '()' and '[]' are treated as one element and thus ',' + inside these brackets are ignored. + """ + assert (string.count("(") == string.count(")")) and ( + string.count("[") == string.count("]") + ), f"Imbalanced brackets exist in {string}" + end = len(string) + for idx, char in enumerate(string): + pre = string[:idx] + # The string before this ',' is balanced + if (char == ",") and (pre.count("(") == pre.count(")")) and (pre.count("[") == pre.count("]")): + end = idx + break + return end + + # Strip ' and " characters and replace whitespace. + val = val.strip("'\"").replace(" ", "") + is_tuple = False + if val.startswith("(") and val.endswith(")"): + is_tuple = True + val = val[1:-1] + elif val.startswith("[") and val.endswith("]"): + val = val[1:-1] + elif "," not in val: + # val is a single value + return DictAction._parse_int_float_bool(val) + + values = [] + while len(val) > 0: + comma_idx = find_next_comma(val) + element = DictAction._parse_iterable(val[:comma_idx]) + values.append(element) + val = val[comma_idx + 1 :] + if is_tuple: + values = tuple(values) + return values + + def __call__(self, parser, namespace, values, option_string=None): + options = {} + for kv in values: + key, val = kv.split("=", maxsplit=1) + options[key] = self._parse_iterable(val) + setattr(namespace, self.dest, options) + + +def sanitize_filename(value, allow_unicode=False): + value = str(value) + if allow_unicode: + value = unicodedata.normalize("NFKC", value) + else: + value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii") + value = re.sub(r"[^\w\s-]", "", value.lower()) + return re.sub(r"[-\s]+", "-", value).strip("-_") + + +def parse_devices(gpu_ids): + if "-" in gpu_ids: + gpus = gpu_ids.split("-") + gpus[0] = int(gpus[0]) + gpus[1] = int(gpus[1]) + 1 + parsed_ids = ",".join(map(lambda x: str(x), list(range(*gpus)))) + return parsed_ids + else: + return gpu_ids + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("not supported") + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def get_param_groups(model, optimizer_setup): + def match_name_keywords(n, name_keywords): + out = False + for b in name_keywords: + if b in n: + out = True + break + return out + + for n, p in model.named_parameters(): + if match_name_keywords(n, optimizer_setup["freeze_names"]): + p.requires_grad = False + + param_groups = [ + { + "params": [ + p + for n, p in model.named_parameters() + if not match_name_keywords(n, optimizer_setup["backb_names"]) + and not match_name_keywords(n, optimizer_setup["extra_names"]) + and p.requires_grad + ], + "lr": optimizer_setup["base_lr"], + "wd": optimizer_setup["wd"], + }, + { + "params": [ + p + for n, p in model.named_parameters() + if match_name_keywords(n, optimizer_setup["backb_names"]) and p.requires_grad + ], + "lr": optimizer_setup["backb_lr"], + "wd": optimizer_setup["wd"], + }, + { + "params": [ + p + for n, p in model.named_parameters() + if match_name_keywords(n, optimizer_setup["extra_names"]) and p.requires_grad + ], + "lr": optimizer_setup["extra_lr"], + "wd": optimizer_setup["wd"], + }, + ] + + return param_groups + + +def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"): + """setup logger for training and testing. + Args: + save_dir(str): loaction to save log file + distributed_rank(int): device rank when multi-gpu environment + mode(str): log file write mode, `append` or `override`. default is `a`. + Return: + logger instance. + """ + save_file = os.path.join(save_dir, filename) + if mode == "o" and os.path.exists(save_file): + os.remove(save_file) + format = f"[Rank #{distributed_rank}] | " + "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}" + if distributed_rank > 0: + logger.remove() + logger.add( + stderr, + format=format, + level="WARNING", + ) + logger.add( + save_file, + format=format, + filter="", + level="INFO" if distributed_rank == 0 else "WARNING", + enqueue=True, + ) + + return logger + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def synchronize(): + """Helper function to synchronize (barrier) among all processes when using distributed training""" + if not dist.is_available(): + return + if not dist.is_initialized(): + return + current_world_size = dist.get_world_size() + if current_world_size == 1: + return + dist.barrier() + + +def reduce_sum(tensor): + world_size = get_world_size() + if world_size < 2: + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return tensor + + +def reduce_mean(tensor): + return reduce_sum(tensor) / float(get_world_size()) + + +def all_gather_object(obj): + world_size = get_world_size() + if world_size < 2: + return [obj] + output = [None for _ in range(world_size)] + dist.all_gather_object(output, obj) + return output + + +def is_distributed() -> bool: + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def is_available() -> bool: + return dist.is_available()