-
Notifications
You must be signed in to change notification settings - Fork 10.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow pooled embeddings on any model #7477
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting and seems like this is going to be useful. Otherwise only BERT was affected by pooling types.
llama.cpp
Outdated
struct ggml_tensor * inp = gf->nodes[gf->n_nodes - 1]; | ||
if (strcmp(inp->name, "result_embd") != 0) { | ||
inp = gf->nodes[gf->n_nodes - 2]; | ||
GGML_ASSERT(strcmp(inp->name, "result_norm") == 0 && "embeddings tensor not found"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This probably won't work for Grok, Phi 2, MiniCPM, and Command R, as their "result_norm" is the 3rd (or sometimes 4th for Command R) last tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, brining back the backwards search for result_norm
.
} else if (!hparams.causal_attn) { | ||
res = nullptr; // do not extract logits for embedding models such as BERT | ||
|
||
// token or sequence embeddings | ||
embd = gf->nodes[gf->n_nodes - 1]; | ||
|
||
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0); | ||
} else if (cparams.embeddings) { | ||
// the embeddings could be in the second to last tensor, or any of the previous tensors | ||
int i_embd = gf->n_nodes - 2; | ||
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) { | ||
i_embd = gf->n_nodes - i; | ||
if (i_embd < 0) { break; } | ||
embd = gf->nodes[i_embd]; | ||
} | ||
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor"); | ||
|
||
// TODO: use a per-batch flag to know when to skip logits while keeping embeddings | ||
if (!cparams.causal_attn) { | ||
res = nullptr; // do not extract logits when not needed | ||
// skip computing logits | ||
// TODO: is this safe? | ||
gf->n_nodes = i_embd + 1; | ||
res = nullptr; // do not extract logits for embedding case | ||
embd = gf->nodes[gf->n_nodes - 1]; | ||
if (strcmp(embd->name, "result_embd_pooled") != 0) { | ||
embd = gf->nodes[gf->n_nodes - 2]; | ||
} | ||
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor"); | ||
} else { | ||
embd = nullptr; // do not extract embeddings when not needed | ||
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So an embeddings model will crash on the first decode when cparams.embeddings
is set to false
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, though I can't think of any case where you'd use an embedding model without cparams.embeddings
. I guess there's nothing really indicating something is an embedding model other than the lack of a result_output
tensor, so it's hard to intercept this earlier and give an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, hparams.causal_attn
is false
for BERT at least, and it's the only embedding-only model architecture currently in llama.cpp
. All BERT-like architectures also set this key to false
when converted to GGUF. It's true
by default, and by extension, for all other models.
There might be a need for a dedicated metadata key-value pair for embedding-only models if non-causal text generation models are a thing. (T5? Or is it causal?) Anyway, cparams.causal_attn
can be used to get non-causal attention with any model, I think (I did not test this), except for recurrent models (Mamba).
I think there should at least be some abstraction (exported in llama.h
) to know whether or not a model can provide embeddings and/or logits. This would make things like #7448 easier, even if it initially relies on hparams.causal_attn
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, so at least for now, it looks like hparams.causal_attn
is a good indicator of whether a model is embedding-only. And I can't imagine a generative model with non-causal attention. I think T5 is causal, at least for the decoder part.
Then I guess we want to assert hparams.causal_attn || cparams.embeddings
at some point. That way we don't have to worry about divergence and the error is caught earlier.
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); | ||
if (embd == NULL) { | ||
embd = llama_get_embeddings_ith(ctx, i); | ||
if (embd == NULL) { | ||
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i); | ||
continue; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove support for LLAMA_POOLING_TYPE_NONE
in the embedding
example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly because we're not actually printing out the entire token level embeddings anyway. The way it was implemented before was essentially doing last token pooling (not necessarily the last position in the sequence though, just the last one in the order the batch was loaded), but now that last token pooling is an official option, may as well encourage the user to make that choice conciously.
// no output | ||
res = nullptr; | ||
embd = nullptr; | ||
} else if (!hparams.causal_attn) { | ||
res = nullptr; // do not extract logits for embedding models such as BERT | ||
|
||
// token or sequence embeddings | ||
embd = gf->nodes[gf->n_nodes - 1]; | ||
|
||
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0); | ||
} else if (cparams.embeddings) { | ||
// the embeddings could be in the second to last tensor, or any of the previous tensors | ||
int i_embd = gf->n_nodes - 2; | ||
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) { | ||
i_embd = gf->n_nodes - i; | ||
if (i_embd < 0) { break; } | ||
embd = gf->nodes[i_embd]; | ||
} | ||
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor"); | ||
|
||
// TODO: use a per-batch flag to know when to skip logits while keeping embeddings | ||
if (!cparams.causal_attn) { | ||
res = nullptr; // do not extract logits when not needed | ||
// skip computing logits | ||
// TODO: is this safe? | ||
gf->n_nodes = i_embd + 1; | ||
res = nullptr; // do not extract logits for embedding case | ||
embd = gf->nodes[gf->n_nodes - 1]; | ||
if (strcmp(embd->name, "result_embd_pooled") != 0) { | ||
embd = gf->nodes[gf->n_nodes - 2]; | ||
} | ||
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor"); | ||
} else { | ||
embd = nullptr; // do not extract embeddings when not needed | ||
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are places that need to know when embeddings or logits will be output, like llama_output_reserve
Lines 11064 to 11065 in cd93a28
const bool has_logits = cparams.causal_attn; | |
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); |
This will need to be updated to reflect exactly how this affects what happens later in this function near the comments // extract logits
and // extract embeddings
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So can we get away with saying you're either getting logits or embeddings but never both, and that behavior is exclusively controlled by cparams.embeddings
? In that case we could just have
const bool has_logits = !cparams.embeddings;
const bool has_embd = cparams.embeddings;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I can't really think of a use-case where both would be needed at the same time. Except maybe for a server
serving both completions and embeddings out of the same model. So that's something to consider.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, but for a given call to llama_decode
presumably you would never want both. For the gritlm
example, I actually just made two contexts, one for generation one for embeddings. Another option would be to add a llama_set_embeddings
function.
llama.cpp
Outdated
const bool has_logits = cparams.causal_attn; | ||
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); | ||
const bool has_logits = !cparams.embeddings; | ||
const bool has_embd = cparams.embeddings; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that lctx.embd
is not used by all pooled embeddings types, it's really only used with LLAMA_POOLING_TYPE_NONE
.
(This is all done near the end of llama_decode_internal
in a switch
statement)
So maybe the condition for has_embd
could be cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE
?
(See also the other places where hparams.causal_attn
was used to understand the assumptions that stem from it, to check if they need to be modified)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, that makes sense. I put in the pooling_type
check you suggested in there. I also changed the inp_out_ids
calculation to rely on cparams.embeddings
rather than hparams.causal_attn
.
llama.h
Outdated
@@ -275,7 +282,7 @@ extern "C" { | |||
|
|||
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` | |||
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id | |||
// (ignored if no pooling layer) | |||
enum llama_attention_type attention_type; // causal, non-causal, or unspecified |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could be missing something, but it seems that attention_type
does not bring any value over the existing (h/c)params.causal_attn
+ llama_set_causal_attn()
. Do we need both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, actually I don't think it gives you any new capabilities. Perhaps it's best to keep it the way it is and avoid breaking changes. Will switch back!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only inconvenience is that it makes it slightly awkward to specify as an CLI flag in the examples.
One option would be to have both attention_type
in the constructor and llama_set_causal_attn
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, took out the attention_type
stuff, should work now. If server tests are failing, does that mean I have to rebase on master?
… last token pooling; update examples
@ngxson Just rebased onto master. Let's see if that server error persists. I'm on a weak laptop and internet connection right now, but will double check things later today. |
llama.cpp
Outdated
@@ -11754,7 +11779,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | |||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); | |||
} | |||
|
|||
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { | |||
if (!cparams.embeddings || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. Some outputs should still be skipped when embedding for some of the pooling types, no?
This will cause use of uninitialized lctx.inp_out_ids
when embedding with non-Bert models with pooling types other than NONE.
This condition was there originally for how BERT managed output skipping.
Line 8534 in f8ec887
if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { |
Since batch.logits
is likely correctly set when using pooled embeddings (at least, how you wrote them seems correct), then should this condition instead always be true?
And if that is done, then inp_cls
would be redundant, since the correct rows would already be the only thing left.
Might be out of scope for this PR. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that makes sense. I'm guessing we want to avoid putting a pooling_type == LLAMA_POOLING_TYPE_NONE
in every single other model? In that case, I guess we have to actually require all logits be set when getting non-NONE embeddings from non-Bert models. The downside is that it results in a needless get_rows
on all the outputs.
In fact, it seems like batch.logits
isn't really used when pooling_type
is not NONE
, since we use all the outputs and the results are stored in embd_seq_out
. Or actually, all that's currently required is that at least one logit is requested so you go down the right branch when we check if lctx.n_outputs == 0
in llama_decode_internal
. It seems like in this case we might want to officially ignore batch.logits
and give priority to cparams.embeddings
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the simpler way to fix this in the meantime is to make n_outputs == n_tokens_all
in llama_decode_internal
for all non-NONE pooling types when cparams.embeddings
is true, even when batch.logits
is set. This would then re-use the same logic as logits_all
in the other places that use n_outputs
.
But I think the CLS and LAST pooling types could eventually skip computing the embeddings they don't need (but it's not necessary to do this in this PR).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I think this should do it. Basically bypass logits when doing non-NONE embeddings. Note that I'm using hparams.causal_attn
to decide if we're in a BERT model or not in llama_set_inputs
.
@compilade sorry to ping you again, but I think this is ready to go. I believe the issues with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry to ping you again
Pinging is the right thing, because otherwise I tend to forget to go back and re-review, unless recent activity catches my attention enough to have another look at the changes.
I believe the issues with
n_outputs
are sorted
I believe this as well.
@@ -44,6 +44,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve | |||
|
|||
// clear previous kv_cache values (irrelevant for embeddings) | |||
llama_kv_cache_clear(ctx); | |||
llama_set_embeddings(ctx, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a small question here: in the case when both embeddings
and causal_attn
are enabled, will it still be correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, it's possible to run with embeddings=true
and causal_attn=true
, as long as the underlying model supports causal attention. For the GritLM case, I just checked here, and it will run but give incorrect results since it expects the embeddings to be run non-causally.
* create append_pooling operation; allow to specify attention_type; add last token pooling; update examples * find result_norm/result_embd tensors properly; update output allocation logic * only use embd output for pooling_type NONE * get rid of old causal_attn accessor * take out attention_type; add in llama_set_embeddings * bypass logits when doing non-NONE pooling
This allows one to compute pooled embeddings on any model, not just classical embedding models. This is increasingly useful due to the rise of generative-type models in embedding benchmarks (most recently,
gte-Qwen1.5-7B-instruct
). The main changes are:append_pooling
function tollm_build_context
that grafts a pooling layer onto the last tensor of an existing graph. This makes some assumptions about how the underlying graph is laid out, but we're already doing that in a couple of places, and there are tensor name checks too.LLAMA_POOLING_TYPE_LAST
pooling type since this is a common type of pooling used with generative models. Works very similarly to CLS pooling.embedding
/retreival
examples to request correct logits depending onpooling_type
.