Skip to content

Commit

Permalink
feat: added last_output_prefix to instruct
Browse files Browse the repository at this point in the history
  • Loading branch information
Vali-98 committed Sep 22, 2024
1 parent 8afde1b commit a52f90d
Show file tree
Hide file tree
Showing 9 changed files with 904 additions and 11 deletions.
22 changes: 17 additions & 5 deletions app/Instruct.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ import CheckboxTitle from '@components/CheckboxTitle'
import SliderItem from '@components/SliderItem'
import TextBox from '@components/TextBox'
import TextBoxModal from '@components/TextBoxModal'
import { InstructListItem } from 'app/constants/Instructs'
import { FontAwesome } from '@expo/vector-icons'
import { Global, Instructs, saveStringExternal, Logger, Style } from '@globals'
import Slider from '@react-native-community/slider'
import { InstructListItem } from 'app/constants/Instructs'
import { Stack } from 'expo-router'
import { useState, useEffect } from 'react'
import { useAutosave } from 'react-autosave'
Expand Down Expand Up @@ -45,7 +44,6 @@ const Instruct = () => {
const targetID = id === -1 ? instructID : id
const currentitem = list.filter((item) => item.id === targetID)
if (currentitem.length === 0) {
// item no longer exists
setSelectedItem(list[0])
loadInstruct(list[0].id)
} else {
Expand Down Expand Up @@ -121,8 +119,6 @@ const Instruct = () => {
await loadInstruct(newid)
loadInstructList(newid)
})

//Instructs.saveFile(text, { ...currentInstruct, name: text })
}}
/>

Expand Down Expand Up @@ -316,6 +312,22 @@ const Instruct = () => {
/>
</View>
*/}
<View style={{ flexDirection: 'row' }}>
<TextBox
name="Last Output Prefix"
varname="last_output_prefix"
body={currentInstruct}
setValue={setCurrentInstruct}
multiline
/>
{/*<TextBox
name="Separator Sequence"
varname="separator_sequence"
body={currentInstruct}
setValue={setCurrentInstruct}
multiline
/>*/}
</View>
<View style={{ flexDirection: 'row' }}>
<TextBox
name="Stop Sequence"
Expand Down
11 changes: 8 additions & 3 deletions app/constants/APIState/BaseAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ export abstract class APIBase implements IAPIBase {

let instruct_len = message.is_user
? instructCache.input_prefix_length
: instructCache.output_suffix_length
: is_last
? instructCache.last_output_prefix_length
: instructCache.output_suffix_length

// for last message, we want to skip the end token to allow the LLM to generate

Expand All @@ -143,14 +145,17 @@ export abstract class APIBase implements IAPIBase {
swipe_len + instruct_len + name_length + timestamp_length + wrap_length

// check if within context window

if (message_acc_length + payload_length + shard_length > max_length) {
break
}

// apply strings

let message_shard = `${message.is_user ? currentInstruct.input_prefix : currentInstruct.output_prefix}`
let message_shard = message.is_user
? currentInstruct.input_prefix
: is_last
? currentInstruct.last_output_prefix
: currentInstruct.output_prefix

if (currentInstruct.timestamp) message_shard += timestamp_string

Expand Down
33 changes: 31 additions & 2 deletions app/constants/Instructs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const defaultInstructs: InstructType[] = [
input_prefix: '### Instruction: ',
input_suffix: '\n',
output_prefix: '### Response: ',
last_output_prefix: '### Response: ',
output_suffix: '\n',
stop_sequence: '### Instruction',
user_alignment_message: '',
Expand All @@ -40,6 +41,7 @@ const defaultInstructs: InstructType[] = [
input_prefix: '<|start_header_id|>user<|end_header_id|>\n\n',
input_suffix: '<|eot_id|>',
output_prefix: '<|start_header_id|>assistant<|end_header_id|>\n\n',
last_output_prefix: '<|start_header_id|>assistant<|end_header_id|>\n\n',
output_suffix: '<|eot_id|>',
stop_sequence: '<|eot_id|>',
user_alignment_message: '',
Expand All @@ -60,6 +62,7 @@ const defaultInstructs: InstructType[] = [
input_prefix: '<|im_start|>user\n',
input_suffix: '<|im_end|>\n',
output_prefix: '<|im_start|>assistant\n',
last_output_prefix: '<|im_start|>assistant\n',
output_suffix: '<|im_end|>\n',
stop_sequence: '<|im_end|>',
user_alignment_message: '',
Expand All @@ -80,6 +83,7 @@ const defaultInstructs: InstructType[] = [
input_prefix: '<|user|>\n',
input_suffix: '<|endoftext|>\n',
output_prefix: '<|assistant|>\n',
last_output_prefix: '<|assistant|>\n',
output_suffix: '<|endoftext|>\n',
stop_sequence: '<|endoftext|>\n',
user_alignment_message: '',
Expand All @@ -100,6 +104,7 @@ const defaultInstructs: InstructType[] = [
input_prefix: '<|user|>\n',
input_suffix: '<|end|>\n',
output_prefix: '<|assistant|>\n',
last_output_prefix: '<|assistant|>\n',
output_suffix: '<|end|>\n',
stop_sequence: '<|end|>\n',
user_alignment_message: '',
Expand All @@ -120,6 +125,7 @@ const defaultInstructs: InstructType[] = [
input_prefix: '<start_of_turn>user\n',
input_suffix: '<end_of_turn>\n',
output_prefix: '<start_of_turn>model',
last_output_prefix: '<start_of_turn>model',
output_suffix: '<end_of_turn>\n',
stop_sequence: '<end_of_turn>',
user_alignment_message: '',
Expand Down Expand Up @@ -158,6 +164,7 @@ type InstructTokenCache = {
input_prefix_length: number
input_suffix_length: number
output_prefix_length: number
last_output_prefix_length: number
output_suffix_length: number
user_alignment_message_length: number
}
Expand All @@ -170,6 +177,7 @@ export namespace Instructs {
input_prefix: '### Instruction: ',
input_suffix: '\n',
output_prefix: '### Response: ',
last_output_prefix: '### Response: ',
output_suffix: '\n',
stop_sequence: '### Instruction',
user_alignment_message: '',
Expand Down Expand Up @@ -212,6 +220,7 @@ export namespace Instructs {
input_prefix_length: 0,
input_suffix_length: 0,
output_prefix_length: 0,
last_output_prefix_length: 0,
output_suffix_length: 0,
user_alignment_message_length: 0,
}
Expand All @@ -229,9 +238,11 @@ export namespace Instructs {
input_prefix_length: getTokenCount(instruct.input_prefix),
input_suffix_length: getTokenCount(instruct.input_suffix),
output_prefix_length: getTokenCount(instruct.output_prefix),
last_output_prefix_length: getTokenCount(instruct.last_output_prefix),
output_suffix_length: getTokenCount(instruct.output_suffix),
user_alignment_message_length: getTokenCount(instruct.system_prompt),
}
console.log('cache created')
set((state) => ({ ...state, tokenCache: newCache }))
return newCache
},
Expand All @@ -254,14 +265,30 @@ export namespace Instructs {
name: 'instruct-storage',
storage: createJSONStorage(() => mmkvStorage),
partialize: (state) => ({ data: state.data }),
version: 1,
migrate: (persistedState: any, version) => {
version: 2,
migrate: async (persistedState: any, version) => {
if (!version) {
persistedState.timestamp = false
persistedState.examples = true
persistedState.format_type = 0
Logger.log('[INSTRUCT] Migrated to v1')
}
if (version === 1) {
const entries = await database.query.instructs.findMany({
columns: {
id: true,
output_prefix: true,
},
})
entries.forEach(async (item) => {
if (item?.id === persistedState?.id)
persistedState.last_output_prefix = item.output_prefix
await database
.update(instructs)
.set({ last_output_prefix: item.output_prefix })
.where(eq(instructs.id, item.id))
})
}

return persistedState
},
Expand Down Expand Up @@ -344,4 +371,6 @@ export type InstructType = {
timestamp: boolean
examples: boolean
format_type: number

last_output_prefix: string
}
23 changes: 23 additions & 0 deletions app/constants/Utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { API } from './API'
import { Characters } from './Characters'
import { Global } from './GlobalValues'
import { mmkv } from './MMKV'

export const humanizedISO8601DateTime = (date = '') => {
const baseDate = typeof date === 'number' ? new Date(date) : new Date()
Expand Down Expand Up @@ -50,3 +53,23 @@ export const replaceMacros = (text: string) => {
for (const rule of rules) newtext = newtext.replaceAll(rule.macro, rule.value)
return newtext
}

const getMMKVObjectModel = (mmkvKey: string, field: string) => {
const data = mmkv.getString(mmkvKey)
if (!data) return 'undefined'
const model = JSON.parse(data)[field]
return model
}

export const getCurrentModel = () => {
const api = mmkv.getString(Global.APIType)
switch (api) {
case API.CHATCOMPLETIONS: {
return getMMKVObjectModel(Global.ChatCompletionsModel, 'id')
}
case API.CLAUDE: {
return getMMKVObjectModel(Global.ClaudeModel, 'name')
}
// TODO: Finish this - need data for KAI api
}
}
1 change: 1 addition & 0 deletions db/migrations/0003_violet_firelord.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE `instructs` ADD `last_output_prefix` text DEFAULT '' NOT NULL;
Loading

0 comments on commit a52f90d

Please sign in to comment.