From 365a71df0d15e32713d9ef3d8063e2b241d33219 Mon Sep 17 00:00:00 2001 From: geetu040 Date: Sun, 22 Dec 2024 02:44:27 +0500 Subject: [PATCH] update upsample and projection --- .../convert_depth_pro_weights_to_hf.py | 33 +-- .../models/depth_pro/modeling_depth_pro.py | 220 ++++++++++-------- 2 files changed, 138 insertions(+), 115 deletions(-) diff --git a/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py b/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py index f4895f7730c1e6..15c063ca377a00 100644 --- a/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py +++ b/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py @@ -59,25 +59,25 @@ r"fov.head.head.(\d+).(weight|bias)": r"fov_model.head.\1.\2", # upsamples (hard coded; regex is not very feasible here) - "encoder.upsample_latent0.0.weight": "depth_pro.encoder.upsample_intermediate.1.proj.weight", - "encoder.upsample_latent0.1.weight": "depth_pro.encoder.upsample_intermediate.1.upsample_blocks.0.weight", - "encoder.upsample_latent0.2.weight": "depth_pro.encoder.upsample_intermediate.1.upsample_blocks.1.weight", - "encoder.upsample_latent0.3.weight": "depth_pro.encoder.upsample_intermediate.1.upsample_blocks.2.weight", - "encoder.upsample_latent1.0.weight": "depth_pro.encoder.upsample_intermediate.0.proj.weight", - "encoder.upsample_latent1.1.weight": "depth_pro.encoder.upsample_intermediate.0.upsample_blocks.0.weight", - "encoder.upsample_latent1.2.weight": "depth_pro.encoder.upsample_intermediate.0.upsample_blocks.1.weight", - "encoder.upsample0.0.weight": "depth_pro.encoder.upsample_scaled_images.2.proj.weight", - "encoder.upsample0.1.weight": "depth_pro.encoder.upsample_scaled_images.2.upsample_blocks.0.weight", - "encoder.upsample1.0.weight": "depth_pro.encoder.upsample_scaled_images.1.proj.weight", - "encoder.upsample1.1.weight": "depth_pro.encoder.upsample_scaled_images.1.upsample_blocks.0.weight", - "encoder.upsample2.0.weight": "depth_pro.encoder.upsample_scaled_images.0.proj.weight", - "encoder.upsample2.1.weight": "depth_pro.encoder.upsample_scaled_images.0.upsample_blocks.0.weight", - "encoder.upsample_lowres.weight": "depth_pro.encoder.upsample_image.upsample_blocks.0.weight", - "encoder.upsample_lowres.bias": "depth_pro.encoder.upsample_image.upsample_blocks.0.bias", + "encoder.upsample_latent0.0.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.5.0.weight", + "encoder.upsample_latent0.1.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.5.1.weight", + "encoder.upsample_latent0.2.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.5.2.weight", + "encoder.upsample_latent0.3.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.5.3.weight", + "encoder.upsample_latent1.0.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.4.0.weight", + "encoder.upsample_latent1.1.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.4.1.weight", + "encoder.upsample_latent1.2.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.4.2.weight", + "encoder.upsample0.0.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.3.0.weight", + "encoder.upsample0.1.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.3.1.weight", + "encoder.upsample1.0.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.2.0.weight", + "encoder.upsample1.1.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.2.1.weight", + "encoder.upsample2.0.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.1.0.weight", + "encoder.upsample2.1.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.1.1.weight", + "encoder.upsample_lowres.weight": "depth_pro.encoder.feature_upsample.upsample_blocks.0.0.weight", + "encoder.upsample_lowres.bias": "depth_pro.encoder.feature_upsample.upsample_blocks.0.0.bias", # projections between encoder and fusion r"decoder.convs.(\d+).weight": lambda match: ( - f"projections.{4-int(match.group(1))}.weight" + f"depth_pro.encoder.feature_projection.projections.{4-int(match.group(1))}.weight" ), # fusion stage @@ -274,6 +274,7 @@ def main(): ) if args.push_to_hub: + print("Pushing to hub...") model.push_to_hub(args.hub_repo_id) image_processor.push_to_hub(args.hub_repo_id) diff --git a/src/transformers/models/depth_pro/modeling_depth_pro.py b/src/transformers/models/depth_pro/modeling_depth_pro.py index e23cfbdc9f5004..c24ffce7bf93e4 100644 --- a/src/transformers/models/depth_pro/modeling_depth_pro.py +++ b/src/transformers/models/depth_pro/modeling_depth_pro.py @@ -743,21 +743,60 @@ def forward( ) -class DepthProUpsampleBlock(nn.Module): - def __init__( - self, - input_dims, - intermediate_dims, - output_dims, - n_upsample_layers, - use_proj=True, - bias=False, - ): +class DepthProFeatureUpsample(nn.Module): + def __init__(self, config: DepthProConfig): super().__init__() + self.config = config + + self.upsample_blocks = nn.ModuleList() + + # for image_features + self.upsample_blocks.append( + self._create_upsample_block( + input_dims=config.hidden_size, + intermediate_dims=config.hidden_size, + output_dims=config.scaled_images_feature_dims[0], + n_upsample_layers=1, + use_proj=False, + bias=True, + ) + ) + + # for scaled_images_features + for i, feature_dims in enumerate(config.scaled_images_feature_dims): + upsample_block = self._create_upsample_block( + input_dims=config.hidden_size, + intermediate_dims=feature_dims, + output_dims=feature_dims, + n_upsample_layers=1, + ) + self.upsample_blocks.append(upsample_block) + + # for intermediate_features + for i, feature_dims in enumerate(config.intermediate_feature_dims): + intermediate_dims = config.fusion_hidden_size if i == 0 else feature_dims + upsample_block = self._create_upsample_block( + input_dims=config.hidden_size, + intermediate_dims=intermediate_dims, + output_dims=feature_dims, + n_upsample_layers=2 + i, + ) + self.upsample_blocks.append(upsample_block) - # create first projection block + def _create_upsample_block( + self, + input_dims: int, + intermediate_dims: int, + output_dims: int, + n_upsample_layers: int, + use_proj: bool = True, + bias: bool = False, + ) -> nn.Module: + upsample_block = nn.Sequential() + + # create first projection layer if use_proj: - self.proj = nn.Conv2d( + proj = nn.Conv2d( in_channels=input_dims, out_channels=intermediate_dims, kernel_size=1, @@ -765,11 +804,9 @@ def __init__( padding=0, bias=bias, ) - else: - self.proj = nn.Identity() + upsample_block.append(proj) - # create following upsample blocks - self.upsample_blocks = nn.Sequential() + # create following upsample layers for i in range(n_upsample_layers): in_channels = intermediate_dims if i == 0 else output_dims layer = nn.ConvTranspose2d( @@ -780,11 +817,47 @@ def __init__( padding=0, bias=bias, ) - self.upsample_blocks.append(layer) + upsample_block.append(layer) + + return upsample_block + + def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]: + upsampled_features = [] + for i, upsample_block in enumerate(self.upsample_blocks): + upsampled_feature = upsample_block(features[i]) + upsampled_features.append(upsampled_feature) + return upsampled_features + + +class DepthProFeatureProjection(nn.Module): + def __init__(self, config: DepthProConfig): + super().__init__() + self.config = config + + combined_feature_dims = config.scaled_images_feature_dims + config.intermediate_feature_dims + self.projections = nn.ModuleList() + for i, in_channels in enumerate(combined_feature_dims): + if i == len(combined_feature_dims) - 1 and in_channels == config.fusion_hidden_size: + # projection for last layer can be ignored if input and output channels already match + self.projections.append(nn.Identity()) + else: + self.projections.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + ) - def forward(self, features): - projected = self.proj(features) - return self.upsample_blocks(projected) + def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]: + projected_features = [] + for i, projection in enumerate(self.projections): + upsampled_feature = projection(features[i]) + projected_features.append(upsampled_feature) + return projected_features def interpolate(pixel_values, scale_factor): @@ -944,38 +1017,8 @@ def __init__(self, config: DepthProConfig): # image encoder self.image_encoder = DepthProViT(config) - # upsampling patch features (high_res, med_res, low_res) - (3-5) in diagram - self.upsample_scaled_images = nn.ModuleList() - for i, feature_dims in enumerate(self.scaled_images_feature_dims): - upsample_block = DepthProUpsampleBlock( - input_dims=config.hidden_size, - intermediate_dims=feature_dims, - output_dims=feature_dims, - n_upsample_layers=1, - ) - self.upsample_scaled_images.append(upsample_block) - - # upsampling intermediate features - (1-2) in diagram - self.upsample_intermediate = nn.ModuleList() - for i, feature_dims in enumerate(self.intermediate_feature_dims): - intermediate_dims = self.fusion_hidden_size if i == 0 else feature_dims - upsample_block = DepthProUpsampleBlock( - input_dims=config.hidden_size, - intermediate_dims=intermediate_dims, - output_dims=feature_dims, - n_upsample_layers=2 + i, - ) - self.upsample_intermediate.append(upsample_block) - - # upsampling image features - (6) in diagram - self.upsample_image = DepthProUpsampleBlock( - input_dims=config.hidden_size, - intermediate_dims=config.hidden_size, - output_dims=config.scaled_images_feature_dims[0], - n_upsample_layers=1, - use_proj=False, - bias=True, - ) + # upsample features + self.feature_upsample = DepthProFeatureUpsample(config) # for STEP 7: fuse low_res and image features self.fuse_image_with_low_res = nn.Conv2d( @@ -987,6 +1030,9 @@ def __init__(self, config: DepthProConfig): bias=True, ) + # project features + self.feature_projection = DepthProFeatureProjection(config) + def forward( self, pixel_values: torch.Tensor, @@ -1079,10 +1125,6 @@ def forward( features, batch_size=B, merge_out_size=self.out_size * 2**i ) # (B, config.hidden_size, self.out_size*2**i, self.out_size*2**i) - # d. upsample - features = self.upsample_scaled_images[i](features) - # (B, self.scaled_images_feature_dims[i], self.out_size*2**(i+1), self.out_size*2**(i+1)) - scaled_images_features.append(features) # STEP 5: get intermediate features - (1-2) in diagram @@ -1114,10 +1156,6 @@ def forward( merge_out_size=self.out_size * 2 ** (self.n_scaled_images - 1), ) # (B, config.hidden_size, self.out_size*2**(self.n_scaled_images-1), self.out_size*2**(self.n_scaled_images-1)) - # d. upsample - features = self.upsample_intermediate[i](features) - # (B, config.intermediate_feature_dims[i], self.out_size*2**(self.n_scaled_images+i+1), self.out_size*2**(self.n_scaled_images+i+1)) - intermediate_features.append(features) # STEP 6: get image features - (6) in diagram @@ -1133,24 +1171,30 @@ def forward( # c. merge patches back together # no merge required for image_features as they are already in batches instead of patches - # d. upsample - image_features = self.upsample_image( - image_features - ) # (B, config.scaled_images_feature_dims[0], self.out_size*2**1, self.out_size*2**1) - - # STEP 7: apply fusion (global_features = image_features + scaled_images_features[0]) - # fuses image_features with lowest resolution features as they are of same size - scaled_images_features[0] = torch.cat((scaled_images_features[0], image_features), dim=1) - scaled_images_features[0] = self.fuse_image_with_low_res(scaled_images_features[0]) - - # STEP 8: return these features in order of increasing size as what fusion expects + # STEP 7: combine all features features = [ + image_features, # (B, self.scaled_images_feature_dims[i], self.out_size*2**(i+1), self.out_size*2**(i+1)) *scaled_images_features, # (B, config.intermediate_feature_dims[i], self.out_size*2**(self.n_scaled_images+i+1), self.out_size*2**(self.n_scaled_images+i+1)) *intermediate_features, ] + # STEP 8: upsample features + features = self.feature_upsample(features) + + # STEP 9: apply fusion + # (global features = low res features + image features) + # fuses image_features with lowest resolution features as they are of same size + global_features = torch.cat((features[1], features[0]), dim=1) + global_features = self.fuse_image_with_low_res(global_features) + features = [global_features, *features[2:]] + + # STEP 10: project features + features = self.feature_projection(features) + + # STEP 11: return output + last_hidden_state = patch_encodings.last_hidden_state hidden_states = patch_encodings.hidden_states if output_hidden_states else None attentions = patch_encodings.attentions if output_attentions else None @@ -1380,11 +1424,13 @@ def forward(self, hidden_state: torch.Tensor, residual: Optional[torch.Tensor] = # Take from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage with DPT->DepthPro -# with num_layers, deconv and reversed layers +# with deconv and reversed layers class DepthProFeatureFusionStage(nn.Module): - def __init__(self, config, num_layers): + def __init__(self, config): super().__init__() - self.num_layers = num_layers + self.config = config + + self.num_layers = len(config.intermediate_hook_ids) + len(config.scaled_images_ratios) self.layers = nn.ModuleList() for _ in range(self.num_layers - 1): self.layers.append(DepthProFeatureFusionLayer(config)) @@ -1491,9 +1537,6 @@ def forward( # c. merge patches back together # no merge required for fov_features as they are already in batches instead of patches - # d. upsample - # no upsampling required for fov_features, the head later downsamples to create scalars - global_features = self.global_neck(global_features) fov_features = fov_features + global_features @@ -1548,28 +1591,8 @@ def __init__(self, config, use_fov_model=None): # dinov2 (vit) like encoders self.depth_pro = DepthProModel(config) - # project hidden states from encoder to match expected inputs in fusion stage - combined_feature_dims = config.scaled_images_feature_dims + config.intermediate_feature_dims - self.projections = nn.ModuleList() - for i, in_channels in enumerate(combined_feature_dims): - if i == len(combined_feature_dims) - 1 and in_channels == config.fusion_hidden_size: - # projection for last layer can be ignored if input and output channels already match - self.projections.append(nn.Identity()) - else: - self.projections.append( - nn.Conv2d( - in_channels=in_channels, - out_channels=config.fusion_hidden_size, - kernel_size=3, - stride=1, - padding=1, - bias=False, - ) - ) - # dpt (vit) like fusion stage - self.num_fusion_layers = len(config.intermediate_hook_ids) + len(config.scaled_images_ratios) - self.fusion_stage = DepthProFeatureFusionStage(config, num_layers=self.num_fusion_layers) + self.fusion_stage = DepthProFeatureFusionStage(config) # depth estimation head self.head = DepthProDepthEstimationHead(config) @@ -1647,7 +1670,6 @@ def forward( return_dict=True, ) features = depth_pro_outputs.features - features = [proj(feature) for proj, feature in zip(self.projections, features)] fused_hidden_states = self.fusion_stage(features) predicted_depth = self.head(fused_hidden_states[-1])