From a4a710fd9ba115d34d54c0bf5b3319bf70308fff Mon Sep 17 00:00:00 2001 From: Vali98 Date: Fri, 27 Dec 2024 19:20:48 +0800 Subject: [PATCH] refactor: moved hook definitions for chat --- app/CharacterEditor.tsx | 5 +- .../CharacterMenu/CharacterListing.tsx | 2 +- .../CharacterMenu/CharacterNewMenu.tsx | 2 +- app/components/ChatMenu/ChatEditPopup.tsx | 11 +-- app/components/ChatMenu/ChatInput.tsx | 12 ++- app/components/ChatMenu/ChatMenu.tsx | 7 +- .../ChatMenu/ChatWindow/ChatBody.tsx | 9 +-- .../ChatMenu/ChatWindow/ChatFrame.tsx | 11 +-- .../ChatMenu/ChatWindow/ChatText.tsx | 10 +-- .../ChatMenu/ChatWindow/ChatTextLast.tsx | 22 +----- .../ChatMenu/ChatWindow/ChatWindow.tsx | 6 +- .../ChatMenu/ChatWindow/EditorModal.tsx | 21 ++---- app/components/ChatMenu/ChatWindow/Swipes.tsx | 37 +++------- app/components/ChatMenu/ChatWindow/TTS.tsx | 13 ++-- app/components/ChatMenu/ChatsDrawer.tsx | 7 +- app/components/ChatMenu/OptionsMenu.tsx | 2 +- constants/API/APIBuilder.ts | 24 +++--- constants/API/ContextBuilder.ts | 12 +-- constants/APIState/BaseAPI.ts | 18 ++--- constants/APIState/HordeAPI.ts | 2 +- constants/APIState/LocalAPI.ts | 8 +- constants/Chat.ts | 73 +++++++++++++++++-- constants/Global.ts | 4 +- constants/Inference.ts | 18 ++--- 24 files changed, 165 insertions(+), 171 deletions(-) diff --git a/app/CharacterEditor.tsx b/app/CharacterEditor.tsx index b75b9cb..5abac71 100644 --- a/app/CharacterEditor.tsx +++ b/app/CharacterEditor.tsx @@ -31,10 +31,7 @@ const ChracterEditor = () => { const getTokenCount = Tokenizer.useTokenizer((state) => state.getTokenCount) const [characterCard, setCharacterCard] = useState(currentCard) - const { chat, unloadChat } = Chats.useChat((state) => ({ - chat: state.data, - unloadChat: state.reset, - })) + const { chat, unloadChat } = Chats.useChat() const setShowViewer = useViewerState((state) => state.setShow) diff --git a/app/components/CharacterMenu/CharacterListing.tsx b/app/components/CharacterMenu/CharacterListing.tsx index 63eb11d..2d4af80 100644 --- a/app/components/CharacterMenu/CharacterListing.tsx +++ b/app/components/CharacterMenu/CharacterListing.tsx @@ -34,7 +34,7 @@ const CharacterListing: React.FC = ({ setCurrentCard: state.setCard, })) - const loadChat = Chats.useChat((state) => state.load) + const { loadChat } = Chats.useChat() const setCurrentCharacter = async (charId: number) => { if (nowLoading) return diff --git a/app/components/CharacterMenu/CharacterNewMenu.tsx b/app/components/CharacterMenu/CharacterNewMenu.tsx index 7a750ba..16a8a75 100644 --- a/app/components/CharacterMenu/CharacterNewMenu.tsx +++ b/app/components/CharacterMenu/CharacterNewMenu.tsx @@ -68,7 +68,7 @@ const CharacterNewMenu: React.FC = ({ id: state.id, })) ) - const loadChat = Chats.useChat((state) => state.load) + const { loadChat } = Chats.useChat() const router = useRouter() const [showNewChar, setShowNewChar] = useState(false) diff --git a/app/components/ChatMenu/ChatEditPopup.tsx b/app/components/ChatMenu/ChatEditPopup.tsx index d023382..ea02aba 100644 --- a/app/components/ChatMenu/ChatEditPopup.tsx +++ b/app/components/ChatMenu/ChatEditPopup.tsx @@ -28,12 +28,7 @@ const ChatEditPopup: React.FC = ({ item, setNowLoading, nowL charName: state.card?.name ?? 'Unknown', })) - const { deleteChat, loadChat, currentChatId, unloadChat } = Chats.useChat((state) => ({ - deleteChat: state.delete, - loadChat: state.load, - currentChatId: state.data?.id, - unloadChat: state.reset, - })) + const { deleteChat, loadChat, chatId, unloadChat } = Chats.useChat() const handleDeleteChat = (menuRef: MenuRef) => { Alert.alert({ @@ -45,13 +40,13 @@ const ChatEditPopup: React.FC = ({ item, setNowLoading, nowL label: 'Delete Chat', onPress: async () => { await deleteChat(item.id) - if (charId && currentChatId === item.id) { + if (charId && chatId === item.id) { const returnedChatId = await Chats.db.query.chatNewestId(charId) const chatId = returnedChatId ? returnedChatId : await Chats.db.mutate.createChat(charId) chatId && (await loadChat(chatId)) - } else if (item.id === currentChatId) { + } else if (item.id === chatId) { Logger.log(`Something went wrong with creating a default chat`, true) unloadChat() } diff --git a/app/components/ChatMenu/ChatInput.tsx b/app/components/ChatMenu/ChatInput.tsx index c83f8c6..79cb29b 100644 --- a/app/components/ChatMenu/ChatInput.tsx +++ b/app/components/ChatMenu/ChatInput.tsx @@ -1,18 +1,16 @@ import { MaterialIcons } from '@expo/vector-icons' -import { AppSettings, Characters, Chats, Logger, Style } from 'constants/Global' import { useInference } from 'constants/Chat' +import { AppSettings, Characters, Chats, Logger, Style } from 'constants/Global' import { generateResponse } from 'constants/Inference' import React, { useState } from 'react' -import { View, StyleSheet, TextInput, TouchableOpacity } from 'react-native' +import { StyleSheet, TextInput, TouchableOpacity, View } from 'react-native' import { useMMKVBoolean } from 'react-native-mmkv' import { useShallow } from 'zustand/react/shallow' const ChatInput = () => { const [sendOnEnter, setSendOnEnter] = useMMKVBoolean(AppSettings.SendOnEnter) - const { insertEntry } = Chats.useChat((state) => ({ - insertEntry: state.addEntry, - })) + const { addEntry } = Chats.useEntry() const { nowGenerating, abortFunction } = useInference((state) => ({ nowGenerating: state.nowGenerating, @@ -37,8 +35,8 @@ const ChatInput = () => { } const handleSend = async () => { - if (newMessage.trim() !== '') await insertEntry(userName ?? '', true, newMessage) - const swipeId = await insertEntry(charName ?? '', false, '') + if (newMessage.trim() !== '') await addEntry(userName ?? '', true, newMessage) + const swipeId = await addEntry(charName ?? '', false, '') setNewMessage((message) => '') if (swipeId) generateResponse(swipeId) } diff --git a/app/components/ChatMenu/ChatMenu.tsx b/app/components/ChatMenu/ChatMenu.tsx index 1bd8d4b..49620c6 100644 --- a/app/components/ChatMenu/ChatMenu.tsx +++ b/app/components/ChatMenu/ChatMenu.tsx @@ -23,12 +23,7 @@ const ChatMenu = () => { })) ) - const { chat, unloadChat } = Chats.useChat( - useShallow((state) => ({ - chat: state?.data?.id, - unloadChat: state.reset, - })) - ) + const { chat, unloadChat } = Chats.useChat() const [showDrawer, setShowDrawer] = useState(false) const [showChats, setShowChats] = useState(false) diff --git a/app/components/ChatMenu/ChatWindow/ChatBody.tsx b/app/components/ChatMenu/ChatWindow/ChatBody.tsx index e4fad62..40e1f6d 100644 --- a/app/components/ChatMenu/ChatWindow/ChatBody.tsx +++ b/app/components/ChatMenu/ChatWindow/ChatBody.tsx @@ -1,7 +1,6 @@ import { Chats, Style } from 'constants/Global' import React, { useState } from 'react' -import { View, TouchableOpacity, StyleSheet } from 'react-native' -import { useShallow } from 'zustand/react/shallow' +import { StyleSheet, TouchableOpacity, View } from 'react-native' import ChatText from './ChatText' import ChatTextLast from './ChatTextLast' @@ -16,11 +15,7 @@ type ChatTextProps = { } const ChatBody: React.FC = ({ id, nowGenerating, isLastMessage, isGreeting }) => { - const { message } = Chats.useChat( - useShallow((state) => ({ - message: state?.data?.messages?.[id] ?? Chats.dummyEntry, - })) - ) + const message = Chats.useEntryData(id) const [editMode, setEditMode] = useState(false) const handleEnableEdit = () => { diff --git a/app/components/ChatMenu/ChatWindow/ChatFrame.tsx b/app/components/ChatMenu/ChatWindow/ChatFrame.tsx index e8edd63..ac48a28 100644 --- a/app/components/ChatMenu/ChatWindow/ChatFrame.tsx +++ b/app/components/ChatMenu/ChatWindow/ChatFrame.tsx @@ -1,11 +1,10 @@ import Avatar from '@components/Avatar' import { useViewerState } from 'constants/AvatarViewer' -import { Characters, Global, Style } from 'constants/Global' import { Chats } from 'constants/Chat' +import { Characters, Global, Style } from 'constants/Global' import { ReactNode } from 'react' -import { View, Text, StyleSheet, TouchableOpacity } from 'react-native' +import { StyleSheet, Text, TouchableOpacity, View } from 'react-native' import { useMMKVBoolean } from 'react-native-mmkv' -import { useShallow } from 'zustand/react/shallow' import TTSMenu from './TTS' @@ -17,11 +16,7 @@ type ChatFrameProps = { } const ChatFrame: React.FC = ({ children, id, nowGenerating, isLast }) => { - const { message } = Chats.useChat( - useShallow((state) => ({ - message: state?.data?.messages?.[id] ?? Chats.dummyEntry, - })) - ) + const message = Chats.useEntryData(id) const setShowViewer = useViewerState((state) => state.setShow) diff --git a/app/components/ChatMenu/ChatWindow/ChatText.tsx b/app/components/ChatMenu/ChatWindow/ChatText.tsx index 6ce96ac..eea0054 100644 --- a/app/components/ChatMenu/ChatWindow/ChatText.tsx +++ b/app/components/ChatMenu/ChatWindow/ChatText.tsx @@ -9,11 +9,7 @@ type ChatTextProps = { } const ChatText: React.FC = ({ nowGenerating, id }) => { - const mes = Chats.useChat( - (state) => - state?.data?.messages?.[id]?.swipes?.[state?.data?.messages?.[id].swipe_id ?? -1] - .swipe ?? '' - ) + const { swipeText } = Chats.useSwipeData(id) const viewRef = useRef(null) const animHeight = useAnimatedValue(-1) @@ -47,7 +43,7 @@ const ChatText: React.FC = ({ nowGenerating, id }) => { return } requestAnimationFrame(() => updateHeight()) - }, [mes]) + }, [swipeText]) return ( @@ -56,7 +52,7 @@ const ChatText: React.FC = ({ nowGenerating, id }) => { markdownit={MarkdownStyle.Rules} rules={MarkdownStyle.RenderRules} style={MarkdownStyle.Styles}> - {mes.trim()} + {swipeText.trim()} diff --git a/app/components/ChatMenu/ChatWindow/ChatTextLast.tsx b/app/components/ChatMenu/ChatWindow/ChatTextLast.tsx index 3bed226..ac25195 100644 --- a/app/components/ChatMenu/ChatWindow/ChatTextLast.tsx +++ b/app/components/ChatMenu/ChatWindow/ChatTextLast.tsx @@ -12,25 +12,11 @@ type ChatTextProps = { } const ChatTextLast: React.FC = ({ nowGenerating, id }) => { - const { mes, swipeId } = Chats.useChat((state) => ({ - mes: - state?.data?.messages?.[id]?.swipes?.[state?.data?.messages?.[id].swipe_id ?? -1] - .swipe ?? '', + const { swipeText, swipeId } = Chats.useSwipeData(id) + const { buffer } = Chats.useBuffer() - swipeId: - state?.data?.messages?.[id]?.swipes?.[state?.data?.messages?.[id].swipe_id ?? -1].id ?? - -1, - })) const viewRef = useRef(null) - const currentSwipeId = useInference((state) => state.currentSwipeId) - - const { buffer } = Chats.useChat( - useShallow((state) => ({ - buffer: state.buffer, - })) - ) - const animHeight = useAnimatedValue(-1) const targetHeight = useRef(-1) const firstRender = useRef(true) @@ -66,7 +52,7 @@ const ChatTextLast: React.FC = ({ nowGenerating, id }) => { return } requestAnimationFrame(() => updateHeight()) - }, [buffer, mes, nowGenerating]) + }, [buffer, swipeText, nowGenerating]) return ( @@ -78,7 +64,7 @@ const ChatTextLast: React.FC = ({ nowGenerating, id }) => { markdownit={MarkdownStyle.Rules} rules={MarkdownStyle.RenderRules} style={MarkdownStyle.Styles}> - {nowGenerating && swipeId === currentSwipeId ? buffer.trim() : mes.trim()} + {nowGenerating && swipeId === currentSwipeId ? buffer.trim() : swipeText.trim()} diff --git a/app/components/ChatMenu/ChatWindow/ChatWindow.tsx b/app/components/ChatMenu/ChatWindow/ChatWindow.tsx index 373f364..f7bd745 100644 --- a/app/components/ChatMenu/ChatWindow/ChatWindow.tsx +++ b/app/components/ChatMenu/ChatWindow/ChatWindow.tsx @@ -12,15 +12,15 @@ type ListItem = { } const ChatWindow = () => { - const data = Chats.useChat((state) => state.data) + const { chat } = Chats.useChat() const [autoScroll, setAutoScroll] = useMMKVBoolean(AppSettings.AutoScroll) - const list: ListItem[] = (data?.messages ?? []) + const list: ListItem[] = (chat?.messages ?? []) .map((item, index) => ({ index: index, key: item.id.toString(), isGreeting: index === 0, - isLastMessage: !!data?.messages && index === data?.messages.length - 1, + isLastMessage: !!chat?.messages && index === chat?.messages.length - 1, })) .reverse() diff --git a/app/components/ChatMenu/ChatWindow/EditorModal.tsx b/app/components/ChatMenu/ChatWindow/EditorModal.tsx index 60471db..78544f0 100644 --- a/app/components/ChatMenu/ChatWindow/EditorModal.tsx +++ b/app/components/ChatMenu/ChatWindow/EditorModal.tsx @@ -47,20 +47,15 @@ type EditorProps = { } const EditorModal: React.FC = ({ id, isLastMessage, setEditMode, editMode }) => { - const { updateChat, deleteChat } = Chats.useChat( - useShallow((state) => ({ - updateChat: state.updateEntry, - deleteChat: state.deleteEntry, - })) - ) - const message = Chats.useChat((state) => state?.data?.messages?.[id]) + const { deleteChat } = Chats.useChat() + const { updateEntry } = Chats.useEntry() + const { swipeText } = Chats.useSwipeData(id) + const entry = Chats.useEntryData(id) - const [placeholderText, setPlaceholderText] = useState( - message?.swipes[message?.swipe_id]?.swipe ?? '' - ) + const [placeholderText, setPlaceholderText] = useState(swipeText) const handleEditMessage = () => { - updateChat(id, placeholderText, false) + updateEntry(id, placeholderText, false) setEditMode(false) } @@ -99,9 +94,9 @@ const EditorModal: React.FC = ({ id, isLastMessage, setEditMode, ed - {message?.name} + {entry?.name} - {message?.swipes[message.swipe_id].send_date.toLocaleTimeString()} + {entry?.swipes[entry.swipe_id].send_date.toLocaleTimeString()} diff --git a/app/components/ChatMenu/ChatWindow/Swipes.tsx b/app/components/ChatMenu/ChatWindow/Swipes.tsx index 78fb63e..4a2a3b5 100644 --- a/app/components/ChatMenu/ChatWindow/Swipes.tsx +++ b/app/components/ChatMenu/ChatWindow/Swipes.tsx @@ -1,10 +1,9 @@ import { AntDesign } from '@expo/vector-icons' -import { Style } from 'constants/Global' import { Chats } from 'constants/Chat' +import { Style } from 'constants/Global' import { continueResponse, generateResponse, regenerateResponse } from 'constants/Inference' import React from 'react' -import { View, Text, StyleSheet, TouchableHighlight } from 'react-native' -import { useShallow } from 'zustand/react/shallow' +import { StyleSheet, Text, TouchableHighlight, View } from 'react-native' type SwipesProps = { nowGenerating: boolean @@ -13,16 +12,8 @@ type SwipesProps = { } const Swipes: React.FC = ({ nowGenerating, isGreeting, index }) => { - const { swipeChat, addSwipe } = Chats.useChat( - useShallow((state) => ({ - swipeChat: state.swipe, - addSwipe: state.addSwipe, - })) - ) - - const { message } = Chats.useChat((state) => ({ - message: state?.data?.messages?.[index] ?? Chats.dummyEntry, - })) + const { swipeChat, addSwipe } = Chats.useSwipes() + const entry = Chats.useEntryData(index) const handleSwipeLeft = () => { swipeChat(index, -1) @@ -38,19 +29,19 @@ const Swipes: React.FC = ({ nowGenerating, isGreeting, index }) => } } - const isLastAltGreeting = isGreeting && message.swipe_id === message.swipes.length - 1 + const isLastAltGreeting = isGreeting && entry.swipe_id === entry.swipes.length - 1 return ( + disabled={nowGenerating || entry.swipe_id === 0}> = ({ nowGenerating, isGreeting, index }) => {index !== 0 && ( regenerateResponse(message.swipes[message.swipe_id].id)} - onLongPress={() => - regenerateResponse(message.swipes[message.swipe_id].id, false) - } + onPress={() => regenerateResponse(entry.swipes[entry.swipe_id].id)} + onLongPress={() => regenerateResponse(entry.swipes[entry.swipe_id].id, false)} disabled={nowGenerating} style={styles.swipeButton}> = ({ nowGenerating, isGreeting, index }) => )} - {message.swipe_id + 1} / {message.swipes.length} + {entry.swipe_id + 1} / {entry.swipes.length} {index !== 0 && ( continueResponse(message.swipes[message.swipe_id].id)} + onPress={() => continueResponse(entry.swipes[entry.swipe_id].id)} disabled={nowGenerating} style={styles.swipeButton}> = ({ nowGenerating, isGreeting, index }) => handleSwipeRight('')} - onLongPress={() => - handleSwipeRight(message?.swipes?.[message.swipe_id]?.swipe ?? '') - } + onLongPress={() => handleSwipeRight(entry?.swipes?.[entry.swipe_id]?.swipe ?? '')} disabled={nowGenerating || isLastAltGreeting}> = ({ id, isLast }) => { const [start, setStart] = useMMKVBoolean(Global.TTSAutoStart) const nowGenerating = useInference((state) => state.nowGenerating) - const { message } = Chats.useChat((state) => ({ - message: - state.data?.messages?.[id]?.swipes[state.data?.messages?.[id].swipe_id].swipe ?? '', - })) + const { swipeText } = Chats.useSwipeData(id) useEffect(() => { if (nowGenerating && isSpeaking) handleStopSpeaking() @@ -42,14 +39,14 @@ const TTS: React.FC = ({ id, isLast }) => { setIsSpeaking(true) const filter = /([!?.,*"])/ const filteredchunks: string[] = [] - const chunks = message.split(filter) + const chunks = swipeText.split(filter) chunks.forEach((item, index) => { if (!filter.test(item) && item) return filteredchunks.push(item) if (index > 0) filteredchunks[filteredchunks.length - 1] = filteredchunks[filteredchunks.length - 1] + item }) - if (filteredchunks.length === 0) filteredchunks.push(message) + if (filteredchunks.length === 0) filteredchunks.push(swipeText) const cleanedchunks = filteredchunks.map((item) => item.replaceAll(/[*"]/g, '').trim()) Logger.debug('TTS started with ' + cleanedchunks.length + ' chunks') diff --git a/app/components/ChatMenu/ChatsDrawer.tsx b/app/components/ChatMenu/ChatsDrawer.tsx index 3301e94..54a0e8d 100644 --- a/app/components/ChatMenu/ChatsDrawer.tsx +++ b/app/components/ChatMenu/ChatsDrawer.tsx @@ -25,10 +25,7 @@ const ChatsDrawer: React.FC = ({ booleans: [showModal, setShow const [nowLoading, setNowLoading] = useState(false) const { data } = useLiveQuery(Chats.db.query.chatListQuery(charId ?? 0)) - const { loadChat, currentChatId } = Chats.useChat((state) => ({ - loadChat: state.load, - currentChatId: state.data?.id, - })) + const { loadChat, chatId } = Chats.useChat() const handleLoadChat = async (chatId: number) => { await loadChat(chatId) @@ -45,7 +42,7 @@ const ChatsDrawer: React.FC = ({ booleans: [showModal, setShow const renderChat = (item: ListItem, index: number) => { const date = new Date(item.last_modified ?? 0) return ( - + handleLoadChat(item.id)}> diff --git a/app/components/ChatMenu/OptionsMenu.tsx b/app/components/ChatMenu/OptionsMenu.tsx index 1c1105d..92a8d7c 100644 --- a/app/components/ChatMenu/OptionsMenu.tsx +++ b/app/components/ChatMenu/OptionsMenu.tsx @@ -30,7 +30,7 @@ const OptionsMenu: React.FC = ({ menuRef, showChats }) => { unloadCharacter: state.unloadCard, })) - const unloadChat = Chats.useChat((state) => state.reset) + const { unloadChat } = Chats.useChat() const menuoptions: MenuData[] = [ { diff --git a/constants/API/APIBuilder.ts b/constants/API/APIBuilder.ts index c9c22c1..fbd9795 100644 --- a/constants/API/APIBuilder.ts +++ b/constants/API/APIBuilder.ts @@ -18,7 +18,7 @@ export const buildAndSendRequest = async () => { if (!requestValues) { Logger.log(`No Active API`, true) - Chats.useChat.getState().stopGenerating() + Chats.useChatState.getState().stopGenerating() return } @@ -30,7 +30,7 @@ export const buildAndSendRequest = async () => { const config = configs[0] if (!config) { Logger.log(`Configuration "${requestValues?.configName}" found`, true) - Chats.useChat.getState().stopGenerating() + Chats.useChatState.getState().stopGenerating() return } @@ -41,7 +41,7 @@ export const buildAndSendRequest = async () => { if (!payload) { Logger.log('Something Went Wrong With Payload Construction', true) - Chats.useChat.getState().stopGenerating() + Chats.useChatState.getState().stopGenerating() return } @@ -108,7 +108,7 @@ const readableStreamResponse = async ( const closeStream = () => { Logger.debug('Running Close Stream') - Chats.useChat.getState().stopGenerating() + Chats.useChatState.getState().stopGenerating() } useInference.getState().setAbort(async () => { @@ -118,8 +118,8 @@ const readableStreamResponse = async ( sse.setOnEvent((data) => { const text = jsonreader(data) ?? '' - const output = Chats.useChat.getState().buffer + text - Chats.useChat.getState().setBuffer(output.replaceAll(replace, '')) + const output = Chats.useChatState.getState().buffer + text + Chats.useChatState.getState().setBuffer(output.replaceAll(replace, '')) }) sse.setOnError(() => { @@ -167,7 +167,7 @@ const hordeResponse = async ( }).catch((error) => { Logger.log(error) }) - Chats.useChat.getState().stopGenerating() + Chats.useChatState.getState().stopGenerating() }) Logger.log(`Using Horde`) @@ -185,12 +185,12 @@ const hordeResponse = async ( if (request.status === 401) { Logger.log(`Invalid API Key`, true) - Chats.useChat.getState().stopGenerating() + Chats.useChatState.getState().stopGenerating() return } if (request.status !== 202) { Logger.log(`Request failed.`) - Chats.useChat.getState().stopGenerating() + Chats.useChatState.getState().stopGenerating() const body = await request.json() Logger.log(JSON.stringify(body)) for (const e of body.errors) Logger.log(e) @@ -217,7 +217,7 @@ const hordeResponse = async ( if (response.status === 400) { Logger.log(`Response failed.`) - Chats.useChat.getState().stopGenerating() + Chats.useChatState.getState().stopGenerating() Logger.log((await response.json())?.message) return } @@ -234,8 +234,8 @@ const hordeResponse = async ( 'g' ) - Chats.useChat.getState().setBuffer(result.generations[0].text.replaceAll(replace, '')) - Chats.useChat.getState().stopGenerating() + Chats.useChatState.getState().setBuffer(result.generations[0].text.replaceAll(replace, '')) + Chats.useChatState.getState().stopGenerating() } const constructReplaceStrings = (): string[] => { diff --git a/constants/API/ContextBuilder.ts b/constants/API/ContextBuilder.ts index 1c0e5fe..3486898 100644 --- a/constants/API/ContextBuilder.ts +++ b/constants/API/ContextBuilder.ts @@ -1,8 +1,8 @@ import { replaceMacros } from 'constants/Characters' +import { Characters, Chats, Global, Instructs, Logger, mmkv } from 'constants/Global' import { AppMode, AppSettings } from 'constants/GlobalValues' import { Llama } from 'constants/LlamaLocal' import { Tokenizer } from 'constants/Tokenizer' -import { Characters, Chats, Global, Instructs, Logger, mmkv } from 'constants/Global' import { APIConfiguration, APIValues } from './APIBuilder.types' @@ -14,7 +14,7 @@ export const buildTextCompletionContext = (max_length: number) => { ? Llama.useLlama.getState().tokenLength : Tokenizer.useTokenizer.getState().getTokenCount - const messages = [...(Chats.useChat.getState().data?.messages ?? [])] + const messages = [...(Chats.useChatState.getState().data?.messages ?? [])] const currentInstruct = Instructs.useInstruct.getState().replacedMacros() @@ -76,7 +76,7 @@ export const buildTextCompletionContext = (max_length: number) => { // we require lengths for names if use_names is enabled for (const message of messages.reverse()) { - const swipe_len = Chats.useChat.getState().getTokenCount(index) + const swipe_len = Chats.useChatState.getState().getTokenCount(index) const swipe_data = message.swipes[message.swipe_id] /** Accumulate total string length @@ -176,7 +176,7 @@ export const buildChatCompletionContext = ( ? Llama.useLlama.getState().tokenLength : Tokenizer.useTokenizer.getState().getTokenCount - const messages = [...(Chats.useChat.getState().data?.messages ?? [])] + const messages = [...(Chats.useChatState.getState().data?.messages ?? [])] const userCard = { ...Characters.useUserCard.getState().card } const currentCard = { ...Characters.useCharacterCard.getState().card } const currentInstruct = Instructs.useInstruct.getState().replacedMacros() @@ -188,7 +188,7 @@ export const buildChatCompletionContext = ( const userCache = Characters.useUserCard.getState().getCache(charName) const instructCache = Instructs.useInstruct.getState().getCache(charName, userName) - const buffer = Chats.useChat.getState().buffer + const buffer = Chats.useChatState.getState().buffer // Logic here is that if the buffer is empty, this is not a regen, hence can popped if (!buffer) messages.pop() @@ -230,7 +230,7 @@ export const buildChatCompletionContext = ( const name_length = currentInstruct.names ? tokenizer(name_string) : 0 const len = - Chats.useChat.getState().getTokenCount(index) + + Chats.useChatState.getState().getTokenCount(index) + total_length + name_length + timestamp_length diff --git a/constants/APIState/BaseAPI.ts b/constants/APIState/BaseAPI.ts index 20bbffa..bee467a 100644 --- a/constants/APIState/BaseAPI.ts +++ b/constants/APIState/BaseAPI.ts @@ -66,7 +66,7 @@ export abstract class APIBase implements IAPIBase { ? Llama.useLlama.getState().tokenLength : Tokenizer.useTokenizer.getState().getTokenCount - const messages = [...(Chats.useChat.getState().data?.messages ?? [])] + const messages = [...(Chats.useChatState.getState().data?.messages ?? [])] const currentInstruct = Instructs.useInstruct.getState().replacedMacros() @@ -131,7 +131,7 @@ export abstract class APIBase implements IAPIBase { // we require lengths for names if use_names is enabled for (const message of messages.reverse()) { - const swipe_len = Chats.useChat.getState().getTokenCount(index) + const swipe_len = Chats.useChatState.getState().getTokenCount(index) const swipe_data = message.swipes[message.swipe_id] /** Accumulate total string length @@ -233,7 +233,7 @@ export abstract class APIBase implements IAPIBase { ? Llama.useLlama.getState().tokenLength : Tokenizer.useTokenizer.getState().getTokenCount - const messages = [...(Chats.useChat.getState().data?.messages ?? [])] + const messages = [...(Chats.useChatState.getState().data?.messages ?? [])] const userCard = { ...Characters.useUserCard.getState().card } const currentCard = { ...Characters.useCharacterCard.getState().card } const currentInstruct = Instructs.useInstruct.getState().replacedMacros() @@ -245,7 +245,7 @@ export abstract class APIBase implements IAPIBase { const userCache = Characters.useUserCard.getState().getCache(charName) const instructCache = Instructs.useInstruct.getState().getCache(charName, userName) - const buffer = Chats.useChat.getState().buffer + const buffer = Chats.useChatState.getState().buffer // Logic here is that if the buffer is empty, this is not a regen, hence can popped if (!buffer) messages.pop() @@ -285,7 +285,7 @@ export abstract class APIBase implements IAPIBase { const name_length = currentInstruct.names ? tokenizer(name_string) : 0 const len = - Chats.useChat.getState().getTokenCount(index) + + Chats.useChatState.getState().getTokenCount(index) + total_length + name_length + timestamp_length @@ -321,7 +321,7 @@ export abstract class APIBase implements IAPIBase { const closeStream = () => { Logger.debug('Running Close Stream') - Chats.useChat.getState().stopGenerating() + Chats.useChatState.getState().stopGenerating() } useInference.getState().setAbort(async () => { @@ -331,8 +331,8 @@ export abstract class APIBase implements IAPIBase { sse.setOnEvent((data) => { const text = jsonreader(data) ?? '' - const output = Chats.useChat.getState().buffer + text - Chats.useChat.getState().setBuffer(output.replaceAll(replace, '')) + const output = Chats.useChatState.getState().buffer + text + Chats.useChatState.getState().setBuffer(output.replaceAll(replace, '')) }) sse.setOnError(() => { @@ -380,7 +380,7 @@ export abstract class APIBase implements IAPIBase { return mmkv.getString(key) ?? '' } stopGenerating = () => { - Chats.useChat.getState().stopGenerating() + Chats.useChatState.getState().stopGenerating() } } diff --git a/constants/APIState/HordeAPI.ts b/constants/APIState/HordeAPI.ts index b8b84c8..b40a432 100644 --- a/constants/APIState/HordeAPI.ts +++ b/constants/APIState/HordeAPI.ts @@ -180,7 +180,7 @@ class HordeAPI extends APIBase { 'g' ) - Chats.useChat.getState().setBuffer(result.generations[0].text.replaceAll(replace, '')) + Chats.useChatState.getState().setBuffer(result.generations[0].text.replaceAll(replace, '')) this.stopGenerating() } } diff --git a/constants/APIState/LocalAPI.ts b/constants/APIState/LocalAPI.ts index e265503..91bf43d 100644 --- a/constants/APIState/LocalAPI.ts +++ b/constants/APIState/LocalAPI.ts @@ -102,13 +102,13 @@ class LocalAPI extends APIBase { const payload = this.buildPayload() const outputStream = (text: string) => { - const output = Chats.useChat.getState().buffer + text - Chats.useChat.getState().setBuffer(output.replaceAll(replace, '')) + const output = Chats.useChatState.getState().buffer + text + Chats.useChatState.getState().setBuffer(output.replaceAll(replace, '')) } const outputCompleted = (text: string) => { - const regenCache = Chats.useChat.getState().getRegenCache() - Chats.useChat.getState().setBuffer((regenCache + text).replaceAll(replace, '')) + const regenCache = Chats.useChatState.getState().getRegenCache() + Chats.useChatState.getState().setBuffer((regenCache + text).replaceAll(replace, '')) if (mmkv.getBoolean(AppSettings.PrintContext)) Logger.log(`Completion Output:\n${text}`) this.stopGenerating() } diff --git a/constants/Chat.ts b/constants/Chat.ts index a3826af..3bce3c7 100644 --- a/constants/Chat.ts +++ b/constants/Chat.ts @@ -3,6 +3,7 @@ import { chatEntries, chatSwipes, chats } from 'db/schema' import { count, desc, eq, getTableColumns } from 'drizzle-orm' import * as Notifications from 'expo-notifications' import { create } from 'zustand' +import { useShallow } from 'zustand/react/shallow' import { API } from './API' import { Characters } from './Characters' @@ -89,7 +90,7 @@ export const sendGenerateCompleteNotification = async () => { : 'Response Complete' const notificationText = showMessage - ? Chats.useChat.getState().buffer.trim() + ? Chats.useChatState.getState().buffer.trim() : 'ChatterUI has finished a response.' Notifications.setNotificationHandler({ @@ -136,12 +137,13 @@ export const useInference = create((set, get) => ({ })) export namespace Chats { - export const useChat = create((set, get: () => ChatState) => ({ + export const useChatState = create((set, get: () => ChatState) => ({ data: undefined, buffer: '', startGenerating: (swipeId: number) => { useInference.getState().startGenerating(swipeId) }, + // TODO : Replace this function stopGenerating: async () => { const cachedSwipeId = useInference.getState().currentSwipeId Logger.log(`Saving Chat`) @@ -419,11 +421,13 @@ export namespace Chats { } } export namespace mutate { - //TODO : refactor this, the requirement to pull charID is not needed, no error handling either - //TODO : perhaps pull data from DB instead of useCharacterCard, currently reliable BUT may fail in future export const createChat = async (charId: number) => { - const card = { ...Characters.useCharacterCard.getState().card } - const charName = card?.name + const card = await Characters.db.query.card(charId) + if (!card) { + Logger.error('Character does not exist!') + return + } + const charName = card.name return await database.transaction(async (tx) => { if (!card || !charName) return const [{ chatId }, ..._] = await tx @@ -617,6 +621,63 @@ export namespace Chats { } } + export const useEntryData = (index: number) => { + // TODO: Investigate if dummyEntry is dangerous + const entry = useChatState((state) => state?.data?.messages?.[index] ?? dummyEntry) + return entry + } + + export const useSwipes = () => { + const { swipeChat, addSwipe } = Chats.useChatState( + useShallow((state) => ({ + swipeChat: state.swipe, + addSwipe: state.addSwipe, + })) + ) + return { swipeChat, addSwipe } + } + + export const useSwipeData = (index: number) => { + const message = useEntryData(index) + const swipeId = message.swipe_id + const swipe = message.swipes?.[swipeId] + const swipeText = swipe?.swipe + return { swipeId, swipe, swipeText } + } + + export const useChat = () => { + const { loadChat, unloadChat, chat, chatId, deleteChat } = Chats.useChatState( + useShallow((state) => ({ + loadChat: state.load, + unloadChat: state.reset, + chat: state.data, + chatId: state.data?.id, + deleteChat: state.delete, + })) + ) + return { chat, loadChat, unloadChat, deleteChat, chatId } + } + + export const useEntry = () => { + const { addEntry, deleteEntry, updateEntry } = Chats.useChatState( + useShallow((state) => ({ + addEntry: state.addEntry, + deleteEntry: state.deleteEntry, + updateEntry: state.updateEntry, + })) + ) + return { addEntry, deleteEntry, updateEntry } + } + + export const useBuffer = () => { + const { buffer } = Chats.useChatState( + useShallow((state) => ({ + buffer: state.buffer, + })) + ) + return { buffer } + } + export const dummyEntry: ChatEntry = { id: 0, chat_id: -1, diff --git a/constants/Global.ts b/constants/Global.ts index 2ef6096..0177981 100644 --- a/constants/Global.ts +++ b/constants/Global.ts @@ -91,7 +91,7 @@ const loadChatOnInit = async () => { const newestChat = await Chats.db.query.chatNewest() if (!newestChat) return await Characters.useCharacterCard.getState().setCard(newestChat.character_id) - await Chats.useChat.getState().load(newestChat.id) + await Chats.useChatState.getState().load(newestChat.id) } const createDefaultUserData = async () => { @@ -117,7 +117,7 @@ export const unlockScreenOrientation = async () => { export const startupApp = () => { console.log('[APP STARTED]: T1APT') // Only for dev to properly reset - Chats.useChat.getState().reset() + Chats.useChatState.getState().reset() Characters.useCharacterCard.getState().unloadCard() // Resets horde state, may be better if left active diff --git a/constants/Inference.ts b/constants/Inference.ts index 70ce91b..cd3e573 100644 --- a/constants/Inference.ts +++ b/constants/Inference.ts @@ -11,29 +11,29 @@ import { mmkv } from './MMKV' export const regenerateResponse = async (swipeId: number, regenCache: boolean = true) => { const charName = Characters.useCharacterCard.getState().card?.name - const messagesLength = Chats.useChat.getState()?.data?.messages?.length ?? -1 - const message = Chats.useChat.getState()?.data?.messages?.[messagesLength - 1] + const messagesLength = Chats.useChatState.getState()?.data?.messages?.length ?? -1 + const message = Chats.useChatState.getState()?.data?.messages?.[messagesLength - 1] Logger.log('Regenerate Response' + (regenCache ? '' : ' , Resetting Message')) if (message?.is_user) { - await Chats.useChat.getState().addEntry(charName ?? '', true, '') + await Chats.useChatState.getState().addEntry(charName ?? '', true, '') } else if (messagesLength && messagesLength !== 1) { let replacement = '' if (regenCache) replacement = message?.swipes[message.swipe_id].regen_cache ?? '' - else Chats.useChat.getState().resetRegenCache() + else Chats.useChatState.getState().resetRegenCache() - if (replacement) Chats.useChat.getState().setBuffer(replacement) - await Chats.useChat.getState().updateEntry(messagesLength - 1, replacement, true, true) + if (replacement) Chats.useChatState.getState().setBuffer(replacement) + await Chats.useChatState.getState().updateEntry(messagesLength - 1, replacement, true, true) } await generateResponse(swipeId) } export const continueResponse = async (swipeId: number) => { Logger.log(`Continuing Response`) - Chats.useChat.getState().setRegenCache() - Chats.useChat.getState().insertLastToBuffer() + Chats.useChatState.getState().setRegenCache() + Chats.useChatState.getState().insertLastToBuffer() await generateResponse(swipeId) } @@ -59,7 +59,7 @@ export const generateResponse = async (swipeId: number) => { Logger.log('Generation already in progress', true) return } - Chats.useChat.getState().startGenerating(swipeId) + Chats.useChatState.getState().startGenerating(swipeId) Logger.log(`Obtaining response.`) const data = performance.now() const appMode = getString(Global.AppMode)