Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] support 01-ai Yi-6B/34B #899

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Supports the following popular model checkpoints:

| Model and usage | Model size | Reference |
|-----------------------------------------------------------------------------------|------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|
| 01-ai [Yi](tutorials/download_yi.md) | 6B, 34B | [01-ai 2023](https://github.com/01-ai/Yi) |
| EleutherAI [Pythia](tutorials/download_pythia.md) | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| LMSYS [LongChat](tutorials/download_longchat.md) | 7B, 13B | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| LMSYS [Vicuna](tutorials/download_vicuna.md) | 7B, 13B, 33B | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) |
Expand Down
61 changes: 53 additions & 8 deletions chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ def generate(

def decode(fabric: L.Fabric, tokenizer: Tokenizer, token_stream: Iterator[torch.Tensor]) -> int:
tokens_generated = 0
resp = ""
if tokenizer.backend == "huggingface":
try:
for token in token_stream:
fabric.print(tokenizer.decode(token), end="", flush=True)
_decode = tokenizer.decode(token)
fabric.print(_decode, end="", flush=True)
resp += _decode
tokens_generated += 1
except KeyboardInterrupt:
# support stopping generation
Expand All @@ -89,14 +92,15 @@ def decode(fabric: L.Fabric, tokenizer: Tokenizer, token_stream: Iterator[torch.
so_far = torch.cat((so_far, token.view(-1)))
decoded_new = tokenizer.decode(so_far)
fabric.print(decoded_new[len(decoded_so_far) :], end="", flush=True)
resp += decoded_new[len(decoded_so_far) :]
decoded_so_far = decoded_new
tokens_generated += 1
except KeyboardInterrupt:
# support stopping generation
return tokens_generated
else:
raise NotImplementedError(tokenizer.backend)
return tokens_generated
return tokens_generated, resp


@torch.inference_mode()
Expand All @@ -108,6 +112,8 @@ def main(
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
precision: Optional[str] = None,
compile: bool = False,
max_seq_length: Optional[int] = None,
history_length: int = 10,
) -> None:
"""Starts a conversation with a tuned GPT model.

Expand All @@ -122,6 +128,8 @@ def main(
for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
compile: Whether to use compilation to speed up token generation. Will increase startup time.
max_seq_length: The maximum number of tokens to generate (includes prompt). If not specified, will use the model's default.
history_length: The number of previous messages to keep in history. Set to 0 to disable and -1 to keep all.
"""
precision = precision or get_default_supported_precision(training=False)

Expand All @@ -144,8 +152,11 @@ def main(
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
with fabric.init_module(empty_init=True):
model = GPT(config)
if max_seq_length is not None:
model.max_seq_length = max_seq_length
# enable the kv cache
model.set_kv_cache(batch_size=1)
fabric.print(f"Model's max_seq_length: {model.max_seq_length}", file=sys.stderr)
load_checkpoint(fabric, model, checkpoint_path)
model.eval()

Expand All @@ -159,37 +170,49 @@ def main(
model = fabric.setup_module(model)

tokenizer = Tokenizer(checkpoint_dir)
system_prompt, stop_tokens = prompt_config(checkpoint_dir, tokenizer)

L.seed_everything(1234)
history = []
while True:
system_prompt, stop_tokens = prompt_config(checkpoint_dir, tokenizer, history)
try:
prompt = input(">> Prompt: ")
except KeyboardInterrupt:
break
if not prompt:
break
prompt = system_prompt.format(prompt=prompt)
encoded_prompt = tokenizer.encode(prompt, device=fabric.device)
if prompt == "/reset":
history = []
continue
encoded_prompt = tokenizer.encode(system_prompt.format(prompt=prompt),
device=fabric.device)
y = generate(
model, encoded_prompt, model.max_seq_length, temperature=temperature, top_k=top_k, stop_tokens=stop_tokens
)
fabric.print(">> Reply: ", end="")
t0 = time.perf_counter()
tokens_generated = decode(fabric, tokenizer, y)
tokens_generated, reply = decode(fabric, tokenizer, y)
if not history_length:
history.append({"role": "user", "content": prompt})
history.append({"role": "assistant", "content": reply})
if history_length > 0:
history = history[-history_length:]
t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
fabric.print(
f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec,"
f" {tokens_generated} tokens",
f" {tokens_generated} tokens, prompt length {len(encoded_prompt)}",
file=sys.stderr,
)
fabric.print()


def prompt_config(checkpoint_dir: Path, tokenizer: Tokenizer) -> Tuple[str, Tuple[List[int], ...]]:
def prompt_config(checkpoint_dir: Path, tokenizer: Tokenizer,
history: Optional[List] = None) -> Tuple[str, Tuple[List[int], ...]]:
checkpoint_name = str(checkpoint_dir)
if history is None:
history = []
if re.search(r"stabilityai.*tuned-alpha", checkpoint_name):
system_prompt = (
"<|SYSTEM|># StableLM Tuned (Alpha version)\n- StableLM is a helpful and harmless open-source AI language"
Expand Down Expand Up @@ -361,6 +384,28 @@ def prompt_config(checkpoint_dir: Path, tokenizer: Tokenizer) -> Tuple[str, Tupl
stop_tokens = ([tokenizer.eos_id],)
return system_prompt, stop_tokens

if re.search(r"yi-.*b-chat", checkpoint_name.lower()):
'''
<|im_start|>system
{system_message}<|im_end|>
<|im_start|>user
{prompt}<|im_end|>
<|im_start|>assistant
'''
system_prompt = ""
for item in history:
if item["role"] == "user":
system_prompt += f"<|im_start|>user\n{item['content']}<|im_end|>\n"
elif item["role"] == "assistant":
system_prompt += f"<|im_start|>assistant\n{item['content']}<|im_end|>\n"
system_prompt += (
"<|im_start|>user\n"
"{prompt}<|im_end|>\n"
"<|im_start|>assistant\n"
)
stop_tokens = ([tokenizer.token_to_id("<|im_end|>")],)
return system_prompt, stop_tokens

# default format
return "{prompt}", ([tokenizer.eos_id],)

Expand Down
46 changes: 46 additions & 0 deletions lit_gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,4 +1330,50 @@ def norm_class(self) -> Type:

configs.extend(llama_2_function_calling)

###############
# 01/Yi
###############

# https://huggingface.co/01-ai/Yi-6B-Chat/blob/main/config.json
yi_6b = dict(
name="yi-6b{}-hf",
hf_config=dict(org="01-ai", name="Yi-6B{}"),
vocab_size=64000,
padding_multiple=64,
n_layer=32,
n_head=32,
n_embd=4096,
block_size=4096,
n_query_groups=4,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
_mlp_class="LLaMAMLP",
intermediate_size=11008,
rope_base=5000000,
norm_eps=1e-5,
)

# https://huggingface.co/01-ai/Yi-34B-Chat/blob/main/config.json
yi_34b = deepcopy(yi_6b)
yi_34b.update(dict(
name="yi-34b{}-hf",
hf_config=dict(org="01-ai", name="Yi-34B{}"),
n_layer=60,
n_head=56,
n_embd=7168,
n_query_groups=8,
intermediate_size=20480,
))

for c in [yi_6b, yi_34b]:
for posfix in ['', '-200K', '-Chat']:
copy = deepcopy(c)
if posfix == "-200K":
copy["block_size"] = 200000
copy["name"] = c["name"].format(posfix.lower())
copy["hf_config"]["name"] = c["hf_config"]["name"].format(posfix)
configs.append(copy)

name_to_config = {config["name"]: config for config in configs}
1 change: 1 addition & 0 deletions lit_gpt/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import json
import re
from pathlib import Path
from typing import Optional, Union

Expand Down
2 changes: 1 addition & 1 deletion scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def download_from_hub(
" https://huggingface.co/settings/tokens"
)

download_files = ["tokenizer*", "generation_config.json"]
download_files = ["tokenizer*", "*config.json"]
if not tokenizer_only:
if from_safetensors:
if not _SAFETENSORS_AVAILABLE:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ def test_against_original_open_llama_3b(device, dtype):
@torch.inference_mode()
@pytest.mark.parametrize(
"ours_kwargs",
[{"name": "Llama-2-7b-hf"}, {"name": "CodeLlama-7b-hf"}, {"name": "Llama-2-70b-chat-hf", "n_query_groups": 1}],
[{"name": "Llama-2-7b-hf"}, {"name": "CodeLlama-7b-hf"}, {"name": "Llama-2-70b-chat-hf", "n_query_groups": 1},
{"name": "yi-6b-chat-hf"}, {"name": "yi-6b-200k-hf"}, {"name": "yi-6b-hf"},
{"name": "yi-34b-chat-hf"}, {"name": "yi-34b-200k-hf"}, {"name": "yi-34b-hf"}],
)
@pytest.mark.parametrize(
("device", "dtype"),
Expand Down
14 changes: 11 additions & 3 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import os
import re
import sys
from pathlib import Path

Expand Down Expand Up @@ -45,9 +46,16 @@ def test_tokenizer_against_hf(config):
for file, hf_file in file_to_cache.items():
(checkpoint_dir / file).symlink_to(hf_file)

theirs = AutoTokenizer.from_pretrained(
repo_id, cache_dir=cache_dir / "hf", local_files_only=True, token=access_token
)
if re.search(r"yi-.*b", repo_id.lower()):
# AutoTokenizer will direct to LlamaTokenizerFast
from transformers import LlamaTokenizer
theirs = LlamaTokenizer.from_pretrained(
repo_id, cache_dir=cache_dir / "hf", local_files_only=True, token=access_token
)
else:
theirs = AutoTokenizer.from_pretrained(
repo_id, cache_dir=cache_dir / "hf", local_files_only=True, token=access_token
)
ours = Tokenizer(checkpoint_dir)

assert ours.vocab_size == theirs.vocab_size
Expand Down
65 changes: 65 additions & 0 deletions tutorials/download_yi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
## Download Yi weights

The Yi series models are the next generation of open-source large language models trained from scratch by 01.AI.
Targeted as a bilingual language (English and Chinese) model and trained on 3T multilingual corpus, the Yi series models become one of the strongest LLM worldwide, showing promise in language understanding, commonsense reasoning, reading comprehension, and more.
For more details, see the official [README](https://github.com/01-ai/Yi).

To see all available versions, run:

```bash
python scripts/download.py | grep Yi
```

which will print

```text
01-ai/Yi-6B
01-ai/Yi-6B-200K
01-ai/Yi-6B-Chat
01-ai/Yi-34B
01-ai/Yi-34B-200K
01-ai/Yi-34B-Chat
```

Download the weights and convert the checkpoint to the lit-gpt format (eg. 01-ai/Yi-6B-Chat):

```bash
pip install huggingface_hub

python scripts/download.py --repo_id 01-ai/Yi-6B-Chat --from_safetensors=True

# for base version:
python scripts/download.py --repo_id 01-ai/Yi-6B

python scripts/convert_hf_checkpoint.py \
--checkpoint_dir checkpoints/01-ai/Yi-6B-Chat
```

-----

You're done! To execute the model just run:

```bash
pip install sentencepiece

python chat/base.py --checkpoint_dir ./checkpoints/01-ai/Yi-6B-Chat --precision "bf16-true"

# for base version:
python generate/base.py --checkpoint_dir ./checkpoints/01-ai/Yi-6B --precision "bf16-true"
```

Chat example (with chat history):

```bash
>> Prompt: hi
>> Reply: Hello! How can I assist you today?
Time for inference: 0.65 sec total, 13.93 tokens/sec, 9 tokens, prompt length 10

>> Prompt: 你是谁
>> Reply: My name is Yi, and I am a language model based on the transformers architecture developed by 01.AI. My purpose is to be a helpful resource for you, capable of answering questions and offering insightful information across a wide range of topics. How may I help you?
Time for inference: 1.55 sec total, 37.36 tokens/sec, 58 tokens, prompt length 32

>> Prompt: 床前明月光
>> Reply: 床前明月光 (Before the bed shines brightly) 常被理解为唐代诗人李白《静夜思》中的名句,表达了诗人夜晚独处时,看到窗户前洒满月光而产生的思乡之情。这句诗经常被用作中国人思念家乡或亲人时的表达方式。
Time for inference: 1.62 sec total, 37.63 tokens/sec, 61 tokens, prompt length 105
```
Loading