Skip to content

Commit

Permalink
Update chat types
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesLoder committed Dec 17, 2024
1 parent af28bcd commit c3a6200
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 63 deletions.
16 changes: 7 additions & 9 deletions components/Chat/Response/Response.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ import {
StyledResponseWrapper,
} from "./Response.styled";

import BouncingLoader from "@/components/Shared/BouncingLoader";
import Container from "@/components/Shared/Container";
import ResponseImages from "@/components/Chat/Response/Images";
import ResponseMarkdown from "@/components/Chat/Response/Markdown";
import BouncingLoader from "@/components/Shared/BouncingLoader";
import Container from "@/components/Shared/Container";
import { StreamingMessage } from "@/types/components/chat";
import { Work } from "@nulib/dcapi-types";

interface ChatResponseProps {
conversationRef?: string;
Expand All @@ -35,7 +34,7 @@ const ChatResponse: React.FC<ChatResponseProps> = ({
const { type } = message;

if (type === "token") {
setStreamedMessage((prev) => prev + message?.message);
setStreamedMessage((prev) => prev + message.message);
}

if (type === "answer") {
Expand All @@ -51,19 +50,18 @@ const ChatResponse: React.FC<ChatResponseProps> = ({
}

if (type === "tool_start") {
// @ts-ignore
const { tool, input } = message?.message;
const { tool, input } = message.message;
let interstitialMessage = "";
switch (tool) {
case "discover_fields":
interstitialMessage = "Discovering fields";
break;
case "search":
interstitialMessage = `Searching for: ${input?.query}`;
interstitialMessage = `Searching for: ${input.query}`;
break;
case "aggregate":
console.log(`aggregate input`, input);
interstitialMessage = `Aggregating ${input?.agg_field} by ${input?.term_field} ${input?.term}`;
interstitialMessage = `Aggregating ${input.agg_field} by ${input.term_field} ${input.term}`;
break;
default:
console.warn("Unknown tool_start message", message);
Expand All @@ -87,7 +85,7 @@ const ChatResponse: React.FC<ChatResponseProps> = ({
<>
{prev}
<ResponseImages
works={message?.message as Work[]}
works={message.message}
isStreamingComplete={isStreamingComplete}
/>
</>
Expand Down
136 changes: 82 additions & 54 deletions types/components/chat.ts
Original file line number Diff line number Diff line change
@@ -1,62 +1,90 @@
import { Work } from "@nulib/dcapi-types";

export type MessageTypes =
| "answer"
| "aggregation_result"
| "final"
| "final_message"
| "search_result"
| "start"
| "stop"
| "token"
| "tool_start";

type MessageAggregationResult = {
buckets: [
{
key: string;
doc_count: number;
},
];
doc_count_error_upper_bound: number;
sum_other_doc_count: number;
};

type MessageSearchResult = Array<Work>;

type MessageModel = {
model: string;
};

type MessageTool =
| {
tool: "discover_fields";
input: {};
}
| {
tool: "search";
input: {
query: string;
};
}
| {
tool: "aggregate";
input: { agg_field: string; term_field: string; term: string };
};

type MessageShape =
| string
| MessageAggregationResult
| MessageSearchResult
| MessageModel
| MessageTool;

export type StreamingMessage = {
export type Ref = {
ref: string;
message?: MessageShape;
type: MessageTypes;
};

export type AggregationResultMessage = {
type: "aggregation_result";
message: {
buckets: [
{
key: string;
doc_count: number;
},
];
doc_count_error_upper_bound: number;
sum_other_doc_count: number;
};
};

export type AgentFinalMessage = {
type: "final";
message: string;
};

export type LLMAnswerMessage = {
type: "answer";
message: string;
};

export type LLMFinalMessage = {
type: "final_message";
};

export type LLMTokenMessage = {
type: "token";
message: string;
};

export type LLMStopMessage = {
type: "stop";
};

export type SearchResultMessage = {
type: "search_result";
message: Array<Work>;
};

export type StartMessage = {
type: "start";
message: {
model: string;
};
};

export type ToolStartMessage = {
type: "tool_start";
message:
| {
tool: "discover_fields";
input: {};
}
| {
tool: "search";
input: {
query: string;
};
}
| {
tool: "aggregate";
input: { agg_field: string; term_field: string; term: string };
};
};

export type StreamingMessage = Ref &
(
| AgentFinalMessage
| LLMFinalMessage
| LLMStopMessage
| LLMAnswerMessage
| LLMTokenMessage
| ToolStartMessage
| AggregationResultMessage
| SearchResultMessage
| StartMessage
);

export type ChatConfig = {
auth: string;
endpoint: string;
Expand Down

0 comments on commit c3a6200

Please sign in to comment.