Skip to content

Commit

Permalink
feat: use tokenizer from llama.rn
Browse files Browse the repository at this point in the history
  • Loading branch information
Vali-98 committed Jul 18, 2024
1 parent 852f48a commit e953d48
Show file tree
Hide file tree
Showing 14 changed files with 3,473 additions and 868 deletions.
7 changes: 7 additions & 0 deletions app.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
"favicon": "./assets/images/adaptive-icon.png"
},
"plugins": [
[
"expo-custom-assets",
{
"assetsPaths": ["./assets/models"],
"assetsDirName": "appAssets"
}
],
"expo-router",
[
"expo-image-picker",
Expand Down
8 changes: 4 additions & 4 deletions app/CharInfo.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import AnimatedView from '@components/AnimatedView'
import { CharacterCardV2 } from '@constants/Characters'
import { RecentMessages } from '@constants/RecentMessages'
import { Tokenizer } from '@constants/Tokenizer'
import { FontAwesome } from '@expo/vector-icons'
import { Characters, Llama3Tokenizer, Logger, Style } from '@globals'
import { Characters, Logger, Style } from '@globals'
import * as DocumentPicker from 'expo-document-picker'
import { Stack, useRouter } from 'expo-router'
import { useState } from 'react'
Expand Down Expand Up @@ -30,7 +31,7 @@ const CharInfo = () => {
charName: state.card?.data.name,
}))
)

const getTokenCount = Tokenizer.useTokenizer((state) => state.getTokenCount)
const [characterCard, setCharacterCard] = useState<CharacterCardV2 | undefined>(currentCard)

const imageDir = Characters.getImageDir(currentCard?.data.image_id ?? -1)
Expand Down Expand Up @@ -157,8 +158,7 @@ const CharInfo = () => {

<Text style={styles.boxText}>
Description Tokens:{' '}
{characterCard?.data?.description !== undefined &&
Llama3Tokenizer.encode(characterCard.data.description).length}
{getTokenCount(characterCard?.data?.description ?? '')}
</Text>

<ScrollView
Expand Down
Binary file added assets/models/llama3tokenizer.gguf
Binary file not shown.
2 changes: 1 addition & 1 deletion babel.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module.exports = function (api) {
'module-resolver',
{
root: ['.'],
extension: ['.js', '.ts', '.tsx'],
extension: ['.js', '.jsx', '.ts', '.tsx'],
alias: {
'@globals': './constants/global',
'@components': './components',
Expand Down
242 changes: 121 additions & 121 deletions constants/Characters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { create } from 'zustand'

import { Global } from './GlobalValues'
import { Logger } from './Logger'
import { Llama3Tokenizer } from './Tokenizer/tokenizer'
import { Tokenizer } from './Tokenizer'
import { mmkv } from './mmkv'

type CharacterTokenCache = {
Expand All @@ -23,7 +23,7 @@ type CharacterTokenCache = {
}

type CharacterCardState = {
card: CharacterCardV2 | undefined
card?: CharacterCardV2
tokenCache: CharacterTokenCache | undefined
id: number | undefined
setCard: (id: number) => Promise<string | undefined>
Expand All @@ -34,129 +34,130 @@ type CharacterCardState = {
}

export namespace Characters {
export const useUserCard = create<CharacterCardState>()(
(set, get: () => CharacterCardState) => ({
id: undefined,
card: undefined,
tokenCache: undefined,
setCard: async (id: number) => {
let start = performance.now()
const card = await db.query.card(id)
Logger.debug(`[User] time for database query: ${performance.now() - start}`)
start = performance.now()
set((state) => ({ ...state, card: card, id: id, tokenCache: undefined }))
Logger.debug(`[User] time for zustand set: ${performance.now() - start}`)
mmkv.set(Global.UserID, id)
return card?.data.name
},
unloadCard: () => {
set((state) => ({
...state,
id: undefined,
card: undefined,
tokenCache: undefined,
}))
},
getImage: () => {
return getImageDir(get().card?.data.image_id ?? 0)
},
updateImage: async (sourceURI: string) => {
const id = get().id
const oldImageID = get().card?.data.image_id
const card = get().card
if (!id || !oldImageID || !card) {
Logger.log('Could not get data, something very wrong has happned!', true)
return
}
const imageID = new Date().getTime()
await db.mutate.updateCardField('image_id', imageID, id)
await deleteImage(oldImageID)
await copyImage(sourceURI, imageID)
card.data.image_id = imageID
set((state) => ({ ...state, card: card }))
},
getCache: (userName: string) => {
const cache = get().tokenCache
if (cache && cache?.otherName === useCharacterCard.getState().card?.data.name)
return cache

const card = get().card
if (!card)
return {
otherName: userName,
description_length: 0,
examples_length: 0,
}
const description = replaceMacros(card.data.description)
const examples = replaceMacros(card.data.mes_example)
const newCache = {
export const useUserCard = create<CharacterCardState>()((set, get) => ({
id: undefined,
card: undefined,
tokenCache: undefined,
setCard: async (id: number) => {
let start = performance.now()
const card = await db.query.card(id)
Logger.debug(`[User] time for database query: ${performance.now() - start}`)
start = performance.now()
set((state) => ({ ...state, card: card, id: id, tokenCache: undefined }))
Logger.debug(`[User] time for zustand set: ${performance.now() - start}`)
mmkv.set(Global.UserID, id)
return card?.data.name
},
unloadCard: () => {
set((state) => ({
...state,
id: undefined,
card: undefined,
tokenCache: undefined,
}))
},
getImage: () => {
return getImageDir(get().card?.data.image_id ?? 0)
},
updateImage: async (sourceURI: string) => {
const id = get().id
const oldImageID = get().card?.data.image_id
const card = get().card
if (!id || !oldImageID || !card) {
Logger.log('Could not get data, something very wrong has happned!', true)
return
}
const imageID = new Date().getTime()
await db.mutate.updateCardField('image_id', imageID, id)
await deleteImage(oldImageID)
await copyImage(sourceURI, imageID)
card.data.image_id = imageID
set((state) => ({ ...state, card: card }))
},
getCache: (userName: string) => {
const cache = get().tokenCache
if (cache && cache?.otherName === userName) return cache

const card = get().card
if (!card)
return {
otherName: userName,
description_length: Llama3Tokenizer.encode(description).length,
examples_length: Llama3Tokenizer.encode(examples).length,
description_length: 0,
examples_length: 0,
}
const description = replaceMacros(card.data.description)
const examples = replaceMacros(card.data.mes_example)
const getTokenCount = Tokenizer.useTokenizer.getState().getTokenCount
const newCache = {
otherName: userName,
description_length: getTokenCount(description),
examples_length: getTokenCount(examples),
}

set((state) => ({ ...state, tokenCache: newCache }))
return newCache
},
})
)

export const useCharacterCard = create<CharacterCardState>()(
(set, get: () => CharacterCardState) => ({
id: undefined,
card: undefined,
tokenCache: undefined,
setCard: async (id: number) => {
let start = performance.now()
const card = await db.query.card(id)
Logger.debug(`[Characters] time for database query: ${performance.now() - start}`)
start = performance.now()
set((state) => ({ ...state, card: card, id: id, tokenCache: undefined }))

Logger.debug(`[Characters] time for zustand set: ${performance.now() - start}`)
return card?.data.name
},
unloadCard: () => {
set((state) => ({
...state,
id: undefined,
card: undefined,
tokenCache: undefined,
}))
},
getImage: () => {
return getImageDir(get().card?.data.image_id ?? 0)
},
updateImage: async (sourceURI: string) => {
const imageID = get().card?.data.image_id
if (imageID) return copyImage(sourceURI, imageID)
},
getCache: (charName: string) => {
const cache = get().tokenCache
if (cache?.otherName && cache.otherName === useUserCard.getState().card?.data.name)
return cache

const card = get().card
if (!card)
return {
otherName: charName,
description_length: 0,
examples_length: 0,
}
const description = replaceMacros(card.data.description)
const examples = replaceMacros(card.data.mes_example)

const newCache = {
set((state) => ({ ...state, tokenCache: newCache }))
return newCache
},
}))

export const useCharacterCard = create<CharacterCardState>()((set, get) => ({
id: undefined,
card: undefined,
tokenCache: undefined,
setCard: async (id: number) => {
let start = performance.now()
const card = await db.query.card(id)
Logger.debug(`[Characters] time for database query: ${performance.now() - start}`)
start = performance.now()
set((state) => {
return { ...state, card: card, id: id, tokenCache: undefined }
})
Logger.debug(`[Characters] time for zustand set: ${performance.now() - start}`)
return card?.data.name
},
unloadCard: () => {
set((state) => ({
...state,
id: undefined,
card: undefined,
tokenCache: undefined,
}))
},
getImage: () => {
return getImageDir(get().card?.data.image_id ?? 0)
},
updateImage: async (sourceURI: string) => {
const imageID = get().card?.data.image_id
if (imageID) return copyImage(sourceURI, imageID)
},
getCardTest: () => {
return get().card
},
getCache: (charName: string) => {
const cache = get().tokenCache
console.log(cache)
const card = get().card
if (cache?.otherName && cache.otherName === useUserCard.getState().card?.data.name)
return cache

console.log(card)
if (!card)
return {
otherName: charName,
description_length: Llama3Tokenizer.encode(description).length,
examples_length: Llama3Tokenizer.encode(examples).length,
description_length: 0,
examples_length: 0,
}

set((state) => ({ ...state, tokenCache: newCache }))
return newCache
},
})
)
const description = replaceMacros(card.data.description)
const examples = replaceMacros(card.data.mes_example)
const getTokenCount = Tokenizer.useTokenizer.getState().getTokenCount
const newCache = {
otherName: charName,
description_length: getTokenCount(description),
examples_length: getTokenCount(examples),
}
set((state) => ({ ...state, tokenCache: newCache }))
return newCache
},
}))

export namespace db {
export namespace query {
Expand Down Expand Up @@ -449,7 +450,6 @@ export namespace Characters {
const param = new URLSearchParams(text)
const character_id = param.get('id')?.replaceAll(`"`, '')
const path = url.pathname.replace('/character/', '')
console.log(path)
if (character_id) return importCharacterFromPyg(character_id)
else if (uuidRegex.test(path)) return importCharacterFromPyg(path)
else {
Expand Down
8 changes: 4 additions & 4 deletions constants/Chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { AppSettings, Global } from './GlobalValues'
import { Logger } from './Logger'
import { RecentMessages } from './RecentMessages'
import { convertToFormatInstruct } from './TextFormat'
import { Llama3Tokenizer } from './Tokenizer/tokenizer'
import { Tokenizer } from './Tokenizer'
import { replaceMacros } from './Utils'
import { mmkv } from './mmkv'

Expand Down Expand Up @@ -240,9 +240,9 @@ export namespace Chats {
const swipe_id = messages[index].swipe_id
const cached_token_count = messages[index].swipes[swipe_id].token_count
if (cached_token_count) return cached_token_count
const token_count = Llama3Tokenizer.encode(
messages[index].swipes[swipe_id].swipe
).length
const token_count = Tokenizer.useTokenizer
.getState()
.getTokenCount(messages[index].swipes[swipe_id].swipe)
messages[index].swipes[swipe_id].token_count = token_count
set((state: ChatState) => ({
...state,
Expand Down
21 changes: 10 additions & 11 deletions constants/Instructs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { createJSONStorage, persist } from 'zustand/middleware'

import { Global } from './GlobalValues'
import { Logger } from './Logger'
import { Llama3Tokenizer } from './Tokenizer/tokenizer'
import { Tokenizer } from './Tokenizer'
import { replaceMacros } from './Utils'
import { mmkv, mmkvStorage } from './mmkv'

Expand Down Expand Up @@ -193,19 +193,18 @@ export namespace Instructs {
output_suffix_length: 0,
user_alignment_message_length: 0,
}
const getTokenCount = Tokenizer.useTokenizer.getState().getTokenCount
const newCache: InstructTokenCache = {
charName: charName,
userName: userName,
system_prompt_length: Llama3Tokenizer.encode(instruct.system_prompt).length,
system_prefix_length: Llama3Tokenizer.encode(instruct.system_prefix).length,
system_suffix_length: Llama3Tokenizer.encode(instruct.system_suffix).length,
input_prefix_length: Llama3Tokenizer.encode(instruct.input_prefix).length,
input_suffix_length: Llama3Tokenizer.encode(instruct.input_suffix).length,
output_prefix_length: Llama3Tokenizer.encode(instruct.output_prefix).length,
output_suffix_length: Llama3Tokenizer.encode(instruct.output_suffix).length,
user_alignment_message_length: Llama3Tokenizer.encode(
instruct.system_prompt
).length,
system_prompt_length: getTokenCount(instruct.system_prompt),
system_prefix_length: getTokenCount(instruct.system_prefix),
system_suffix_length: getTokenCount(instruct.system_suffix),
input_prefix_length: getTokenCount(instruct.input_prefix),
input_suffix_length: getTokenCount(instruct.input_suffix),
output_prefix_length: getTokenCount(instruct.output_prefix),
output_suffix_length: getTokenCount(instruct.output_suffix),
user_alignment_message_length: getTokenCount(instruct.system_prompt),
}
set((state) => ({ ...state, tokenCache: newCache }))
return newCache
Expand Down
Loading

0 comments on commit e953d48

Please sign in to comment.