Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

[Neural Speed] Support continuous batching + beam search inference in LLAMA #145

Merged
merged 16 commits into from
Mar 4, 2024
Merged
Prev Previous commit
Next Next commit
add model_scratch_enlarge_scale
Signed-off-by: Yu, Zhentao <[email protected]>
zhentaoyu committed Mar 4, 2024
commit 521875752ebb3d453a78127adde1bd2dd24c283d
32 changes: 26 additions & 6 deletions neural_speed/models/llama/llama.h
Original file line number Diff line number Diff line change
@@ -26,18 +26,38 @@ enum llama_model {
LLAMA_65B,
};

static const model_scratch llama_mem_req(int n_layers) {
static const model_scratch llama_mem_req(int n_layers, float enlarge_scale = 1.0f) {
a32543254 marked this conversation as resolved.
Show resolved Hide resolved
switch (n_layers) {
case 32:
return {1024ull * MB, 1024ull * MB, 1608ull * MB};
return {
static_cast<unsigned long long>(enlarge_scale * 1024) * MB,
static_cast<unsigned long long>(enlarge_scale * 1024) * MB,
static_cast<unsigned long long>(enlarge_scale * 1608) * MB,
};
case 40:
return {512ull * MB, 512ull * MB, 1608ull * MB};
return {
static_cast<unsigned long long>(enlarge_scale * 512) * MB,
static_cast<unsigned long long>(enlarge_scale * 512) * MB,
static_cast<unsigned long long>(enlarge_scale * 1608) * MB,
};
case 48:
return {512ull * MB, 512ull * MB, 2366ull * MB};
return {
static_cast<unsigned long long>(enlarge_scale * 512) * MB,
static_cast<unsigned long long>(enlarge_scale * 512) * MB,
static_cast<unsigned long long>(enlarge_scale * 2366) * MB,
};
case 60:
return {512ull * MB, 512ull * MB, 3124ull * MB};
return {
static_cast<unsigned long long>(enlarge_scale * 512) * MB,
static_cast<unsigned long long>(enlarge_scale * 512) * MB,
static_cast<unsigned long long>(enlarge_scale * 3124) * MB,
};
case 80:
return {2048ull * MB, 2048ull * MB, 10240ull * MB};
return {
static_cast<unsigned long long>(enlarge_scale * 2048) * MB,
static_cast<unsigned long long>(enlarge_scale * 2048) * MB,
static_cast<unsigned long long>(enlarge_scale * 10240) * MB,
};
default:
MODEL_ASSERT(false);
}
2 changes: 1 addition & 1 deletion neural_speed/models/llama/llama_utils.cpp
Original file line number Diff line number Diff line change
@@ -81,7 +81,7 @@ void Llama::init(const char* path_model, model_context* ctx, int n_gpu_layer_, b
n_head = hparams.n_head;
n_expert = hparams.n_experts;
n_expert_used = hparams.n_experts_used;
scratch = llama_mem_req(n_layer);
scratch = llama_mem_req(n_layer, lctx.model_scratch_enlarge_scale);
model.scratchs = scratch;
}