Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add files via upload #3622

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions mmseg/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .aspp_head import ASPPHead
from .cc_head import CCHead
from .da_head import DAHead
from .ddr_head import DDRHead
from .dm_head import DMHead
from .dnl_head import DNLHead
from .dpt_head import DPTHead
Expand All @@ -13,19 +12,14 @@
from .fcn_head import FCNHead
from .fpn_head import FPNHead
from .gc_head import GCHead
from .ham_head import LightHamHead
from .isa_head import ISAHead
from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
from .lraspp_head import LRASPPHead
from .mask2former_head import Mask2FormerHead
from .maskformer_head import MaskFormerHead
from .nl_head import NLHead
from .ocr_head import OCRHead
from .pid_head import PIDHead
from .point_head import PointHead
from .psa_head import PSAHead
from .psp_head import PSPHead
from .san_head import SideAdapterCLIPHead
from .segformer_head import SegformerHead
from .segmenter_mask_head import SegmenterMaskTransformerHead
from .sep_aspp_head import DepthwiseSeparableASPPHead
Expand All @@ -34,7 +28,7 @@
from .setr_up_head import SETRUPHead
from .stdc_head import STDCHead
from .uper_head import UPerHead
from .vpd_depth_head import VPDDepthHead
from .atm_head import ATMHead

__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
Expand All @@ -43,6 +37,5 @@
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead', 'SideAdapterCLIPHead'
'KernelUpdateHead', 'ATMHead', 'KernelUpdator'
]
318 changes: 318 additions & 0 deletions mmseg/models/decode_heads/atm_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerDecoder, TransformerDecoderLayer
from typing import Optional
import math
from functools import partial
from mmcv.runner import auto_fp16, force_fp32
import matplotlib.pyplot as plt

from mmseg.models.builder import HEADS
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from timm.models.layers import trunc_normal_
import matplotlib.pyplot as plt
from mmseg.models.losses import accuracy

def trunc_normal_init(module: nn.Module,
mean: float = 0,
std: float = 1,
a: float = -2,
b: float = 2,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
trunc_normal_(module.weight, mean, std, a, b) # type: ignore
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) # type: ignore

def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)

class TPN_Decoder(TransformerDecoder):
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None):
output = tgt
# attns = []
for mod in self.layers:
output, attn = mod(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
# attns.append(attn)

if self.norm is not None:
output = self.norm(output)

return output, attn

class TPN_DecoderLayer(TransformerDecoderLayer):
def __init__(self, **kwargs):
super(TPN_DecoderLayer, self).__init__(**kwargs)
del self.multihead_attn
self.multihead_attn = Attention(
kwargs['d_model'], num_heads=kwargs['nhead'], qkv_bias=True, attn_drop=0.1)

def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2, attn2 = self.multihead_attn(
tgt.transpose(0, 1), memory.transpose(0, 1), memory.transpose(0, 1))
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt, attn2

class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5

self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)

self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, xq, xk, xv):
B, Nq, C = xq.size()
Nk = xk.size()[1]
Nv = xv.size()[1]

q = self.q(xq).reshape(B, Nq, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
k = self.k(xk).reshape(B, Nk, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
v = self.v(xv).reshape(B, Nv, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)

attn = (q @ k.transpose(-2, -1)) * self.scale
attn_save = attn.clone()
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
x = self.proj(x)
x = self.proj_drop(x)
return x.transpose(0, 1), attn_save.sum(dim=1) / self.num_heads


class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""

def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)

def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x

@HEADS.register_module()
class ATMHead(BaseDecodeHead):
def __init__(
self,
img_size,
in_channels,
embed_dims=768,
num_layers=3,
num_heads=8,
use_stages=3,
use_proj=True,
CE_loss=False,
crop_train=False,
shrink_ratio=None,
**kwargs,
):
super(ATMHead, self).__init__(
in_channels=in_channels, **kwargs)

self.image_size = img_size
self.use_stages = use_stages
self.crop_train = crop_train
nhead = num_heads
dim = embed_dims
input_proj = []
proj_norm = []
atm_decoders = []
for i in range(self.use_stages):
# FC layer to change ch
if use_proj:
proj = nn.Linear(self.in_channels, dim)
trunc_normal_(proj.weight, std=.02)
else:
proj = nn.Identity()
self.add_module("input_proj_{}".format(i + 1), proj)
input_proj.append(proj)
# norm layer
if use_proj:
norm = nn.LayerNorm(dim)
else:
norm = nn.Identity()
self.add_module("proj_norm_{}".format(i + 1), norm)
proj_norm.append(norm)
# decoder layer
decoder_layer = TPN_DecoderLayer(d_model=dim, nhead=nhead, dim_feedforward=dim * 4)
decoder = TPN_Decoder(decoder_layer, num_layers)
self.add_module("decoder_{}".format(i + 1), decoder)
atm_decoders.append(decoder)

self.input_proj = input_proj
self.proj_norm = proj_norm
self.decoder = atm_decoders
self.q = nn.Embedding(self.num_classes, dim)

self.class_embed = nn.Linear(dim, self.num_classes + 1)
self.CE_loss = CE_loss
delattr(self, 'conv_seg')

def init_weights(self):
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0)
elif isinstance(m, nn.LayerNorm):
constant_init(m, val=1.0, bias=0.0)

def forward(self, inputs):
x = []
for stage_ in inputs[:self.use_stages]:
x.append(self.d4_to_d3(stage_) if stage_.dim() > 3 else stage_)
x.reverse()
bs = x[0].size()[0]

laterals = []
attns = []
maps_size = []
qs = []
q = self.q.weight.repeat(bs, 1, 1).transpose(0, 1)

for idx, (x_, proj_, norm_, decoder_) in enumerate(zip(x, self.input_proj, self.proj_norm, self.decoder)):
lateral = norm_(proj_(x_))
# if idx == 0:
if True:
laterals.append(lateral)
else:
if laterals[idx - 1].size()[1] == lateral.size()[1]:
laterals.append(lateral + laterals[idx - 1])
else:
# nearest interpolate
l_ = self.d3_to_d4(laterals[idx - 1])
l_ = F.interpolate(l_, scale_factor=2, mode="nearest")
l_ = self.d4_to_d3(l_)
laterals.append(l_ + lateral)

q, attn = decoder_(q, lateral.transpose(0, 1))
attn = attn.transpose(-1, -2)
if self.crop_train and self.training:
blank_attn = torch.zeros_like(attn)
blank_attn = blank_attn[:, 0].unsqueeze(1).repeat(1, (self.image_size//16)**2, 1)
blank_attn[:, inputs[-1]] = attn
attn = blank_attn
self.crop_idx = inputs[-1]
attn = self.d3_to_d4(attn)
maps_size.append(attn.size()[-2:])
qs.append(q.transpose(0, 1))
attns.append(attn)
qs = torch.stack(qs, dim=0)
outputs_class = self.class_embed(qs)
out = {"pred_logits": outputs_class[-1]}

outputs_seg_masks = []
size = maps_size[-1]

for i_attn, attn in enumerate(attns):
if i_attn == 0:
outputs_seg_masks.append(F.interpolate(attn, size=size, mode='bilinear', align_corners=False))
else:
outputs_seg_masks.append(outputs_seg_masks[i_attn - 1] +
F.interpolate(attn, size=size, mode='bilinear', align_corners=False))

out["pred_masks"] = F.interpolate(outputs_seg_masks[-1],
size=(self.image_size, self.image_size),
mode='bilinear', align_corners=False)

out["pred"] = self.semantic_inference(out["pred_logits"], out["pred_masks"])

if self.training:
# [l, bs, queries, embed]
outputs_seg_masks = torch.stack(outputs_seg_masks, dim=0)
out["aux_outputs"] = self._set_aux_loss(
outputs_class, outputs_seg_masks
)
else:
return out["pred"]

return out

@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_seg_masks):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [
{"pred_logits": a, "pred_masks": b}
for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
]

def semantic_inference(self, mask_cls, mask_pred):
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
mask_pred = mask_pred.sigmoid()
semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
return semseg

def d3_to_d4(self, t):
n, hw, c = t.size()
if hw % 2 != 0:
t = t[:, 1:]
h = w = int(math.sqrt(hw))
return t.transpose(1, 2).reshape(n, c, h, w)

def d4_to_d3(self, t):
return t.flatten(-2).transpose(-1, -2)

@force_fp32(apply_to=('seg_logit',))
def losses(self, seg_logit, seg_label):
"""Compute segmentation loss."""
if self.CE_loss:
return super().losses(seg_logit["pred"], seg_label)

if isinstance(seg_logit, dict):
# atm loss
seg_label = seg_label.squeeze(1)
if self.crop_train:
# mask seg_label by crop_idx
bs, h, w = seg_label.size()
mask_label = seg_label.reshape(bs, h//16, 16, w//16, 16)\
.permute(0, 1, 3, 2, 4).reshape(bs, h*w//256, 256)
empty_label = torch.zeros_like(mask_label) + self.ignore_index
empty_label[:, self.crop_idx] = mask_label[:, self.crop_idx]
seg_label = empty_label.reshape(bs, h//16, w//16, 16, 16)\
.permute(0, 1, 3, 2, 4).reshape(bs, h, w)
loss = self.loss_decode(
seg_logit,
seg_label,
ignore_index=self.ignore_index)

loss['acc_seg'] = accuracy(seg_logit["pred"], seg_label, ignore_index=self.ignore_index)
return loss