From ffdfd13abe602a721716d7cb15809bae5293b4ce Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 31 Jan 2024 00:54:24 +0900 Subject: [PATCH] updated sampling api --- common/sampling.cpp | 19 +++++++++++++++++-- common/sampling.h | 7 +++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index b2a249ad4d262..69bb6be74b2c4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -36,10 +36,10 @@ void llama_sampling_free(struct llama_sampling_context * ctx) { delete ctx; } -void llama_sampling_reset(llama_sampling_context * ctx) { +void llama_sampling_reset_grammar(struct llama_sampling_context * ctx) { if (ctx->grammar != NULL) { llama_grammar_free(ctx->grammar); - ctx->grammar = NULL; + ctx->grammar = nullptr; } if (!ctx->parsed_grammar.rules.empty()) { @@ -49,6 +49,10 @@ void llama_sampling_reset(llama_sampling_context * ctx) { grammar_rules.data(), grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); } +} + +void llama_sampling_reset(llama_sampling_context * ctx) { + llama_sampling_reset_grammar(ctx); std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); @@ -331,3 +335,14 @@ void llama_sampling_accept( llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); } } + + +void llama_sampling_rollback( + struct llama_sampling_context * ctx_sampling, + int rollback_num) { + if(rollback_num > ctx_sampling->prev.size()) { + rollback_num = ctx_sampling->prev.size(); + } + + ctx_sampling->prev.erase(ctx_sampling->prev.end() - rollback_num, ctx_sampling->prev.end()); +} \ No newline at end of file diff --git a/common/sampling.h b/common/sampling.h index 78ced2c5fead2..a5277dc0107d6 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -68,6 +68,9 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ void llama_sampling_free(struct llama_sampling_context * ctx); +// Reset the sampler grammar without resetting the context +void llama_sampling_reset_grammar(struct llama_sampling_context * ctx); + // Reset the sampler context // - clear prev tokens // - reset grammar @@ -116,3 +119,7 @@ void llama_sampling_accept( struct llama_context * ctx_main, llama_token id, bool apply_grammar); + +void llama_sampling_rollback( + struct llama_sampling_context * ctx_sampling, + int rollback_num);