Skip to content

Commit

Permalink
llama : add "rank" pooling type
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Sep 19, 2024
1 parent 152e903 commit f03bcd8
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 17 deletions.
3 changes: 2 additions & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1098,8 +1098,9 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & value) {
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
else { throw std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
Expand Down
4 changes: 4 additions & 0 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ int main(int argc, char ** argv) {
}
LOG("\n");
}
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
for (int j = 0; j < n_embd_count; j++) {
LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]);
}
} else {
// print the first part of the embeddings or for a single prompt, the full embedding
for (int j = 0; j < n_prompts; j++) {
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ extern "C" {
LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2,
LLAMA_POOLING_TYPE_LAST = 3,
LLAMA_POOLING_TYPE_RANK = 4,
};

enum llama_attention_type {
Expand Down
43 changes: 27 additions & 16 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10066,6 +10066,10 @@ struct llm_build_context {
struct ggml_tensor * cur;

switch (pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
cur = inp;
} break;
case LLAMA_POOLING_TYPE_MEAN:
{
struct ggml_tensor * inp_mean = build_inp_mean();
Expand All @@ -10077,9 +10081,24 @@ struct llm_build_context {
struct ggml_tensor * inp_cls = build_inp_cls();
cur = ggml_get_rows(ctx0, inp, inp_cls);
} break;
case LLAMA_POOLING_TYPE_NONE:
case LLAMA_POOLING_TYPE_RANK:
{
cur = inp;
struct ggml_tensor * inp_cls = build_inp_cls();
inp = ggml_get_rows(ctx0, inp, inp_cls);

// classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
GGML_ASSERT(model.cls != nullptr);
GGML_ASSERT(model.cls_b != nullptr);
GGML_ASSERT(model.cls_out != nullptr);
GGML_ASSERT(model.cls_out_b != nullptr);

cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
cur = ggml_tanh(ctx0, cur);
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);

// broadcast across the embedding size to make it compatible with the llama_get_embeddings API
cur = ggml_repeat(ctx0, cur, inp);
} break;
default:
{
Expand Down Expand Up @@ -11293,18 +11312,6 @@ struct llm_build_context {

cur = inpL;

// classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
// TODO: become pooling layer?
if (model.cls) {
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls, cur), model.cls_b);

cur = ggml_tanh(ctx0, cur);

cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
// TODO: cur is now a scalar - what to do?
}

cb(cur, "result_embd", -1);

ggml_build_forward_expand(gf, cur);
Expand Down Expand Up @@ -16280,7 +16287,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
}
}

if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
if (cparams.embeddings && (
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
const int64_t n_tokens = batch.n_tokens;
const int64_t n_seq_tokens = batch.n_seq_tokens;
const int64_t n_seqs = batch.n_seqs;
Expand All @@ -16295,7 +16304,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
const llama_seq_id seq_id = batch.seq_id[s][0];

// TODO: adapt limits to n_seqs when batch.equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");

for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = batch.pos[s*n_seq_tokens + i];
Expand Down Expand Up @@ -16822,6 +16831,7 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_RANK:
{
// extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = lctx.embd_seq;
Expand Down Expand Up @@ -17025,6 +17035,7 @@ static int llama_encode_internal(
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_RANK:
{
// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
Expand Down

0 comments on commit f03bcd8

Please sign in to comment.