Skip to content

Commit

Permalink
增加NPU的单算子Sliu和MulTo
Browse files Browse the repository at this point in the history
  • Loading branch information
TylunasLi committed Aug 4, 2024
1 parent 2fc0bbb commit 948da1e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 0 deletions.
13 changes: 13 additions & 0 deletions include/devices/ascend/ascenddevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class AscendSiluOp : public BaseAscendOperator {
public:
AscendSiluOp();
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class AscendMulToOp : public BaseAscendOperator {
public:
AscendMulToOp();
bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

}

#endif // FASTLLM_ASCEND_DEVICE_H
42 changes: 42 additions & 0 deletions src/devices/ascend/ascenddevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace fastllm {
this->deviceType = "npu";
npu::FastllmAclInit();
this->ops["Linear"] = new AscendLinearOp();
this->ops["Silu"] = new AscendSiluOp();
this->ops["MulTo"] = new AscendMulToOp();
}

AscendNpuDevice::~AscendNpuDevice() {
Expand Down Expand Up @@ -224,4 +226,44 @@ namespace fastllm {
npu::FastllmAclDestoryTensors(inputTensors, inputBuffers, outputTensors, outputBuffers, &attr);
}

AscendSiluOp::AscendSiluOp() :
BaseAscendOperator("Swish") {}

void AscendSiluOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data *input = datas.find("input")->second;
Data *output = datas.find("output")->second;
DynamicShapeDict dynamicShapes;
dynamicShapes["x"] = std::make_pair(std::vector<int32_t>({0, 1}), std::vector<std::vector<int64_t>>({{1,128}, {1,2048}}));
dynamicShapes["y"] = std::make_pair(std::vector<int32_t>({0, 1}), std::vector<std::vector<int64_t>>({{1,128}, {1,2048}}));
deviceOk = CompileAndRunSingleOp(this->name, {{"x", input}}, {{"y", output}}, dynamicShapes, {{"scale", 1.0}}, {}, {});
}

AscendMulToOp::AscendMulToOp() :
BaseAscendOperator("Mul") {}

bool AscendMulToOp::CanRun(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
if (!BaseAscendOperator::CanRun(opType, datas, floatParams, intParams))
return false;
float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0f;
return alpha == 1.0f;
}

void AscendMulToOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input0 = *(datas.find("input0")->second);
Data &input1 = *(datas.find("input1")->second);

AssertInFastLLM((input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32) ||
(input0.dataType == DataType::FLOAT16 && input1.dataType == DataType::FLOAT16),
"MulTo error: Data's type should be float32 or float16.\n");
AssertInFastLLM(input0.dims == input1.dims, "MulTo error: input's shape should be same.\n");
DynamicShapeDict dynamicShapes;
dynamicShapes["x1"] = std::make_pair(std::vector<int32_t>({0, 1}), std::vector<std::vector<int64_t>>({{1,128}, {1,2048}}));
dynamicShapes["x2"] = std::make_pair(std::vector<int32_t>({0, 1}), std::vector<std::vector<int64_t>>({{1,128}, {1,2048}}));
dynamicShapes["y"] = std::make_pair(std::vector<int32_t>({0, 1}), std::vector<std::vector<int64_t>>({{1,128}, {1,2048}}));
deviceOk = CompileAndRunSingleOp(this->name, {{"x1", &input0}, {"x2", &input1}}, {{"y", &input0}}, dynamicShapes, {}, {}, {});
}

}
6 changes: 6 additions & 0 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ namespace fastllm {
const fastllm::Data &positionIds, std::vector<std::pair<Data, Data>> &pastKeyValues,
const GenerationConfig &generationConfig, const LastTokensManager &lastTokens,
std::vector <std::vector <float>*> *retLogits) {
#ifdef USE_ASCEND_NPU
this->mergeSwiglu = true;
#endif
if (!mergeQKV) {
bool canMerge = true;
for (int i = 0; i < block_cnt; i++) {
Expand Down Expand Up @@ -246,6 +249,9 @@ namespace fastllm {

this->mergeSwiglu = canMerge;
}
#ifdef USE_ASCEND_NPU
this->mergeSwiglu = false;
#endif

Data alibiData;
if (this->weight.dicts["use_alibi"] == "1") {
Expand Down

0 comments on commit 948da1e

Please sign in to comment.