Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Qwen2.5 #1834

Merged
merged 18 commits into from
Nov 27, 2024
Merged

Adding Qwen2.5 #1834

merged 18 commits into from
Nov 27, 2024

Conversation

ysjprojects
Copy link
Contributor

@ysjprojects ysjprojects commented Nov 20, 2024

see #1709

Qwen2.5

0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B

Qwen2.5 Coder

0.5B, 1.5B, 3B, 7B, 14B, 32B

Both base and instruct models.

Motivation:

  • Proven SOTA coding performance, especially on Qwen2.5-Coder series.
  • One of more recent open-source LM releases with decent performance on general benchmarks that is competitive with proprietary models.
  • Notably strong performance on Chinese benchmarks.
  • SOTA model that goes as small as 0.5B, which is very where and will serve many use cases in small LMs.

Copy link
Collaborator

@Andrei-Aksionov Andrei-Aksionov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @ysjprojects 👋

Thanks for the PR.
I believe we wanted to implement Qwen models for quite a while, but never did.

Overall, it looks wonderful.
I just added a couple of nits.

litgpt/prompts.py Outdated Show resolved Hide resolved
tests/test_tokenizer.py Outdated Show resolved Hide resolved
litgpt/model.py Show resolved Hide resolved
@Andrei-Aksionov
Copy link
Collaborator

One more thing: please update the description of the PR with more info about the model.

@Andrei-Aksionov
Copy link
Collaborator

I did a quick check of 0.5B and 1.5B instruct version (with the fix for conversion script).
Don't know about other languages. but in English even 0.5B performs surprisingly well 🙂.

@ysjprojects After you apply the fix that I've mentioned in the comment, I'll be happy to merge the PR.

@ysjprojects
Copy link
Contributor Author

One more thing: please update the description of the PR with more info about the model.

is there some specific details that should be included?

@ysjprojects
Copy link
Contributor Author

I did a quick check of 0.5B and 1.5B instruct version (with the fix for conversion script). Don't know about other languages. but in English even 0.5B performs surprisingly well 🙂.

@ysjprojects After you apply the fix that I've mentioned in the comment, I'll be happy to merge the PR.

fixed

@Andrei-Aksionov
Copy link
Collaborator

One more thing: please update the description of the PR with more info about the model.

is there some specific details that should be included?

Would be nice if you added why did you decide to add this exact model.
Coding abilities, excellent support of Chinese, ...
Just a bit of context.

@Andrei-Aksionov
Copy link
Collaborator

Hey @ysjprojects
Forgot to mention that you also need to write tests for that model, since it has a custom conversion scripts.

Unfortunatelly, I cannot push changes to your branch.
But it should be easy.
Something like this for test_convert_lit_checkpoint.py:

@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B"))
@pytest.mark.parametrize(
    ("device", "dtype"),
    [
        (torch.device("cpu"), torch.float32),
        pytest.param(
            torch.device("cuda"),
            torch.float16,
            marks=[
                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
                # is slightly different
                pytest.mark.xfail(raises=AssertionError, strict=False),
                RunIf(min_cuda_gpus=1),
            ],
        ),
    ],
)
def test_against_original_qwen_2_5(model_name, device, dtype):
    torch.set_default_dtype(dtype)

    T = 20
    ours_config = Config.from_name(
        model_name,
        block_size=T,
        n_layer=2,
        n_head=16,
        n_embd=32,
        intermediate_size=86,
    )
    theirs_config = Qwen2Config(
        vocab_size=ours_config.padded_vocab_size,
        hidden_size=ours_config.n_embd,
        head_dim=ours_config.head_size,
        num_attention_heads=ours_config.n_head,
        num_hidden_layers=ours_config.n_layer,
        intermediate_size=ours_config.intermediate_size,
        max_position_embeddings=ours_config.block_size,
        rms_norm_eps=ours_config.norm_eps,
        num_key_value_heads=ours_config.n_query_groups,
        rope_theta=ours_config.rope_base,
        attention_bias=ours_config.attn_bias,
        tie_word_embeddings=True,
    )

    assert ours_config.intermediate_size == theirs_config.intermediate_size

    ours_model = GPT(ours_config).to(device)
    # tie weights
    ours_model.lm_head.weight = ours_model.transformer.wte.weight
    ours_state_dict = ours_model.state_dict()
    theirs_state_dict = {}
    copy_weights_qwen_2_5(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True)
    theirs_model = Qwen2ForCausalLM(theirs_config).to(device)
    keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
    assert not keys.unexpected_keys

    # test end to end
    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
    assert x.size(1) == T
    ours_y = ours_model(x)
    theirs_y = theirs_model(x)["logits"].to(dtype)  # HF converts logits to float
    torch.testing.assert_close(ours_y, theirs_y)

and for test_model.py:

@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B"))
@pytest.mark.parametrize(
    ("device", "dtype"),
    [
        (torch.device("cpu"), torch.float32),
        pytest.param(
            torch.device("cuda"),
            torch.float16,
            marks=[
                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
                # is slightly different
                pytest.mark.xfail(raises=AssertionError, strict=False),
                RunIf(min_cuda_gpus=1),
            ],
        ),
    ],
)
def test_against_original_qwen_2_5(model_name, device, dtype):
    torch.set_default_dtype(dtype)

    T = 20
    ours_config = Config.from_name(
        model_name,
        block_size=T,
        n_layer=2,
        n_head=16,
        n_embd=32,
        intermediate_size=86,
    )
    theirs_config = Qwen2Config(
        vocab_size=ours_config.padded_vocab_size,
        hidden_size=ours_config.n_embd,
        head_dim=ours_config.head_size,
        num_attention_heads=ours_config.n_head,
        num_hidden_layers=ours_config.n_layer,
        intermediate_size=ours_config.intermediate_size,
        max_position_embeddings=ours_config.block_size,
        rms_norm_eps=ours_config.norm_eps,
        num_key_value_heads=ours_config.n_query_groups,
        rope_theta=ours_config.rope_base,
        attention_bias=ours_config.attn_bias,
        tie_word_embeddings=True,
    )

    theirs_model = Qwen2ForCausalLM(theirs_config).to(device)
    theirs_state_dict = theirs_model.state_dict()
    # Gemma weights are shipped without `lm_head.weight`
    theirs_state_dict.pop("lm_head.weight")
    state_dict = {}
    copy_weights_qwen_2_5(ours_config, {}, state_dict, theirs_state_dict)
    ours_model = GPT(ours_config).to(device)
    ours_model.load_state_dict(state_dict)

    # test end to end
    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
    assert x.size(1) == T
    ours_y = ours_model(x)
    theirs_y = theirs_model(x)["logits"].to(dtype)  # HF converts logits to float
    torch.testing.assert_close(ours_y, theirs_y)

@Andrei-Aksionov
Copy link
Collaborator

The comments above should fix CI issues.
(Apart from MacOS one, don't know why exactly that is happening.)

@ysjprojects
Copy link
Contributor Author

The comments above should fix CI issues. (Apart from MacOS one, don't know why exactly that is happening.)

yep everything is fixed except macos

@Andrei-Aksionov
Copy link
Collaborator

Hey @ysjprojects
I think it's worth increasing tolerance for the failing test, like it was done here.

@Andrei-Aksionov Andrei-Aksionov merged commit ff8b1b6 into Lightning-AI:main Nov 27, 2024
8 of 9 checks passed
@Andrei-Aksionov
Copy link
Collaborator

Hello @ysjprojects
An issue with litGPT + Thunder workflow is, I believe, not related to this PR and thus not a blocker.

I'm merging the PR.
Thanks for the patience 😊.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants