Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Enable tiny_llama #270

Merged
merged 6 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ Neural Speed supports the following models:
<td>8192</td>
</tr>
<tr>
<td><a href="https://huggingface.co/meta-llama/Llama-2-7b-chat-hf" target="_blank" rel="noopener noreferrer">LLaMA2-7B</a>,
<td><a href="https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0" target="_blank" rel="noopener noreferrer">TinyLlama-1.1B</a>,
<a href="https://huggingface.co/meta-llama/Llama-2-7b-chat-hf" target="_blank" rel="noopener noreferrer">LLaMA2-tB</a>,
<a href="https://huggingface.co/meta-llama/Llama-2-13b-chat-hf" target="_blank" rel="noopener noreferrer">LLaMA2-13B</a>,
<a href="https://huggingface.co/meta-llama/Llama-2-70b-chat-hf" target="_blank" rel="noopener noreferrer">LLaMA2-70B</a></td>
<td>✅</td>
Expand Down
2 changes: 2 additions & 0 deletions neural_speed/convert/convert_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,6 +1357,8 @@ def load_some_model(path: Path) -> ModelPlus:
if path.is_dir():
# Check if it's a set of safetensors files first
files = list(path.glob("model-00001-of-*.safetensors"))
if not files:
files = list(path.glob("model*.safetensors")) # for only one safetensor
if not files:
# Try the PyTorch patterns too, with lower priority
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
Expand Down
7 changes: 7 additions & 0 deletions neural_speed/models/llama/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

enum llama_model {
LLAMA_UNKNOWN,
TINY_LLAMA,
LLAMA_7B,
LLAMA_13B,
LLAMA_30B,
Expand All @@ -28,6 +29,12 @@ enum llama_model {

static const model_scratch llama_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 22:
return {
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
static_cast<unsigned long long>(scratch_size_ratio * 2048) * MB,
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
};
case 32:
return {
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
Expand Down
Loading