Skip to content

Commit

Permalink
Merge pull request #986 from martindevans/logit_bias
Browse files Browse the repository at this point in the history
Implemented `LogitBias` for `DefaultSamplingPipeline`
  • Loading branch information
martindevans authored Nov 15, 2024
2 parents f68c1f1 + 07ec3fc commit 1e1a131
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions LLama/Sampling/DefaultSamplingPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,28 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl
{
var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());

// Rent a temporary array and copy the biases into it
var biases = ArrayPool<LLamaLogitBias>.Shared.Rent(LogitBias.Count);
try
{
var index = 0;
foreach (var bias in LogitBias)
{
biases[index++] = new LLamaLogitBias
{
Token = bias.Key,
Bias = bias.Value
};
}

// Add the biases to the sampler
chain.AddLogitBias(context.ModelHandle.VocabCount, biases.AsSpan(0, LogitBias.Count));
}
finally
{
ArrayPool<LLamaLogitBias>.Shared.Return(biases);
}

if (Grammar != null)
chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root);

Expand Down

0 comments on commit 1e1a131

Please sign in to comment.