Skip to content

Commit

Permalink
rerank : cleanup + comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Sep 25, 2024
1 parent 6916ed1 commit 62a45d1
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 14 deletions.
2 changes: 1 addition & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ int main(int argc, char ** argv) {
}
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
for (int j = 0; j < n_embd_count; j++) {
LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]);
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
}
} else {
// print the first part of the embeddings or for a single prompt, the full embedding
Expand Down
16 changes: 11 additions & 5 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1419,7 +1419,7 @@ struct server_context {
queue_results.send(res);
}

void send_rank(const server_slot & slot, const llama_batch & batch) {
void send_rerank(const server_slot & slot, const llama_batch & batch) {
server_task_result res;
res.id = slot.id_task;
res.error = false;
Expand All @@ -1440,19 +1440,19 @@ struct server_context {

res.data = json {
{"index", slot.index},
{"rank", -1e6},
{"score", -1e6},
};

continue;
}

res.data = json {
{"index", slot.index},
{"rank", embd[0]},
{"score", embd[0]},
};
}

SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str());
SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());

queue_results.send(res);
}
Expand Down Expand Up @@ -1493,6 +1493,9 @@ struct server_context {
else if (prompt.is_array()) {
std::vector<json> prompts = prompt;
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// prompts[0] is the question
// the rest are the answers/documents
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
for (size_t i = 1; i < prompts.size(); i++) {
json qd;
qd.push_back(prompts[0]);
Expand All @@ -1501,6 +1504,7 @@ struct server_context {
create_task(data, true, qd);
}
} else {
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
for (size_t i = 0; i < prompts.size(); i++) {
const auto & e = prompts[i];
if (e.is_string() || json_is_array_of_numbers(e)) {
Expand Down Expand Up @@ -1965,6 +1969,7 @@ struct server_context {
// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
// TODO: make enum
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;

// next, batch any pending prompts without exceeding n_batch
Expand Down Expand Up @@ -2133,6 +2138,7 @@ struct server_context {
slot.n_prompt_tokens_processed = 0;
}

// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
Expand Down Expand Up @@ -2318,7 +2324,7 @@ struct server_context {
}

if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
send_rank(slot, batch_view);
send_rerank(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue; // continue loop of slots
Expand Down
2 changes: 1 addition & 1 deletion examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ static json format_response_rerank(const json & request, const json & ranks) {
for (const auto & rank : ranks) {
data.push_back(json{
{"index", i++},
{"relevance_score", json_value(rank, "rank", 0.0)},
{"relevance_score", json_value(rank, "score", 0.0)},
});
}

Expand Down
11 changes: 6 additions & 5 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ extern "C" {
LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2,
LLAMA_POOLING_TYPE_LAST = 3,
LLAMA_POOLING_TYPE_RANK = 4,
LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph
};

enum llama_attention_type {
Expand All @@ -202,9 +202,9 @@ extern "C" {
};

enum llama_split_mode {
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
};

// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
Expand Down Expand Up @@ -872,7 +872,8 @@ extern "C" {

// Get the embeddings for a sequence id
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
// shape: [n_embd] (1-dimensional)
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
// otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);

//
Expand Down
10 changes: 8 additions & 2 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17009,7 +17009,7 @@ static int llama_decode_internal(
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rank score - a single float per sequence
// extract the rerank score - a single float per sequence
auto & embd_seq_out = lctx.embd_seq;

for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
Expand Down Expand Up @@ -17211,7 +17211,6 @@ static int llama_encode_internal(
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_RANK:
{
// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
Expand All @@ -17228,6 +17227,13 @@ static int llama_encode_internal(
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
// wait for an encoder model that requires this pooling type in order to test it
// https://github.com/ggerganov/llama.cpp/pull/9510
GGML_ABORT("RANK pooling not implemented yet");
}
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
Expand Down

0 comments on commit 62a45d1

Please sign in to comment.