Skip to content

Commit

Permalink
fix: possible incorrect seed values
Browse files Browse the repository at this point in the history
  • Loading branch information
Vali-98 committed May 17, 2024
1 parent 2ce8de1 commit b2bb338
Showing 1 changed file with 30 additions and 23 deletions.
53 changes: 30 additions & 23 deletions constants/Inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,27 @@ const constructStopSequence = (instruct: InstructType): Array<string> => {
return sequence
}

const getRandomSeed = () => {
return Math.floor(Math.random() * 99999)
}

const getSeed = (seed: string | number | undefined): number => {
if (!seed) return getRandomSeed()

if (typeof seed === 'string') {
const newSeed = parseInt(seed)
if (newSeed !== -1) return newSeed
return getRandomSeed()
}

if (typeof seed === 'number') {
if (seed === -1) return getRandomSeed()
return seed
}

return getRandomSeed()
}

// Payloads

const constructKAIPayload = () => {
Expand All @@ -275,10 +296,7 @@ const constructKAIPayload = () => {
top_p: parseFloat(preset.top_p),
typical: parseFloat(preset.typical),
sampler_order: [6, 0, 1, 3, 4, 2, 5],
sampler_seed:
parseInt(preset.seed) === -1
? Math.floor(Math.random() * 999999)
: parseInt(preset.seed),
sampler_seed: getSeed(preset?.seed),
stop_sequence: constructStopSequence(currentInstruct),
mirostat: parseInt(preset.mirostat_mode),
mirostat_tau: parseFloat(preset.mirostat_tau),
Expand Down Expand Up @@ -385,10 +403,7 @@ const constructTGWUIPayload = () => {
ban_eos_token: preset.ban_eos_token,
skip_special_tokens: preset.skip_special_tokens,
stopping_strings: constructStopSequence(currentInstruct),
seed:
preset?.seed === undefined || preset.seed === -1
? Math.floor(Math.random() * 999999)
: parseInt(preset.seed),
seed: getSeed(preset?.seed),
guidance_scale: preset.guidance_scale,
negative_prompt: preset.negative_prompt,
temperature_last: parseFloat(preset.min_p) !== 1,
Expand Down Expand Up @@ -431,6 +446,7 @@ const constructCompletionsPayload = () => {
const completionsModel = getObject(Global.CompletionsModel)
const preset = getObject(Global.PresetData)
const currentInstruct = Instructs.useInstruct.getState().replacedMacros()

return {
stream: true,
max_context_length: preset.max_length,
Expand All @@ -452,12 +468,9 @@ const constructCompletionsPayload = () => {
mirostat_tau: preset.mirostat_tau,
mirostat_eta: preset.mirostat_eta,
grammar: preset.grammar_string,
seed:
preset?.seed === undefined || preset.seed === -1
? Math.floor(Math.random() * 999999)
: parseInt(preset.seed),
seed: getSeed(preset?.seed),
sampler_order: [6, 0, 1, 3, 4, 2, 5],
stop: ['\n\n\n\n\n', currentInstruct.input_prefix],
stop: constructStopSequence(currentInstruct),
frequency_penalty: preset.freq_pen,
presence_penalty: preset.presence_pen,
smoothing_factor: preset.smoothing_factor,
Expand Down Expand Up @@ -488,26 +501,23 @@ const constructLocalPayload = () => {
tfs_z: preset.tfs,
typical_p: preset.typical,
min_p: preset.min_p,
seed: parseInt(preset.seed) ?? -1,
seed: getSeed(preset?.seed),
}
}

const constructOpenRouterPayload = () => {
const openRouterModel = getObject(Global.OpenRouterModel)
const currentInstruct = Instructs.useInstruct.getState().replacedMacros()
const preset = getObject(Global.PresetData)

console.log(getSeed(preset.seed))
return {
messages: buildChatCompletionContext(openRouterModel.context_length),
model: openRouterModel.id,
frequency_penalty: preset.freq_pen,
max_tokens: preset.genamt,
presence_penalty: preset.presence_pen,
response_format: { type: 'json_object' },
seed:
preset?.seed === undefined || preset.seed === -1
? Math.floor(Math.random() * 999999)
: parseInt(preset.seed),
seed: getSeed(preset?.seed),
stop: constructStopSequence(currentInstruct),
stream: true,
temperature: preset.temp,
Expand All @@ -526,10 +536,7 @@ const constructOpenAIPayload = () => {
max_tokens: preset.genamt,
frequency_penalty: preset.freq_pen,
presence_penalty: preset.presence_pen,
seed:
preset?.seed === undefined || preset.seed === -1
? Math.floor(Math.random() * 999999)
: parseInt(preset.seed),
seed: getSeed(preset?.seed),
stop: constructStopSequence(currentInstruct),
stream: true,
temperature: preset.temp,
Expand Down

0 comments on commit b2bb338

Please sign in to comment.