Skip to content

Commit

Permalink
feat: updated sampler type checking, added DRY samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
Vali-98 committed Aug 8, 2024
1 parent 3dee6b5 commit a5eb86c
Show file tree
Hide file tree
Showing 17 changed files with 227 additions and 281 deletions.
8 changes: 3 additions & 5 deletions app/SamplerMenu.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import AnimatedView from '@components/AnimatedView'
import { APIState } from 'app/constants/APIState'
import { SamplerPreset } from 'app/constants/Presets'
import { SamplerID, Samplers } from 'app/constants/Samplers'
import { FontAwesome } from '@expo/vector-icons'
import { Global, Presets, saveStringExternal, Logger, Style, API } from '@globals'
import { APIState } from 'app/constants/APIState'
import { Samplers, SamplerPreset } from 'app/constants/SamplerData'
import { Stack } from 'expo-router'
import { useState, useEffect } from 'react'
import {
Expand All @@ -16,7 +15,6 @@ import {
Alert,
} from 'react-native'
import { Dropdown } from 'react-native-element-dropdown'
import { TextInput } from 'react-native-gesture-handler'
import { useMMKVObject, useMMKVString } from 'react-native-mmkv'

import { TextBoxModal, SliderItem, TextBox, CheckboxTitle } from './components'
Expand Down Expand Up @@ -235,7 +233,7 @@ const SamplerMenu = () => {
name={samplerItem.friendlyName}
/>
)
case 'custom':
//case 'custom':
default:
return (
<Text style={styles.warningText}>
Expand Down
3 changes: 1 addition & 2 deletions app/constants/APIState/BaseAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ import { Chats, useInference } from 'app/constants/Chat'
import { InstructType, Instructs } from 'app/constants/Instructs'
import { Logger } from 'app/constants/Logger'
import { mmkv } from 'app/constants/MMKV'
import { SamplerPreset } from 'app/constants/Presets'
import { replaceMacros } from 'app/constants/Utils'
import EventSource from 'react-native-sse'

import { SamplerID, Samplers } from '../Samplers'
import { SamplerID, Samplers, SamplerPreset } from '../SamplerData'

export type APISampler = {
samplerID: SamplerID
Expand Down
2 changes: 1 addition & 1 deletion app/constants/APIState/ClaudeAPI.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Global } from '@constants/GlobalValues'
import { Logger } from 'app/constants/Logger'
import { SamplerID } from 'app/constants/Samplers'
import { SamplerID } from 'app/constants/SamplerData'

import { APIBase, APISampler } from './BaseAPI'

Expand Down
4 changes: 2 additions & 2 deletions app/constants/APIState/HordeAPI.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Chats, useInference } from 'app/constants/Chat'
import { Global } from '@constants/GlobalValues'
import { Chats, useInference } from 'app/constants/Chat'
import { Logger } from 'app/constants/Logger'
import { SamplerID } from 'app/constants/Samplers'
import { SamplerID } from 'app/constants/SamplerData'
import { nativeApplicationVersion } from 'expo-application'

import { APIBase, APISampler } from './BaseAPI'
Expand Down
10 changes: 9 additions & 1 deletion app/constants/APIState/KoboldAPI.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Global } from '@constants/GlobalValues'
import { Logger } from 'app/constants/Logger'
import { SamplerID } from 'app/constants/Samplers'
import { mmkv } from 'app/constants/MMKV'
import { SamplerID } from 'app/constants/SamplerData'
import axios from 'axios'

import { APIBase, APISampler } from './BaseAPI'
Expand All @@ -27,16 +27,24 @@ class KoboldAPI extends APIBase {
{ externalName: 'dynatemp_range', samplerID: SamplerID.DYNATEMP_RANGE },
{ externalName: 'smooth_range', samplerID: SamplerID.SMOOTHING_FACTOR },
{ externalName: 'sampler_seed', samplerID: SamplerID.SEED },
{ externalName: 'dry_multiplier', samplerID: SamplerID.DRY_MULTIPLIER },
{ externalName: 'dry_base', samplerID: SamplerID.DRY_BASE },
{ externalName: 'dry_allowed_length', samplerID: SamplerID.DRY_ALLOWED_LENGTH },
{ externalName: 'dry_sequence_break', samplerID: SamplerID.DRY_SEQUENCE_BREAK },
]
buildPayload = () => {
const payloadFields = this.getSamplerFields()
const length = payloadFields?.['max_context_length']
const dry_sequence_break = payloadFields?.['dry_sequence_break'] as string

const seq_break_array = dry_sequence_break ? dry_sequence_break.split(',') : []

return {
...payloadFields,
samplerOrder: [6, 0, 1, 3, 4, 2, 5],
prompt: this.buildTextCompletionContext(typeof length === 'number' ? length : 0),
stop_sequence: this.constructStopSequence(),
dry_sequence_break: seq_break_array,
}
}
inference = async () => {
Expand Down
2 changes: 1 addition & 1 deletion app/constants/APIState/LocalAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Chats, useInference } from 'app/constants/Chat'
import { Llama, LlamaPreset } from 'app/constants/LlamaLocal'
import { Logger } from 'app/constants/Logger'
import { mmkv } from 'app/constants/MMKV'
import { SamplerID } from 'app/constants/Samplers'
import { SamplerID } from 'app/constants/SamplerData'
import BackgroundService from 'react-native-background-actions'

import { APIBase, APISampler } from './BaseAPI'
Expand Down
2 changes: 1 addition & 1 deletion app/constants/APIState/MancerAPI.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Global } from '@constants/GlobalValues'
import { Logger } from 'app/constants/Logger'
import { SamplerID } from 'app/constants/Samplers'
import { SamplerID } from 'app/constants/SamplerData'

import { APIBase, APISampler } from './BaseAPI'

Expand Down
3 changes: 1 addition & 2 deletions app/constants/APIState/OllamaAPI.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { Global } from '@constants/GlobalValues'
import { Logger } from 'app/constants/Logger'
import { SamplerID } from 'app/constants/Samplers'
import axios from 'axios'
import { SamplerID } from 'app/constants/SamplerData'

import { APIBase, APISampler } from './BaseAPI'

Expand Down
2 changes: 1 addition & 1 deletion app/constants/APIState/OpenAIAPI.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Global } from '@constants/GlobalValues'
import { Logger } from 'app/constants/Logger'
import { SamplerID } from 'app/constants/Samplers'
import { SamplerID } from 'app/constants/SamplerData'

import { APIBase, APISampler } from './BaseAPI'

Expand Down
2 changes: 1 addition & 1 deletion app/constants/APIState/OpenRouterAPI.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Global } from '@constants/GlobalValues'
import { Logger } from 'app/constants/Logger'
import { SamplerID } from 'app/constants/Samplers'
import { SamplerID } from 'app/constants/SamplerData'

import { APIBase, APISampler } from './BaseAPI'

Expand Down
2 changes: 1 addition & 1 deletion app/constants/APIState/TGWUIAPI.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Global } from '@constants/GlobalValues'
import { SamplerID } from 'app/constants/Samplers'
import { SamplerID } from 'app/constants/SamplerData'

import { APIBase, APISampler } from './BaseAPI'

Expand Down
2 changes: 1 addition & 1 deletion app/constants/APIState/TextCompletionAPI.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { OpenAIModel } from '@components/Endpoint/OpenAI'
import { Global } from '@constants/GlobalValues'
import { Logger } from 'app/constants/Logger'
import { SamplerID } from 'app/constants/Samplers'
import { SamplerID } from 'app/constants/SamplerData'

import { APIBase, APISampler } from './BaseAPI'

Expand Down
12 changes: 8 additions & 4 deletions app/constants/Global.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ import { Global, AppSettings } from './GlobalValues'
import { Instructs } from './Instructs'
import { Llama } from './LlamaLocal'
import { Logger } from './Logger'
import { mmkv } from './MMKV'
import { MarkdownStyle } from './Markdown'
import { Presets } from './Presets'
import { RecentEntry, RecentMessages } from './RecentMessages'
import { Style } from './Style'
import { humanizedISO8601DateTime } from './Utils'
import { mmkv } from './MMKV'
export {
mmkv,
Presets,
Expand Down Expand Up @@ -125,7 +125,11 @@ export const startupApp = () => {
// This was in case of initializing new data into Presets, may change with SQL migration
mmkv.set(
Global.PresetData,
Presets.fixPreset(JSON.parse(mmkv.getString(Global.PresetData) ?? '{}'))
Presets.fixPreset(
JSON.parse(mmkv.getString(Global.PresetData) ?? '{}'),
Global.PresetName,
true
)
)

// default horde [0000000000] key is needed
Expand Down Expand Up @@ -166,8 +170,8 @@ export const initializeApp = async () => {
await Presets.getFileList()
.then((files) => {
if (files.length > 0) return
mmkv.set(Global.PresetData, JSON.stringify(Presets.defaultPreset()))
Presets.saveFile('Default', Presets.defaultPreset())
mmkv.set(Global.PresetData, JSON.stringify(Presets.defaultPreset))
Presets.saveFile('Default', Presets.defaultPreset)
Logger.log('Created default Preset')
})
.catch((error) => Logger.log(`Could not generate default Preset. Reason: ${error}`))
Expand Down
Loading

0 comments on commit a5eb86c

Please sign in to comment.