Skip to content

Commit

Permalink
feat: added swipe id verification to generations, cleaned up chat win…
Browse files Browse the repository at this point in the history
…dow props
  • Loading branch information
Vali-98 committed Oct 1, 2024
1 parent d1757cb commit 70e8e03
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 73 deletions.
4 changes: 2 additions & 2 deletions app/components/ChatMenu/ChatInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ const ChatInput = () => {

const handleSend = async () => {
if (newMessage.trim() !== '') await insertEntry(userName ?? '', true, newMessage)
await insertEntry(charName ?? '', false, '')
const swipeId = await insertEntry(charName ?? '', false, '')
setNewMessage((message) => '')
generateResponse()
if (swipeId) generateResponse(swipeId)
}

return (
Expand Down
10 changes: 6 additions & 4 deletions app/components/ChatMenu/ChatMenu.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ const ChatMenu = () => {
}))
)

const { chat, unloadChat } = Chats.useChat((state) => ({
chat: state.data,
unloadChat: state.reset,
}))
const { chat, unloadChat } = Chats.useChat(
useShallow((state) => ({
chat: state?.data?.id,
unloadChat: state.reset,
}))
)

const [showDrawer, setShowDrawer] = useState<boolean>(false)
const [showChats, setShowChats] = useState<boolean>(false)
Expand Down
7 changes: 3 additions & 4 deletions app/components/ChatMenu/ChatWindow/ChatBody.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ import Swipes from './Swipes'
type ChatTextProps = {
id: number
nowGenerating: boolean
messagesLength: number
isLastMessage: boolean
isGreeting: boolean
}

const ChatBody: React.FC<ChatTextProps> = ({ id, nowGenerating, messagesLength }) => {
const isLastMessage = id === messagesLength - 1
const isGreeting = messagesLength === 1
const ChatBody: React.FC<ChatTextProps> = ({ id, nowGenerating, isLastMessage, isGreeting }) => {
const { message } = Chats.useChat(
useShallow((state) => ({
message: state?.data?.messages?.[id] ?? Chats.dummyEntry,
Expand Down
3 changes: 1 addition & 2 deletions app/components/ChatMenu/ChatWindow/ChatFrame.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Chats } from 'app/constants/Chat'
import { Characters, Global, Style } from '@globals'
import { Chats } from 'app/constants/Chat'
import { ReactNode, useEffect, useState } from 'react'
import { View, Text, Image, StyleSheet } from 'react-native'
import { useMMKVBoolean } from 'react-native-mmkv'
Expand All @@ -10,7 +10,6 @@ import TTSMenu from './TTS'
type ChatFrameProps = {
children?: ReactNode
id: number
charId: number
nowGenerating: boolean
isLast?: boolean
}
Expand Down
21 changes: 7 additions & 14 deletions app/components/ChatMenu/ChatWindow/ChatItem.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,22 @@ import ChatFrame from './ChatFrame'

type ChatItemProps = {
id: number
charId: number
messagesLength: number
isLastMessage: boolean
isGreeting: boolean
}

const ChatItem: React.FC<ChatItemProps> = ({ id, charId, messagesLength }) => {
const isLastMessage = id === messagesLength - 1
const ChatItem: React.FC<ChatItemProps> = ({ id, isLastMessage, isGreeting }) => {
const nowGenerating = useInference((state) => state.nowGenerating)

return (
<AnimatedView dy={100} fade={0} fduration={200} tduration={400}>
<View
style={{
...styles.chatItem,
}}>
<ChatFrame
charId={charId}
id={id}
nowGenerating={nowGenerating}
isLast={isLastMessage}>
<View style={styles.chatItem}>
<ChatFrame id={id} nowGenerating={nowGenerating} isLast={isLastMessage}>
<ChatBody
nowGenerating={nowGenerating}
id={id}
messagesLength={messagesLength}
isLastMessage={isLastMessage}
isGreeting={isGreeting}
/>
</ChatFrame>
</View>
Expand Down
17 changes: 10 additions & 7 deletions app/components/ChatMenu/ChatWindow/ChatTextLast.tsx
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { useInference } from '@constants/Chat'
import { Chats, Style, MarkdownStyle } from '@globals'
import React, { useEffect, useRef } from 'react'
import { StyleSheet, Animated, Easing, LayoutChangeEvent } from 'react-native'
Expand All @@ -16,12 +17,12 @@ const ChatTextLast: React.FC<ChatTextProps> = ({ nowGenerating, id }) => {
const animatedHeight = useRef(new Animated.Value(-1)).current
const height = useRef(-1)

const mes = Chats.useChat(
(state) =>
state?.data?.messages?.[id]?.swipes?.[state?.data?.messages?.[id].swipe_id ?? -1]
.swipe ?? ''
const swipe = Chats.useChat(
(state) => state?.data?.messages?.[id]?.swipes?.[state?.data?.messages?.[id].swipe_id ?? -1]
)

const currentSwipeId = useInference((state) => state.currentSwipeId)

const { buffer } = Chats.useChat(
useShallow((state) => ({
buffer: state.buffer,
Expand Down Expand Up @@ -57,7 +58,7 @@ const ChatTextLast: React.FC<ChatTextProps> = ({ nowGenerating, id }) => {
useEffect(() => {
if (!nowGenerating && height.current !== -1) {
handleAnimateHeight(height.current)
} else if (nowGenerating && !mes) {
} else if (nowGenerating && !swipe?.swipe) {
// NOTE: this assumes that mes is empty due to a swipe and may break, but unlikely
height.current = 0
handleAnimateHeight(height.current)
Expand All @@ -70,7 +71,7 @@ const ChatTextLast: React.FC<ChatTextProps> = ({ nowGenerating, id }) => {
height: __DEV__ ? 'auto' : animatedHeight, // dev fix for slow emulator animations
overflow: 'scroll',
}}>
{nowGenerating && buffer === '' && (
{swipe?.id === currentSwipeId && nowGenerating && buffer === '' && (
<AnimatedEllipsis
style={{
color: Style.getColor('primary-text2'),
Expand All @@ -84,7 +85,9 @@ const ChatTextLast: React.FC<ChatTextProps> = ({ nowGenerating, id }) => {
style={styles.messageText}
rules={{ rules: MarkdownStyle.Rules }}
styles={MarkdownStyle.Format}>
{nowGenerating ? buffer.trim() : mes.trim()}
{nowGenerating && swipe?.id === currentSwipeId
? buffer.trim()
: swipe?.swipe.trim()}
</Markdown>
</Animated.View>
)
Expand Down
19 changes: 13 additions & 6 deletions app/components/ChatMenu/ChatWindow/ChatWindow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,31 @@ import { ChatItem } from './ChatItem'
type ListItem = {
index: number
key: string
isLastMessage: boolean
isGreeting: boolean
}

const ChatWindow = () => {
'use no memo'
const charId = Characters.useCharacterCard(useShallow((state) => state?.id))
const messages = Chats.useChat((state) => state.data?.messages)
const messagesLength = messages?.length ?? -1
const data = Chats.useChat((state) => state.data)
const [autoScroll, setAutoScroll] = useMMKVBoolean(AppSettings.AutoScroll)

const list: ListItem[] = (messages ?? [])
const list: ListItem[] = (data?.messages ?? [])
.map((item, index) => ({
index: index,
key: item.id.toString(),
isGreeting: index === 0,
isLastMessage: !!data?.messages && index === data?.messages.length - 1,
}))
.reverse()

const renderItems = ({ item, index }: { item: ListItem; index: number }) => {
return <ChatItem messagesLength={messagesLength} id={item.index} charId={charId ?? -1} />
return (
<ChatItem
id={item.index}
isLastMessage={item.isLastMessage}
isGreeting={item.isGreeting}
/>
)
}

return (
Expand Down
12 changes: 6 additions & 6 deletions app/components/ChatMenu/ChatWindow/Swipes.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Chats } from 'app/constants/Chat'
import { continueResponse, generateResponse, regenerateResponse } from 'app/constants/Inference'
import { AntDesign } from '@expo/vector-icons'
import { Style } from '@globals'
import { Chats } from 'app/constants/Chat'
import { continueResponse, generateResponse, regenerateResponse } from 'app/constants/Inference'
import React from 'react'
import { View, Text, StyleSheet, TouchableHighlight } from 'react-native'
import { useShallow } from 'zustand/react/shallow'
Expand Down Expand Up @@ -31,8 +31,8 @@ const Swipes: React.FC<SwipesProps> = ({ nowGenerating, isGreeting, index }) =>
const handleSwipeRight = async () => {
const atLimit = await swipeChat(index, 1)
if (atLimit && !isGreeting) {
await addSwipe(index)
generateResponse()
const id = await addSwipe(index)
if (id) generateResponse(id)
}
}

Expand All @@ -57,7 +57,7 @@ const Swipes: React.FC<SwipesProps> = ({ nowGenerating, isGreeting, index }) =>

{index !== 0 && (
<TouchableHighlight
onPress={regenerateResponse}
onPress={() => regenerateResponse(message.swipes[message.swipe_id].id)}
disabled={nowGenerating}
style={styles.swipeButton}>
<AntDesign
Expand All @@ -78,7 +78,7 @@ const Swipes: React.FC<SwipesProps> = ({ nowGenerating, isGreeting, index }) =>

{index !== 0 && (
<TouchableHighlight
onPress={continueResponse}
onPress={() => continueResponse(message.swipes[message.swipe_id].id)}
disabled={nowGenerating}
style={styles.swipeButton}>
<AntDesign
Expand Down
68 changes: 51 additions & 17 deletions app/constants/Chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,34 +51,36 @@ export interface ChatState {
buffer: string
load: (chatId: number) => Promise<void>
delete: (chatId: number) => Promise<void>
addEntry: (name: string, is_user: boolean, message: string) => Promise<void>
addEntry: (name: string, is_user: boolean, message: string) => Promise<number | void>
updateEntry: (
index: number,
message: string,
updateFinished?: boolean,
updateStarted?: boolean
updateStarted?: boolean,
verifySwipeId?: number
) => Promise<void>
deleteEntry: (index: number) => Promise<void>
reset: () => void
swipe: (index: number, direction: number) => Promise<boolean>
addSwipe: (index: number) => Promise<void>
addSwipe: (index: number) => Promise<number | void>
getTokenCount: (index: number) => number
setBuffer: (data: string) => void
insertBuffer: (data: string) => void
updateFromBuffer: () => Promise<void>
updateFromBuffer: (cachedSwipeId?: number) => Promise<void>
insertLastToBuffer: () => void
setRegenCache: () => void
getRegenCache: () => string
stopGenerating: () => void
startGenerating: () => void
startGenerating: (swipeId: number) => void
abortFunction: undefined | AbortFunction
setAbortFunction: SetAbortFunction
}

type AbortFunctionType = {
abortFunction: () => void
nowGenerating: boolean
startGenerating: () => void
currentSwipeId?: number
startGenerating: (swipeId: number) => void
stopGenerating: () => void
setAbort: (fn: () => void) => void
}
Expand All @@ -88,25 +90,35 @@ export const useInference = create<AbortFunctionType>((set, get) => ({
get().stopGenerating()
},
nowGenerating: false,
startGenerating: () => set((state) => ({ ...state, nowGenerating: true })),
stopGenerating: () => set((state) => ({ ...state, nowGenerating: false })),
currentSwipeId: undefined,
startGenerating: (swipeId: number) =>
set((state) => ({ ...state, nowGenerating: true, currentSwipeId: swipeId })),
stopGenerating: () =>
set((state) => ({ ...state, nowGenerating: false, currentSwipeId: undefined })),
setAbort: (fn) => {
Logger.debug('Setting abort function')
set((state) => ({ ...state, abortFunction: fn }))
set((state) => ({
...state,
abortFunction: () => {
fn()
get().stopGenerating()
},
}))
},
}))

export namespace Chats {
export const useChat = create<ChatState>((set, get: () => ChatState) => ({
data: undefined,
buffer: '',
startGenerating: () => {
useInference.getState().startGenerating()
startGenerating: (swipeId: number) => {
useInference.getState().startGenerating(swipeId)
},
stopGenerating: async () => {
const cachedSwipeId = useInference.getState().currentSwipeId
useInference.getState().stopGenerating()
Logger.log(`Saving Chat`)
await get().updateFromBuffer()
await get().updateFromBuffer(cachedSwipeId)
get().setBuffer('')

if (mmkv.getBoolean(Global.TTSEnable) && mmkv.getBoolean(Global.TTSAuto)) {
Expand Down Expand Up @@ -154,6 +166,7 @@ export namespace Chats {
...state,
data: state?.data ? { ...state.data, messages: messages } : state.data,
}))
return entry?.swipes[0].id
},
deleteEntry: async (index: number) => {
const messages = get().data?.messages
Expand All @@ -175,18 +188,32 @@ export namespace Chats {
index: number,
message: string,
updateFinished: boolean = true,
updateStarted: boolean = false
updateStarted: boolean = false,
verifySwipeId: number | undefined = undefined
) => {
const messages = get()?.data?.messages
if (!messages) return
const chatSwipeId = messages[index]?.swipes[messages[index].swipe_id].id

let chatSwipeId: number | undefined =
messages[index]?.swipes[messages[index].swipe_id].id
let updateState = true

if (verifySwipeId) {
updateState = verifySwipeId === chatSwipeId
if (!updateState) {
chatSwipeId = verifySwipeId
}
}

if (!chatSwipeId) return

const date = await db.mutate.updateChatSwipe(
chatSwipeId,
message,
updateStarted,
updateFinished
)
if (!updateState) return
messages[index].swipes[messages[index].swipe_id].swipe = message
messages[index].swipes[messages[index].swipe_id].token_count = undefined
if (updateFinished) messages[index].swipes[messages[index].swipe_id].gen_finished = date
Expand Down Expand Up @@ -233,6 +260,7 @@ export namespace Chats {
...state,
data: state?.data ? { ...state.data, messages: messages } : state.data,
}))
return swipe?.id
},

getTokenCount: (index: number) => {
Expand Down Expand Up @@ -261,10 +289,16 @@ export namespace Chats {
insertBuffer: (data: string) =>
set((state: ChatState) => ({ ...state, buffer: state.buffer + data })),

updateFromBuffer: async () => {
updateFromBuffer: async (cachedSwipeId) => {
const index = get().data?.messages?.length
if (!index) return
await get().updateEntry(index - 1, get().buffer)
if (!index) {
// this means there is no chat loaded, we need to update the db anyways
if (cachedSwipeId) {
await db.mutate.updateChatSwipe(cachedSwipeId, get().buffer, false, true)
}
return
}
await get().updateEntry(index - 1, get().buffer, true, false, cachedSwipeId)
},
insertLastToBuffer: () => {
const message = get()?.data?.messages?.at(-1)
Expand Down
Loading

0 comments on commit 70e8e03

Please sign in to comment.