diff --git a/src/sensemaker.ts b/src/sensemaker.ts index 8b05d2c..54a3145 100644 --- a/src/sensemaker.ts +++ b/src/sensemaker.ts @@ -31,6 +31,45 @@ import { getPrompt, hydrateCommentRecord } from "./sensemaker_utils"; import { Type } from "@sinclair/typebox"; import { ModelSettings, Model } from "./models/model"; import { groundSummary } from "./tasks/grounding"; +import { SummaryStats } from "./stats_util"; +import { summaryContainsStats } from "./tasks/stats_checker"; + +/** + * Rerun a function multiple times. + * @param func the function to attempt + * @param isValid checks that the response from func is valid + * @param maxRetries the maximum number of times to retry func + * @param errorMsg the error message to throw + * @param retryDelayMS how long to wait in miliseconds between calls + * @param funcArgs the args for func and isValid + * @returns the valid response from func + */ +/* eslint-disable @typescript-eslint/no-explicit-any */ +async function retryCall( + func: (...args: any[]) => Promise, + isValid: (response: T, ...args: any[]) => boolean, + maxRetries: number, + errorMsg: string, + retryDelayMS: number = RETRY_DELAY_MS, + ...funcArgs: any[] +) { + /* eslint-enable @typescript-eslint/no-explicit-any */ + for (let attempt = 1; attempt <= maxRetries; attempt++) { + try { + const response = await func(...funcArgs); + if (isValid(response, ...funcArgs)) { + return response; + } + console.error(`Attempt ${attempt} failed. Invalid response:`, response); + } catch (error) { + console.error(`Attempt ${attempt} failed:`, error); + } + + console.log(`Retrying in ${retryDelayMS / 1000} seconds`); + await new Promise((resolve) => setTimeout(resolve, retryDelayMS)); + } + throw new Error(`Failed after ${maxRetries} attempts: ${errorMsg}`); +} // Class to make sense of a deliberation. Uses LLMs to learn what topics were discussed and // categorize comments. Then these categorized comments can be used with optional Vote data to @@ -107,13 +146,20 @@ export class Sensemaker { } comments = await this.categorizeComments(comments, true, topics, additionalInstructions); } - - const summary = await summarizeByType( + const summary = await retryCall( + async function (model: Model, summaryStats: SummaryStats): Promise { + return summarizeByType(model, summaryStats, summarizationType, additionalInstructions); + }, + function (summary: string, summaryStats: SummaryStats): boolean { + return summaryContainsStats(summary, summaryStats, summarizationType); + }, + MAX_RETRIES, + "The statistics don't match what's in the summary.", + undefined, this.getModel("summarizationModel"), - comments, - summarizationType, - additionalInstructions + new SummaryStats(comments) ); + return groundSummary(this.getModel("groundingModel"), summary, comments); } @@ -139,23 +185,22 @@ export class Sensemaker { const commentTexts = comments.map((comment) => "```" + comment.text + "```"); // decide which schema to use based on includeSubtopics const schema = Type.Array(includeSubtopics ? NestedTopic : FlatTopic); - for (let attempt = 1; attempt <= MAX_RETRIES; attempt++) { - const response = (await this.getModel("categorizationModel").generateData( - getPrompt(instructions, commentTexts, additionalInstructions), - schema - )) as Topic[]; - if (learnedTopicsValid(response, topics)) { - return response; - } else { - console.warn( - `Learned topics failed validation, attempt ${attempt}. Retrying in ${RETRY_DELAY_MS / 1000} seconds...` - ); - await new Promise((resolve) => setTimeout(resolve, RETRY_DELAY_MS)); - } - } - - throw new Error("Topic modeling failed after multiple retries."); + return retryCall( + async function (model: Model): Promise { + return (await model.generateData( + getPrompt(instructions, commentTexts, additionalInstructions), + schema + )) as Topic[]; + }, + function (response: Topic[]): boolean { + return learnedTopicsValid(response, topics); + }, + MAX_RETRIES, + "Topic modeling failed.", + undefined, + this.getModel("categorizationModel") + ); } /** diff --git a/src/tasks/stats_checker.test.ts b/src/tasks/stats_checker.test.ts new file mode 100644 index 0000000..32f1358 --- /dev/null +++ b/src/tasks/stats_checker.test.ts @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { SummarizationType, VoteTally } from "../types"; +import { SummaryStats } from "../stats_util"; +import { summaryContainsStats } from "./stats_checker"; + +// Has 5 comments and 60 votes. +const TEST_SUMMARY_STATS = new SummaryStats([ + { id: "1", text: "hello", voteTalliesByGroup: { "group 0": new VoteTally(10, 20, 30) } }, + { id: "2", text: "hello" }, + { id: "3", text: "hello" }, + { id: "4", text: "hello" }, + { id: "5", text: "hello" }, +]); + +describe("StatsCheckerTest", () => { + it("should return true for a good summary", () => { + const summary = "There are 60 votes and 5 statements."; + expect( + summaryContainsStats(summary, TEST_SUMMARY_STATS, SummarizationType.VOTE_TALLY) + ).toBeTruthy(); + }); + + it("should return false if missing the right statement count", () => { + const summary = "There are 60 votes and 6 statements."; + expect( + summaryContainsStats(summary, TEST_SUMMARY_STATS, SummarizationType.VOTE_TALLY) + ).toBeFalsy(); + }); + + it("should return false if missing the right vote count", () => { + const summary = "There are 6 votes and 5 statements."; + expect( + summaryContainsStats(summary, TEST_SUMMARY_STATS, SummarizationType.VOTE_TALLY) + ).toBeFalsy(); + }); +}); diff --git a/src/tasks/stats_checker.ts b/src/tasks/stats_checker.ts new file mode 100644 index 0000000..6b54ef1 --- /dev/null +++ b/src/tasks/stats_checker.ts @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Checks a Summary for simple string matches. + +import { SummarizationType } from "../types"; +import { SummaryStats } from "../stats_util"; + +/** + * Checks that a summary contains the numbers from the SummaryStats. + * @param summary the summary to consider + * @param summaryStats the numbers to check that the summary contains + * @param summarizationType the type of summarization done + * @returns true if the summary contains the statistics (not necessarily in the right context) + */ +export function summaryContainsStats( + summary: string, + summaryStats: SummaryStats, + summarizationType: SummarizationType +): boolean { + if (!summary.includes(`${summaryStats.commentCount} statements`)) { + console.error(`Summary does not contain the correct number of total comments from the + deliberation. commentCount=${summaryStats.commentCount} and summary=${summary}`); + return false; + } + + if ( + summarizationType == SummarizationType.VOTE_TALLY && + !summary.includes(`${summaryStats.voteCount} votes`) + ) { + console.error(`Summary does not contain the correct number of total votes from the + deliberation. voteCount=${summaryStats.voteCount} and summary=${summary}`); + return false; + } + + return true; +} diff --git a/src/tasks/summarization.test.ts b/src/tasks/summarization.test.ts index 0e6af2d..e18fd82 100644 --- a/src/tasks/summarization.test.ts +++ b/src/tasks/summarization.test.ts @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +import { SummaryStats } from "../stats_util"; import { formatCommentsWithVotes, getSummarizationInstructions, @@ -60,14 +61,22 @@ const TEST_COMMENTS = [ describe("SummaryTest", () => { it("prompt should include the comment count and the vote count", () => { // Has 2 comments and 55 votes. - expect(getSummarizationInstructions(true, TEST_COMMENTS)).toContain("2 statements"); - expect(getSummarizationInstructions(true, TEST_COMMENTS)).toContain("55 votes"); + expect(getSummarizationInstructions(true, new SummaryStats(TEST_COMMENTS))).toContain( + "2 statements" + ); + expect(getSummarizationInstructions(true, new SummaryStats(TEST_COMMENTS))).toContain( + "55 votes" + ); }); it("prompt shouldn't include votes if groups aren't included", () => { // Has 2 comments and 55 votes. - expect(getSummarizationInstructions(false, TEST_COMMENTS)).toContain("2 statements"); - expect(getSummarizationInstructions(false, TEST_COMMENTS)).not.toContain("55 votes"); + expect(getSummarizationInstructions(false, new SummaryStats(TEST_COMMENTS))).toContain( + "2 statements" + ); + expect(getSummarizationInstructions(false, new SummaryStats(TEST_COMMENTS))).not.toContain( + "55 votes" + ); }); it("should format comments with vote tallies via formatCommentsWithVotes", () => { diff --git a/src/tasks/summarization.ts b/src/tasks/summarization.ts index 8c15083..7aca01b 100644 --- a/src/tasks/summarization.ts +++ b/src/tasks/summarization.ts @@ -19,10 +19,12 @@ import { Comment, SummarizationType } from "../types"; import { getPrompt } from "../sensemaker_utils"; import { SummaryStats, TopicStats } from "../stats_util"; -export function getSummarizationInstructions(includeGroups: boolean, comments: Comment[]): string { +export function getSummarizationInstructions( + includeGroups: boolean, + summaryStats: SummaryStats +): string { // Prepare statistics like vote count and number of comments per topic for injecting in prompt as // well as sorts topics based on count. - const summaryStats = new SummaryStats(comments); const topicStats = summaryStats.getStatsByTopic(); const sortedTopics = _sortTopicsByComments(topicStats); const quantifiedTopics = _quantifyTopicNames(sortedTopics); @@ -93,14 +95,14 @@ ${includeGroups ? "There should be a one-paragraph section describing the voting */ export async function summarizeByType( model: Model, - comments: Comment[], + summaryStats: SummaryStats, summarizationType: SummarizationType, additionalInstructions?: string ): Promise { if (summarizationType === SummarizationType.BASIC) { - return await basicSummarize(comments, model, additionalInstructions); + return await basicSummarize(summaryStats, model, additionalInstructions); } else if (summarizationType === SummarizationType.VOTE_TALLY) { - return await voteTallySummarize(comments, model, additionalInstructions); + return await voteTallySummarize(summaryStats, model, additionalInstructions); } else { throw new TypeError("Unknown Summarization Type."); } @@ -114,13 +116,17 @@ export async function summarizeByType( * @returns: the LLM's summarization. */ export async function basicSummarize( - comments: Comment[], + summaryStats: SummaryStats, model: Model, additionalInstructions?: string ): Promise { - const commentTexts = comments.map((comment) => comment.text); + const commentTexts = summaryStats.comments.map((comment) => comment.text); return await model.generateText( - getPrompt(getSummarizationInstructions(false, comments), commentTexts, additionalInstructions) + getPrompt( + getSummarizationInstructions(false, summaryStats), + commentTexts, + additionalInstructions + ) ); } @@ -144,14 +150,14 @@ export function formatCommentsWithVotes(commentData: Comment[]): string[] { * @returns: the LLM's summarization. */ export async function voteTallySummarize( - comments: Comment[], + summaryStats: SummaryStats, model: Model, additionalInstructions?: string ): Promise { return await model.generateText( getPrompt( - getSummarizationInstructions(true, comments), - formatCommentsWithVotes(comments), + getSummarizationInstructions(true, summaryStats), + formatCommentsWithVotes(summaryStats.comments), additionalInstructions ) );