Skip to content

Commit

Permalink
Updated names and comments for some sampling params
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyrcaxis committed Nov 24, 2024
1 parent 77bfb56 commit b943d1d
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
8 changes: 4 additions & 4 deletions LLama.KernelMemory/LlamaSharpTextGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In
SamplingPipeline = new DefaultSamplingPipeline()
{
Temperature = (float)options.Temperature,
AlphaFrequency = (float)options.FrequencyPenalty,
AlphaPresence = (float)options.PresencePenalty,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling,
}
};
Expand All @@ -107,8 +107,8 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In
SamplingPipeline = new DefaultSamplingPipeline()
{
Temperature = (float)options.Temperature,
AlphaFrequency = (float)options.FrequencyPenalty,
AlphaPresence = (float)options.PresencePenalty,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling,
}
};
Expand Down
4 changes: 2 additions & 2 deletions LLama.SemanticKernel/ExtensionMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ internal static LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LL
{
Temperature = (float)requestSettings.Temperature,
TopP = (float)requestSettings.TopP,
AlphaPresence = (float)requestSettings.PresencePenalty,
AlphaFrequency = (float)requestSettings.FrequencyPenalty,
PresencePenalty = (float)requestSettings.PresencePenalty,
FrequencyPenalty = (float)requestSettings.FrequencyPenalty,
}
};
}
Expand Down
6 changes: 3 additions & 3 deletions LLama/Extensions/LLamaExecutorExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ private string CreatePrompt(IList<ChatMessage> messages)
MaxTokens = options?.MaxOutputTokens ?? 256, // arbitrary upper limit
SamplingPipeline = new DefaultSamplingPipeline()
{
AlphaFrequency = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaFrequency), out float af) is true ? af : s_defaultPipeline.AlphaFrequency,
AlphaPresence = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaPresence), out float ap) is true ? ap : s_defaultPipeline.AlphaPresence,
PenalizeEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeEOS), out bool eos) is true ? eos : s_defaultPipeline.PenalizeEOS,
FrequencyPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.FrequencyPenalty), out float af) is true ? af : s_defaultPipeline.FrequencyPenalty,
PresencePenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PresencePenalty), out float ap) is true ? ap : s_defaultPipeline.PresencePenalty,
PreventEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PreventEOS), out bool eos) is true ? eos : s_defaultPipeline.PreventEOS,
PenalizeNewline = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeNewline), out bool pnl) is true ? pnl : s_defaultPipeline.PenalizeNewline,
RepeatPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenalty), out float rp) is true ? rp : s_defaultPipeline.RepeatPenalty,
RepeatPenaltyCount = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenaltyCount), out int rpc) is true ? rpc : s_defaultPipeline.RepeatPenaltyCount,
Expand Down
20 changes: 10 additions & 10 deletions LLama/Sampling/DefaultSamplingPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ public sealed class DefaultSamplingPipeline
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
/// </summary>
public float AlphaFrequency
public float FrequencyPenalty
{
get => _alphaFreq;
init
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(FrequencyPenalty)} must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(FrequencyPenalty)} must be less than 2");
_alphaFreq = value;
}
}
Expand All @@ -44,15 +44,15 @@ public float AlphaFrequency
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
/// text so far, increasing the model's likelihood to talk about new topics.
/// </summary>
public float AlphaPresence
public float PresencePenalty
{
get => _alphaPresence;
init
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(PresencePenalty)} must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(PresencePenalty)} must be less than 2");
_alphaPresence = value;
}
}
Expand All @@ -69,9 +69,9 @@ public float AlphaPresence
public bool PenalizeNewline { get; init; } = false;

/// <summary>
/// Whether the EOS token should be protected from being modified by penalty
/// Whether the EOS token should be suppressed. Setting this to 'true' prevents EOS from being sampled
/// </summary>
public bool PenalizeEOS { get; init; } = false;
public bool PreventEOS { get; init; } = false;

/// <summary>
/// Temperature to apply (higher temperature is more "creative")
Expand Down Expand Up @@ -147,8 +147,8 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl
context.VocabCount,
context.ModelHandle.Tokens.EOS, context.ModelHandle.Tokens.Newline ?? 0,
RepeatPenaltyCount, RepeatPenalty,
AlphaFrequency, AlphaPresence,
PenalizeNewline, PenalizeEOS
FrequencyPenalty, PresencePenalty,
PenalizeNewline, PreventEOS
);

chain.AddTopK(TopK);
Expand Down

0 comments on commit b943d1d

Please sign in to comment.