Skip to content

Commit

Permalink
cosmetic: num_params helper fn
Browse files Browse the repository at this point in the history
  • Loading branch information
i-gao committed Sep 11, 2023
1 parent 7261639 commit 41f288f
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 56 deletions.
8 changes: 8 additions & 0 deletions open_flamingo/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,11 @@ def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
)
padded_tensors.append(padded_tensor)
return torch.stack(padded_tensors)


def num_params(module, filter_to_trainable=False):
"""Returns the number of parameters in the module, or optionally only the trainable parameters"""
if filter_to_trainable:
return sum(p.numel() for p in module.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in module.parameters())
166 changes: 110 additions & 56 deletions open_flamingo/src/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from einops import rearrange
from torch import nn
from typing import List, Optional, Tuple, Union
from .utils import extend_instance, stack_with_padding
from .utils import extend_instance, stack_with_padding, num_params
from .cross_attn_lm import CrossAttentionMixin
from .helpers import DecoupledEmbedding, DecoupledLinear, VLMOutputWithPast
from transformers.modeling_outputs import CausalLMOutputWithPast
Expand Down Expand Up @@ -179,6 +179,39 @@ def _encode_vision_x(self, vision_x: torch.Tensor):
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
return vision_x

def _concat_vision_cache(
self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
):
"""Helper function to include the past vision tokens and past media locations in the output"""
if use_cache:
if past_media_locations is not None and past_vision_tokens is not None:
if vision_tokens is not None:
updated_vision_tokens = torch.cat(
[
past_vision_tokens,
vision_tokens,
],
dim=1,
)
else:
updated_vision_tokens = past_vision_tokens
updated_media_locations = torch.cat(
[
past_media_locations,
lang_x == self.media_token_id,
],
dim=1,
)
else:
updated_vision_tokens = vision_tokens
updated_media_locations = lang_x == self.media_token_id

else:
updated_vision_tokens = None
updated_media_locations = None

return updated_vision_tokens, updated_media_locations

def generate(
self,
vision_x: torch.Tensor,
Expand Down Expand Up @@ -239,7 +272,7 @@ def generate(
@property
def num_trainable_params(self):
"""Print the number of trainable parameters"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
return num_params(self, filter_to_trainable=True)

def set_trainable(self):
"""
Expand Down Expand Up @@ -337,33 +370,13 @@ def _postprocess_outputs_from_forward(
use_cache: bool = False,
):
"""Include the past vision tokens and past media locations in the output"""
if use_cache:
if past_media_locations is not None and past_vision_tokens is not None:
if vision_tokens is not None:
updated_vision_tokens = torch.cat(
[
past_vision_tokens,
vision_tokens,
],
dim=1,
)
else:
updated_vision_tokens = past_vision_tokens
updated_media_locations = torch.cat(
[
past_media_locations,
lang_x == self.media_token_id,
],
dim=1,
)
else:
updated_vision_tokens = vision_tokens
updated_media_locations = lang_x == self.media_token_id

else:
updated_vision_tokens = None
updated_media_locations = None

updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
lang_x=lang_x,
vision_tokens=vision_tokens,
past_vision_tokens=past_vision_tokens,
past_media_locations=past_media_locations,
use_cache=use_cache,
)
output = VLMOutputWithPast(
loss=output.loss,
logits=output.logits,
Expand All @@ -380,6 +393,41 @@ def _post_forward_hook(self):
# clear the conditioned layers
self.lang_model.clear_conditioned_layers()

@property
def num_params_per_module(self):
"""Print the number of parameters per module in the model"""
num_xattn_params = num_params(self.lang_model.gated_cross_attn_layers)
return "\n".join(
[
"Vision encoder: " + str(num_params(self.vision_encoder)),
"Vision tokenizer: " + str(num_params(self.vision_tokenizer)),
"Cross attention: " + str(num_xattn_params),
"Language model: "
+ str(num_params(self.lang_model) - num_xattn_params),
]
)

@property
def num_trainable_params_per_module(self):
"""Print the number of trainable parameters per module in the model"""
num_xattn_params = num_params(
self.lang_model.gated_cross_attn_layers, filter_to_trainable=True
)
return "\n".join(
[
"Vision encoder: "
+ str(num_params(self.vision_encoder, filter_to_trainable=True)),
"Vision tokenizer: "
+ str(num_params(self.vision_tokenizer, filter_to_trainable=True)),
"Cross attention: " + str(num_xattn_params),
"Language model: "
+ str(
num_params(self.lang_model, filter_to_trainable=True)
- num_xattn_params
),
]
)


class VLMWithLanguageStream(VLM):
"""
Expand Down Expand Up @@ -428,7 +476,7 @@ def _prepare_inputs_for_forward(

# get the language embeddings
lang_embeds = self.lang_model.get_input_embeddings()(lang_x)

# build up the multimodal embeddings
B = lang_x.shape[0]
has_labels = labels is not None
Expand Down Expand Up @@ -515,32 +563,13 @@ def _postprocess_outputs_from_forward(
use_cache: bool = False,
):
# Include the past vision tokens and past media locations in the output
if use_cache:
if past_media_locations is not None and past_vision_tokens is not None:
if vision_tokens is not None:
updated_vision_tokens = torch.cat(
[
past_vision_tokens,
vision_tokens,
],
dim=1,
)
else:
updated_vision_tokens = past_vision_tokens
updated_media_locations = torch.cat(
[
past_media_locations,
lang_x == self.media_token_id,
],
dim=1,
)
else:
updated_vision_tokens = vision_tokens
updated_media_locations = lang_x == self.media_token_id

else:
updated_vision_tokens = None
updated_media_locations = None
updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
lang_x=lang_x,
vision_tokens=vision_tokens,
past_vision_tokens=past_vision_tokens,
past_media_locations=past_media_locations,
use_cache=use_cache,
)

# return logits that are the same shape as the original input_ids
logits = output.logits
Expand Down Expand Up @@ -580,3 +609,28 @@ def _postprocess_outputs_from_forward(

def _post_forward_hook(self):
pass

@property
def num_params_per_module(self):
"""Print the number of parameters per module in the model"""
return "\n".join(
[
"Vision encoder: " + str(num_params(self.vision_encoder)),
"Vision tokenizer: " + str(num_params(self.vision_tokenizer)),
"Language model: " + str(num_params(self.lang_model)),
]
)

@property
def num_trainable_params_per_module(self):
"""Print the number of trainable parameters per module in the model"""
return "\n".join(
[
"Vision encoder: "
+ str(num_params(self.vision_encoder, filter_to_trainable=True)),
"Vision tokenizer: "
+ str(num_params(self.vision_tokenizer, filter_to_trainable=True)),
"Language model: "
+ str(num_params(self.lang_model, filter_to_trainable=True)),
]
)

0 comments on commit 41f288f

Please sign in to comment.