Skip to content

Commit

Permalink
feat: added token caching to character card and instruct format
Browse files Browse the repository at this point in the history
  • Loading branch information
Vali-98 committed Apr 23, 2024
1 parent 626894a commit 1113c34
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 53 deletions.
78 changes: 76 additions & 2 deletions constants/Characters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,76 @@ import { db } from '@db'
import { characterGreetings, characterTags, characters, tags } from 'db/schema'
import { eq, inArray, notExists, notInArray, sql } from 'drizzle-orm'
import { create } from 'zustand'
import { LlamaTokenizer } from './tokenizer'
import { mmkv } from './mmkv'
import { Global } from './GlobalValues'

type CharacterTokenCache = {
otherName: string
description_length: number
examples_length: number
}

type CharacterCardState = {
card: CharacterCardV2 | undefined
tokenCache: CharacterTokenCache | undefined
id: number | undefined
setCard: (id: number) => Promise<string | undefined>
unloadCard: () => void
getImage: () => string
getCache: (otherName: string) => CharacterTokenCache
}

export namespace Characters {
export const useUserCard = create<CharacterCardState>((set, get: () => CharacterCardState) => ({
id: undefined,
card: undefined,
tokenCache: undefined,
setCard: async (id: number) => {
let start = performance.now()
const card = await readCard(id)
Logger.debug(`[User] time for db query: ${performance.now() - start}`)
start = performance.now()
set((state) => ({ ...state, card: card, id: id }))
set((state) => ({ ...state, card: card, id: id, tokenCache: undefined }))

Logger.debug(`[User] time for zustand set: ${performance.now() - start}`)
return card?.data.name
},
unloadCard: () => {
set((state) => ({ ...state, id: undefined, card: undefined }))
set((state) => ({ ...state, id: undefined, card: undefined, tokenCache: undefined }))
},
getImage: () => {
return getImageDir(get().card?.data.image_id ?? 0)
},
getCache: (userName: string) => {
const cache = get().tokenCache
if (cache) return cache

const card = get().card
if (!card)
return {
otherName: userName,
description_length: 0,
examples_length: 0,
}
const description = replaceMacros(card.data.description)
const examples = replaceMacros(card.data.mes_example)
const newCache = {
otherName: userName,
description_length: LlamaTokenizer.encode(description).length,
examples_length: LlamaTokenizer.encode(examples).length,
}

set((state) => ({ ...state, tokenCache: newCache }))
return newCache
},
}))

export const useCharacterCard = create<CharacterCardState>(
(set, get: () => CharacterCardState) => ({
id: undefined,
card: undefined,
tokenCache: undefined,
setCard: async (id: number) => {
let start = performance.now()
const card = await readCard(id)
Expand All @@ -61,6 +96,28 @@ export namespace Characters {
getImage: () => {
return getImageDir(get().card?.data.image_id ?? 0)
},
getCache: (charName: string) => {
const cache = get().tokenCache
if (cache && cache.otherName && cache.otherName === charName) return cache

const card = get().card
if (!card)
return {
otherName: charName,
description_length: 0,
examples_length: 0,
}
const description = replaceMacros(card.data.description)
const examples = replaceMacros(card.data.mes_example)
const newCache = {
otherName: charName,
description_length: LlamaTokenizer.encode(description).length,
examples_length: LlamaTokenizer.encode(examples).length,
}

set((state) => ({ ...state, tokenCache: newCache }))
return newCache
},
})
)

Expand Down Expand Up @@ -445,3 +502,20 @@ const TavernCardV2 = (name: string) => {
},
}
}
type Rule = {
macro: string
value: string
}

export const replaceMacros = (text: string) => {
if (text == undefined) return ''
let newtext: string = text
const charName = Characters.useCharacterCard.getState().card?.data.name
const userName = mmkv.getString(Global.CurrentUser)
const rules: Rule[] = [
{ macro: '{{user}}', value: userName ?? '' },
{ macro: '{{char}}', value: charName ?? '' },
]
for (const rule of rules) newtext = newtext.replaceAll(rule.macro, rule.value)
return newtext
}
95 changes: 46 additions & 49 deletions constants/Inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ export const hordeHeader = () => {
* System
* - System Prefix
* - System Prompt
* - User Card
* - Character Card
* - User Card
* - Example Messages (if max_length allows after Context)
* - System Suffix
* Context
Expand All @@ -123,30 +123,41 @@ const buildContext = (max_length: number) => {
const delta = performance.now()
const messages = [...(Chats.useChat.getState().data?.messages ?? [])]

const currentInstruct = instructReplaceMacro()
const currentInstruct = Instructs.useInstruct.getState().replacedMacros()

const userCard = getObject(Global.CurrentUserCard)
const userName = getString(Global.CurrentUser)
const currentCard = { ...Characters.useCharacterCard.getState().card }
const characterCache = Characters.useCharacterCard.getState().getCache(userName)

const instructCache = Instructs.useInstruct
.getState()
.getCache(currentCard.data?.name ?? '', userName)
const user_card_data = (userCard?.description ?? '').trim()
const char_card_data = (currentCard?.data?.description ?? '').trim()

let payload = ``
if (currentInstruct.system_prefix) payload += currentInstruct.system_prefix
if (currentInstruct.system_prompt) payload += `${currentInstruct.system_prompt}\n`
if (user_card_data) payload += user_card_data + '\n'
if (char_card_data) payload += char_card_data + '\n'
let payload_length = 0
if (currentInstruct.system_prefix) {
payload += currentInstruct.system_prefix
payload_length += instructCache.system_prefix_length
}
if (currentInstruct.system_prompt) {
payload += `${currentInstruct.system_prompt}`
payload_length += instructCache.system_prompt_length
}
if (char_card_data) {
payload += char_card_data
payload_length += characterCache.description_length
}
if (user_card_data) {
payload += user_card_data
payload_length += LlamaTokenizer.encode(user_card_data).length
}
// suffix must be delayed for example messages
const payload_length =
LlamaTokenizer.encode(payload).length + currentInstruct.system_suffix
? LlamaTokenizer.encode(currentInstruct.system_suffix).length
: 0

let message_acc = ``
let message_acc_length = LlamaTokenizer.encode(message_acc).length

const input_prefix_length = LlamaTokenizer.encode(currentInstruct.input_prefix).length
const input_suffix_length = LlamaTokenizer.encode(currentInstruct.input_suffix).length
const output_prefix_length = LlamaTokenizer.encode(currentInstruct.output_prefix).length
const output_suffix_length = LlamaTokenizer.encode(currentInstruct.output_suffix).length
let message_acc_length = 0

let is_last = true
let index = messages.length - 1
Expand All @@ -156,8 +167,8 @@ const buildContext = (max_length: number) => {
: 0
// for last message, we want to skip the end token to allow the LLM to generate
const instruct_len = message.is_user
? input_prefix_length + (is_last ? 0 : input_suffix_length)
: output_prefix_length + (is_last ? 0 : output_suffix_length)
? instructCache.input_prefix_length + (is_last ? 0 : instructCache.input_suffix_length)
: instructCache.input_suffix_length + (is_last ? 0 : instructCache.output_suffix_length)
const shard_length = swipe_len + instruct_len
if (message_acc_length + payload_length + shard_length > max_length) {
break
Expand All @@ -180,18 +191,19 @@ const buildContext = (max_length: number) => {

const examples = currentCard.data?.mes_example
if (examples) {
const examples_length = LlamaTokenizer.encode(examples).length
if (message_acc_length + payload_length + examples_length < max_length) {
if (message_acc_length + payload_length + characterCache.examples_length < max_length) {
payload += examples
message_acc_length += examples_length
message_acc_length += characterCache.examples_length
}
}

if (currentInstruct.system_suffix) payload += ' ' + currentInstruct.system_suffix
if (currentInstruct.system_suffix) {
payload += ' ' + currentInstruct.system_suffix
message_acc_length += instructCache.system_suffix_length
}
payload = replaceMacros(payload + message_acc)

//Logger.log(`Payload size: ${LlamaTokenizer.encode(payload).length}`)
Logger.log(`Approximate Payload Size: ${message_acc_length + payload_length}`)
Logger.log(`Approximate Context Size: ${message_acc_length + payload_length} tokens`)
Logger.log(`${(performance.now() - delta).toFixed(2)}ms taken to build context`)
return payload
}
Expand All @@ -200,7 +212,7 @@ const buildChatCompletionContext = (max_length: number) => {
const messages = [...(Chats.useChat.getState().data?.messages ?? [])]
const userCard = getObject(Global.CurrentUserCard)
const currentCard = { ...Characters.useCharacterCard.getState().card }
const currentInstruct = instructReplaceMacro()
const currentInstruct = Instructs.useInstruct.getState().replacedMacros()

const initial = `${currentInstruct.system_prefix}
${currentInstruct.system_prompt}
Expand Down Expand Up @@ -232,25 +244,11 @@ const constructStopSequence = (instruct: InstructType): Array<string> => {
return sequence
}

const instructReplaceMacro = (): InstructType => {
const rawinstruct = Instructs.useInstruct.getState().data

if (!rawinstruct) return Instructs.defaultInstruct
const instruct = { ...rawinstruct }
const keys = Object.keys(instruct) as (keyof typeof instruct)[]

keys.forEach((key) => {
if (typeof instruct[key] === 'string') replaceMacros(instruct[key] as string)
})

return instruct
}

// Payloads

const constructKAIPayload = () => {
const preset = getObject(Global.PresetData)
const currentInstruct = instructReplaceMacro()
const currentInstruct = Instructs.useInstruct.getState().replacedMacros()

return {
prompt: buildContext(preset.max_length),
Expand Down Expand Up @@ -283,7 +281,7 @@ const constructKAIPayload = () => {

const constructHordePayload = () => {
const preset = getObject(Global.PresetData)
const currentInstruct = instructReplaceMacro()
const currentInstruct = Instructs.useInstruct.getState().replacedMacros()
const hordeModels = getObject(Global.HordeModels)
const hordeWorkers = getObject(Global.HordeWorkers)

Expand Down Expand Up @@ -344,7 +342,7 @@ const constructHordePayload = () => {

const constructTGWUIPayload = () => {
const preset = getObject(Global.PresetData)
const currentInstruct = instructReplaceMacro()
const currentInstruct = Instructs.useInstruct.getState().replacedMacros()
return {
stream: true,
prompt: buildContext(preset.max_length),
Expand Down Expand Up @@ -392,7 +390,7 @@ const constructTGWUIPayload = () => {

const constructMancerPayload = () => {
const preset = getObject(Global.PresetData)
const currentInstruct = instructReplaceMacro()
const currentInstruct = Instructs.useInstruct.getState().replacedMacros()
const mancerModel = getObject(Global.MancerModel)

const context_len = Math.min(preset.max_length, mancerModel.limits.context)
Expand Down Expand Up @@ -420,7 +418,7 @@ const constructMancerPayload = () => {
const constructCompletionsPayload = () => {
const completionsModel = getObject(Global.CompletionsModel)
const preset = getObject(Global.PresetData)
const currentInstruct = instructReplaceMacro()
const currentInstruct = Instructs.useInstruct.getState().replacedMacros()
return {
stream: true,
max_context_length: preset.max_length,
Expand Down Expand Up @@ -457,7 +455,7 @@ const constructCompletionsPayload = () => {

const constructLocalPayload = () => {
const preset = getObject(Global.PresetData)
const currentInstruct = instructReplaceMacro()
const currentInstruct = Instructs.useInstruct.getState().replacedMacros()
const localPreset = getObject(Global.LocalPreset)
return {
prompt: buildContext(preset.max_length),
Expand Down Expand Up @@ -485,7 +483,7 @@ const constructLocalPayload = () => {

const constructOpenRouterPayload = () => {
const openRouterModel = getObject(Global.OpenRouterModel)
const currentInstruct = instructReplaceMacro()
const currentInstruct = Instructs.useInstruct.getState().replacedMacros()
const preset = getObject(Global.PresetData)

return {
Expand All @@ -509,7 +507,7 @@ const constructOpenRouterPayload = () => {

const constructOpenAIPayload = () => {
const openAIModel = getObject(Global.OpenAIModel)
const currentInstruct = instructReplaceMacro()
const currentInstruct = Instructs.useInstruct.getState().replacedMacros()
const preset = getObject(Global.PresetData)
return {
messages: buildChatCompletionContext(preset.max_length),
Expand Down Expand Up @@ -738,7 +736,7 @@ const openAIResponseStream = async (setAbortFunction: AbortFunction) => {
}

const constructReplaceStrings = (): Array<string> => {
const currentInstruct: InstructType = instructReplaceMacro()
const currentInstruct: InstructType = Instructs.useInstruct.getState().replacedMacros()
const userName: string = mmkv.getString(Global.CurrentUser) ?? ''
const charName: string = Characters.useCharacterCard.getState()?.card?.data?.name ?? ''
const stops: Array<string> = constructStopSequence(currentInstruct)
Expand Down Expand Up @@ -825,7 +823,6 @@ const readableStreamResponse = async (
es.close()
return
}
console.log(event.data)
const text = jsonreader(event.data)
const output = Chats.useChat.getState().buffer + text
Chats.useChat.getState().setBuffer(output.replaceAll(replace, ''))
Expand Down
Loading

0 comments on commit 1113c34

Please sign in to comment.