Skip to content

Commit

Permalink
feat: added generic chat completions
Browse files Browse the repository at this point in the history
  • Loading branch information
Vali-98 committed Aug 8, 2024
1 parent 8fb49a4 commit a8c0708
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 5 deletions.
3 changes: 3 additions & 0 deletions app/APIMenu.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
OpenAI,
Ollama,
Claude,
ChatCompletions,
} from '@components/Endpoint'
import { Global, API, Style } from '@globals'
import { Stack } from 'expo-router'
Expand All @@ -26,6 +27,7 @@ const APIMenu = () => {
{ label: 'Ollama', value: API.OLLAMA },
{ label: 'Text Generation Web UI', value: API.TGWUI },
{ label: 'Text Completions', value: API.COMPLETIONS },
{ label: 'Chat Completions', value: API.CHATCOMPLETIONS },
{ label: 'Horde', value: API.HORDE },
{ label: 'Mancer', value: API.MANCER },
{ label: 'Open Router', value: API.OPENROUTER },
Expand Down Expand Up @@ -80,6 +82,7 @@ const APIMenu = () => {
{APIType === API.OPENAI && <OpenAI />}
{APIType === API.OLLAMA && <Ollama />}
{APIType === API.CLAUDE && <Claude />}
{APIType === API.CHATCOMPLETIONS && <ChatCompletions />}
</ScrollView>
</SafeAreaView>
</AnimatedView>
Expand Down
214 changes: 214 additions & 0 deletions app/components/Endpoint/ChatCompletions.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import { FontAwesome, MaterialIcons } from '@expo/vector-icons'
import { Global, Logger, Style } from '@globals'
import { useEffect, useState } from 'react'
import { View, Text, StyleSheet, TextInput, TouchableOpacity } from 'react-native'
import { Dropdown } from 'react-native-element-dropdown'
import { useMMKVObject, useMMKVString } from 'react-native-mmkv'

import HeartbeatButton from './HeartbeatButton'
import { OpenAIModel } from './OpenAI'

const ChatCompletions = () => {
const [endpoint, setEndpoint] = useMMKVString(Global.ChatCompletionsEndpoint)
const [chatCompletionsKey, setChatCompletionsKey] = useMMKVString(Global.ChatCompletionsKey)
const [chatCompletionsModel, setChatCompletionsModel] = useMMKVObject<OpenAIModel>(
Global.ChatCompletionsModel
)
const [keyInput, setKeyInput] = useState('')

const [modelList, setModelList] = useState<OpenAIModel[]>([])

useEffect(() => {
getModelList()
}, [])

const getModelList = async () => {
if (!endpoint) return

try {
const url = new URL('v1/models', endpoint).toString()

const response = await fetch(url, {
headers: {
accept: 'application/json',
Authorization: `Bearer ${chatCompletionsKey}`,
},
})
if (response.status !== 200) {
Logger.log(`Error with response: ${response.status}`, true)
return
}
const { data } = await response.json()
setModelList(data)
} catch (e) {
setModelList([])
}
}

return (
<View style={styles.mainContainer}>
<Text style={styles.title}>API Key</Text>
<Text style={styles.subtitle}>Key will not be shown</Text>
<View style={{ flexDirection: 'row', alignItems: 'center' }}>
<TextInput
style={styles.input}
value={keyInput}
onChangeText={setKeyInput}
placeholder="Press save to confirm key"
placeholderTextColor={Style.getColor('primary-text2')}
secureTextEntry
/>
<TouchableOpacity
style={styles.button}
onPress={() => {
if (keyInput === '') {
Logger.log('No key entered!', true)
return
}
setChatCompletionsKey(keyInput)
setKeyInput('')
Logger.log('Key saved!', true)
}}>
<FontAwesome name="save" color={Style.getColor('primary-text1')} size={28} />
</TouchableOpacity>
</View>

<Text style={styles.title}>Endpoint</Text>
<Text style={styles.subtitle}>
This endpoint should be cross compatible with many different services. Be sure to
end with ' / '
</Text>
<TextInput
style={styles.input}
value={endpoint}
onChangeText={(value) => {
setEndpoint(value)
}}
placeholder="eg. https://127.0.0.1:5000"
placeholderTextColor={Style.getColor('primary-text2')}
/>

{endpoint && (
<HeartbeatButton
api={endpoint}
headers={{ Authorization: `Bearer ${chatCompletionsKey}` }}
/>
)}

<View style={styles.dropdownContainer}>
<Text style={styles.title}>Models</Text>

<Text style={styles.subtitle}>API Key must be provided to get model list.</Text>

<View style={{ flexDirection: 'row', alignItems: 'center' }}>
<Dropdown
value={chatCompletionsModel}
data={modelList}
labelField="id"
valueField="id"
onChange={(item: OpenAIModel) => {
if (item.id === chatCompletionsModel?.id) return
setChatCompletionsModel(item)
}}
{...Style.drawer.default}
placeholder={
modelList.length === 0 ? 'No Models Available' : 'Select Model'
}
/>
<TouchableOpacity
style={styles.button}
onPress={() => {
getModelList()
}}>
<MaterialIcons
name="refresh"
color={Style.getColor('primary-text1')}
size={28}
/>
</TouchableOpacity>
</View>
</View>
{chatCompletionsModel?.id && (
<View style={styles.modelInfo}>
<Text style={{ ...styles.title, marginBottom: 8 }}>
{chatCompletionsModel.id}
</Text>
<View style={{ flexDirection: 'row' }}>
<View>
<Text style={{ color: Style.getColor('primary-text2') }}>Id</Text>
<Text style={{ color: Style.getColor('primary-text2') }}>Object</Text>
<Text style={{ color: Style.getColor('primary-text2') }}>Created</Text>
<Text style={{ color: Style.getColor('primary-text2') }}>Owned By</Text>
</View>
<View style={{ marginLeft: 8 }}>
<Text style={{ color: Style.getColor('primary-text2') }}>
: {chatCompletionsModel.id}
</Text>
<Text style={{ color: Style.getColor('primary-text2') }}>
: {chatCompletionsModel.object}
</Text>
<Text style={{ color: Style.getColor('primary-text2') }}>
: {chatCompletionsModel.created}
</Text>
<Text style={{ color: Style.getColor('primary-text2') }}>
: {chatCompletionsModel.owned_by}
</Text>
</View>
</View>
</View>
)}
</View>
)
}

export default ChatCompletions

const styles = StyleSheet.create({
mainContainer: {
marginVertical: 16,
paddingVertical: 16,
paddingHorizontal: 20,
},

title: {
paddingTop: 8,
color: Style.getColor('primary-text1'),
fontSize: 16,
},

subtitle: {
color: Style.getColor('primary-text2'),
},

input: {
flex: 1,
color: Style.getColor('primary-text1'),
borderColor: Style.getColor('primary-brand'),
borderWidth: 1,
paddingVertical: 4,
paddingHorizontal: 8,
marginVertical: 8,
borderRadius: 8,
},

button: {
padding: 5,
borderColor: Style.getColor('primary-brand'),
borderWidth: 1,
borderRadius: 4,
marginLeft: 8,
},

dropdownContainer: {
marginTop: 16,
},

modelInfo: {
borderRadius: 8,
backgroundColor: Style.getColor('primary-surface2'),
flex: 1,
paddingHorizontal: 16,
paddingTop: 12,
paddingBottom: 24,
},
})
2 changes: 1 addition & 1 deletion app/components/Endpoint/HeartbeatButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const HeartbeatButton: React.FC<HeartbeatButtonProps> = ({
buttonText = 'Test',
apiFormat = (url: string) => {
try {
const newurl = new URL('/v1/models', api)
const newurl = new URL('v1/models', api)
return newurl.toString()
} catch (e) {
return ''
Expand Down
15 changes: 14 additions & 1 deletion app/components/Endpoint/index.tsx
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ChatCompletions from './ChatCompletions'
import Claude from './Claude'
import Horde from './Horde'
import KAI from './KAI'
Expand All @@ -8,4 +9,16 @@ import OpenAI from './OpenAI'
import OpenRouter from './OpenRouter'
import TGWUI from './TGWUI'
import TextCompletions from './TextCompletions'
export { Horde, KAI, TGWUI, Mancer, TextCompletions, Local, OpenRouter, OpenAI, Ollama, Claude }
export {
Horde,
KAI,
TGWUI,
Mancer,
TextCompletions,
Local,
OpenRouter,
OpenAI,
Ollama,
Claude,
ChatCompletions,
}
1 change: 1 addition & 0 deletions app/constants/API.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ export const enum API {
OPENAI = 'openai',
OLLAMA = 'ollama',
CLAUDE = 'claude',
CHATCOMPLETIONS = 'chatcompletions',
}
17 changes: 16 additions & 1 deletion app/constants/APIState/BaseAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ export abstract class APIBase implements IAPIBase {
}

buildChatCompletionContext = (max_length: number) => {
const tokenizer =
mmkv.getString(Global.APIType) === API.LOCAL
? Llama.useLlama.getState().tokenLength
: Tokenizer.useTokenizer.getState().getTokenCount

const messages = [...(Chats.useChat.getState().data?.messages ?? [])]
const userCard = { ...Characters.useUserCard.getState().card }
const currentCard = { ...Characters.useCharacterCard.getState().card }
Expand Down Expand Up @@ -226,9 +231,19 @@ export abstract class APIBase implements IAPIBase {

let index = messages.length - 1
for (const message of messages.reverse()) {
const swipe_data = message.swipes[message.swipe_id]
// special case for claude, prefill may be useful!
const timestamp_string = `[${swipe_data.send_date.toString().split(' ')[0]} ${swipe_data.send_date.toLocaleTimeString()}]\n`
const timestamp_length = currentInstruct.timestamp ? tokenizer(timestamp_string) : 0

const name_string = `${message.name} :`
const name_length = currentInstruct.names ? tokenizer(name_string) : 0

const len = Chats.useChat.getState().getTokenCount(index) + total_length
const len =
Chats.useChat.getState().getTokenCount(index) +
total_length +
name_length +
timestamp_length
if (len > max_length) break
messageBuffer.push({
role: message.is_user ? 'user' : 'assistant',
Expand Down
57 changes: 57 additions & 0 deletions app/constants/APIState/ChatCompletionsAPI.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { Logger } from '@constants/Logger'
import { SamplerID } from '@constants/SamplerData'

import { APIBase, APISampler } from './BaseAPI'
import { Global } from '../GlobalValues'

class ChatCompletionsAPI extends APIBase {
samplers: APISampler[] = [
{ externalName: 'max_tokens', samplerID: SamplerID.GENERATED_LENGTH },
{ externalName: 'stream', samplerID: SamplerID.STREAMING },
{ externalName: 'temperature', samplerID: SamplerID.TEMPERATURE },
{ externalName: 'top_p', samplerID: SamplerID.TOP_P },
{ externalName: 'presence_penalty', samplerID: SamplerID.PRESENCE_PENALTY },
{ externalName: 'frequency_penalty', samplerID: SamplerID.FREQUENCY_PENALTY },
{ externalName: 'seed', samplerID: SamplerID.SEED },
]

buildPayload = () => {
const payloadFields = this.getSamplerFields()
const max_length = (payloadFields?.['max_tokens'] ?? 0) as number
const messages = this.buildChatCompletionContext(max_length)
const model = this.getObject(Global.ChatCompletionsModel)

return {
...payloadFields,
messages: messages,
model: model.id,
stop: this.constructStopSequence(),
}
}

inference = async () => {
const endpoint = this.getString(Global.ChatCompletionsEndpoint)
const key = this.getString(Global.ChatCompletionsKey)

Logger.log(`Using endpoint: Chat Completions`)
this.readableStreamResponse(
new URL('v1/chat/completions', endpoint).toString(),
JSON.stringify(this.buildPayload()),
(item) => {
console.log(item)
const output = JSON.parse(item)
return (
output?.choices?.[0]?.text ??
output?.choices?.[0]?.delta?.content ??
output?.content ??
''
)
},
() => {},
{ Authorization: `Bearer ${key}` }
)
}
}

const chatCompletionsAPI = new ChatCompletionsAPI()
export default chatCompletionsAPI
1 change: 1 addition & 0 deletions app/constants/APIState/TextCompletionAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class TextCompletionAPI extends APIBase {
inference = async () => {
const endpoint = this.getString(Global.CompletionsEndpoint)
const key = this.getString(Global.CompletionsKey)

Logger.log(`Using endpoint: Text Completions`)
this.readableStreamResponse(
new URL('/v1/completions', endpoint).toString(),
Expand Down
3 changes: 2 additions & 1 deletion app/constants/APIState/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { useInference } from 'app/constants/Chat'
import { Logger } from 'app/constants/Logger'

import { APIBase } from './BaseAPI'
import chatCompletionsAPI from './ChatCompletionsAPI'
import claudeAPI from './ClaudeAPI'
import hordeAPI from './HordeAPI'
import koboldAPI from './KoboldAPI'
Expand Down Expand Up @@ -37,7 +38,7 @@ export const APIState: Record<API, APIBase> = {
[API.OPENROUTER]: openRouterAPI,
[API.OLLAMA]: ollamaAPI,
[API.CLAUDE]: claudeAPI,

[API.CHATCOMPLETIONS]: chatCompletionsAPI,
//UNIMPLEMENTED
[API.NOVELAI]: unimplementedAPI,
[API.APHRODITE]: unimplementedAPI,
Expand Down
Loading

0 comments on commit a8c0708

Please sign in to comment.