From ea472bfa5d98c2d1bab3e446e1bfc6368be05bc2 Mon Sep 17 00:00:00 2001 From: Jigsaw Date: Wed, 18 Dec 2024 15:02:40 -0500 Subject: [PATCH] Internal change GitOrigin-RevId: 13e4441438c6b82806ca55bf1c477bb3ff671542 --- src/models/vertex_model.ts | 47 ++++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/src/models/vertex_model.ts b/src/models/vertex_model.ts index 2d24238..3c7dccb 100644 --- a/src/models/vertex_model.ts +++ b/src/models/vertex_model.ts @@ -65,18 +65,36 @@ export class VertexModel extends Model { */ async generateText(prompt: string): Promise { const req = getRequest(prompt); - const streamingResp = await this.getGenerativeModel().generateContentStream(req); - - const response = await streamingResp.response; - if (response.candidates![0].content.parts[0].text) { - const responseText = response.candidates![0].content.parts[0].text; - console.log(`Input token count: ${response.usageMetadata?.promptTokenCount}`); - console.log(`Output token count: ${response.usageMetadata?.candidatesTokenCount}`); - return responseText; - } else { - console.warn("Malformed response: ", response); - throw new Error("Error from Generative Model, response: " + response); - } + const model = this.getGenerativeModel(); + + const response = await retryCall( + // call LLM + async function (request: Request, model: GenerativeModel) { + return (await model.generateContentStream(request)).response; + }, + // Check if the response exists and contains a text field. + function (response): boolean { + if (!response) { + console.error("Failed to get a model response."); + return false; + } + if (!response.candidates![0].content.parts[0].text) { + console.error(`Model returned a malformed response: ${response}`); + return false; + } + return true; + }, + MAX_RETRIES, + "Failed to get a valid model response.", + RETRY_DELAY_MS, + [req, model], // Arguments for the LLM call + [] // Arguments for the validator function + ); + + const responseText = response.candidates![0].content.parts[0].text!; + console.log(`Input token count: ${response.usageMetadata?.promptTokenCount}`); + console.log(`Output token count: ${response.usageMetadata?.candidatesTokenCount}`); + return responseText; } /** @@ -177,6 +195,7 @@ export async function generateJSON(prompt: string, model: GenerativeModel): Prom const req = getRequest(prompt); const response = await retryCall( + // call LLM async function (request: Request) { return (await model.generateContentStream(request)).response; }, @@ -195,8 +214,8 @@ export async function generateJSON(prompt: string, model: GenerativeModel): Prom MAX_RETRIES, "Failed to get a valid model response.", RETRY_DELAY_MS, - [req], - [] + [req], // Arguments for the LLM call + [] // Arguments for the validator function ); const responseText: string = response.candidates![0].content.parts[0].text!;