Skip to content

Commit

Permalink
feat: moved local llama context state to zustand
Browse files Browse the repository at this point in the history
  • Loading branch information
Vali-98 committed Jul 4, 2024
1 parent a502bfb commit a2237e3
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 119 deletions.
28 changes: 17 additions & 11 deletions components/Endpoint/Local.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Llama } from '@constants/LlamaLocal'
import { Llama, LlamaPreset } from '@constants/LlamaLocal'
import { AppSettings, Global, Logger, Style } from '@globals'
import { useEffect, useState } from 'react'
import {
Expand All @@ -16,19 +16,25 @@ import { useMMKVBoolean, useMMKVObject, useMMKVString } from 'react-native-mmkv'
import { SliderItem } from '..'

const Local = () => {
const { loadModel, unloadModel, modelName } = Llama.useLlama((state) => ({
loadModel: state.load,
unloadModel: state.unload,
modelName: state.modelname,
}))

const [modelLoading, setModelLoading] = useState(false)
const [modelList, setModelList] = useState<string[]>([])
const dropdownValues = modelList.map((item) => {
return { name: item }
})
const [currentModel, setCurrentModel] = useMMKVString(Global.LocalModel)
const [downloadLink, setDownloadLink] = useState('')
const [preset, setPreset] = useMMKVObject<Llama.LlamaPreset>(Global.LocalPreset)
const [loadedModel, setLoadedModel] = useState(Llama.getModelname())
const [preset, setPreset] = useMMKVObject<LlamaPreset>(Global.LocalPreset)
const [loadedModel, setLoadedModel] = useState(modelName)
const [saveKV, setSaveKV] = useMMKVBoolean(AppSettings.SaveLocalKV)
const [kvSize, setKVSize] = useState<number>(-1)
const getModels = async () => {
setModelList(await Llama.getModels())
setModelList(await Llama.getModelList())
}

useEffect(() => {
Expand All @@ -40,20 +46,20 @@ const Local = () => {

const handleLoad = async () => {
setModelLoading(true)
await Llama.loadModel(currentModel ?? '', preset).then(() => {
setLoadedModel(Llama.getModelname())
await loadModel(currentModel ?? '', preset).then(() => {
setLoadedModel(modelName)
})
setModelLoading(false)
getModels()
}

/*
const handleLoadExternal = async () => {
setModelLoading(true)
await Llama.loadModel('', preset, false).then(() => {
setLoadedModel(Llama.getModelname())
})
setModelLoading(false)
}
}*/

const handleDelete = async () => {
if (!(await Llama.modelExists(currentModel ?? ''))) {
Expand All @@ -70,7 +76,7 @@ const Local = () => {
Llama.deleteModel(currentModel ?? '')
.then(() => {
Logger.log('Model Deleted Successfully', true)
setLoadedModel(Llama.getModelname())
setLoadedModel(modelName)
getModels()
})
.catch(() => Logger.log('Could Not Delete Model', true))
Expand All @@ -80,8 +86,8 @@ const Local = () => {
}

const handleUnload = async () => {
await Llama.unloadModel()
setLoadedModel(Llama.getModelname())
await unloadModel()
setLoadedModel(modelName)
}

const handleDownload = () => {
Expand Down
21 changes: 13 additions & 8 deletions constants/APIState/LocalAPI.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Chats, useInference } from '@constants/Chat'
import { AppSettings, Global } from '@constants/GlobalValues'
import { Llama } from '@constants/LlamaLocal'
import { Llama, LlamaPreset } from '@constants/LlamaLocal'
import { Logger } from '@constants/Logger'
import { SamplerID } from '@constants/Samplers'
import { mmkv } from '@constants/mmkv'
Expand Down Expand Up @@ -29,7 +29,7 @@ class LocalAPI extends APIBase {
buildPayload = () => {
const payloadFields = this.getSamplerFields()
const rep_pen = payloadFields?.['penalty_repeat']
const localPreset: Llama.LlamaPreset = this.getObject(Global.LocalPreset)
const localPreset: LlamaPreset = this.getObject(Global.LocalPreset)
return {
...payloadFields,
penalize_nl: typeof rep_pen === 'number' && rep_pen > 1,
Expand All @@ -41,13 +41,15 @@ class LocalAPI extends APIBase {
}

inference = async () => {
if (!Llama.isModelLoaded(false) && mmkv.getBoolean(AppSettings.AutoLoadLocal)) {
const context = Llama.useLlama.getState().context
if (!context && mmkv.getBoolean(AppSettings.AutoLoadLocal)) {
const model = mmkv.getString(Global.LocalModel)
const params = this.getObject(Global.LocalPreset)
if (model && params) await Llama.loadModel(model, params)
if (model && params) await Llama.useLlama.getState().load(model ?? '', params)
}

if (!Llama.isModelLoaded()) {
if (!context) {
Logger.log('No Model Loaded', true)
this.stopGenerating()
return
}
Expand All @@ -56,7 +58,7 @@ class LocalAPI extends APIBase {
mmkv.getBoolean(AppSettings.SaveLocalKV) && !mmkv.getBoolean(Global.LocalSessionLoaded)

if (loadKV) {
await Llama.loadKV()
await Llama.useLlama.getState().loadKV()
mmkv.set(Global.LocalSessionLoaded, true)
}

Expand All @@ -68,7 +70,7 @@ class LocalAPI extends APIBase {
)

useInference.getState().setAbort(async () => {
Llama.stopCompletion()
await Llama.useLlama.getState().stopCompletion()
})

const payload = this.buildPayload()
Expand All @@ -80,9 +82,12 @@ class LocalAPI extends APIBase {

const outputCompleted = (text: string) => {
Chats.useChat.getState().setBuffer(text.replaceAll(replace, ''))
if (mmkv.getBoolean(AppSettings.PrintContext)) Logger.log(`Completion Output:\n${text}`)
}

Llama.completion(payload, outputStream, outputCompleted)
Llama.useLlama
.getState()
.completion(payload, outputStream, outputCompleted)
.then(() => {
this.stopGenerating()
})
Expand Down
Loading

0 comments on commit a2237e3

Please sign in to comment.