diff --git a/components/Endpoint/Local.tsx b/components/Endpoint/Local.tsx index f9174e2..d92078e 100644 --- a/components/Endpoint/Local.tsx +++ b/components/Endpoint/Local.tsx @@ -1,4 +1,4 @@ -import { Llama } from '@constants/LlamaLocal' +import { Llama, LlamaPreset } from '@constants/LlamaLocal' import { AppSettings, Global, Logger, Style } from '@globals' import { useEffect, useState } from 'react' import { @@ -16,6 +16,12 @@ import { useMMKVBoolean, useMMKVObject, useMMKVString } from 'react-native-mmkv' import { SliderItem } from '..' const Local = () => { + const { loadModel, unloadModel, modelName } = Llama.useLlama((state) => ({ + loadModel: state.load, + unloadModel: state.unload, + modelName: state.modelname, + })) + const [modelLoading, setModelLoading] = useState(false) const [modelList, setModelList] = useState([]) const dropdownValues = modelList.map((item) => { @@ -23,12 +29,12 @@ const Local = () => { }) const [currentModel, setCurrentModel] = useMMKVString(Global.LocalModel) const [downloadLink, setDownloadLink] = useState('') - const [preset, setPreset] = useMMKVObject(Global.LocalPreset) - const [loadedModel, setLoadedModel] = useState(Llama.getModelname()) + const [preset, setPreset] = useMMKVObject(Global.LocalPreset) + const [loadedModel, setLoadedModel] = useState(modelName) const [saveKV, setSaveKV] = useMMKVBoolean(AppSettings.SaveLocalKV) const [kvSize, setKVSize] = useState(-1) const getModels = async () => { - setModelList(await Llama.getModels()) + setModelList(await Llama.getModelList()) } useEffect(() => { @@ -40,20 +46,20 @@ const Local = () => { const handleLoad = async () => { setModelLoading(true) - await Llama.loadModel(currentModel ?? '', preset).then(() => { - setLoadedModel(Llama.getModelname()) + await loadModel(currentModel ?? '', preset).then(() => { + setLoadedModel(modelName) }) setModelLoading(false) getModels() } - + /* const handleLoadExternal = async () => { setModelLoading(true) await Llama.loadModel('', preset, false).then(() => { setLoadedModel(Llama.getModelname()) }) setModelLoading(false) - } + }*/ const handleDelete = async () => { if (!(await Llama.modelExists(currentModel ?? ''))) { @@ -70,7 +76,7 @@ const Local = () => { Llama.deleteModel(currentModel ?? '') .then(() => { Logger.log('Model Deleted Successfully', true) - setLoadedModel(Llama.getModelname()) + setLoadedModel(modelName) getModels() }) .catch(() => Logger.log('Could Not Delete Model', true)) @@ -80,8 +86,8 @@ const Local = () => { } const handleUnload = async () => { - await Llama.unloadModel() - setLoadedModel(Llama.getModelname()) + await unloadModel() + setLoadedModel(modelName) } const handleDownload = () => { diff --git a/constants/APIState/LocalAPI.ts b/constants/APIState/LocalAPI.ts index 9ac9a72..38a1b90 100644 --- a/constants/APIState/LocalAPI.ts +++ b/constants/APIState/LocalAPI.ts @@ -1,6 +1,6 @@ import { Chats, useInference } from '@constants/Chat' import { AppSettings, Global } from '@constants/GlobalValues' -import { Llama } from '@constants/LlamaLocal' +import { Llama, LlamaPreset } from '@constants/LlamaLocal' import { Logger } from '@constants/Logger' import { SamplerID } from '@constants/Samplers' import { mmkv } from '@constants/mmkv' @@ -29,7 +29,7 @@ class LocalAPI extends APIBase { buildPayload = () => { const payloadFields = this.getSamplerFields() const rep_pen = payloadFields?.['penalty_repeat'] - const localPreset: Llama.LlamaPreset = this.getObject(Global.LocalPreset) + const localPreset: LlamaPreset = this.getObject(Global.LocalPreset) return { ...payloadFields, penalize_nl: typeof rep_pen === 'number' && rep_pen > 1, @@ -41,13 +41,15 @@ class LocalAPI extends APIBase { } inference = async () => { - if (!Llama.isModelLoaded(false) && mmkv.getBoolean(AppSettings.AutoLoadLocal)) { + const context = Llama.useLlama.getState().context + if (!context && mmkv.getBoolean(AppSettings.AutoLoadLocal)) { const model = mmkv.getString(Global.LocalModel) const params = this.getObject(Global.LocalPreset) - if (model && params) await Llama.loadModel(model, params) + if (model && params) await Llama.useLlama.getState().load(model ?? '', params) } - if (!Llama.isModelLoaded()) { + if (!context) { + Logger.log('No Model Loaded', true) this.stopGenerating() return } @@ -56,7 +58,7 @@ class LocalAPI extends APIBase { mmkv.getBoolean(AppSettings.SaveLocalKV) && !mmkv.getBoolean(Global.LocalSessionLoaded) if (loadKV) { - await Llama.loadKV() + await Llama.useLlama.getState().loadKV() mmkv.set(Global.LocalSessionLoaded, true) } @@ -68,7 +70,7 @@ class LocalAPI extends APIBase { ) useInference.getState().setAbort(async () => { - Llama.stopCompletion() + await Llama.useLlama.getState().stopCompletion() }) const payload = this.buildPayload() @@ -80,9 +82,12 @@ class LocalAPI extends APIBase { const outputCompleted = (text: string) => { Chats.useChat.getState().setBuffer(text.replaceAll(replace, '')) + if (mmkv.getBoolean(AppSettings.PrintContext)) Logger.log(`Completion Output:\n${text}`) } - Llama.completion(payload, outputStream, outputCompleted) + Llama.useLlama + .getState() + .completion(payload, outputStream, outputCompleted) .then(() => { this.stopGenerating() }) diff --git a/constants/LlamaLocal.ts b/constants/LlamaLocal.ts index 351d004..f51a7de 100644 --- a/constants/LlamaLocal.ts +++ b/constants/LlamaLocal.ts @@ -2,6 +2,7 @@ import * as FS from 'expo-file-system' import { CompletionParams, LlamaContext, initLlama } from 'llama.rn' import { Platform } from 'react-native' import DocumentPicker from 'react-native-document-picker' +import { create } from 'zustand' import { AppSettings, Global } from './GlobalValues' import { Logger } from './Logger' @@ -24,26 +25,189 @@ type CompletionOutput = { timings: CompletionTimings } +type LlamaState = { + context: LlamaContext | undefined + modelname: string | undefined + load: (name: string, preset?: LlamaPreset, usecache?: boolean) => Promise + unload: () => Promise + saveKV: () => Promise + loadKV: () => Promise + completion: ( + params: CompletionParams, + callback: (text: string) => void, + completed: (text: string) => void + ) => Promise + stopCompletion: () => Promise + tokenLength: (text: string) => Promise +} + +export type LlamaPreset = { + context_length: number + threads: number + gpu_layers: number + batch: number +} + +const sessionFileDir = `${FS.documentDirectory}llama/` +const sessionFile = `${sessionFileDir}llama-session.bin` + +const default_preset = { + context_length: 2048, + threads: 1, + gpu_layers: 0, + batch: 512, +} + export namespace Llama { const model_dir = `${FS.documentDirectory}models/` - export type LlamaPreset = { - context_length: number - threads: number - gpu_layers: number - batch: number - } + export const useLlama = create()((set, get) => ({ + context: undefined, + modelname: undefined, + load: async ( + name: string, + preset: LlamaPreset = default_preset, + usecache: boolean = true + ) => { + const dir = `${model_dir}${name}` + + switch (name) { + case '': + return Logger.log('No Model Chosen', true) + case get().modelname: + return Logger.log('Model Already Loaded!', true) + } + + if (!(await modelExists(name))) { + Logger.log('Model Does Not Exist!', true) + return + } + + if (get().context !== undefined) { + Logger.log('Unloading current model', true) + await get().context?.release() + set((state) => ({ ...state, context: undefined, modelname: undefined })) + } + + const params = { + model: dir, + n_ctx: preset.context_length, + n_threads: preset.threads, + n_batch: preset.batch, + n_gpu_layers: Platform.OS === 'ios' ? preset.gpu_layers : 0, + } + + mmkv.set(Global.LocalSessionLoaded, false) + Logger.log(`Loading Model: ${name}`, true) + Logger.log(JSON.stringify(params)) + + const llamaContext = await initLlama(params).catch((error) => { + Logger.log(`Could Not Load Model: ${error} `, true) + }) - const default_preset = { - context_length: 2048, - threads: 1, - gpu_layers: 0, - batch: 512, + if (llamaContext) { + set((state) => ({ ...state, context: llamaContext, modelname: name })) + Logger.log('Model Loaded', true) + } + }, + unload: async () => { + get().context?.release() + }, + completion: async ( + params: CompletionParams, + callback = (text: string) => {}, + completed = (text: string) => {} + ) => { + const llamaContext = get().context + if (llamaContext === undefined) { + Logger.log('No Model Loaded', true) + return + } + + return llamaContext + .completion(params, (data: any) => { + callback(data.token) + }) + .then(async ({ text, timings }: CompletionOutput) => { + completed(text) + Logger.log(textTimings(timings)) + if (mmkv.getBoolean(AppSettings.SaveLocalKV)) { + await get().saveKV() + } + }) + }, + stopCompletion: async () => { + await get().context?.stopCompletion() + }, + saveKV: async () => { + const llamaContext = get().context + if (!llamaContext) { + Logger.log('No Model Loaded', true) + return + } + if (!(await FS.getInfoAsync(sessionFileDir)).exists) { + await FS.makeDirectoryAsync(sessionFileDir) + } + + if (!(await FS.getInfoAsync(sessionFile)).exists) { + await FS.writeAsStringAsync(sessionFile, '', { encoding: 'base64' }) + } + + const now = performance.now() + const data = await llamaContext.saveSession(sessionFile.replace('file://', '')) + Logger.log( + data === -1 + ? 'Failed to save KV cache' + : `Saved KV in ${Math.floor(performance.now() - now)}ms with ${data} tokens` + ) + Logger.log(`Current KV Size is: ${await getKVSizeMB()}MB`) + }, + loadKV: async () => { + const llamaContext = get().context + if (!llamaContext) { + Logger.log('No Model Loaded', true) + return + } + const data = await FS.getInfoAsync(sessionFile) + if (!data.exists) { + Logger.log('No cache found') + return + } + await llamaContext + .loadSession(sessionFile.replace('file://', '')) + .then(() => { + Logger.log('Session loaded from KV cache') + }) + .catch(() => { + Logger.log('Session loaded could not load from KV cache') + }) + }, + tokenLength: async () => { + return -1 + }, + })) + + const textTimings = (timings: CompletionTimings) => { + return ( + `\n[Prompt Timings]` + + `\nPrompt Per Token: ${timings.prompt_per_token_ms} ms/token` + + `\nPrompt Per Second: ${timings.prompt_per_second?.toFixed(2) ?? 0} tokens/s` + + `\nPrompt Time: ${(timings.prompt_ms / 1000).toFixed(2)}s` + + `\nPrompt Tokens: ${timings.prompt_n} tokens` + + `\n\n[Predicted Timings]` + + `\nPredicted Per Token: ${timings.predicted_per_token_ms} ms/token` + + `\nPredicted Per Second: ${timings.predicted_per_second?.toFixed(2) ?? 0} tokens/s` + + `\nPrediction Time: ${(timings.predicted_ms / 1000).toFixed(2)}s` + + `\nPredicted Tokens: ${timings.predicted_n} tokens` + ) } + /* let llamaContext: LlamaContext | void = undefined let modelname: string | undefined = undefined + // Model Functions + export const loadModel = async ( name: string, preset: LlamaPreset = default_preset, @@ -120,58 +284,16 @@ export namespace Llama { }) } - const textTimings = (timings: CompletionTimings) => { - return ( - `\n[Prompt Timings]` + - `\nPrompt Per Token: ${timings.prompt_per_token_ms} ms/token` + - `\nPrompt Per Second: ${timings.prompt_per_second?.toFixed(2) ?? 0} tokens/s` + - `\nPrompt Time: ${(timings.prompt_ms / 1000).toFixed(2)}s` + - `\nPrompt Tokens: ${timings.prompt_n} tokens` + - `\n\n[Predicted Timings]` + - `\nPredicted Per Token: ${timings.predicted_per_token_ms} ms/token` + - `\nPredicted Per Second: ${timings.predicted_per_second?.toFixed(2) ?? 0} tokens/s` + - `\nPrediction Time: ${(timings.predicted_ms / 1000).toFixed(2)}s` + - `\nPredicted Tokens: ${timings.predicted_n} tokens` - ) - } + export const stopCompletion = async () => { return await llamaContext?.stopCompletion() } - export const downloadModel = async (url: string) => { - const modelName = nameFromURL(url) - const modelList = await Llama.getModels() - if (modelList.includes(modelName)) { - Logger.log('Model already exists!', true) - return - } - Logger.log('Downloading Model...', true) - await FS.downloadAsync(url, `${model_dir}${modelName}`) - .then(() => { - Logger.log('Model downloaded!', true) - }) - .catch(() => { - Logger.log('Download failed', true) - }) - } - export const getDetails = () => { if (llamaContext) return llamaContext.model } - export const getModels = async () => { - return await FS.readDirectoryAsync(model_dir) - } - - export const modelExists = async (modelName: string) => { - return (await getModels()).includes(modelName) - } - - export const nameFromURL = (url: string) => { - return url.split('resolve/main/')[1].replace('?download=true', '') - } - export const unloadModel = async () => { Logger.log('Unloading Model', true) await llamaContext?.release() @@ -180,36 +302,6 @@ export namespace Llama { Logger.log('Model Unloaded', true) } - export const deleteModel = async (name: string) => { - if (!(await modelExists(name))) return - if (name === modelname) modelname = '' - return await FS.deleteAsync(`${model_dir}${name}`) - } - - export const importModel = async () => { - return DocumentPicker.pickSingle() - .then(async (result: any) => { - if (DocumentPicker.isCancel(result)) return false - const name = result.name - Logger.log('Importing file...', true) - await FS.copyAsync({ - from: result.uri, - to: `${model_dir}${name}`, - }) - .then(() => { - Logger.log('File Imported!', true) - }) - .catch((error) => { - Logger.log(`Import Failed: ${error.message}`, true) - }) - - return false - }) - .catch(() => { - Logger.log('No Model Chosen', true) - }) - } - export const isModelLoaded = (showmessage = true) => { if (showmessage && llamaContext === undefined) { Logger.log('No Model Loaded', true) @@ -221,9 +313,6 @@ export namespace Llama { return modelname } - const sessionFileDir = `${FS.documentDirectory}llama/` - const sessionFile = `${sessionFileDir}llama-session.bin` - export const saveKV = async () => { if (!llamaContext) { Logger.log('No Model Loaded', true) @@ -267,13 +356,79 @@ export namespace Llama { }) } - export const kvInfo = async () => { - const data = await FS.getInfoAsync(sessionFile) - if (!data.exists) { - Logger.log('No KV Cache found') + export const tokenLength = async (text: string) => { + if (!llamaContext) return -1 + return (await llamaContext.tokenize(text)).tokens.length + }*/ + + // Presets + + export const setLlamaPreset = () => { + const presets = mmkv.getString(Global.LocalPreset) + if (presets === undefined) mmkv.set(Global.LocalPreset, JSON.stringify(default_preset)) + } + + // Downloaders + + export const downloadModel = async (url: string) => { + const modelName = nameFromURL(url) + const modelList = await Llama.getModelList() + if (modelList.includes(modelName)) { + Logger.log('Model already exists!', true) return } - Logger.log(`Size of KV cache: ${Math.floor(data.size * 0.000001)} MB`) + Logger.log('Downloading Model...', true) + await FS.downloadAsync(url, `${model_dir}${modelName}`) + .then(() => { + Logger.log('Model downloaded!', true) + }) + .catch(() => { + Logger.log('Download failed', true) + }) + } + + export const nameFromURL = (url: string) => { + return url.split('resolve/main/')[1].replace('?download=true', '') + } + + // Filesystem + + export const getModelList = async () => { + return await FS.readDirectoryAsync(model_dir) + } + + export const modelExists = async (modelName: string) => { + return (await getModelList()).includes(modelName) + } + + export const deleteModel = async (name: string) => { + if (!(await modelExists(name))) return + if (name === useLlama.getState().modelname) await useLlama.getState().unload() + return await FS.deleteAsync(`${model_dir}${name}`) + } + + export const importModel = async () => { + return DocumentPicker.pickSingle() + .then(async (result: any) => { + if (DocumentPicker.isCancel(result)) return false + const name = result.name + Logger.log('Importing file...', true) + await FS.copyAsync({ + from: result.uri, + to: `${model_dir}${name}`, + }) + .then(() => { + Logger.log('File Imported!', true) + }) + .catch((error) => { + Logger.log(`Import Failed: ${error.message}`, true) + }) + + return false + }) + .catch(() => { + Logger.log('No Model Chosen', true) + }) } export const getKVSizeMB = async () => { @@ -290,13 +445,12 @@ export namespace Llama { } } - export const tokenize = async (text: string) => { - if (!llamaContext) return -1 - return (await llamaContext.tokenize(text)).tokens.length - } - - export const setLlamaPreset = () => { - const presets = mmkv.getString(Global.LocalPreset) - if (presets === undefined) mmkv.set(Global.LocalPreset, JSON.stringify(default_preset)) + export const kvInfo = async () => { + const data = await FS.getInfoAsync(sessionFile) + if (!data.exists) { + Logger.log('No KV Cache found') + return + } + Logger.log(`Size of KV cache: ${Math.floor(data.size * 0.000001)} MB`) } }