Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: better "generate more" functionality #891

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions common/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,6 @@ export async function buildPromptParts(

const post = createPostPrompt(opts)

if (opts.continue) {
post.unshift(`${char.name}: ${opts.continue}`)
}

const linesForMemory = [...lines].reverse()
const books: AppSchema.MemoryBook[] = []
if (replyAs.characterBook) books.push(replyAs.characterBook)
Expand Down
25 changes: 24 additions & 1 deletion common/template-parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,30 @@ export async function parseTemplate(
parts.ujb = render(parts.ujb, opts)
}

const ast = parser.parse(template, {}) as PNode[]
let ast = parser.parse(template, {}) as PNode[]

/**
* Continuing the previous message:
* In this case our goal is to end the prompt as close to the
* last message as possible.
*/
if (opts.continue) {
const historyIndex = ast.findIndex(
(node) =>
typeof node !== 'string' &&
((node.kind === 'placeholder' && node.value === 'history') ||
(node.kind === 'each' && node.value === 'history'))
)
if (historyIndex !== -1) {
const node = ast[historyIndex] as PlaceHolder | IteratorNode
// replace iterator with normalized history
if (node.kind === 'each' && node.value === 'history') {
ast[historyIndex] = { kind: 'placeholder', value: 'history' } as PlaceHolder
}
ast = ast.slice(0, historyIndex + 1)
}
}

readInserts(opts, ast)
let output = render(template, opts, ast)
let unusedTokens = 0
Expand Down
9 changes: 9 additions & 0 deletions common/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ export function trimSentence(text: string) {
return index === -1 ? text.trimEnd() : text.slice(0, index + 1).trimEnd()
}

export function concatenateSentence(text: string, next: string) {
if (!text || !next) return `${text}${next}`
if (next.startsWith('\n')) {
return `${text.trimEnd()}\n${next.trimStart()}`
}

return `${text.trimEnd()}${next}`
}

export function slugify(str: string) {
return str
.toLowerCase()
Expand Down
8 changes: 6 additions & 2 deletions srv/api/chat/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { v4 } from 'uuid'
import { Response } from 'express'
import { publishMany } from '../ws/handle'
import { getScenarioEventType } from '/common/scenario'
import { concatenateSentence } from '/common/util'

type GenRequest = UnwrapBody<typeof genValidator>

Expand Down Expand Up @@ -331,7 +332,9 @@ export const generateMessageV2 = handle(async (req, res) => {
return
}

const responseText = body.kind === 'continue' ? `${body.continuing.msg} ${generated}` : generated
const responseText =
body.kind === 'continue' ? concatenateSentence(body.continuing.msg, generated) : generated

const actions: AppSchema.ChatAction[] = []

switch (body.kind) {
Expand Down Expand Up @@ -568,7 +571,8 @@ async function handleGuestGenerate(body: GenRequest, req: AppRequest, res: Respo

if (error) return

const responseText = body.kind === 'continue' ? `${body.continuing.msg} ${generated}` : generated
const responseText =
body.kind === 'continue' ? concatenateSentence(body.continuing.msg, generated) : generated

const characterId = body.kind === 'self' ? undefined : body.replyAs?._id || body.char?._id
const senderId = body.kind === 'self' ? 'anon' : undefined
Expand Down
3 changes: 0 additions & 3 deletions tests/__snapshots__/prompt.spec.js.snap
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ How MainChar speaks: SAMPLECHAT MainChar

<START>
MainChar: FIRST
MainChar: ORIGINAL
MainChar:"
`;

Expand Down Expand Up @@ -99,7 +98,6 @@ Scenario: MAIN MainChar
This is how OtherBot should talk: SAMPLECHAT OtherBot
MainChar: FIRST
ChatOwner: SECOND
MainChar: ORIGINAL
OtherBot:"
`;

Expand Down Expand Up @@ -165,7 +163,6 @@ SAMPLECHAT OtherBot
System: New conversation started. Previous conversations are examples only.
MainChar: FIRST
ChatOwner: SECOND
MainChar: ORIGINAL
OtherBot:"
`;

Expand Down
Loading