Skip to content

Commit

Permalink
updated sampling api
Browse files Browse the repository at this point in the history
  • Loading branch information
l3utterfly committed Jan 30, 2024
1 parent 83dd88a commit ffdfd13
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
19 changes: 17 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
delete ctx;
}

void llama_sampling_reset(llama_sampling_context * ctx) {
void llama_sampling_reset_grammar(struct llama_sampling_context * ctx) {
if (ctx->grammar != NULL) {
llama_grammar_free(ctx->grammar);
ctx->grammar = NULL;
ctx->grammar = nullptr;
}

if (!ctx->parsed_grammar.rules.empty()) {
Expand All @@ -49,6 +49,10 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
grammar_rules.data(),
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
}
}

void llama_sampling_reset(llama_sampling_context * ctx) {
llama_sampling_reset_grammar(ctx);

std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
Expand Down Expand Up @@ -331,3 +335,14 @@ void llama_sampling_accept(
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
}
}


void llama_sampling_rollback(
struct llama_sampling_context * ctx_sampling,
int rollback_num) {
if(rollback_num > ctx_sampling->prev.size()) {
rollback_num = ctx_sampling->prev.size();
}

ctx_sampling->prev.erase(ctx_sampling->prev.end() - rollback_num, ctx_sampling->prev.end());
}
7 changes: 7 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_

void llama_sampling_free(struct llama_sampling_context * ctx);

// Reset the sampler grammar without resetting the context
void llama_sampling_reset_grammar(struct llama_sampling_context * ctx);

// Reset the sampler context
// - clear prev tokens
// - reset grammar
Expand Down Expand Up @@ -116,3 +119,7 @@ void llama_sampling_accept(
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar);

void llama_sampling_rollback(
struct llama_sampling_context * ctx_sampling,
int rollback_num);

0 comments on commit ffdfd13

Please sign in to comment.