From 07ec3fc9caf0e174b5bd9d372a19bce8df113bae Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 15 Nov 2024 02:33:05 +0000 Subject: [PATCH] Implemented `LogitBias` for `DefaultSamplingPipeline` --- LLama/Sampling/DefaultSamplingPipeline.cs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index ee339be1f..639a87b59 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -118,6 +118,28 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl { var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default()); + // Rent a temporary array and copy the biases into it + var biases = ArrayPool.Shared.Rent(LogitBias.Count); + try + { + var index = 0; + foreach (var bias in LogitBias) + { + biases[index++] = new LLamaLogitBias + { + Token = bias.Key, + Bias = bias.Value + }; + } + + // Add the biases to the sampler + chain.AddLogitBias(context.ModelHandle.VocabCount, biases.AsSpan(0, LogitBias.Count)); + } + finally + { + ArrayPool.Shared.Return(biases); + } + if (Grammar != null) chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root);