From ea395ef77ecd0f25b0a6240a582fa7fe415dce27 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 5 Mar 2024 18:08:40 +0900 Subject: [PATCH] add example generation + peft --- docs/source/en/model_doc/mamba.md | 54 +++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/docs/source/en/model_doc/mamba.md b/docs/source/en/model_doc/mamba.md index 3ce869d3204e51..7378f79f94df7f 100644 --- a/docs/source/en/model_doc/mamba.md +++ b/docs/source/en/model_doc/mamba.md @@ -37,6 +37,60 @@ Tips: This model was contributed by [ArthurZ](https://huggingface.co/ArthurZ). The original code can be found [here](https://github.com/state-spaces/mamba). +# Usage + +### A simple generation example: +```python +from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer +import torch + +tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m") +tokenizer.pad_token = tokenizer.eos_token + +model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float32) +model.config.use_cache = True +input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"] + +out = model.generate(input_ids, max_new_tokens=10) +print(tokenizer.batch_decode(out)) +``` + +### Peft finetuning +The slow version is not very stable for training, and the fast one needs `float32`! + +```python +from datasets import load_dataset +from trl import SFTTrainer +from peft import LoraConfig +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments +model_id = "ArthurZ/mamba-2.8b" +tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token ="") +model = AutoModelForCausalLM.from_pretrained(model_id) +dataset = load_dataset("Abirate/english_quotes", split="train") +training_args = TrainingArguments( + output_dir="./results", + num_train_epochs=3, + per_device_train_batch_size=4, + logging_dir='./logs', + logging_steps=10, + learning_rate=2e-3 +) +lora_config = LoraConfig( + r=8, + target_modules="all-linear", + task_type="CAUSAL_LM", + bias="none" +) +trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + args=training_args, + peft_config=lora_config, + train_dataset=dataset, + dataset_text_field="quote", +) +trainer.train() +``` ## MambaConfig