Skip to content

Commit

Permalink
Add polling to handle run state updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
lublagg committed Dec 9, 2024
1 parent f929309 commit d0c0ef5
Showing 1 changed file with 62 additions and 38 deletions.
100 changes: 62 additions & 38 deletions src/models/assistant-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,40 +69,69 @@ export const AssistantModel = types

const startRun = flow(function* () {
try {
const run = yield davai.beta.threads.runs.create(self.thread.id, {
const currentRun = yield davai.beta.threads.runs.create(self.thread.id, {
assistant_id: self.assistant.id,
});
transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run created",
content: formatMessage(currentRun),
});

// Wait for run completion and handle responses
let runState = yield davai.beta.threads.runs.retrieve(self.thread.id, run.id);
while (runState.status !== "completed" && runState.status !== "requires_action") {
runState = yield davai.beta.threads.runs.retrieve(self.thread.id, run.id);
}
yield pollRunState(currentRun.id);
} catch (err) {
console.error("Failed to complete run:", err);
transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Failed to complete run",
content: formatMessage(err),
});
}
});

if (runState.status === "requires_action") {
transcriptStore.addMessage(DEBUG_SPEAKER, {description: "User request requires action", content: formatMessage(runState)});
yield handleRequiredAction(runState, run.id);
}
const pollRunState: (currentRunId: string) => Promise<any> = flow(function* (currentRunId) {
let runState = yield davai.beta.threads.runs.retrieve(self.thread.id, currentRunId);
transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Polling run state",
content: formatMessage(runState),
});

while (runState.status !== "completed" && runState.status !== "requires_action") {
yield new Promise((resolve) => setTimeout(resolve, 2000));
runState = yield davai.beta.threads.runs.retrieve(self.thread.id, currentRunId);
transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Polling run state",
content: formatMessage(runState),
});
}

// Get the last assistant message from the messages array
if (runState.status === "requires_action") {
transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run requires action",
content: formatMessage(runState),
});
yield handleRequiredAction(runState, currentRunId);
yield pollRunState(currentRunId);
}

if (runState.status === "completed") {
const messages = yield davai.beta.threads.messages.list(self.thread.id);
transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Updated thread messages list", content: formatMessage(messages)});

const lastMessageForRun = messages.data.filter(
(msg: Message) => msg.run_id === run.id && msg.role === "assistant"
).pop();
const lastMessageForRun = messages.data
.filter((msg: Message) => msg.run_id === currentRunId && msg.role === "assistant")
.pop();

transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run completed, assistant response",
content: formatMessage(lastMessageForRun),
});

const lastMessageContent = lastMessageForRun?.content[0]?.text?.value;
if (lastMessageContent) {
transcriptStore.addMessage(DAVAI_SPEAKER, {content: lastMessageContent});
transcriptStore.addMessage(DAVAI_SPEAKER, { content: lastMessageContent });
} else {
transcriptStore.addMessage(DAVAI_SPEAKER, {content: "I'm sorry, I don't have a response for that."});
transcriptStore.addMessage(DEBUG_SPEAKER, {description: "No content in last message", content: formatMessage(lastMessageForRun)});
transcriptStore.addMessage(DAVAI_SPEAKER, {
content: "I'm sorry, I don't have a response for that.",
});
}

} catch (err) {
console.error("Failed to complete run:", err);
transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Failed to complete run", content: formatMessage(err)});
}
});

Expand All @@ -111,7 +140,13 @@ export const AssistantModel = types
const toolOutputs = runState.required_action?.submit_tool_outputs.tool_calls
? yield Promise.all(
runState.required_action.submit_tool_outputs.tool_calls.map(async (toolCall: any) => {
if (toolCall.function.name === "get_attributes") {
if (toolCall.function.name === "get_data_contexts") {
const dataContextList = await getListOfDataContexts();
const { requestMessage, ...codapResponse } = dataContextList;
transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Request sent to CODAP", content: formatMessage(requestMessage) });
transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Response from CODAP", content: formatMessage(codapResponse) });
return { tool_call_id: toolCall.id, output: JSON.stringify(dataContextList) };
} else if (toolCall.function.name === "get_attributes") {
const { dataset } = JSON.parse(toolCall.function.arguments);
// getting the root collection won't always work. what if a user wants the attributes
// in the Mammals dataset but there is a hierarchy?
Expand All @@ -121,34 +156,23 @@ export const AssistantModel = types
transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Request sent to CODAP", content: formatMessage(requestMessage) });
transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Response from CODAP", content: formatMessage(codapResponse) });
return { tool_call_id: toolCall.id, output: JSON.stringify(attributeListRes) };
} else {
} else if (toolCall.function.name === "create_graph") {
const { dataset, name, xAttribute, yAttribute } = JSON.parse(toolCall.function.arguments);
const { requestMessage, ...codapResponse} = await createGraph(dataset, name, xAttribute, yAttribute);
transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Request sent to CODAP", content: formatMessage(requestMessage) });
transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Response from CODAP", content: formatMessage(codapResponse) });
return { tool_call_id: toolCall.id, output: "Graph created." };
} else {
return { tool_call_id: toolCall.id, output: "Tool call not recognized." };
}
})
)
: [];

if (toolOutputs) {
davai.beta.threads.runs.submitToolOutputsStream(
yield davai.beta.threads.runs.submitToolOutputs(
self.thread.id, runId, { tool_outputs: toolOutputs }
);

const threadMessageList = yield davai.beta.threads.messages.list(self.thread.id);
const threadMessages = threadMessageList.data.map((msg: any) => ({
role: msg.role,
content: msg.content[0].text.value,
}));

yield davai.chat.completions.create({
model: "gpt-4o-mini",
messages: [
...threadMessages
],
});
}
} catch (err) {
console.error(err);
Expand Down

0 comments on commit d0c0ef5

Please sign in to comment.