Skip to content

Commit

Permalink
added missing features
Browse files Browse the repository at this point in the history
  • Loading branch information
l3utterfly committed Sep 8, 2024
1 parent 0bf5e80 commit dd3efc2
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 2 deletions.
48 changes: 47 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ struct ring_buffer {
return value;
}

T pop_back() {
if (sz == 0) {
throw std::runtime_error("ring buffer is empty");
}
pos = (pos + capacity - 1) % capacity;
T value = data[pos];
sz--;
return value;
}

const T & rat(size_t i) const {
if (i >= sz) {
throw std::runtime_error("ring buffer: index out of bounds");
Expand Down Expand Up @@ -165,6 +175,12 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
params.penalty_repeat,
params.penalty_freq,
params.penalty_present,
params.dry_penalty_last_n,
params.dry_base,
params.dry_multiplier,
params.dry_allowed_length,
params.dry_seq_breakers.data(),
params.dry_seq_breakers.size(),
params.penalize_nl,
params.ignore_eos));

Expand Down Expand Up @@ -239,6 +255,19 @@ void gpt_sampler_reset(struct gpt_sampler * gsmpl) {
llama_sampler_reset(gsmpl->chain);
}

void gpt_sampler_reset_grmr(struct gpt_sampler * gsmpl) {
llama_sampler_reset(gsmpl->grmr);
}

void gpt_sampler_reinit_grmr(struct gpt_sampler * gsmpl, const struct llama_model * model, std::string grammar) {
// free first
llama_sampler_free(gsmpl->grmr);

// reinit
gsmpl->params.grammar = grammar;
gsmpl->grmr = llama_sampler_init_grammar(model, grammar.c_str(), "root");
}

struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
return new gpt_sampler {
/* .params = */ gsmpl->params,
Expand Down Expand Up @@ -313,6 +342,10 @@ llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl)
return &gsmpl->cur_p;
}

std::vector<llama_token> gpt_sampler_get_prev(struct gpt_sampler * gsmpl) {
return gsmpl->prev.to_vector();
}

llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
return gsmpl->prev.rat(0);
}
Expand Down Expand Up @@ -440,4 +473,17 @@ std::vector<gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & c
}

return samplers;
}
}

void gpt_sampler_rollback(
gpt_sampler * gsmpl,
int rollback_num) {
if(rollback_num > gsmpl->prev.size()) {
rollback_num = gsmpl->prev.size();
}

// continuously pop the last token
for(int i = 0; i < rollback_num; i++) {
gsmpl->prev.pop_back();
}
}
11 changes: 10 additions & 1 deletion common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ struct gpt_sampler_params {
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
float dry_base = 1.75f;
uint32_t dry_allowed_length = 2;
std::vector<llama_token> dry_seq_breakers;
uint32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
Expand Down Expand Up @@ -93,6 +98,8 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl);
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
void gpt_sampler_reset (struct gpt_sampler * gsmpl);
void gpt_sampler_reset_grmr(struct gpt_sampler * gsmpl);
void gpt_sampler_reinit_grmr(struct gpt_sampler * gsmpl, const struct llama_model * model, std::string grammar);
struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl);

// arguments can be nullptr to skip printing
Expand All @@ -114,6 +121,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context

// access the internal list of current candidate tokens
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
std::vector<llama_token> gpt_sampler_get_prev(struct gpt_sampler * gsmpl);

// get the last accepted token
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
Expand All @@ -128,4 +136,5 @@ char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr);
std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr);

std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);
void gpt_sampler_rollback(gpt_sampler * gsmpl, int rollback_num);
6 changes: 6 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,12 @@ extern "C" {
float penalty_repeat, // 1.0 = disabled
float penalty_freq, // 0.0 = disabled
float penalty_present, // 0.0 = disabled
uint32_t dry_penalty_last_n,
float dry_base,
float dry_multiplier,
float dry_allowed_length,
const llama_token* dry_seq_breakers,
size_t dry_seq_breakers_size,
bool penalize_nl, // consider newlines as a repeatable token
bool ignore_eos); // ignore the end-of-sequence token

Expand Down
134 changes: 134 additions & 0 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <numeric>
#include <random>
#include <unordered_map>
#include <vector>

static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float> & probs) {
probs.resize(cur_p->size);
Expand Down Expand Up @@ -433,6 +434,104 @@ void llama_sampler_penalties_impl(
cur_p->sorted = false;
}

void llama_sampler_dry_impl(
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t last_tokens_size,
float dry_base,
float dry_multiplier,
int dry_allowed_length,
const llama_token * dry_seq_breakers,
size_t dry_seq_breakers_size) {
// skip dry sampler if we don't have a previous token
if (last_tokens_size < 1) return;

// get the last token
auto last_token = last_tokens[last_tokens_size - 1];

// if last token is part of the sequence breakers, skip whole sampler
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) {
return;
}

// create an unordered map of "next tokens" <-> max match length
std::unordered_map<llama_token, size_t> match_lengths;

// loop through each previous token (exclude the last token)
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
// skip if the compare token is not the same as the last token
if (last_tokens[i] != last_token) {
continue;
}

// get the next token (i + 1 is always less than last_tokens_size)
auto next_token = last_tokens[i + 1];

// if next token is part of the sequence breakers, skip
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) {
continue;
}

// try to extend the match backwards (match length starts at 1 because last token is already matched)
size_t match_length = 1;

// loop through the previous tokens
for (;; match_length++) {
// if we have reached the start of our last tokens, break
if (i < match_length) break;

// compare token starts at our prev index, going backwards by match length
auto compare_token = last_tokens[i - match_length];

// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
auto head_token = last_tokens[last_tokens_size - 1 - match_length];

// break out of the match if any tokens don't match
if (compare_token != head_token) {
break;
}

// if compare token is part of the sequence breakers, break out of the match
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) {
break;
}
}

// Check if the next token exists in the map
auto it = match_lengths.find(next_token);

if (it == match_lengths.end()) {
// Key does not exist, insert the new value
match_lengths[next_token] = match_length;
} else {
// Key exists, update it with the max of the new value or the existing value
it->second = std::max(it->second, match_length);
}
}

// apply penalties
for (const auto& pair : match_lengths) {
auto next_token = pair.first;
auto match_length = pair.second;

// if the match length is greater than or equal to our allowed length in config, we apply penalities
if (match_length >= dry_allowed_length) {

// find our next token in the candidates->data
for (size_t i = 0; i < candidates->size; ++i) {
if (candidates->data[i].id == next_token) {
// calculate the penalty
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);

// apply the dry penalty
candidates->data[i].logit -= penalty;
break;
}
}
}
}
}

// llama_sampler API

const char * llama_sampler_name(const struct llama_sampler * smpl) {
Expand Down Expand Up @@ -1216,6 +1315,12 @@ struct llama_sampler_penalties {
const float penalty_freq;
const float penalty_present;

const uint32_t dry_penalty_last_n;
const float dry_base;
const float dry_multiplier;
const float dry_allowed_length;
std::vector<llama_token> dry_seq_breakers;

const bool penalize_nl;
const bool ignore_eos;

Expand Down Expand Up @@ -1286,8 +1391,20 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
token_count[ctx->prev.rat(i)]++;
}

// apply repetition, frequency, and presence penalties
llama_sampler_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present);

// make the ring buffer of last tokens into a vector
auto last_tokens = ctx->prev.to_vector();

// take the last n tokens from the ring buffer
if (last_tokens.size() > (size_t) ctx->dry_penalty_last_n) {
last_tokens.erase(last_tokens.begin(), last_tokens.end() - ctx->penalty_last_n);
}

// apply DRY penalty
llama_sampler_dry_impl(cur_p, last_tokens.data(), last_tokens.size(), ctx->dry_base, ctx->dry_multiplier, ctx->dry_allowed_length, ctx->dry_seq_breakers.data(), ctx->dry_seq_breakers.size());

if (!ctx->penalize_nl && nl_found) {
// restore the logit of the newline token if it was penalized
cur_p->data[nl_idx].logit = nl_logit;
Expand All @@ -1307,6 +1424,12 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
ctx->penalty_repeat,
ctx->penalty_freq,
ctx->penalty_present,
ctx->dry_penalty_last_n,
ctx->dry_base,
ctx->dry_multiplier,
ctx->dry_allowed_length,
ctx->dry_seq_breakers.data(),
ctx->dry_seq_breakers.size(),
ctx->penalize_nl,
ctx->ignore_eos);

Expand All @@ -1332,6 +1455,12 @@ struct llama_sampler * llama_sampler_init_penalties(
float penalty_repeat,
float penalty_freq,
float penalty_present,
uint32_t dry_penalty_last_n,
float dry_base,
float dry_multiplier,
float dry_allowed_length,
const llama_token* dry_seq_breakers,
size_t dry_seq_breakers_size,
bool penalize_nl,
bool ignore_eos) {
if (linefeed_id == LLAMA_TOKEN_NULL) {
Expand All @@ -1352,6 +1481,11 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .penalty_repeat = */ penalty_repeat,
/* .penalty_freq = */ penalty_freq,
/* .penalty_present = */ penalty_present,
/* .dry_penalty_last_n = */ dry_penalty_last_n,
/* .dry_base = */ dry_base,
/* .dry_multiplier = */ dry_multiplier,
/* .dry_allowed_length = */ dry_allowed_length,
/* .dry_seq_breakers = */ std::vector<llama_token>(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size),
/* .penalize_nl = */ penalize_nl,
/* .ignore_eos = */ ignore_eos,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
Expand Down

0 comments on commit dd3efc2

Please sign in to comment.