From c30dfe560c0f661ff2ac5840809f35764d890170 Mon Sep 17 00:00:00 2001 From: Rusyaidi Date: Tue, 6 Aug 2024 15:21:53 +0800 Subject: [PATCH] fix: inaccurate context length calculation when using timestamps and names --- app/components/Endpoint/Local.tsx | 4 +-- app/constants/APIState/BaseAPI.ts | 54 +++++++++++++++++++++++++------ 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/app/components/Endpoint/Local.tsx b/app/components/Endpoint/Local.tsx index 42bdd2d..fcc4106 100644 --- a/app/components/Endpoint/Local.tsx +++ b/app/components/Endpoint/Local.tsx @@ -1,5 +1,5 @@ -import { Llama, LlamaPreset } from 'app/constants/LlamaLocal' import { AppSettings, Global, Logger, Style } from '@globals' +import { Llama, LlamaPreset } from 'app/constants/LlamaLocal' import { useEffect, useState } from 'react' import { View, @@ -227,7 +227,7 @@ const Local = () => { varname="context_length" min={512} max={32768} - step={32} + step={512} /> { const delta = performance.now() + + const tokenizer = + mmkv.getString(Global.APIType) === API.LOCAL + ? Llama.useLlama.getState().tokenLength + : Tokenizer.useTokenizer.getState().getTokenCount + const messages = [...(Chats.useChat.getState().data?.messages ?? [])] const currentInstruct = Instructs.useInstruct.getState().replacedMacros() @@ -71,6 +80,7 @@ export abstract class APIBase implements IAPIBase { const user_card_data = (userCard?.data?.description ?? '').trim() const char_card_data = (currentCard?.data?.description ?? '').trim() let payload = `` + // set suffix length as its always added let payload_length = instructCache.system_suffix_length if (currentInstruct.system_prefix) { @@ -95,41 +105,67 @@ export abstract class APIBase implements IAPIBase { let message_acc_length = 0 let is_last = true let index = messages.length - 1 + + const wrap_string = `\n` + const wrap_length = currentInstruct.wrap ? tokenizer(wrap_string) : 0 + + // we require lengths for names if use_names is enabled for (const message of messages?.reverse() ?? []) { const swipe_len = Chats.useChat.getState().getTokenCount(index) - // for last message, we want to skip the end token to allow the LLM to generate + const swipe_data = message.swipes[message.swipe_id] + + /** Accumulate total string length + * The context builder MUST retain context length below the + * context limit, especially for local gens to prevent truncation + * **/ + let instruct_len = message.is_user ? instructCache.input_prefix_length : instructCache.output_suffix_length + + // for last message, we want to skip the end token to allow the LLM to generate + if (!is_last) instruct_len += message.is_user ? instructCache.input_suffix_length : instructCache.output_suffix_length - const shard_length = swipe_len + instruct_len + + const timestamp_string = `[${swipe_data.send_date.toString().split(' ')[0]} ${swipe_data.send_date.toLocaleTimeString()}]\n` + const timestamp_length = currentInstruct.timestamp ? tokenizer(timestamp_string) : 0 + + const name_string = `${message.name} :` + const name_length = currentInstruct.names ? tokenizer(name_string) : 0 + + const shard_length = + swipe_len + instruct_len + name_length + timestamp_length + wrap_length + + // check if within context window + if (message_acc_length + payload_length + shard_length > max_length) { break } + // apply strings + let message_shard = `${message.is_user ? currentInstruct.input_prefix : currentInstruct.output_prefix}` - const swipe_data = message.swipes[message.swipe_id] + if (currentInstruct.timestamp) message_shard += timestamp_string - if (currentInstruct.timestamp) - message_shard += `[${swipe_data.send_date.toString().split(' ')[0]} ${swipe_data.send_date.toLocaleTimeString()}]\n` - if (currentInstruct.names) message_shard += message.name + ': ' + if (currentInstruct.names) message_shard += name_string message_shard += swipe_data.swipe if (!is_last) { message_shard += `${message.is_user ? currentInstruct.input_suffix : currentInstruct.output_suffix}` } - // ensure no more is_last checks after this - is_last = false if (currentInstruct.wrap) { - message_shard += `\n` + message_shard += wrap_string } + // ensure no more is_last checks after this + is_last = false + message_acc_length += shard_length message_acc = message_shard + message_acc index--