Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding mplugdocowl #31059

Draft
wants to merge 55 commits into
base: main
Choose a base branch
from

Conversation

danaaubakirova
Copy link

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@danaaubakirova danaaubakirova marked this pull request as draft May 27, 2024 12:39
self.vocab_size = config.text_config.vocab_size
#initialize LlamaAttention
#replace_llama_modality_adaptive()
transformers.models.llama.modeling_llama.LlamaAttention = MultiwayAttention
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether this changes the model architecture, because it still doesn't have the modules that I need. Like Multiway Attention or new Decoder.

Copy link
Contributor

@NielsRogge NielsRogge May 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, was seeing this PR so might help here :)

Usually we avoid using inheritance or classes from different models. Each model is typically implemented in a single modeling_xxx.py script which is self-contained and does not depend on any other modeling files.

This means that we'll need to define a MPlugDowOwlMultiwayAttention class, a MPlugDowOwlTextDecoderLayer class, and so on.

One can leverage #Copied from statements above a class or method in case a class or method is the same as another one, e.g. MistralAttention copies from LlamaAttention since they're the same.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello Niels! @molbap provided a similar feedback:) Yes, I just created a separate file for the language model and defined separate classes. This indeed helps! Thank you!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! No worries :)

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice progress!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small comment, should be named with ...modeling... somewhere for glance value

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments on RoPE, will add another review later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need a checkout here to clear diff

Comment on lines 87 to 104
class MPLUGDocOwlRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be an older implementation of RoPE. The class name seems to say this is basic RoPE but it applies linear scaling

       t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        t = t / self.scaling_factor

Which is now done on the position_ids, and in the correct class. Better use the more recent one of llama codebase since this is llama-based

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And from Omni configuration - looks like default RoPE scaling is None, which in this case would instantiate this class, that has a linear scaling

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure this is not impacting generation, it could change a few things if positional embs are messed up

Comment on lines 21 to 33
The mPLUGDocOwl model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>

The abstract from the paper is the following:

*<INSERT PAPER ABSTRACT HERE>*

Tips:

<INSERT TIPS ABOUT MODEL HERE>

This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Todo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still todo, add the paper authors, abstract, tips, your contributor hf handle, original gh repo

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

almost done, tips need to be added

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comments weren't published as I had poor connectivity... most are outdated but submitting in case something slips by

"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit confused here why pixel_values are passed? Images are probably not needed - inputs_embeds are created from images and text, and are then passed to the model forward() call

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep them because to get the image_features, pixel_values need to be not None. Same thing is happening in llava code.


def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = MPLUGDocOwlRotaryEmbedding(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here IIUC this is not the correct RoPE class, it adds scaling

Comment on lines +55 to +64
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only useful for FA2

danaaubakirova and others added 27 commits June 27, 2024 13:18
@danaaubakirova danaaubakirova mentioned this pull request Jul 16, 2024
5 tasks
@huggingface huggingface deleted a comment from github-actions bot Jul 29, 2024
@huggingface huggingface deleted a comment from github-actions bot Sep 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants