Skip to content

Commit

Permalink
lookup : add prompt lookup decoding example (ggerganov#4484)
Browse files Browse the repository at this point in the history
* initial commit, going through initializations

* main loop finished, starting to debug

* BUG: generates gibberish/repeating tokens after a while

* kv_cache management

* Added colors to distinguish drafted tokens (--color). Updated README

* lookup : fix token positions in the draft batch

* lookup : use n_draft from CLI params

* lookup : final touches

---------

Co-authored-by: Leon Ericsson <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
3 people authored Dec 22, 2023
1 parent ba66175 commit 7082d24
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ models-mnt
/llama-bench
/llava-cli
/lookahead
/lookup
/main
/metal
/perplexity
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
BUILD_TARGETS = \
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead tests/test-c.o
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup tests/test-c.o

# Binaries only useful for tests
TEST_TARGETS = \
Expand Down Expand Up @@ -664,6 +664,9 @@ parallel: examples/parallel/parallel.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

ifdef LLAMA_METAL
metal: examples/metal/metal.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct gpt_params {
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_draft = 16; // number of tokens to draft during speculative decoding
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
Expand Down Expand Up @@ -240,3 +240,4 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);

// Dump the KV cache view showing individual sequences in each cell (long output).
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);

1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ else()
add_subdirectory(simple)
add_subdirectory(speculative)
add_subdirectory(lookahead)
add_subdirectory(lookup)
add_subdirectory(train-text-from-scratch)
if (LLAMA_METAL)
add_subdirectory(metal)
Expand Down
5 changes: 5 additions & 0 deletions examples/lookup/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET lookup)
add_executable(${TARGET} lookup.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
13 changes: 13 additions & 0 deletions examples/lookup/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# llama.cpp/examples/lookup

Demonstration of Prompt Lookup Decoding

https://github.com/apoorvumang/prompt-lookup-decoding

The key parameters for lookup decoding are `ngram_min`, `ngram_max` and `n_draft`. The first two determine the size of the ngrams to search for in the prompt for a match. The latter specifies how many subsequent tokens to draft if a match is found.

More info:

https://github.com/ggerganov/llama.cpp/pull/4484
https://github.com/ggerganov/llama.cpp/issues/4226

230 changes: 230 additions & 0 deletions examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
#include "common.h"
#include "llama.h"

#include <cmath>
#include <cstdio>
#include <string>
#include <vector>

int main(int argc, char ** argv){
gpt_params params;

if (!gpt_params_parse(argc, argv, params)) {
return 1;
}

// max/min n-grams size to search for in prompt
const int ngram_max = 4;
const int ngram_min = 1;

// length of the candidate / draft sequence, if match is found
const int n_draft = params.n_draft;

const bool dump_kv_cache = params.dump_kv_cache;

#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("lookup", "log"));
LOG_TEE("Log start\n");
log_dump_cmdline(argc, argv);
#endif // LOG_DISABLE_LOGS

// init llama.cpp
llama_backend_init(params.numa);

llama_model * model = NULL;
llama_context * ctx = NULL;

// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);

// tokenize the prompt
const bool add_bos = llama_should_add_bos_token(model);
LOG("add_bos tgt: %d\n", add_bos);

std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);

const int max_context_size = llama_n_ctx(ctx);
const int max_tokens_list_size = max_context_size - 4;

if ((int) inp.size() > max_tokens_list_size) {
fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
return 1;
}

fprintf(stderr, "\n\n");

for (auto id : inp) {
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
}

fflush(stderr);

const int n_input = inp.size();

const auto t_enc_start = ggml_time_us();

llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));

const auto t_enc_end = ggml_time_us();

int n_predict = 0;
int n_drafted = 0;
int n_accept = 0;

int n_past = inp.size();

bool has_eos = false;

struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);

std::vector<llama_token> draft;

llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);

// debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);

const auto t_dec_start = ggml_time_us();

while (true) {
// debug
if (dump_kv_cache) {
llama_kv_cache_view_update(ctx, &kvc_view);
dump_kv_cache_view_seqs(kvc_view, 40);
}

// print current draft sequence
LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str());

int i_dft = 0;
while (true) {
// sample from the target model
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);

llama_sampling_accept(ctx_sampling, ctx, id, true);

const std::string token_str = llama_token_to_piece(ctx, id);

if (!params.use_color) {
printf("%s", token_str.c_str());
}

if (id == llama_token_eos(model)) {
has_eos = true;
}

++n_predict;

// check if the target token matches the draft
if (i_dft < (int) draft.size() && id == draft[i_dft]) {
LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str());
++n_accept;
++n_past;
++i_dft;
inp.push_back(id);

if (params.use_color) {
// color accepted draft token
printf("\033[34m%s\033[0m", token_str.c_str());
fflush(stdout);
}
continue;
}

if (params.use_color) {
printf("%s", token_str.c_str());
}
fflush(stdout);


LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());

draft.clear();
draft.push_back(id);
inp.push_back(id);
break;
}

if ((params.n_predict > 0 && n_predict > params.n_predict) || has_eos) {
break;
}

// KV cache management
// clean the cache of draft tokens that weren't accepted
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);

llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

// generate n_pred tokens through prompt lookup
auto prompt_lookup = [&]() -> void {
int inp_size = inp.size();
for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
const llama_token * ngram = &inp[inp_size - ngram_size];

for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {
bool match = true;
for (int j = 0; j < ngram_size; ++j) {
if (inp[i + j] != ngram[j]) {
match = false;
break;
}
}

if (match) {
const int startIdx = i + ngram_size;
const int endIdx = startIdx + n_draft;
if (endIdx < inp_size) {
for (int j = startIdx; j < endIdx; ++j) {
LOG(" - draft candidate %d: %d\n", j, inp[j]);
draft.push_back(inp[j]);
llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true);
++n_drafted;
}
return;
}
}
}
}
return;
};

prompt_lookup();

llama_decode(ctx, batch_tgt);
++n_past;

draft.erase(draft.begin());
}

auto t_dec_end = ggml_time_us();

LOG_TEE("\n\n");

LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));

LOG_TEE("\n");
LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_predict = %d\n", n_predict);
LOG_TEE("n_drafted = %d\n", n_drafted);
LOG_TEE("n_accept = %d\n", n_accept);
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);

LOG_TEE("\ntarget:\n");
llama_print_timings(ctx);

llama_sampling_free(ctx_sampling);
llama_batch_free(batch_tgt);

llama_free(ctx);
llama_free_model(model);

llama_backend_free();

fprintf(stderr, "\n\n");

return 0;
}

0 comments on commit 7082d24

Please sign in to comment.