Skip to content

Commit

Permalink
embedding : show full embedding for single prompt (ggerganov#6342)
Browse files Browse the repository at this point in the history
* embedding : show full embedding for single prompt

To support the use case of creating an embedding for a given prompt, the entire embedding and not just the first part needed to be printed.

Also, show cosine similarity matrix only if there is more than one prompt, as the cosine similarity matrix for a single prompt is always `1.00`.

* Update examples/embedding/embedding.cpp

---------

Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
howlger and ggerganov authored Mar 27, 2024
1 parent e82f9e2 commit 1e13987
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,25 +178,27 @@ int main(int argc, char ** argv) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);

// print the first part of the embeddings
// print the first part of the embeddings or for a single prompt, the full embedding
fprintf(stdout, "\n");
for (int j = 0; j < n_prompts; j++) {
fprintf(stdout, "embedding %d: ", j);
for (int i = 0; i < std::min(16, n_embd); i++) {
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
}
fprintf(stdout, "\n");
}

// print cosine similarity matrix
fprintf(stdout, "\n");
printf("cosine similarity matrix:\n\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f ", sim);
}
if (n_prompts > 1) {
fprintf(stdout, "\n");
printf("cosine similarity matrix:\n\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f ", sim);
}
fprintf(stdout, "\n");
}
}

// clean up
Expand Down

0 comments on commit 1e13987

Please sign in to comment.