Skip to content

Commit

Permalink
Vectorize
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Apr 20, 2024
1 parent 75a8ccb commit 94afdbd
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions src/transformers/models/zoedepth/modeling_zoedepth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
""" PyTorch ZoeDepth model. """


import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -115,38 +114,41 @@ def __init__(self, config):
nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
)

def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> List[torch.Tensor]:
"""
Args:
hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
List of hidden states from the backbone.
"""
out = []
batch_size = hidden_states[0].shape[0]

# stack along batch dimension
# shape (batch_size*num_stages, sequence_length + 1, hidden_size)
hidden_states = torch.cat(hidden_states, dim=0)

for i, hidden_state in enumerate(hidden_states):
# reshape to (batch_size, num_channels, height, width)
cls_token, hidden_state = hidden_state[:, 0], hidden_state[:, 1:]
batch_size, sequence_length, num_channels = hidden_state.shape
if patch_height is not None and patch_width is not None:
hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
else:
size = int(math.sqrt(sequence_length))
hidden_state = hidden_state.reshape(batch_size, size, size, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
# reshape to (batch_size, num_channels, height, width)
cls_token, hidden_states = hidden_states[:, 0], hidden_states[:, 1:]
total_batch_size, sequence_length, num_channels = hidden_states.shape
hidden_states = hidden_states.reshape(total_batch_size, patch_height, patch_width, num_channels)
hidden_states = hidden_states.permute(0, 3, 1, 2).contiguous()

out = []
for stage_idx, (hidden_state, cls_token_stage) in enumerate(
zip(hidden_states.split(batch_size, dim=0), cls_token.split(batch_size, dim=0))
):
feature_shape = hidden_state.shape
if self.readout_type == "project":
# reshape to (batch_size, height*width, num_channels)
hidden_state = hidden_state.flatten(2).permute((0, 2, 1))
readout = cls_token.unsqueeze(1).expand_as(hidden_state)
readout = cls_token_stage.unsqueeze(dim=1).expand_as(hidden_state)
# concatenate the readout token to the hidden states and project
hidden_state = self.readout_projects[i](torch.cat((hidden_state, readout), -1))
hidden_state = self.readout_projects[stage_idx](torch.cat((hidden_state, readout), -1))
# reshape back to (batch_size, num_channels, height, width)
hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape)
elif self.readout_type == "add":
hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1)
hidden_state = hidden_state.flatten(2) + cls_token_stage.unsqueeze(-1)
hidden_state = hidden_state.reshape(feature_shape)
hidden_state = self.layers[i](hidden_state)
hidden_state = self.layers[stage_idx](hidden_state)
out.append(hidden_state)

return out
Expand Down Expand Up @@ -329,7 +331,7 @@ def __init__(self, config):
# fusion
self.fusion_stage = ZoeDepthFeatureFusionStage(config)

def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> List[torch.Tensor]:
"""
Args:
hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
Expand Down Expand Up @@ -460,7 +462,7 @@ def __init__(
min_temp=1e-7,
act=torch.softmax,
):
"""Conditional Log Binomial distribution.
"""Per-pixel MLP followed by a Conditional Log Binomial distribution.
Args:
in_features (`int`):
Expand All @@ -482,10 +484,7 @@ def __init__(
"""
super().__init__()
self.p_eps = p_eps
self.max_temp = max_temp
self.min_temp = min_temp
self.log_binomial_transform = LogBinomial(n_classes, act=act)

bottleneck = (in_features + condition_dim) // bottleneck_factor
self.mlp = nn.Sequential(
nn.Conv2d(in_features + condition_dim, bottleneck, kernel_size=1, stride=1, padding=0),
Expand All @@ -495,6 +494,11 @@ def __init__(
nn.Softplus(),
)

self.p_eps = p_eps
self.max_temp = max_temp
self.min_temp = min_temp
self.log_binomial_transform = LogBinomial(n_classes, act=act)

def forward(self, main_feature, condition_feature):
"""
Args:
Expand Down

0 comments on commit 94afdbd

Please sign in to comment.