Skip to content

Commit

Permalink
Add Nemotron GGUF Loading Support (#34725)
Browse files Browse the repository at this point in the history
* Add Nemotron GGUF Loading Support

* fix the Nemotron architecture assignation

---------

Co-authored-by: Marc Sun <[email protected]>
  • Loading branch information
farrosalferro and SunMarc authored Nov 21, 2024
1 parent d4e1acb commit c57eafd
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ For now the supported model architectures are the architectures that have been v
- Starcoder2
- T5
- Mamba
- Nemotron

## Example usage

Expand Down
27 changes: 27 additions & 0 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,20 @@
"output_norm": "backbone.norm_f",
"output.weight": "lm_head.weight",
},
"nemotron": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_v": "self_attn.v_proj",
"attn_k": "self_attn.k_proj",
"attn_output": "self_attn.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
}


Expand Down Expand Up @@ -397,6 +411,18 @@
"ssm.time_step_rank": "time_step_rank",
"ssm.inner_size": "intermediate_size",
},
"nemotron": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "norm_eps",
"vocab_size": "vocab_size",
},
}

GGUF_TOKENIZER_MAPPING = {
Expand Down Expand Up @@ -793,6 +819,7 @@ def converted(self) -> Tokenizer:
"starcoder2": GGUFGPTConverter,
"t5": GGUFT5Converter,
"mamba": GGUFGPTConverter,
"nemotron": GGUFGPTConverter,
}


Expand Down
40 changes: 40 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class GgufIntegrationTests(unittest.TestCase):
starcoder2_original_model_id = "bigcode/starcoder2-3b"
mamba_original_model_id = "state-spaces/mamba-2.8b-hf"
mamba_model_id = "jpodivin/mamba-2.8b-hf-GGUF"
nemotron_original_model_id = "nvidia/Nemotron-Mini-4B-Instruct"
nemotron_model_id = "bartowski/Nemotron-Mini-4B-Instruct-GGUF"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand Down Expand Up @@ -106,6 +108,8 @@ class GgufIntegrationTests(unittest.TestCase):
fp16_starcoder2_gguf_model_id = "starcoder2-3b.fp16.gguf"
q6_k_mamba_model_id = "ggml-model-Q6_K.gguf"
fp16_mamba_model_id = "ggml-model-f16.gguf"
q6_k_nemotron_model_id = "Nemotron-Mini-4B-Instruct-Q6_K.gguf"
fp16_nemotron_model_id = "Nemotron-Mini-4B-Instruct-f16.gguf"

example_text = "Hello"

Expand Down Expand Up @@ -792,6 +796,42 @@ def test_mamba_q6_k(self):
EXPECTED_TEXT = "Hello,I answerthe question.\n\nA"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_nemotron_weights_conversion_fp16(self):
original_model = AutoModelForCausalLM.from_pretrained(
self.nemotron_original_model_id,
torch_dtype=torch.float16,
)

converted_model = AutoModelForCausalLM.from_pretrained(
self.nemotron_model_id,
gguf_file=self.fp16_nemotron_model_id,
torch_dtype=torch.float16,
)

converted_state_dict = converted_model.state_dict()
original_state_dict = original_model.state_dict()

for layer_name, original_params in original_state_dict.items():
if layer_name in converted_state_dict:
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
else:
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")

def test_nemotron_q6_k(self):
model = AutoModelForCausalLM.from_pretrained(
self.nemotron_model_id,
gguf_file=self.q6_k_nemotron_model_id,
torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(self.nemotron_model_id, gguf_file=self.q6_k_nemotron_model_id)
text = tokenizer(self.example_text, return_tensors="pt")["input_ids"]
out = model.generate(text, max_new_tokens=10)

EXPECTED_TEXT = "'Hello. hotmail.com.'"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down

0 comments on commit c57eafd

Please sign in to comment.