Skip to content

Commit

Permalink
llama实现Grouped Query Attention
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli authored and TylunasLi committed Mar 4, 2024
1 parent 97db288 commit 2db84e3
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 41 deletions.
4 changes: 2 additions & 2 deletions include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,9 @@ namespace fastllm {

void CatDirect(Data &input0, const Data &input1, int axis); // 直接把input1的数据拷贝到input0后面(需要input0提前扩容了足够的空间)

void MatMul(const Data &input0, const Data &input1, Data &output, float alpha = 1.0);
void MatMul(const Data &input0, const Data &input1, Data &output, float alpha = 1.0, int group = 1);

void MatMulTransB(const Data &input0, const Data &input1, Data &output, float alpha = 1.0);
void MatMulTransB(const Data &input0, const Data &input1, Data &output, float alpha = 1.0, int group = 1);

void Softmax(const Data &input, Data &output, int axis);

Expand Down
2 changes: 2 additions & 0 deletions include/models/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ namespace fastllm {
float rope_base = 10000.f;

float rope_factor = 1.f;

int num_key_value_heads = num_attention_heads;
};
}

Expand Down
24 changes: 15 additions & 9 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1738,7 +1738,9 @@ namespace fastllm {
int input1Spatial = input1.Count(input1.dims.size() - 2);
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;
AssertInFastLLM(batch0 == batch1, "MatMul's shape error.\n");
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
AssertInFastLLM(batch0 == batch1 * group, "MatMul: input0.dims[1] should be equal to input1.dims[0] * group.\n");
// AssertInFastLLM(batch0 == batch1, "MatMul's shape error.\n");

std::vector <int> dims = input0.dims;
dims.back() = input1.dims[input1.dims.size() - 1];
Expand All @@ -1755,18 +1757,19 @@ namespace fastllm {

output.Allocate();

float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0;
int input0Spatial = input0.Count(input0.dims.size() - 2);
float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0f;
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
int input0Spatial = input0.Count(input0.dims.size() - 2) * group;
int input1Spatial = input1.Count(input1.dims.size() - 2);
int input0Stride = input0.strides[input0.dims.size() - 2];
int input1Stride = input1.strides[input1.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2] * group;
int m = input0.dims.back();
int k = input1.dims[input1.dims.size() - 1];
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;

int outputSpatial = output.Count(output.dims.size() - 2);
int outputSpatial = output.Count(output.dims.size() - 2) * group;
int threadNum = GetThreads();
#ifdef _WIN64
threadNum = 1;
Expand Down Expand Up @@ -1831,7 +1834,9 @@ namespace fastllm {
int input1Spatial = input1.Count(input1.dims.size() - 2);
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;
AssertInFastLLM(batch0 == batch1, "MatMulTransB's shape error.\n");
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
AssertInFastLLM(batch0 == batch1 * group, "MatMulTransB: input0.dims[0] should be equal to input1.dims[0] * group.\n");
// AssertInFastLLM(batch0 == batch1, "MatMulTransB's shape error.\n");

std::vector <int> dims = input0.dims;
dims.back() = input1.dims[input1.dims.size() - 2];
Expand All @@ -1848,17 +1853,18 @@ namespace fastllm {
output.Allocate();

float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0;
int input0Spatial = input0.Count(input0.dims.size() - 2);
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
int input0Spatial = input0.Count(input0.dims.size() - 2) * group;
int input1Spatial = input1.Count(input1.dims.size() - 2);
int input0Stride = input0.strides[input0.dims.size() - 2];
int input1Stride = input1.strides[input1.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2] * group;
int m = input0.dims.back();
int k = input1.dims[input1.dims.size() - 2];
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;

int outputSpatial = output.Count(output.dims.size() - 2);
int outputSpatial = output.Count(output.dims.size() - 2) * group;
int threadNum = GetThreads();
#ifdef _WIN64
threadNum = 1;
Expand Down
34 changes: 20 additions & 14 deletions src/devices/cuda/cudadevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ namespace fastllm {
int input1Spatial = input1.Count(input1.dims.size() - 2);
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;
AssertInFastLLM(batch0 == batch1, "MatMul's shape error.\n");
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
AssertInFastLLM(batch0 == batch1 * group, "MatMul: input0.dims[1] should be equal to input1.dims[0] * group.\n");
// AssertInFastLLM(batch0 == batch1, "MatMul's shape error.\n");

std::vector <int> dims = input0.dims;
dims.back() = input1.dims[input1.dims.size() - 1];
Expand All @@ -354,21 +356,22 @@ namespace fastllm {

output.Allocate();

float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : -1;
int input0Spatial = input0.Count(input0.dims.size() - 2);
float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0f;
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
int input0Spatial = input0.Count(input0.dims.size() - 2) * group;
int input1Spatial = input1.Count(input1.dims.size() - 2);
int input0Stride = input0.strides[input0.dims.size() - 2];
int input1Stride = input1.strides[input1.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2] * group;
int m = input0.dims.back();
int k = input1.dims[input1.dims.size() - 1];
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;

int outputSpatial = output.Count(output.dims.size() - 2);
int outputSpatial = output.Count(output.dims.size() - 2) * group;
FastllmCudaBatchMatMul(input0, input1, output,
input0Spatial, input1Spatial, outputSpatial, input0Stride, input1Stride,
batch0, n, m, k, alpha);
input0Spatial, input1Spatial, outputSpatial, input0Stride, input1Stride,
batch1, n, m, k, alpha);
}

void CudaMatMulTransBOp::Reshape(const std::string &opType, const fastllm::DataDict &datas,
Expand All @@ -388,7 +391,9 @@ namespace fastllm {
int input1Spatial = input1.Count(input1.dims.size() - 2);
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;
AssertInFastLLM(batch0 == batch1, "MatMulTransB's shape error.\n");
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
AssertInFastLLM(batch0 == batch1 * group, "MatMulTransB: input0.dims[0] should be equal to input1.dims[0] * group.\n");
// AssertInFastLLM(batch0 == batch1, "MatMulTransB's shape error.\n");

std::vector <int> dims = input0.dims;
dims.back() = input1.dims[input1.dims.size() - 2];
Expand All @@ -404,21 +409,22 @@ namespace fastllm {

output.Allocate();

float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : -1;
int input0Spatial = input0.Count(input0.dims.size() - 2);
float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0f;
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
int input0Spatial = input0.Count(input0.dims.size() - 2) * group;
int input1Spatial = input1.Count(input1.dims.size() - 2);
int input0Stride = input0.strides[input0.dims.size() - 2];
int input1Stride = input1.strides[input1.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2] * group;
int m = input0.dims.back();
int k = input1.dims[input1.dims.size() - 2];
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;

int outputSpatial = output.Count(output.dims.size() - 2);
int outputSpatial = output.Count(output.dims.size() - 2) * group;
FastllmCudaBatchMatMulTransB(input0, input1, output,
input0Spatial, input1Spatial, outputSpatial, input0Stride, input1Stride,
batch0, n, m, k, alpha);
input0Spatial, input1Spatial, outputSpatial, input0Stride, input1Stride,
batch1, n, m, k, alpha);
}

bool CudaSoftMaxOp::CanRun(const std::string &opType, const fastllm::DataDict &datas,
Expand Down
8 changes: 4 additions & 4 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1938,16 +1938,16 @@ namespace fastllm {
}, {}, {{"axis", axis}});
}

void MatMul(const Data &input0, const Data &input1, Data &output, float alpha) {
void MatMul(const Data &input0, const Data &input1, Data &output, float alpha, int group) {
curExecutor->Run("MatMul", {
{"input0", (Data*)&input0}, {"input1", (Data*)&input1}, {"output", &output}
}, {{"alpha", alpha}}, {});
}, {{"alpha", alpha}}, {{"group", group}});
}

void MatMulTransB(const Data &input0, const Data &input1, Data &output, float alpha) {
void MatMulTransB(const Data &input0, const Data &input1, Data &output, float alpha, int group) {
curExecutor->Run("MatMulTransB", {
{"input0", (Data*)&input0}, {"input1", (Data*)&input1}, {"output", &output}
}, {{"alpha", alpha}}, {});
}, {{"alpha", alpha}}, {{"group", group}});
}

void Softmax(const Data &input, Data &output, int axis) {
Expand Down
30 changes: 18 additions & 12 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ namespace fastllm {

void LlamaModel::InitParams() {
basellm::InitParams();
num_key_value_heads = num_attention_heads;
if (this->weight.dicts.find("num_key_value_heads") != this->weight.dicts.end()) {
num_key_value_heads = atoi(this->weight.dicts["num_key_value_heads"].c_str());
}
head_dim = embed_dim / num_attention_heads;
rotary_dim = head_dim;
if (this->weight.dicts.find("max_position_embeddings") != this->weight.dicts.end()) {
max_positions = atoi(this->weight.dicts["max_position_embeddings"].c_str());
}
Expand Down Expand Up @@ -162,7 +168,7 @@ namespace fastllm {
Linear(attenInput, weight[vWeightName], vBias, v);
}

std::vector <int> qkvSize = {bsz, seqlen, num_attention_heads, -1};
std::vector <int> qkvSize = {bsz, seqlen, -1, head_dim};
q.Reshape(qkvSize);
k.Reshape(qkvSize);
v.Reshape(qkvSize);
Expand All @@ -188,7 +194,7 @@ namespace fastllm {
fastllm::LlamaRotatePosition2D(k, positionIds, *sinDataPtr, *cosDataPtr, rotary_dim);
}

qkvSize = {bsz * seqlen, num_attention_heads, -1};
qkvSize = {bsz * seqlen, -1, head_dim};
q.Reshape(qkvSize);
k.Reshape(qkvSize);
v.Reshape(qkvSize);
Expand Down Expand Up @@ -228,7 +234,7 @@ namespace fastllm {

// 1.2 Attention
// 1.2.0 q * k^T
MatMulTransB(q, pastKey, attenWeights, 1.0 / sqrt(head_dim));
MatMulTransB(q, pastKey, attenWeights, 1.0 / sqrt(head_dim), q.dims[0] / pastKey.dims[0]);
attenWeights.Reshape({1, attenWeights.dims[0], attenWeights.dims[1], attenWeights.dims[2]});
if (alibiData.dims.size() != 0) {
AlibiMask(attenWeights, alibiData, -10000);
Expand All @@ -237,7 +243,7 @@ namespace fastllm {
}

Softmax(attenWeights, attenWeights, -1);
MatMul(attenWeights, pastValue, attenOutput);
MatMul(attenWeights, pastValue, attenOutput, 1.f, attenWeights.dims[1] / pastValue.dims[0]);

attenOutput.Reshape({attenOutput.dims[1], attenOutput.dims[2], attenOutput.dims[3]});
PermuteSelf(attenOutput, {1, 0, 2});
Expand Down Expand Up @@ -345,7 +351,7 @@ namespace fastllm {
Linear(attenInput, weight[vWeightName], vBias, v);
}

std::vector <int> qkvSize = {bsz, seqlen, num_attention_heads, -1};
std::vector <int> qkvSize = {bsz, seqlen, -1, head_dim};
q.Reshape(qkvSize);
k.Reshape(qkvSize);
v.Reshape(qkvSize);
Expand Down Expand Up @@ -375,7 +381,7 @@ namespace fastllm {
PermuteSelf(k, {0, 2, 1, 3});
PermuteSelf(v, {0, 2, 1, 3});

qkvSize = {bsz * num_attention_heads, seqlen, -1};
qkvSize = {-1, seqlen, head_dim};
q.Reshape(qkvSize);
k.Reshape(qkvSize);
v.Reshape(qkvSize);
Expand Down Expand Up @@ -412,7 +418,7 @@ namespace fastllm {

// 1.2 Attention
// 1.2.0 q * k^T
MatMulTransB(q, pastKey, attenWeights, 1.0 / sqrt(head_dim));
MatMulTransB(q, pastKey, attenWeights, 1.0 / sqrt(head_dim), q.dims[0] / pastKey.dims[0]);
attenWeights.Reshape({1, attenWeights.dims[0], attenWeights.dims[1], attenWeights.dims[2]});
if (alibiData.dims.size() != 0) {
attenWeights.Reshape({-1, num_attention_heads, attenWeights.dims[2], attenWeights.dims[3]});
Expand All @@ -422,7 +428,7 @@ namespace fastllm {
AttentionMask(attenWeights, attentionMask, -10000);
}
Softmax(attenWeights, attenWeights, -1);
MatMul(attenWeights, pastValue, attenOutput);
MatMul(attenWeights, pastValue, attenOutput, 1.f, attenWeights.dims[1] / pastValue.dims[0]);

attenOutput.Reshape({attenOutput.dims[1], attenOutput.dims[2], attenOutput.dims[3]});
PermuteSelf(attenOutput, {1, 0, 2});
Expand Down Expand Up @@ -552,7 +558,7 @@ namespace fastllm {
for (int b = 0; b < batch; b++) {
auto &q = curQs[b], &k = curKs[b], &v = curVs[b];

std::vector<int> qkvSize = {bsz, seqLens[b], num_attention_heads, -1};
std::vector<int> qkvSize = {bsz, seqLens[b], -1, head_dim};
q.Reshape(qkvSize);
k.Reshape(qkvSize);
v.Reshape(qkvSize);
Expand Down Expand Up @@ -582,7 +588,7 @@ namespace fastllm {
PermuteSelf(k, {0, 2, 1, 3});
PermuteSelf(v, {0, 2, 1, 3});

qkvSize = {bsz * num_attention_heads, seqLens[b], -1};
qkvSize = {-1, seqLens[b], head_dim};
q.Reshape(qkvSize);
k.Reshape(qkvSize);
v.Reshape(qkvSize);
Expand Down Expand Up @@ -621,7 +627,7 @@ namespace fastllm {

// 1.2 Attention
// 1.2.0 q * k^T
MatMulTransB(q, pastKey, attenWeights, 1.0 / sqrt(head_dim));
MatMulTransB(q, pastKey, attenWeights, 1.0 / sqrt(head_dim), q.dims[0] / pastKey.dims[0]);
attenWeights.Reshape({1, attenWeights.dims[0], attenWeights.dims[1], attenWeights.dims[2]});
if (alibiData.dims.size() != 0) {
AlibiMask(attenWeights, alibiData, -10000);
Expand All @@ -630,7 +636,7 @@ namespace fastllm {
}

Softmax(attenWeights, attenWeights, -1);
MatMul(attenWeights, pastValue, curAttenOutput);
MatMul(attenWeights, pastValue, curAttenOutput, 1.f, attenWeights.dims[1] / pastValue.dims[0]);
curAttenOutput.Reshape({curAttenOutput.dims[1], curAttenOutput.dims[2], curAttenOutput.dims[3]});
PermuteSelf(curAttenOutput, {1, 0, 2});
curAttenOutput.Reshape({seqLens[b], bsz, -1});
Expand Down

0 comments on commit 2db84e3

Please sign in to comment.