Skip to content

Commit

Permalink
基类增加基础单算子操作封装;遇错误回退到CPU。
Browse files Browse the repository at this point in the history
  • Loading branch information
TylunasLi committed Jul 31, 2024
1 parent e56280a commit 3f27e13
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
4 changes: 4 additions & 0 deletions include/devices/ascend/ascenddevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ namespace fastllm {
public:
BaseAscendOperator() {}
BaseAscendOperator(std::string name) : name(name) {}
bool RunSingleOp(const std::string &opType, const fastllm::DataDict &inputData,
const fastllm::DataDict &outputData, const fastllm::FloatDict &floatParams,
const fastllm::IntDict &intParams, const std::map <std::string, bool> &boolParams);
protected:
bool warmUpMode;
bool deviceOk = true;
std::string name;
};

Expand Down
39 changes: 30 additions & 9 deletions src/devices/ascend/ascenddevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ namespace fastllm {
return result == 0;
}

bool BaseAscendOperator::RunSingleOp(const std::string &opType, const fastllm::DataDict &inputData,
const fastllm::DataDict &outputData, const fastllm::FloatDict &floatParams,
const fastllm::IntDict &intParams, const std::map <std::string, bool> &boolParams) {
std::vector<aclTensorDesc *> inputTensors;
std::vector<aclDataBuffer *> inputBuffers;
for (auto &pair : inputData)
npu::FastllmAclToTensor(pair, inputTensors, inputBuffers);
std::vector<aclTensorDesc *> outputTensors;
std::vector<aclDataBuffer *> outputBuffers;
for (auto &pair : inputData)
npu::FastllmAclToTensor(pair, outputTensors, outputBuffers);
aclopAttr *attr;
npu::FastllmAclToOpAttribute(floatParams, intParams, boolParams, &attr);
bool result = npu::FastllmAclExecute(opType, inputTensors, inputBuffers,
outputTensors, outputBuffers, attr);
npu::FastllmAclDestoryTensors(inputTensors, inputBuffers, outputTensors, outputBuffers, &attr);
return result;
}

AscendLinearOp::AscendLinearOp() :
BaseAscendOperator("BatchMatMulV2") {}

Expand All @@ -60,6 +79,8 @@ namespace fastllm {

bool AscendLinearOp::CanRun(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
if (!deviceOk)
return false;
Executor *executor = (Executor *) GetExecutor();
this->warmUpMode = executor->isWarmUpMode();

Expand Down Expand Up @@ -124,23 +145,23 @@ namespace fastllm {
if (input.dataType == DataType::FLOAT16) {
if (weight.dataType == DataType::FLOAT16) {
if (warmUpMode)
npu::FastllmAclInitOp(this->name, inputTensorsForComplie, outputTensorsForComplie, attr);
npu::FastllmAclExecuteAfterInit(this->name, inputTensors, inputBuffers,
outputTensors, outputBuffers, attr);
deviceOk = npu::FastllmAclInitOp(this->name, inputTensorsForComplie, outputTensorsForComplie, attr);
deviceOk = npu::FastllmAclExecuteAfterInit(this->name, inputTensors, inputBuffers,
outputTensors, outputBuffers, attr);
} else {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
}
} else if (input.dataType == DataType::FLOAT32) {
if (weight.dataType == DataType::FLOAT32) {
if (warmUpMode)
npu::FastllmAclInitOp(this->name, inputTensorsForComplie, outputTensorsForComplie, attr);
npu::FastllmAclExecuteAfterInit(this->name, inputTensors, inputBuffers,
outputTensors, outputBuffers, attr);
deviceOk = npu::FastllmAclInitOp(this->name, inputTensorsForComplie, outputTensorsForComplie, attr);
deviceOk = npu::FastllmAclExecuteAfterInit(this->name, inputTensors, inputBuffers,
outputTensors, outputBuffers, attr);
} else if (weight.dataType == DataType::FLOAT16) {
if (warmUpMode)
npu::FastllmAclInitOp(this->name, inputTensorsForComplie, outputTensorsForComplie, attr);
npu::FastllmAclExecuteAfterInit(this->name, inputTensors, inputBuffers,
outputTensors, outputBuffers, attr);
deviceOk = npu::FastllmAclInitOp(this->name, inputTensorsForComplie, outputTensorsForComplie, attr);
deviceOk = npu::FastllmAclExecuteAfterInit(this->name, inputTensors, inputBuffers,
outputTensors, outputBuffers, attr);
} else {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
}
Expand Down

0 comments on commit 3f27e13

Please sign in to comment.