Skip to content

Commit

Permalink
support Qwen2, change dashinfer model extensions
Browse files Browse the repository at this point in the history
- support Qwen2, add model_type Qwen_v20
- change dashinfer model extensions (asgraph, asparam -> dimodel, ditensors)
- remove xxx_quantize.json config file, use command line arg instead
  • Loading branch information
laiwenzh committed May 29, 2024
1 parent add989c commit 1b9b010
Show file tree
Hide file tree
Showing 67 changed files with 867 additions and 714 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ During inference, the quantized weight is recovered as bfloat16 for matrix multi

![Workflow and Dependency](documents/resources/image/workflow-deps.jpg?row=true)

1. **Model Loading and Serialization**: This procedure involves loading model weights, setting up transformation parameters, and quantization settings. Based on this information, the model is serialized and converted into the DashInfer format (.asparam, .asgraph). This functionality is accessible exclusively through a Python interface and relies on the PyTorch and transformers libraries to access the weights. The version requirements for PyTorch and transformers may vary from model to model. DashInfer itself does not impose any specific version constraints.
1. **Model Loading and Serialization**: This procedure involves loading model weights, setting up transformation parameters, and quantization settings. Based on this information, the model is serialized and converted into the DashInfer format (.dimodel, .ditensors). This functionality is accessible exclusively through a Python interface and relies on the PyTorch and transformers libraries to access the weights. The version requirements for PyTorch and transformers may vary from model to model. DashInfer itself does not impose any specific version constraints.

2. **Model Inference**: This step is responsible for executing the model inference using the serialized model with DashInfer, without depending on components like PyTorch. DashInfer employs [DLPack](https://github.com/dmlc/dlpack) format tensors to facilitate interaction with external frameworks, such as PyTorch. Tensors in DLPack format can be manually created or generated through tensor conversion functions provided by deep learning frameworks. Regarding the C++ interface, since most dependencies have been statically linked, it primarily relies on the OpenMP runtime library and C++ system libraries. We applied [control over symbol exports](https://anadoxin.org/blog/control-over-symbol-exports-in-gcc.html/) to ensure that only DashInfer's API interface symbols are visible, thereby preventing version conflicts with existing libraries in the user's system, such as protobuf.

> Note:
> - .asparam, .asgraph is a special model format defined by DashInfer kernel (allspark).
> - .dimodel, .ditensors is a special model format defined by DashInfer kernel.
> - When utilizing the Python interface, you can combine the code from steps 1 and 2. However, due to the lack of functionality for loading Huggingface models at the C++ level, the C++ interface is limited to conducting inferences with models in the DashInfer format. Therefore, it's essential to serialize the model first using the Python interface before proceeding with the C++ interface.
## Single-NUMA Architecture
Expand Down
21 changes: 17 additions & 4 deletions README_CN.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
# 简介
<div align="center">

[![PyPI](https://img.shields.io/pypi/v/dashinfer)](https://pypi.org/project/dashinfer/)
<!-- [![Documentation Status](https://readthedocs.org/projects/easy-cv/badge/?version=latest)](https://easy-cv.readthedocs.io/en/latest/) -->

<h4 align="center">
<p>
<a href="https://github.com/modelscope/dash-infer/blob/main/README.md">English</a> |
<b>中文</b>
</p>
</h4>

DashInfer用于推理预训练大语言模型(LLM)的推理引擎。

</div>

# 简介

DashInfer采用C++ Runtime编写,提供C++和Python语言接口。DashInfer具有生产级别的高性能表现,适用于多种CPU架构,包括x86和ARMv9。DashInfer支持连续批处理(Continuous Batching)和多NUMA推理(NUMA-Aware),能够充分利用服务器级CPU的算力,为推理14B及以下的LLM模型提供更多的硬件选择。

Expand Down Expand Up @@ -82,12 +95,12 @@ $$ x_{u8} = x_{fp32} / scale + zeropoint $$

![Workflow and Dependency](documents/resources/image/workflow-deps.jpg?row=true)

1. **模型加载与序列化**:此过程负责读取模型权重、配置模型转换参数及量化参数,并根据这些信息对模型进行序列化,并生成DashInfer格式(.asparam、.asgraph)的模型。此功能仅提供Python接口,并依赖于PyTorch和transformers库来访问权重。不同模型对PyTorch和transformers的版本要求可能有所不同,DashInfer本身并没有特殊的版本要求。
1. **模型加载与序列化**:此过程负责读取模型权重、配置模型转换参数及量化参数,并根据这些信息对模型进行序列化,并生成DashInfer格式(.dimodel、.ditensors)的模型。此功能仅提供Python接口,并依赖于PyTorch和transformers库来访问权重。不同模型对PyTorch和transformers的版本要求可能有所不同,DashInfer本身并没有特殊的版本要求。

2. **模型推理**:此步骤负责执行模型推理,使用DashInfer推理序列化后的模型,不依赖PyTorch等组件。DashInfer采用[DLPack](https://github.com/dmlc/dlpack)格式的tensor来实现与外部框架(如PyTorch)的交互。DLPack格式的tensor,可以通过手动创建或由深度学习框架的tensor转换函数产生。对于C++接口,由于已经将几乎所有依赖静态编译,仅对openmp运行时库以及C++系统库的有依赖。我们进行了[链接符号处理](https://anadoxin.org/blog/control-over-symbol-exports-in-gcc.html/),以确保只有DashInfer的API接口符号可见,避免与客户系统中已有的公共库(如protobuf等)发生版本冲突。

> 说明:
> - .asparam、.asgraph是由DashInfer内核(allspark)定义的一种特殊的模型格式
> - .dimodel、.ditensors是由DashInfer内核定义的一种特殊的模型格式
> - 使用Python接口时,可以将步骤1和2的代码放在一起。由于缺少C++层面加载Huggingface模型的功能,C++接口只能进行DashInfer格式的模型推理,因此在使用C++接口前,必须先用Python接口先对模型进行序列化。
## 单NUMA架构图
Expand Down
65 changes: 25 additions & 40 deletions csrc/common/as_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,6 @@ AsStatus AsEngineImpl::SetNumThreads(int num_threads) {
DLOG(INFO) << "AsEngineImpl::SetNumThreads()" << std::endl;
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
device_ctx_->SetNumThreads(num_threads);
if (nranks_ > threadpool_size_) {
threadpool_size_ = nranks_ * 2;
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
}
std::future<AsStatus> result[nranks_];
for (int i = 0; i < workers_.size(); ++i) {
result[i] = threadpool_->enqueue([this, i, &num_threads]() {
Expand Down Expand Up @@ -561,10 +557,6 @@ AsStatus AsEngineImpl::UnloadModelFromDeviceMemory(const char* model_name) {
DLOG(INFO) << "[" << model_name << "] "
<< "AsEngineImpl::UnloadModelFromDeviceMemory()" << std::endl;
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
if (nranks_ > threadpool_size_) {
threadpool_size_ = nranks_ * 2;
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
}
std::future<AsStatus> result[nranks_];
for (int i = 0; i < nranks_; ++i) {
result[i] = threadpool_->enqueue(
Expand Down Expand Up @@ -635,10 +627,6 @@ AsStatus AsEngineImpl::StartModel(const char* model_name, bool do_warmup) {
int64_t min_bytes_available = std::numeric_limits<int64_t>::max();
int64_t rank_0_bytes_available{0};
if (use_adaptive_cache_) {
if (nranks_ > threadpool_size_) {
threadpool_size_ = nranks_ * 2;
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
}
std::future<int64_t> result[nranks_];
for (int i = 0; i < nranks_; ++i) {
result[i] = threadpool_->enqueue([this, i]() -> int64_t {
Expand Down Expand Up @@ -850,6 +838,7 @@ AsStatus AsEngineImpl::StopModel(const char* model_name) {
model_state->cond_var->notify_all();

auto ret = reply_promise->get_future().get();
model_state->model_stopping = true;

if (ret != AsStatus::ALLSPARK_SUCCESS) {
LOG(ERROR) << "[" << model_name << "] "
Expand Down Expand Up @@ -885,10 +874,6 @@ AsStatus AsEngineImpl::ReloadModelFromDeviceMemory(const char* model_name) {
}

AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
if (nranks_ > threadpool_size_) {
threadpool_size_ = nranks_ * 2;
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
}
std::future<AsStatus> result[nranks_];
for (int i = 0; i < nranks_; ++i) {
result[i] = threadpool_->enqueue([this, i, &model_ir]() {
Expand Down Expand Up @@ -1159,10 +1144,6 @@ AsStatus AsEngineImpl::RunTextGenerationContinue(const char* model_name) {
return AsStatus::ALLSPARK_INVALID_CALL_ERROR;
}
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
if (nranks_ > threadpool_size_) {
threadpool_size_ = nranks_ * 2;
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
}
std::future<AsStatus> result[nranks_];
for (int i = 0; i < nranks_; ++i) {
result[i] = threadpool_->enqueue([this, i]() {
Expand Down Expand Up @@ -1236,10 +1217,6 @@ AsStatus AsEngineImpl::RunTextGenerationContext(const char* model_name) {
return AsStatus::ALLSPARK_INVALID_CALL_ERROR;
}
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
if (nranks_ > threadpool_size_) {
threadpool_size_ = nranks_ * 2;
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
}
std::future<AsStatus> result[nranks_];
for (int i = 0; i < nranks_; ++i) {
result[i] = threadpool_->enqueue([this, i]() {
Expand Down Expand Up @@ -1293,18 +1270,15 @@ AsStatus AsEngineImpl::StopRequestByRequestID(const char* model_name,
LOG(ERROR) << "Invalid model name : " << model_name << std::endl;
return AsStatus::ALLSPARK_PARAM_ERROR;
}
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
if (nranks_ > threadpool_size_) {
threadpool_size_ = nranks_ * 2;
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
}

std::future<AsStatus> result[nranks_];
for (int i = 0; i < nranks_; ++i) {
result[i] = threadpool_->enqueue([this, i, request_id]() {
return workers_[i]->StopRequest(request_id);
});
}

AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
// 即使失败、异常,也要让各子线程运行完毕,以保证原子性。在可恢复的情况下,确保下一次请求有干净的环境
AsStatus failed_ret = AsStatus::ALLSPARK_SUCCESS;
for (int i = 0; i < nranks_; ++i) {
Expand Down Expand Up @@ -1333,18 +1307,15 @@ AsStatus AsEngineImpl::ReleaseRequestByRequestID(const char* model_name,
LOG(ERROR) << "Invalid model name : " << model_name << std::endl;
return AsStatus::ALLSPARK_PARAM_ERROR;
}
AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
if (nranks_ > threadpool_size_) {
threadpool_size_ = nranks_ * 2;
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
}

std::future<AsStatus> result[nranks_];
for (int i = 0; i < nranks_; ++i) {
result[i] = threadpool_->enqueue([this, i, request_id]() {
return workers_[i]->ReleaseRequest(request_id);
});
}

AsStatus ret = AsStatus::ALLSPARK_SUCCESS;
// 即使失败、异常,也要让各子线程运行完毕,以保证原子性。在可恢复的情况下,确保下一次请求有干净的环境
AsStatus failed_ret = AsStatus::ALLSPARK_SUCCESS;
for (int i = 0; i < nranks_; ++i) {
Expand Down Expand Up @@ -1461,10 +1432,6 @@ AsStatus AsEngineImpl::StartRequestImpl(
{out_name, std::make_shared<AsTensor>(out_name, DeviceType::CPU,
DataType::INT64, DataMode::DENSE,
Shape{1, engine_max_length_})});
if (nranks_ > threadpool_size_) {
threadpool_size_ = nranks_ * 2;
threadpool_ = std::make_unique<ThreadPool>(threadpool_size_);
}
std::future<AsStatus> result[nranks_];
for (int i = 0; i < nranks_; ++i) {
result[i] = threadpool_->enqueue(
Expand Down Expand Up @@ -1668,7 +1635,6 @@ void AsEngineImpl::ModelRunningThread(
s.c_str()); // set the name (pthread_self() returns the
// pthread_t of the current thread)
bool looping = true;
long loop_cnt = 0;
bool graceful_stop_phase = false;
bool graceful_final_released = false;
std::unique_ptr<EngineControlMessage> graceful_stop_msg = nullptr;
Expand All @@ -1682,7 +1648,6 @@ void AsEngineImpl::ModelRunningThread(

while (looping) {
util::Timer time_outer;
loop_cnt++;
UpdateAsEngineStat();
// print the engine state for easier service trace.
// for multiple numa, only print this info on node 0.
Expand Down Expand Up @@ -1957,6 +1922,7 @@ void AsEngineImpl::ModelRunningThread(
if (graceful_final_released) {
assert(graceful_stop_msg != nullptr);
graceful_stop_msg->promise->set_value(AsStatus::ALLSPARK_SUCCESS);
model_state->model_stopped = true;
DLOG(INFO) << "All done, gracefully stopped!";
break;
}
Expand Down Expand Up @@ -2144,4 +2110,23 @@ std::string AsEngineStat::ToString() const {
return result;
}

std::map<std::string, std::string> AsEngineStat::ToMap() const {
std::map<std::string, std::string> engine_stat_map;
engine_stat_map["free_token"] = std::to_string(free_token);
engine_stat_map["total_token"] = std::to_string(total_token);
engine_stat_map["pendding_request"] = std::to_string(pendding_request);
engine_stat_map["running_request"] = std::to_string(running_request);
engine_stat_map["total_device_memory_pool_size"] =
std::to_string(total_device_memory_pool_size);
engine_stat_map["used_device_memory_pool_size"] =
std::to_string(used_device_memory_pool_size);
engine_stat_map["total_generated_token"] =
std::to_string(total_generated_token);
engine_stat_map["total_prefill_token"] = std::to_string(total_prefill_token);
engine_stat_map["generate_token_persec"] =
std::to_string(generate_token_persec);
engine_stat_map["process_token_persec"] =
std::to_string(process_token_persec);
return engine_stat_map;
}
} // namespace allspark
1 change: 1 addition & 0 deletions csrc/common/engine_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ModelControlState final {
result_queue_map;

bool model_stopping = false; // after GracefulStopModel called...
bool model_stopped = false; // after GracefulStopModel is done.

ModelControlState(const std::string& name) : model_name(name) {
lock = std::make_unique<std::mutex>();
Expand Down
1 change: 1 addition & 0 deletions csrc/core/model/qwen/qwen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ AsStatus QwenModel::Init(const TransformerProto& model_proto,
REGISTER_MODEL("Qwen", QwenModel)
REGISTER_MODEL("Qwen_v10", QwenModel_v10)
REGISTER_MODEL("Qwen_v15", QwenModel_v15)
REGISTER_MODEL("Qwen_v20", QwenModel_v20)
} // namespace allspark
6 changes: 6 additions & 0 deletions csrc/core/model/qwen/qwen.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,10 @@ class QwenModel_v15 : public QwenModel {
: QwenModel(model_type){};
};

class QwenModel_v20 : public QwenModel {
public:
explicit QwenModel_v20(const std::string& model_type = "")
: QwenModel(model_type){};
};

} // namespace allspark
2 changes: 1 addition & 1 deletion csrc/core/operator/general/get_last_line/get_last_line.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace allspark {
class GetLastLineOp : public AsOperator {
public:
explicit GetLastLineOp(const std::string& op_type = "")
: AsOperator(op_type), batch_(0), seq_(0), hid_(0) {}
: AsOperator(op_type) {}
AsStatus Init(const OperatorProto& op_proto, const DeviceContext& ctx,
const TensorMap& weights_map, TensorMap* tensor_map);
AsStatus Reshape() override;
Expand Down
45 changes: 26 additions & 19 deletions csrc/core/operator/general/rotary/rotary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,33 @@ AsStatus RotaryOp::Init(const OperatorProto& op_proto, const DeviceContext& ctx,
if (attr_map.find("rotary_base") != attr_map.end()) {
base_ = *(float*)(attr_map.at("rotary_base").c_str());
}
xlogn_ = -1;
if (attr_map.find("logn_model_embedding") != attr_map.end()) {
xlogn_ = *(int*)(attr_map.at("logn_model_embedding").c_str());

// attr.8 multi_query_group_num
if (attr_map.find("multi_query_group_num") != attr_map.end()) {
group_num_ = *(int*)(attr_map.at("multi_query_group_num").c_str());
} else {
group_num_ = num_heads_;
}
size_per_head_ = ctx.GetSizePerHead();

// backend switch
DeviceType backend = ctx.GetDeviceType();
switch (backend) {
case DeviceType::CPU: {
const CPUContext* cpu_ctx = static_cast<const CPUContext*>(ctx_);
num_heads_ /= cpu_ctx->GetNranks();
if (group_num_ != 1) {
group_num_ /= cpu_ctx->GetNranks();
}
break;
}
default:
LOG(ERROR) << op_type_ << " Operator does not support "
<< DeviceType_Name(backend) << " device type" << std::endl;
return AsStatus::ALLSPARK_RUNTIME_ERROR;
}
kv_stride_ = size_per_head_ * group_num_;
hidden_size_ = size_per_head_ * num_heads_;
return AsStatus::ALLSPARK_SUCCESS;
}
AsStatus RotaryOp::Reshape(RuntimeContext* runtime_ctx) {
Expand All @@ -124,42 +133,40 @@ AsStatus RotaryOp::Reshape(RuntimeContext* runtime_ctx) {
Shape y_shape(x_shape);
batch_size_ = y_shape[0];
seq_len_ = y_shape[1];
hidden_size_ = y_shape[2] / 3;
tensor_map_->at(out_names_[0])->SetShape(std::move(y_shape));
// set variable
if (hidden_size_ % num_heads_) {
LOG(ERROR) << "Invalid attribute in RotaryOp. hidden_size : "
<< hidden_size_ << ", num_heads : " << num_heads_ << std::endl;
qkv_stride_ = y_shape[2];
if (qkv_stride_ != hidden_size_ + 2 * kv_stride_) {
LOG(ERROR) << "Invalid qkv_stride_ in RotaryOp"
<< ", qkv_strde = " << qkv_stride_
<< ", hidden_size = " << hidden_size_
<< ", kv_stride = " << kv_stride_ << std::endl;
return AsStatus::ALLSPARK_RUNTIME_ERROR;
}
size_per_head_ = hidden_size_ / num_heads_;
gemm_batch_ = batch_size_ * num_heads_;
tensor_map_->at(out_names_[0])->SetShape(std::move(y_shape));
return AsStatus::ALLSPARK_SUCCESS;
}
AsStatus RotaryOp::RunRotary(int run_batch_size, AsTensor* rotary_step,
AsTensor* rotary_inv_freq) {
int* run_step = (int*)rotary_step->GetDataPtr();
float* inv_freq = (float*)rotary_inv_freq->GetDataPtr();
int qkv_stride = 3 * hidden_size_;
int qkv_stride = qkv_stride_;
int* batch_offset = nullptr;
int offset = hidden_size_ * SizeofType(dtype_);
void* q_buf = (char*)tensor_map_->at(in_names_[0])->GetDataPtr();
void* k_buf = (char*)q_buf + offset;
void* v_buf = (char*)k_buf + offset;
void* k_buf = (char*)q_buf + hidden_size_ * SizeofType(dtype_);
void* v_buf = (char*)k_buf + kv_stride_ * SizeofType(dtype_);
void* outq_buf = (char*)tensor_map_->at(out_names_[0])->GetDataPtr();
void* outk_buf = (char*)outq_buf + offset;
void* outv_buf = (char*)outk_buf + offset;
void* outk_buf = (char*)outq_buf + hidden_size_ * SizeofType(dtype_);
void* outv_buf = (char*)outk_buf + kv_stride_ * SizeofType(dtype_);

rotary_launcher(dtype_, outq_buf, q_buf, inv_freq, batch_offset,
run_batch_size, seq_len_, run_step, hidden_size_, num_heads_,
size_per_head_, 0, qkv_stride, rotary_type_, rotary_pct_,
xlogn_, ctx_);
rotary_launcher(dtype_, outk_buf, k_buf, inv_freq, batch_offset,
run_batch_size, seq_len_, run_step, hidden_size_, num_heads_,
run_batch_size, seq_len_, run_step, hidden_size_, group_num_,
size_per_head_, 0, qkv_stride, rotary_type_, rotary_pct_, -1,
ctx_);
rotary_launcher(dtype_, outv_buf, v_buf, nullptr, batch_offset,
run_batch_size, seq_len_, run_step, hidden_size_, num_heads_,
run_batch_size, seq_len_, run_step, hidden_size_, group_num_,
size_per_head_, 0, qkv_stride, rotary_type_, rotary_pct_, -1,
ctx_);
return AsStatus::ALLSPARK_SUCCESS;
Expand Down
3 changes: 3 additions & 0 deletions csrc/core/operator/general/rotary/rotary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class RotaryOp : public AsOperator {
float rotary_pct_;
float seqlen_extrapolation_;
int ntk_model_embed_;
int group_num_ = 0;
int qkv_stride_ = 0;
int kv_stride_ = 0;
};

} // namespace allspark
Loading

0 comments on commit 1b9b010

Please sign in to comment.