Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mllama flash version #2585

Merged
merged 24 commits into from
Oct 2, 2024
Merged

Mllama flash version #2585

merged 24 commits into from
Oct 2, 2024

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented Sep 30, 2024

What does this PR do?

Fixes: #2598

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Narsil Narsil force-pushed the mllama_flash branch 2 times, most recently from d407659 to 1f52f1c Compare September 30, 2024 11:41
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Narsil Narsil mentioned this pull request Oct 1, 2024
5 tasks
Comment on lines +501 to +461
self.layernorm_pre = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_pre",
weights=weights,
# torch default
eps=1e-05,
)
self.layernorm_post = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_post",
weights=weights,
# torch default
eps=1e-05,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can use FastLayerNorms in place of the native LayerNorm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had more divergence than without, same for rotary.
In any case the vision heads have minimal overhead (compared to the decode).

Given we have pixel values variance (PIL vs Rust image loader).

Comment on lines 92 to 99
if config.model_type == "idefics":
model = IdeficsForVisionText2Text(config, weights)
elif config.model_type == "mllama":
model = MllamaForConditionalGeneration(
prefix="", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we want to update the name of this class from IDEFICSSharded to something like VLMShared since it seems that mllama will use this path too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it doesn't this is old code that needs to be removed, I just fused idefics.py and idefics_causal_lm.py.



# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this? I'm cant seem to find where its called



# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to above, only used in apply_rotary_pos_emb



# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, removed them

@Narsil Narsil merged commit d18ed5c into main Oct 2, 2024
13 of 14 checks passed
@Narsil Narsil deleted the mllama_flash branch October 2, 2024 09:22
yuanwu2017 pushed a commit to yuanwu2017/tgi-gaudi that referenced this pull request Oct 27, 2024
* Working loading state.

* Preprocessing.

* Working state ? (Broke idefics1 temporarily).

* Cleaner condition.

* Fix idefics.

* Updating config, removing TODO

* Mllama

* Ugrade transformers 4.45

* Flashing mllama.

* Starting to get there.

* Working state.

* Integrations tests for mllama (cutting to 10 tokens because there seems'
to be instability after (meaning size of the batch matters.

* Updating model link.

* Earlier assert.

* Fix vlm ?

* remove log.

* Force ignore all images but last.

* Default dtype bfloat16.

* Update integration test after switch to bf16.

* Remove dead code.

* Removed dead code.

* Upgrade the flake to latest transformers/tokenizers

* Move to hf tgi-nix

* Upgrade to 0.5.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for Llama 3.2 vision / Mllama
3 participants