From ce433fc6941c76e0d48288ff5cdfbf56ba2329e5 Mon Sep 17 00:00:00 2001 From: Thomas Mello Date: Tue, 9 Apr 2024 01:50:44 +0300 Subject: [PATCH] fix: cut prompt to history's last message --- common/prompt.ts | 2 +- common/template-parser.ts | 30 ++++++++++++++++++++---------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/common/prompt.ts b/common/prompt.ts index f64ad3ac5..999eb9ece 100644 --- a/common/prompt.ts +++ b/common/prompt.ts @@ -506,7 +506,7 @@ function createPostPrompt( if (opts.kind === 'chat-query') { post.push(`Query Response:`) - } else if (opts.kind !== 'continue') { + } else { post.push(`${opts.replyAs.name}:`) } diff --git a/common/template-parser.ts b/common/template-parser.ts index 2e6a5cc21..a89d035fb 100644 --- a/common/template-parser.ts +++ b/common/template-parser.ts @@ -138,16 +138,22 @@ export async function parseTemplate( parts.ujb = render(parts.ujb, opts) } - const ast = parser.parse(template, {}) as PNode[] + let ast = parser.parse(template, {}) as PNode[] - // hack: when we continue, remove the post along with the last newline from tree - if (opts.continue && ast.length > 1) { - const last = ast[ast.length - 1] - if (typeof last !== 'string' && last.kind === 'placeholder' && last.value === 'post') { - ast.pop() - } - if (ast[ast.length - 1] === '\n') { - ast.pop() + /** + * Continuing the previous message: + * In this case our goal is to end the prompt as close to the + * last message as possible. + */ + if (opts.continue) { + const historyIndex = ast.findIndex( + (node) => + typeof node !== 'string' && + ((node.kind === 'placeholder' && node.value === 'history') || + (node.kind === 'each' && node.value === 'history')) + ) + if (historyIndex !== -1) { + ast = ast.slice(0, historyIndex + 1) } } @@ -442,7 +448,7 @@ function renderIterator(holder: IterableHolder, children: CNode[], opts: Templat let i = 0 for (const entity of entities) { let curr = '' - for (const child of children) { + children_loop: for (const child of children) { if (typeof child === 'string') { curr += child continue @@ -468,6 +474,8 @@ function renderIterator(holder: IterableHolder, children: CNode[], opts: Templat case 'history-prop': { const result = renderProp(child, opts, entity, i) if (result) curr += result + // when continuing, cut the first node (last response) to its message + if (opts.continue && i === 0 && isHistory && child.prop === 'message') break children_loop break } @@ -477,6 +485,8 @@ function renderIterator(holder: IterableHolder, children: CNode[], opts: Templat if (!prop) break const result = renderEntityCondition(child.children, opts, entity, i) curr += result + // when continuing, cut the first node (last response) to its message + if (opts.continue && i === 0 && isHistory) break children_loop break } }