Skip to content

Commit

Permalink
add checks to make sure the numbers in the summary match the computed…
Browse files Browse the repository at this point in the history
… numbers

GitOrigin-RevId: 0934d4508823dfcd4d7c7a5b27c353df4d90f4b3
  • Loading branch information
alyssachvasta authored and copybara-github committed Dec 17, 2024
1 parent 1a9f9f0 commit 11fca1f
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 36 deletions.
87 changes: 66 additions & 21 deletions src/sensemaker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,45 @@ import { getPrompt, hydrateCommentRecord } from "./sensemaker_utils";
import { Type } from "@sinclair/typebox";
import { ModelSettings, Model } from "./models/model";
import { groundSummary } from "./tasks/grounding";
import { SummaryStats } from "./stats_util";
import { summaryContainsStats } from "./tasks/stats_checker";

/**
* Rerun a function multiple times.
* @param func the function to attempt
* @param isValid checks that the response from func is valid
* @param maxRetries the maximum number of times to retry func
* @param errorMsg the error message to throw
* @param retryDelayMS how long to wait in miliseconds between calls
* @param funcArgs the args for func and isValid
* @returns the valid response from func
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
async function retryCall<T>(
func: (...args: any[]) => Promise<T>,
isValid: (response: T, ...args: any[]) => boolean,
maxRetries: number,
errorMsg: string,
retryDelayMS: number = RETRY_DELAY_MS,
...funcArgs: any[]
) {
/* eslint-enable @typescript-eslint/no-explicit-any */
for (let attempt = 1; attempt <= maxRetries; attempt++) {
try {
const response = await func(...funcArgs);
if (isValid(response, ...funcArgs)) {
return response;
}
console.error(`Attempt ${attempt} failed. Invalid response:`, response);
} catch (error) {
console.error(`Attempt ${attempt} failed:`, error);
}

console.log(`Retrying in ${retryDelayMS / 1000} seconds`);
await new Promise((resolve) => setTimeout(resolve, retryDelayMS));
}
throw new Error(`Failed after ${maxRetries} attempts: ${errorMsg}`);
}

// Class to make sense of a deliberation. Uses LLMs to learn what topics were discussed and
// categorize comments. Then these categorized comments can be used with optional Vote data to
Expand Down Expand Up @@ -107,13 +146,20 @@ export class Sensemaker {
}
comments = await this.categorizeComments(comments, true, topics, additionalInstructions);
}

const summary = await summarizeByType(
const summary = await retryCall(
async function (model: Model, summaryStats: SummaryStats): Promise<string> {
return summarizeByType(model, summaryStats, summarizationType, additionalInstructions);
},
function (summary: string, summaryStats: SummaryStats): boolean {
return summaryContainsStats(summary, summaryStats, summarizationType);
},
MAX_RETRIES,
"The statistics don't match what's in the summary.",
undefined,
this.getModel("summarizationModel"),
comments,
summarizationType,
additionalInstructions
new SummaryStats(comments)
);

return groundSummary(this.getModel("groundingModel"), summary, comments);
}

Expand All @@ -139,23 +185,22 @@ export class Sensemaker {
const commentTexts = comments.map((comment) => "```" + comment.text + "```");
// decide which schema to use based on includeSubtopics
const schema = Type.Array(includeSubtopics ? NestedTopic : FlatTopic);
for (let attempt = 1; attempt <= MAX_RETRIES; attempt++) {
const response = (await this.getModel("categorizationModel").generateData(
getPrompt(instructions, commentTexts, additionalInstructions),
schema
)) as Topic[];

if (learnedTopicsValid(response, topics)) {
return response;
} else {
console.warn(
`Learned topics failed validation, attempt ${attempt}. Retrying in ${RETRY_DELAY_MS / 1000} seconds...`
);
await new Promise((resolve) => setTimeout(resolve, RETRY_DELAY_MS));
}
}

throw new Error("Topic modeling failed after multiple retries.");
return retryCall(
async function (model: Model): Promise<Topic[]> {
return (await model.generateData(
getPrompt(instructions, commentTexts, additionalInstructions),
schema
)) as Topic[];
},
function (response: Topic[]): boolean {
return learnedTopicsValid(response, topics);
},
MAX_RETRIES,
"Topic modeling failed.",
undefined,
this.getModel("categorizationModel")
);
}

/**
Expand Down
49 changes: 49 additions & 0 deletions src/tasks/stats_checker.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import { SummarizationType, VoteTally } from "../types";
import { SummaryStats } from "../stats_util";
import { summaryContainsStats } from "./stats_checker";

// Has 5 comments and 60 votes.
const TEST_SUMMARY_STATS = new SummaryStats([
{ id: "1", text: "hello", voteTalliesByGroup: { "group 0": new VoteTally(10, 20, 30) } },
{ id: "2", text: "hello" },
{ id: "3", text: "hello" },
{ id: "4", text: "hello" },
{ id: "5", text: "hello" },
]);

describe("StatsCheckerTest", () => {
it("should return true for a good summary", () => {
const summary = "There are 60 votes and 5 statements.";
expect(
summaryContainsStats(summary, TEST_SUMMARY_STATS, SummarizationType.VOTE_TALLY)
).toBeTruthy();
});

it("should return false if missing the right statement count", () => {
const summary = "There are 60 votes and 6 statements.";
expect(
summaryContainsStats(summary, TEST_SUMMARY_STATS, SummarizationType.VOTE_TALLY)
).toBeFalsy();
});

it("should return false if missing the right vote count", () => {
const summary = "There are 6 votes and 5 statements.";
expect(
summaryContainsStats(summary, TEST_SUMMARY_STATS, SummarizationType.VOTE_TALLY)
).toBeFalsy();
});
});
48 changes: 48 additions & 0 deletions src/tasks/stats_checker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Checks a Summary for simple string matches.

import { SummarizationType } from "../types";
import { SummaryStats } from "../stats_util";

/**
* Checks that a summary contains the numbers from the SummaryStats.
* @param summary the summary to consider
* @param summaryStats the numbers to check that the summary contains
* @param summarizationType the type of summarization done
* @returns true if the summary contains the statistics (not necessarily in the right context)
*/
export function summaryContainsStats(
summary: string,
summaryStats: SummaryStats,
summarizationType: SummarizationType
): boolean {
if (!summary.includes(`${summaryStats.commentCount} statements`)) {
console.error(`Summary does not contain the correct number of total comments from the
deliberation. commentCount=${summaryStats.commentCount} and summary=${summary}`);
return false;
}

if (
summarizationType == SummarizationType.VOTE_TALLY &&
!summary.includes(`${summaryStats.voteCount} votes`)
) {
console.error(`Summary does not contain the correct number of total votes from the
deliberation. voteCount=${summaryStats.voteCount} and summary=${summary}`);
return false;
}

return true;
}
17 changes: 13 additions & 4 deletions src/tasks/summarization.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { SummaryStats } from "../stats_util";
import {
formatCommentsWithVotes,
getSummarizationInstructions,
Expand Down Expand Up @@ -60,14 +61,22 @@ const TEST_COMMENTS = [
describe("SummaryTest", () => {
it("prompt should include the comment count and the vote count", () => {
// Has 2 comments and 55 votes.
expect(getSummarizationInstructions(true, TEST_COMMENTS)).toContain("2 statements");
expect(getSummarizationInstructions(true, TEST_COMMENTS)).toContain("55 votes");
expect(getSummarizationInstructions(true, new SummaryStats(TEST_COMMENTS))).toContain(
"2 statements"
);
expect(getSummarizationInstructions(true, new SummaryStats(TEST_COMMENTS))).toContain(
"55 votes"
);
});

it("prompt shouldn't include votes if groups aren't included", () => {
// Has 2 comments and 55 votes.
expect(getSummarizationInstructions(false, TEST_COMMENTS)).toContain("2 statements");
expect(getSummarizationInstructions(false, TEST_COMMENTS)).not.toContain("55 votes");
expect(getSummarizationInstructions(false, new SummaryStats(TEST_COMMENTS))).toContain(
"2 statements"
);
expect(getSummarizationInstructions(false, new SummaryStats(TEST_COMMENTS))).not.toContain(
"55 votes"
);
});

it("should format comments with vote tallies via formatCommentsWithVotes", () => {
Expand Down
28 changes: 17 additions & 11 deletions src/tasks/summarization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ import { Comment, SummarizationType } from "../types";
import { getPrompt } from "../sensemaker_utils";
import { SummaryStats, TopicStats } from "../stats_util";

export function getSummarizationInstructions(includeGroups: boolean, comments: Comment[]): string {
export function getSummarizationInstructions(
includeGroups: boolean,
summaryStats: SummaryStats
): string {
// Prepare statistics like vote count and number of comments per topic for injecting in prompt as
// well as sorts topics based on count.
const summaryStats = new SummaryStats(comments);
const topicStats = summaryStats.getStatsByTopic();
const sortedTopics = _sortTopicsByComments(topicStats);
const quantifiedTopics = _quantifyTopicNames(sortedTopics);
Expand Down Expand Up @@ -93,14 +95,14 @@ ${includeGroups ? "There should be a one-paragraph section describing the voting
*/
export async function summarizeByType(
model: Model,
comments: Comment[],
summaryStats: SummaryStats,
summarizationType: SummarizationType,
additionalInstructions?: string
): Promise<string> {
if (summarizationType === SummarizationType.BASIC) {
return await basicSummarize(comments, model, additionalInstructions);
return await basicSummarize(summaryStats, model, additionalInstructions);
} else if (summarizationType === SummarizationType.VOTE_TALLY) {
return await voteTallySummarize(comments, model, additionalInstructions);
return await voteTallySummarize(summaryStats, model, additionalInstructions);
} else {
throw new TypeError("Unknown Summarization Type.");
}
Expand All @@ -114,13 +116,17 @@ export async function summarizeByType(
* @returns: the LLM's summarization.
*/
export async function basicSummarize(
comments: Comment[],
summaryStats: SummaryStats,
model: Model,
additionalInstructions?: string
): Promise<string> {
const commentTexts = comments.map((comment) => comment.text);
const commentTexts = summaryStats.comments.map((comment) => comment.text);
return await model.generateText(
getPrompt(getSummarizationInstructions(false, comments), commentTexts, additionalInstructions)
getPrompt(
getSummarizationInstructions(false, summaryStats),
commentTexts,
additionalInstructions
)
);
}

Expand All @@ -144,14 +150,14 @@ export function formatCommentsWithVotes(commentData: Comment[]): string[] {
* @returns: the LLM's summarization.
*/
export async function voteTallySummarize(
comments: Comment[],
summaryStats: SummaryStats,
model: Model,
additionalInstructions?: string
): Promise<string> {
return await model.generateText(
getPrompt(
getSummarizationInstructions(true, comments),
formatCommentsWithVotes(comments),
getSummarizationInstructions(true, summaryStats),
formatCommentsWithVotes(summaryStats.comments),
additionalInstructions
)
);
Expand Down

0 comments on commit 11fca1f

Please sign in to comment.