-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mllama flash version #2585
Mllama flash version #2585
Conversation
d407659
to
1f52f1c
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
self.layernorm_pre = nn.LayerNorm.load( | ||
prefix=f"{prefix}.layernorm_pre", | ||
weights=weights, | ||
# torch default | ||
eps=1e-05, | ||
) | ||
self.layernorm_post = nn.LayerNorm.load( | ||
prefix=f"{prefix}.layernorm_post", | ||
weights=weights, | ||
# torch default | ||
eps=1e-05, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can use FastLayerNorm
s in place of the native LayerNorm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had more divergence than without, same for rotary.
In any case the vision heads have minimal overhead (compared to the decode).
Given we have pixel values variance (PIL vs Rust image loader).
if config.model_type == "idefics": | ||
model = IdeficsForVisionText2Text(config, weights) | ||
elif config.model_type == "mllama": | ||
model = MllamaForConditionalGeneration( | ||
prefix="", config=config, weights=weights | ||
) | ||
else: | ||
raise RuntimeError(f"Unsupported model type {config.model_type}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if we want to update the name of this class from IDEFICSSharded
to something like VLMShared
since it seems that mllama will use this path too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it doesn't this is old code that needs to be removed, I just fused idefics.py
and idefics_causal_lm.py
.
|
||
|
||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb | ||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove this? I'm cant seem to find where its called
|
||
|
||
# Copied from transformers.models.llama.modeling_llama.rotate_half | ||
def rotate_half(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to above, only used in apply_rotary_pos_emb
|
||
|
||
# Copied from transformers.models.llama.modeling_llama.repeat_kv | ||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, removed them
to be instability after (meaning size of the batch matters.
* Working loading state. * Preprocessing. * Working state ? (Broke idefics1 temporarily). * Cleaner condition. * Fix idefics. * Updating config, removing TODO * Mllama * Ugrade transformers 4.45 * Flashing mllama. * Starting to get there. * Working state. * Integrations tests for mllama (cutting to 10 tokens because there seems' to be instability after (meaning size of the batch matters. * Updating model link. * Earlier assert. * Fix vlm ? * remove log. * Force ignore all images but last. * Default dtype bfloat16. * Update integration test after switch to bf16. * Remove dead code. * Removed dead code. * Upgrade the flake to latest transformers/tokenizers * Move to hf tgi-nix * Upgrade to 0.5.0
What does this PR do?
Fixes: #2598
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.