Skip to content

Commit

Permalink
fix merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Anas Awadalla committed Sep 5, 2023
2 parents ecc74ad + b91da53 commit 3230715
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 12 deletions.
33 changes: 27 additions & 6 deletions open_flamingo/src/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def create_model_and_transforms(
tokenizer_path,
local_files_only=use_local_files,
cache_dir=cache_dir,
trust_remote_code=True
)
# add Flamingo special tokens to the tokenizer
text_tokenizer.add_special_tokens(
Expand All @@ -60,11 +61,17 @@ def create_model_and_transforms(
# modify labels for the loss.
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
new_tokens += 1


ids_for_additional_special_tokens = text_tokenizer.convert_tokens_to_ids(
["<|endofchunk|>", "<image>", "<PAD>"] if new_tokens == 3 else ["<|endofchunk|>", "<image>"]
)
print(f"Added {new_tokens} new tokens to the tokenizer")

lang_encoder = AutoModelForCausalLM.from_pretrained(
lang_encoder_path,
local_files_only=use_local_files,
cache_dir=cache_dir,
trust_remote_code=True
)

# hacks for MPT-1B, which doesn't have a get_input_embeddings method
Expand All @@ -80,18 +87,30 @@ def set_input_embeddings(self, new_embeddings):
extend_instance(lang_encoder, EmbeddingFnMixin)

if not hasattr(lang_encoder, "get_output_embeddings"):
lang_encoder.get_output_embeddings = lambda: lang_encoder.lm_head
lang_encoder.set_output_embeddings = lambda x: setattr(
lang_encoder, "lm_head", x
)
if hasattr(lang_encoder, "lm_head"):
lang_encoder.get_output_embeddings = lambda: lang_encoder.lm_head
else:
raise ValueError(
"We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
)

if not hasattr(lang_encoder, "set_output_embeddings"):
if hasattr(lang_encoder, "lm_head"):
lang_encoder.set_output_embeddings = lambda x: setattr(
lang_encoder, "lm_head", x
)
else:
raise ValueError(
"We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
)

# convert LM to FlamingoLM
extend_instance(lang_encoder, FlamingoLMMixin)

if decoder_layers_attr_name is None:
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)

model = Flamingo(
vision_encoder,
lang_encoder,
Expand All @@ -101,6 +120,8 @@ def set_input_embeddings(self, new_embeddings):
"width"
],
cross_attn_every_n_layers=cross_attn_every_n_layers,
# HACK: The tokenizer's size and model's vocab size sometimes don't match. We use this to find the smaller of the flamingo special tokens and use that as the vocab size (even though the true one might be smaller).
vocab_size=min(ids_for_additional_special_tokens),
new_tokens=new_tokens, # number of tokens embeddings to train
**flamingo_kwargs,
)
Expand Down
7 changes: 5 additions & 2 deletions open_flamingo/src/flamingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ def __init__(
eoc_token_id: int,
media_token_id: int,
vis_dim: int,
vocab_size: int,
new_tokens: int,
cross_attn_every_n_layers: int = 1,
gradient_checkpointing: bool = False,
new_tokens: int = 2,
):
"""
Args:
Expand All @@ -34,9 +35,10 @@ def __init__(
media_token_id (int): Token id for <image>
vis_dim (int): Dimension of the visual features.
Visual features are projected to match this shape along the last dimension.
vocab_size (int): Size of the base vocabulary.
new_tokens (int): Number of new tokens added to the tokenizer.
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
new_tokens (int, optional): Number of new tokens added to the tokenizer. Defaults to 2.
"""
super().__init__()
self.eoc_token_id = eoc_token_id
Expand All @@ -56,6 +58,7 @@ def __init__(
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=cross_attn_every_n_layers,
gradient_checkpointing=gradient_checkpointing,
vocab_size=vocab_size,
new_tokens=new_tokens,
)
self._use_gradient_checkpointing = gradient_checkpointing
Expand Down
5 changes: 3 additions & 2 deletions open_flamingo/src/flamingo_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def init_flamingo(
vis_hidden_size,
cross_attn_every_n_layers,
gradient_checkpointing,
vocab_size,
new_tokens,
):
"""
Expand All @@ -114,7 +115,7 @@ def init_flamingo(
input_embed_weights = self.get_input_embeddings().weight
self.set_input_embeddings(
FlamingoDecoupledEmbedding(
num_embeddings=input_embed_weights.shape[0],
num_embeddings=vocab_size,
num_additional_embeddings=new_tokens,
embedding_dim=input_embed_weights.shape[1],
partially_freeze=True,
Expand All @@ -124,7 +125,7 @@ def init_flamingo(

out_embeds = FlamingoDecoupledLinear(
in_features=input_embed_weights.shape[1],
out_features=input_embed_weights.shape[0],
out_features=vocab_size,
bias=getattr(self.get_output_embeddings(), "bias", None) is not None,
out_additional_features=new_tokens,
partially_freeze=True,
Expand Down
4 changes: 2 additions & 2 deletions open_flamingo/src/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def __init__(
num_embeddings,
num_additional_embeddings,
embedding_dim,
partially_freeze=False,
partially_freeze=True,
device=None,
dtype=None,
padding_idx=None,
Expand All @@ -311,7 +311,7 @@ def __init__(
Number of additional embeddings. Only useful when you `partially_freeze=True`.
embedding_dim (`int`):
The size of each embedding vector
partially_freeze: (`bool`, *optional*, defaults to `False`):
partially_freeze: (`bool`, *optional*, defaults to `True`):
If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
padding_idx (`int`, *optional*):
The padding index (needs to be less than num_embeddings)
Expand Down

0 comments on commit 3230715

Please sign in to comment.