-
-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: added generic chat completions
- Loading branch information
Showing
10 changed files
with
314 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
import { FontAwesome, MaterialIcons } from '@expo/vector-icons' | ||
import { Global, Logger, Style } from '@globals' | ||
import { useEffect, useState } from 'react' | ||
import { View, Text, StyleSheet, TextInput, TouchableOpacity } from 'react-native' | ||
import { Dropdown } from 'react-native-element-dropdown' | ||
import { useMMKVObject, useMMKVString } from 'react-native-mmkv' | ||
|
||
import HeartbeatButton from './HeartbeatButton' | ||
import { OpenAIModel } from './OpenAI' | ||
|
||
const ChatCompletions = () => { | ||
const [endpoint, setEndpoint] = useMMKVString(Global.ChatCompletionsEndpoint) | ||
const [chatCompletionsKey, setChatCompletionsKey] = useMMKVString(Global.ChatCompletionsKey) | ||
const [chatCompletionsModel, setChatCompletionsModel] = useMMKVObject<OpenAIModel>( | ||
Global.ChatCompletionsModel | ||
) | ||
const [keyInput, setKeyInput] = useState('') | ||
|
||
const [modelList, setModelList] = useState<OpenAIModel[]>([]) | ||
|
||
useEffect(() => { | ||
getModelList() | ||
}, []) | ||
|
||
const getModelList = async () => { | ||
if (!endpoint) return | ||
|
||
try { | ||
const url = new URL('v1/models', endpoint).toString() | ||
|
||
const response = await fetch(url, { | ||
headers: { | ||
accept: 'application/json', | ||
Authorization: `Bearer ${chatCompletionsKey}`, | ||
}, | ||
}) | ||
if (response.status !== 200) { | ||
Logger.log(`Error with response: ${response.status}`, true) | ||
return | ||
} | ||
const { data } = await response.json() | ||
setModelList(data) | ||
} catch (e) { | ||
setModelList([]) | ||
} | ||
} | ||
|
||
return ( | ||
<View style={styles.mainContainer}> | ||
<Text style={styles.title}>API Key</Text> | ||
<Text style={styles.subtitle}>Key will not be shown</Text> | ||
<View style={{ flexDirection: 'row', alignItems: 'center' }}> | ||
<TextInput | ||
style={styles.input} | ||
value={keyInput} | ||
onChangeText={setKeyInput} | ||
placeholder="Press save to confirm key" | ||
placeholderTextColor={Style.getColor('primary-text2')} | ||
secureTextEntry | ||
/> | ||
<TouchableOpacity | ||
style={styles.button} | ||
onPress={() => { | ||
if (keyInput === '') { | ||
Logger.log('No key entered!', true) | ||
return | ||
} | ||
setChatCompletionsKey(keyInput) | ||
setKeyInput('') | ||
Logger.log('Key saved!', true) | ||
}}> | ||
<FontAwesome name="save" color={Style.getColor('primary-text1')} size={28} /> | ||
</TouchableOpacity> | ||
</View> | ||
|
||
<Text style={styles.title}>Endpoint</Text> | ||
<Text style={styles.subtitle}> | ||
This endpoint should be cross compatible with many different services. Be sure to | ||
end with ' / ' | ||
</Text> | ||
<TextInput | ||
style={styles.input} | ||
value={endpoint} | ||
onChangeText={(value) => { | ||
setEndpoint(value) | ||
}} | ||
placeholder="eg. https://127.0.0.1:5000" | ||
placeholderTextColor={Style.getColor('primary-text2')} | ||
/> | ||
|
||
{endpoint && ( | ||
<HeartbeatButton | ||
api={endpoint} | ||
headers={{ Authorization: `Bearer ${chatCompletionsKey}` }} | ||
/> | ||
)} | ||
|
||
<View style={styles.dropdownContainer}> | ||
<Text style={styles.title}>Models</Text> | ||
|
||
<Text style={styles.subtitle}>API Key must be provided to get model list.</Text> | ||
|
||
<View style={{ flexDirection: 'row', alignItems: 'center' }}> | ||
<Dropdown | ||
value={chatCompletionsModel} | ||
data={modelList} | ||
labelField="id" | ||
valueField="id" | ||
onChange={(item: OpenAIModel) => { | ||
if (item.id === chatCompletionsModel?.id) return | ||
setChatCompletionsModel(item) | ||
}} | ||
{...Style.drawer.default} | ||
placeholder={ | ||
modelList.length === 0 ? 'No Models Available' : 'Select Model' | ||
} | ||
/> | ||
<TouchableOpacity | ||
style={styles.button} | ||
onPress={() => { | ||
getModelList() | ||
}}> | ||
<MaterialIcons | ||
name="refresh" | ||
color={Style.getColor('primary-text1')} | ||
size={28} | ||
/> | ||
</TouchableOpacity> | ||
</View> | ||
</View> | ||
{chatCompletionsModel?.id && ( | ||
<View style={styles.modelInfo}> | ||
<Text style={{ ...styles.title, marginBottom: 8 }}> | ||
{chatCompletionsModel.id} | ||
</Text> | ||
<View style={{ flexDirection: 'row' }}> | ||
<View> | ||
<Text style={{ color: Style.getColor('primary-text2') }}>Id</Text> | ||
<Text style={{ color: Style.getColor('primary-text2') }}>Object</Text> | ||
<Text style={{ color: Style.getColor('primary-text2') }}>Created</Text> | ||
<Text style={{ color: Style.getColor('primary-text2') }}>Owned By</Text> | ||
</View> | ||
<View style={{ marginLeft: 8 }}> | ||
<Text style={{ color: Style.getColor('primary-text2') }}> | ||
: {chatCompletionsModel.id} | ||
</Text> | ||
<Text style={{ color: Style.getColor('primary-text2') }}> | ||
: {chatCompletionsModel.object} | ||
</Text> | ||
<Text style={{ color: Style.getColor('primary-text2') }}> | ||
: {chatCompletionsModel.created} | ||
</Text> | ||
<Text style={{ color: Style.getColor('primary-text2') }}> | ||
: {chatCompletionsModel.owned_by} | ||
</Text> | ||
</View> | ||
</View> | ||
</View> | ||
)} | ||
</View> | ||
) | ||
} | ||
|
||
export default ChatCompletions | ||
|
||
const styles = StyleSheet.create({ | ||
mainContainer: { | ||
marginVertical: 16, | ||
paddingVertical: 16, | ||
paddingHorizontal: 20, | ||
}, | ||
|
||
title: { | ||
paddingTop: 8, | ||
color: Style.getColor('primary-text1'), | ||
fontSize: 16, | ||
}, | ||
|
||
subtitle: { | ||
color: Style.getColor('primary-text2'), | ||
}, | ||
|
||
input: { | ||
flex: 1, | ||
color: Style.getColor('primary-text1'), | ||
borderColor: Style.getColor('primary-brand'), | ||
borderWidth: 1, | ||
paddingVertical: 4, | ||
paddingHorizontal: 8, | ||
marginVertical: 8, | ||
borderRadius: 8, | ||
}, | ||
|
||
button: { | ||
padding: 5, | ||
borderColor: Style.getColor('primary-brand'), | ||
borderWidth: 1, | ||
borderRadius: 4, | ||
marginLeft: 8, | ||
}, | ||
|
||
dropdownContainer: { | ||
marginTop: 16, | ||
}, | ||
|
||
modelInfo: { | ||
borderRadius: 8, | ||
backgroundColor: Style.getColor('primary-surface2'), | ||
flex: 1, | ||
paddingHorizontal: 16, | ||
paddingTop: 12, | ||
paddingBottom: 24, | ||
}, | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import { Logger } from '@constants/Logger' | ||
import { SamplerID } from '@constants/SamplerData' | ||
|
||
import { APIBase, APISampler } from './BaseAPI' | ||
import { Global } from '../GlobalValues' | ||
|
||
class ChatCompletionsAPI extends APIBase { | ||
samplers: APISampler[] = [ | ||
{ externalName: 'max_tokens', samplerID: SamplerID.GENERATED_LENGTH }, | ||
{ externalName: 'stream', samplerID: SamplerID.STREAMING }, | ||
{ externalName: 'temperature', samplerID: SamplerID.TEMPERATURE }, | ||
{ externalName: 'top_p', samplerID: SamplerID.TOP_P }, | ||
{ externalName: 'presence_penalty', samplerID: SamplerID.PRESENCE_PENALTY }, | ||
{ externalName: 'frequency_penalty', samplerID: SamplerID.FREQUENCY_PENALTY }, | ||
{ externalName: 'seed', samplerID: SamplerID.SEED }, | ||
] | ||
|
||
buildPayload = () => { | ||
const payloadFields = this.getSamplerFields() | ||
const max_length = (payloadFields?.['max_tokens'] ?? 0) as number | ||
const messages = this.buildChatCompletionContext(max_length) | ||
const model = this.getObject(Global.ChatCompletionsModel) | ||
|
||
return { | ||
...payloadFields, | ||
messages: messages, | ||
model: model.id, | ||
stop: this.constructStopSequence(), | ||
} | ||
} | ||
|
||
inference = async () => { | ||
const endpoint = this.getString(Global.ChatCompletionsEndpoint) | ||
const key = this.getString(Global.ChatCompletionsKey) | ||
|
||
Logger.log(`Using endpoint: Chat Completions`) | ||
this.readableStreamResponse( | ||
new URL('v1/chat/completions', endpoint).toString(), | ||
JSON.stringify(this.buildPayload()), | ||
(item) => { | ||
console.log(item) | ||
const output = JSON.parse(item) | ||
return ( | ||
output?.choices?.[0]?.text ?? | ||
output?.choices?.[0]?.delta?.content ?? | ||
output?.content ?? | ||
'' | ||
) | ||
}, | ||
() => {}, | ||
{ Authorization: `Bearer ${key}` } | ||
) | ||
} | ||
} | ||
|
||
const chatCompletionsAPI = new ChatCompletionsAPI() | ||
export default chatCompletionsAPI |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.