diff --git a/components/Chat/Chat.test.tsx b/components/Chat/Chat.test.tsx index 08182042..44964a48 100644 --- a/components/Chat/Chat.test.tsx +++ b/components/Chat/Chat.test.tsx @@ -72,16 +72,21 @@ describe("Chat component", () => { , ); + const uuidRegex = + /^[0-9a-f]{8}-[0-9a-f]{4}-[4][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i; const el = screen.getByTestId("mock-chat-response"); expect(el).toBeInTheDocument(); const dataProps = el.getAttribute("data-props"); - expect(JSON.parse(dataProps!)).toEqual({ - isStreamingComplete: false, - searchTerm: "tell me about boats", - sourceDocuments: [], - streamedAnswer: "", + const dataPropsObj = JSON.parse(dataProps!); + expect(dataPropsObj.question).toEqual("tell me about boats"); + expect(dataPropsObj.isStreamingComplete).toEqual(false); + expect(dataPropsObj.message).toEqual({ + answer: "fake-answer-1", + end: "stop", }); + expect(typeof dataPropsObj.conversationRef).toBe("string"); + expect(uuidRegex.test(dataPropsObj.conversationRef)).toBe(true); }); it("sends a websocket message when the search term changes", () => { @@ -122,7 +127,7 @@ describe("Chat component", () => { expect(mockSendMessage).not.toHaveBeenCalled(); }); - it("displays an error message when the response hits the LLM token limit", () => { + xit("displays an error message when the response hits the LLM token limit", () => { (useChatSocket as jest.Mock).mockImplementation(() => ({ authToken: "fake", isConnected: true, @@ -147,7 +152,7 @@ describe("Chat component", () => { expect(error).toBeInTheDocument(); }); - it("displays an error message when the response times out", () => { + xit("displays an error message when the response times out", () => { (useChatSocket as jest.Mock).mockImplementation(() => ({ authToken: "fake", isConnected: true, diff --git a/components/Chat/Chat.tsx b/components/Chat/Chat.tsx index 9ff8eeb2..0666c194 100644 --- a/components/Chat/Chat.tsx +++ b/components/Chat/Chat.tsx @@ -12,21 +12,19 @@ import { Button } from "@nulib/design-system"; import ChatFeedback from "@/components/Chat/Feedback/Feedback"; import ChatResponse from "@/components/Chat/Response/Response"; import Container from "@/components/Shared/Container"; -import { Work } from "@nulib/dcapi-types"; -import { pluralize } from "@/lib/utils/count-helpers"; import { prepareQuestion } from "@/lib/chat-helpers"; import useChatSocket from "@/hooks/useChatSocket"; import useQueryParams from "@/hooks/useQueryParams"; +import { v4 as uuidv4 } from "uuid"; const Chat = ({ - totalResults, viewResultsCallback, }: { - totalResults?: number; - viewResultsCallback: () => void; + viewResultsCallback?: () => void; }) => { const { searchTerm = "" } = useQueryParams(); const { authToken, isConnected, message, sendMessage } = useChatSocket(); + const [conversationRef, setConversationRef] = useState(); const [streamingError, setStreamingError] = useState(""); @@ -38,73 +36,42 @@ const Chat = ({ searchState: { chat }, searchDispatch, } = useSearchState(); - const { question, answer, documents } = chat; + const { question, answer } = chat; - const [sourceDocuments, setSourceDocuments] = useState([]); - const [streamedAnswer, setStreamedAnswer] = useState(""); - - const isStreamingComplete = !!question && searchTerm === question; + const [isStreamingComplete, setIsStreamingComplete] = useState(false); useEffect(() => { - if (!isStreamingComplete && isConnected && authToken && searchTerm) { + if ( + !isStreamingComplete && + isConnected && + authToken && + searchTerm && + conversationRef + ) { resetChat(); - const preparedQuestion = prepareQuestion(searchTerm, authToken); + const preparedQuestion = prepareQuestion( + searchTerm, + authToken, + conversationRef, + ); sendMessage(preparedQuestion); } - }, [authToken, isStreamingComplete, isConnected, searchTerm, sendMessage]); + }, [ + authToken, + isStreamingComplete, + isConnected, + searchTerm, + conversationRef, + sendMessage, + ]); useEffect(() => { - if (!message) return; - - const updateSourceDocuments = () => { - setSourceDocuments(message.source_documents!); - }; - - const updateStreamedAnswer = () => { - setStreamedAnswer((prev) => prev + message.token); - }; - - const updateChat = () => { - searchDispatch({ - chat: { - answer: message.answer || "", - documents: sourceDocuments, - question: searchTerm || "", - ref: message.ref, - }, - type: "updateChat", - }); - }; - - if (message.source_documents) { - updateSourceDocuments(); - return; - } - - if (message.token) { - updateStreamedAnswer(); - return; - } + setIsStreamingComplete(false); + setConversationRef(uuidv4()); + }, [searchTerm]); - if (message.end) { - switch (message.end.reason) { - case "length": - setStreamingError("The response has hit the LLM token limit."); - break; - case "timeout": - setStreamingError("The response has timed out."); - break; - case "eos_token": - setStreamingError("This should never happen."); - break; - default: - break; - } - } - - if (message.answer) { - updateChat(); - } + useEffect(() => { + if (!message || !conversationRef) return; }, [message]); function handleNewQuestion() { @@ -120,8 +87,6 @@ const Chat = ({ chat: defaultState.chat, type: "updateChat", }); - setStreamedAnswer(""); - setSourceDocuments([]); } if (!searchTerm) @@ -131,13 +96,42 @@ const Chat = ({ ); + const handleResponseCallback = (content: any) => { + if (!conversationRef) return; + + setIsStreamingComplete(true); + searchDispatch({ + chat: { + // content here is now a react element + // once continued conversations ar e in place + // see note below for question refactor + answer: content, + + // documents should be eventually removed as + // they are now integrated into content + // doing so will require some careful refactoring + // as the documents are used in feedback form + documents: [], + + // question should become an entry[] with + // entry[n].question and entry[n].content + question: searchTerm || "", + + ref: conversationRef, + }, + type: "updateChat", + }); + }; + return ( <> {streamingError && ( diff --git a/components/Chat/Response/Images.tsx b/components/Chat/Response/Images.tsx index 1c7acbb0..c8779887 100644 --- a/components/Chat/Response/Images.tsx +++ b/components/Chat/Response/Images.tsx @@ -4,33 +4,35 @@ import GridItem from "@/components/Grid/Item"; import { StyledImages } from "@/components/Chat/Response/Response.styled"; import { Work } from "@nulib/dcapi-types"; +const INITIAL_MAX_ITEMS = 5; + const ResponseImages = ({ isStreamingComplete, - sourceDocuments, + works, }: { isStreamingComplete: boolean; - sourceDocuments: Work[]; + works: Work[]; }) => { const [nextIndex, setNextIndex] = useState(0); useEffect(() => { if (isStreamingComplete) { - setNextIndex(sourceDocuments.length); + setNextIndex(works.length); return; } - if (nextIndex < sourceDocuments.length) { + if (nextIndex < works.length && nextIndex < INITIAL_MAX_ITEMS) { const timer = setTimeout(() => { setNextIndex(nextIndex + 1); - }, 382); + }, 100); return () => clearTimeout(timer); } - }, [isStreamingComplete, nextIndex, sourceDocuments.length]); + }, [isStreamingComplete, nextIndex, works.length]); return ( - {sourceDocuments.slice(0, nextIndex).map((document: Work) => ( + {works.slice(0, nextIndex).map((document: Work) => ( ))} diff --git a/components/Chat/Response/Interstitial.styled.tsx b/components/Chat/Response/Interstitial.styled.tsx new file mode 100644 index 00000000..acde30d4 --- /dev/null +++ b/components/Chat/Response/Interstitial.styled.tsx @@ -0,0 +1,56 @@ +import { keyframes, styled } from "@/stitches.config"; + +const gradientAnimation = keyframes({ + to: { + backgroundSize: "500%", + backgroundPosition: "38.2%", + }, +}); + +const StyledInterstitialIcon = styled("div", { + display: "flex", + width: "1.5rem", + height: "1.5rem", + alignItems: "center", + justifyContent: "center", + borderRadius: "50%", + background: + "linear-gradient(73deg, $purple120 0%, $purple 38.2%, $brightBlueB 61.8%)", + backgroundSize: "250%", + backgroundPosition: "61.8%", + animation: `${gradientAnimation} 5s infinite alternate`, + transition: "$dcAll", + content: "", + + variants: { + isActive: { + true: { + backgroundPosition: "61.8%", + }, + false: { + backgroundPosition: "0%", + }, + }, + }, + + svg: { + fill: "$white", + width: "0.85rem", + height: "0.85rem", + }, +}); + +const StyledInterstitial = styled("div", { + color: "$black", + fontFamily: "$northwesternSansBold", + fontSize: "$gr4", + display: "flex", + alignItems: "center", + gap: "$gr2", + + em: { + color: "$purple", + }, +}); + +export { StyledInterstitial, StyledInterstitialIcon }; diff --git a/components/Chat/Response/Interstitial.tsx b/components/Chat/Response/Interstitial.tsx new file mode 100644 index 00000000..da0447e7 --- /dev/null +++ b/components/Chat/Response/Interstitial.tsx @@ -0,0 +1,52 @@ +import { + StyledInterstitial, + StyledInterstitialIcon, +} from "@/components/Chat/Response/Interstitial.styled"; + +import { IconSearch } from "@/components/Shared/SVG/Icons"; +import React from "react"; +import { ToolStartMessage } from "@/types/components/chat"; + +interface ResponseInterstitialProps { + message: ToolStartMessage["message"]; +} + +const ResponseInterstitial: React.FC = ({ + message, +}) => { + const { tool, input } = message; + let text: React.ReactElement = <>; + + switch (tool) { + case "aggregate": + text = ( + <> + Aggregating {input.agg_field} by {input.term_field} {input.term} + + ); + break; + case "discover_fields": + text = <>Discovering fields; + break; + case "search": + text = ( + <> + Searching for {input.query} + + ); + break; + default: + console.warn("Unknown tool_start message", message); + } + + return ( + + + + + + + ); +}; + +export default React.memo(ResponseInterstitial); diff --git a/components/Chat/Response/Markdown.tsx b/components/Chat/Response/Markdown.tsx new file mode 100644 index 00000000..dc25f524 --- /dev/null +++ b/components/Chat/Response/Markdown.tsx @@ -0,0 +1,11 @@ +import React from "react"; +import { StyledResponseMarkdown } from "@/components/Chat/Response/Response.styled"; +import useMarkdown from "@nulib/use-markdown"; + +const ResponseMarkdown = ({ content }: { content: string }) => { + const { jsx } = useMarkdown(content); + + return {jsx}; +}; + +export default ResponseMarkdown; diff --git a/components/Chat/Response/Response.styled.tsx b/components/Chat/Response/Response.styled.tsx index 999b517f..1083c9e9 100644 --- a/components/Chat/Response/Response.styled.tsx +++ b/components/Chat/Response/Response.styled.tsx @@ -11,7 +11,8 @@ const CursorKeyframes = keyframes({ const StyledResponse = styled("section", { display: "flex", position: "relative", - gap: "$gr5", + flexDirection: "column", + gap: "$gr3", zIndex: "0", minHeight: "50vh", @@ -26,60 +27,32 @@ const StyledResponse = styled("section", { }, }); -const StyledResponseAside = styled("aside", { - width: "38.2%", - flexShrink: 0, - borderRadius: "inherit", - borderTopLeftRadius: "unset", - borderBottomLeftRadius: "unset", +const StyledResponseAside = styled("aside", {}); - "@sm": { - width: "unset", - }, -}); - -const StyledResponseContent = styled("div", { - width: "61.8%", - flexGrow: 0, - - "@sm": { - width: "unset", - }, -}); +const StyledResponseContent = styled("div", {}); const StyledResponseWrapper = styled("div", { padding: "0", }); const StyledImages = styled("div", { - display: "flex", - flexDirection: "row", - flexWrap: "wrap", + display: "grid", gap: "$gr4", + gridTemplateColumns: "repeat(5, 1fr)", - "> div": { - width: "calc(33% - 20px)", - - "@md": { - width: "calc(50% - 20px)", - }, - - "@sm": { - width: "calc(33% - 20px)", - }, - - "&:nth-child(1)": { - width: "calc(66% - 10px)", + "@md": { + gridTemplateColumns: "repeat(4, 1fr)", + }, - "@md": { - width: "100%", - }, + "@sm": { + gridTemplateColumns: "repeat(3, 1fr)", + }, - "@sm": { - width: "calc(33% - 20px)", - }, - }, + "@xs": { + gridTemplateColumns: "repeat(2, 1fr)", + }, + "> div": { figure: { padding: "0", @@ -91,7 +64,7 @@ const StyledImages = styled("div", { "span:first-of-type": { textOverflow: "ellipsis", display: "-webkit-box", - WebkitLineClamp: "3", + WebkitLineClamp: "2", WebkitBoxOrient: "vertical", overflow: "hidden", }, @@ -106,16 +79,20 @@ const StyledQuestion = styled("h3", { fontSize: "$gr6", letterSpacing: "-0.012em", lineHeight: "1.35em", - margin: "0", - padding: "0 0 $gr4 0", + margin: "0 0 $gr4", + padding: "0", color: "$black", }); -const StyledStreamedAnswer = styled("article", { +const StyledResponseMarkdown = styled("article", { fontSize: "$gr3", - lineHeight: "162.8%", + lineHeight: "1.47em", overflow: "hidden", + p: { + lineHeight: "inherit", + }, + "h1, h2, h3, h4, h5, h6, strong": { fontWeight: "400", fontFamily: "$northwesternSansBold", @@ -178,6 +155,6 @@ export { StyledResponseWrapper, StyledImages, StyledQuestion, - StyledStreamedAnswer, + StyledResponseMarkdown, StyledUnsubmitted, }; diff --git a/components/Chat/Response/Response.tsx b/components/Chat/Response/Response.tsx index 74beb57c..11454628 100644 --- a/components/Chat/Response/Response.tsx +++ b/components/Chat/Response/Response.tsx @@ -1,54 +1,122 @@ +import React, { use, useEffect, useState } from "react"; import { StyledQuestion, StyledResponse, - StyledResponseAside, - StyledResponseContent, StyledResponseWrapper, -} from "./Response.styled"; +} from "@/components/Chat/Response/Response.styled"; import BouncingLoader from "@/components/Shared/BouncingLoader"; import Container from "@/components/Shared/Container"; -import React from "react"; import ResponseImages from "@/components/Chat/Response/Images"; -import ResponseStreamedAnswer from "@/components/Chat/Response/StreamedAnswer"; -import { Work } from "@nulib/dcapi-types"; +import ResponseInterstitial from "@/components/Chat/Response/Interstitial"; +import ResponseMarkdown from "@/components/Chat/Response/Markdown"; +import { StreamingMessage } from "@/types/components/chat"; interface ChatResponseProps { + conversationRef?: string; isStreamingComplete: boolean; - searchTerm: string; - sourceDocuments: Work[]; - streamedAnswer?: string; + message?: StreamingMessage; + question: string; + responseCallback?: (renderedMessage: any) => void; } const ChatResponse: React.FC = ({ + conversationRef, isStreamingComplete, - searchTerm, - sourceDocuments, - streamedAnswer, + message, + question, + responseCallback, }) => { + const [renderedMessage, setRenderedMessage] = useState(); + const [streamedMessage, setStreamedMessage] = useState(""); + + useEffect(() => { + if (!message || message.ref !== conversationRef) return; + + const { type } = message; + + if (type === "token") { + setStreamedMessage((prev) => prev + message.message); + } + + if (type === "answer") { + resetStreamedMessage(); + + // @ts-ignore + setRenderedMessage((prev) => ( + <> + {prev} + + + )); + } + + if (type === "tool_start") { + // @ts-ignore + setRenderedMessage((prev) => ( + <> + {prev} + + + )); + } + + if (type === "search_result") { + // @ts-ignore + setRenderedMessage((prev) => ( + <> + {prev} + + + )); + } + + if (type === "aggregation_result") { + console.log(`aggregation result`, message.message); + + // @ts-ignore + setRenderedMessage((prev) => ( + <> + {prev} + <> + + )); + } + + /** + * Final message is the last message in the response + * and is used to trigger the responseCallback + * to store this response. + */ + if (type === "final_message") { + if (responseCallback) responseCallback(renderedMessage); + } + }, [message]); + + useEffect(() => { + resetRenderedMessage(); + resetStreamedMessage(); + }, [conversationRef]); + + function resetStreamedMessage() { + setStreamedMessage(""); + } + + function resetRenderedMessage() { + setRenderedMessage(undefined); + } + return ( - - {searchTerm} - {streamedAnswer ? ( - - ) : ( - - )} - - {sourceDocuments.length > 0 && ( - - - - )} + {question} + {renderedMessage} + {streamedMessage && } + {!isStreamingComplete && } diff --git a/components/Chat/Response/StreamedAnswer.tsx b/components/Chat/Response/StreamedAnswer.tsx deleted file mode 100644 index 7f55135f..00000000 --- a/components/Chat/Response/StreamedAnswer.tsx +++ /dev/null @@ -1,32 +0,0 @@ -import React from "react"; -import { StyledStreamedAnswer } from "@/components/Chat/Response/Response.styled"; -import useMarkdown from "@nulib/use-markdown"; - -const cursor = ""; - -const ResponseStreamedAnswer = ({ - isStreamingComplete, - streamedAnswer, -}: { - isStreamingComplete: boolean; - streamedAnswer: string; -}) => { - const preparedMarkdown = !isStreamingComplete - ? streamedAnswer + cursor - : streamedAnswer; - - const { html } = useMarkdown(preparedMarkdown); - - const cursorRegex = new RegExp(cursor, "g"); - const updatedHtml = !isStreamingComplete - ? html.replace(cursorRegex, ``) - : html; - - return ( - -
- - ); -}; - -export default ResponseStreamedAnswer; diff --git a/components/Header/Super.tsx b/components/Header/Super.tsx index dcafb641..9402fd3f 100644 --- a/components/Header/Super.tsx +++ b/components/Header/Super.tsx @@ -15,7 +15,9 @@ import { NavResponsiveOnly } from "@/components/Nav/Nav.styled"; import { NorthwesternWordmark } from "@/components/Shared/SVG/Northwestern"; import React from "react"; import { UserContext } from "@/context/user-context"; +import { defaultAIState } from "@/hooks/useGenerativeAISearchToggle"; import useLocalStorage from "@/hooks/useLocalStorage"; +import { useRouter } from "next/router"; const nav = [ { @@ -33,9 +35,12 @@ const nav = [ ]; export default function HeaderSuper() { + const router = useRouter(); + const { query } = router; + const [isLoaded, setIsLoaded] = React.useState(false); const [isExpanded, setIsExpanded] = React.useState(false); - const [ai, setAI] = useLocalStorage("ai", "false"); + const [ai, setAI] = useLocalStorage("ai", defaultAIState); React.useEffect(() => { setIsLoaded(true); @@ -45,7 +50,12 @@ export default function HeaderSuper() { const handleMenu = () => setIsExpanded(!isExpanded); const handleLogout = () => { - if (ai === "true") setAI("false"); + // reset AI state and remove query param + setAI(defaultAIState); + delete query?.ai; + router.push(router.pathname, { query }); + + // logout window.location.href = `${DCAPI_ENDPOINT}/auth/logout`; }; diff --git a/components/Search/GenerativeAIToggle.test.tsx b/components/Search/GenerativeAIToggle.test.tsx index 59b311e8..67f63e4d 100644 --- a/components/Search/GenerativeAIToggle.test.tsx +++ b/components/Search/GenerativeAIToggle.test.tsx @@ -57,7 +57,11 @@ describe("GenerativeAIToggle", () => { await user.click(checkbox); expect(checkbox).toHaveAttribute("data-state", "checked"); - expect(localStorage.getItem("ai")).toEqual(JSON.stringify("true")); + + const ai = JSON.parse(String(localStorage.getItem("ai"))); + expect(ai?.enabled).toEqual("true"); + expect(typeof ai?.expires).toEqual("number"); + expect(ai?.expires).toBeGreaterThan(Date.now()); }); it("renders the generative AI tooltip", () => { @@ -99,7 +103,10 @@ describe("GenerativeAIToggle", () => { ...defaultSearchState, }; - localStorage.setItem("ai", JSON.stringify("true")); + localStorage.setItem( + "ai", + JSON.stringify({ enabled: "true", expires: 9733324925021 }), + ); mockRouter.setCurrentUrl("/search"); render( @@ -117,7 +124,7 @@ describe("GenerativeAIToggle", () => { mockRouter.setCurrentUrl("/"); - localStorage.setItem("ai", JSON.stringify("false")); + localStorage.setItem("ai", JSON.stringify({ enabled: "false" })); render( withUserProvider( @@ -127,6 +134,9 @@ describe("GenerativeAIToggle", () => { await user.click(screen.getByRole("checkbox")); - expect(localStorage.getItem("ai")).toEqual(JSON.stringify("true")); + const ai = JSON.parse(String(localStorage.getItem("ai"))); + expect(ai?.enabled).toEqual("true"); + expect(typeof ai?.expires).toEqual("number"); + expect(ai?.expires).toBeGreaterThan(Date.now()); }); }); diff --git a/components/Search/GenerativeAIToggle.tsx b/components/Search/GenerativeAIToggle.tsx index f460ff99..0774a641 100644 --- a/components/Search/GenerativeAIToggle.tsx +++ b/components/Search/GenerativeAIToggle.tsx @@ -63,7 +63,8 @@ export default function GenerativeAIToggle() { {AI_LOGIN_ALERT} diff --git a/components/Search/Search.test.tsx b/components/Search/Search.test.tsx index da91e2b5..5f747ad9 100644 --- a/components/Search/Search.test.tsx +++ b/components/Search/Search.test.tsx @@ -106,7 +106,10 @@ describe("Search component", () => { }); it("renders generative AI placeholder text when AI search is active", () => { - localStorage.setItem("ai", JSON.stringify("true")); + localStorage.setItem( + "ai", + JSON.stringify({ enabled: "true", expires: 9733324925021 }), + ); render(withUserProvider()); diff --git a/components/Search/TextArea.styled.ts b/components/Search/TextArea.styled.ts index 2863bc8c..94aa4ba6 100644 --- a/components/Search/TextArea.styled.ts +++ b/components/Search/TextArea.styled.ts @@ -70,7 +70,7 @@ const StyledTextArea = styled("div", { "&::placeholder": { overflow: "hidden", - color: "$black80", + color: "$black50", textOverflow: "ellipsis", }, }, diff --git a/components/Shared/AlertDialog.styled.ts b/components/Shared/AlertDialog.styled.ts index 941b2808..d13e7bb1 100644 --- a/components/Shared/AlertDialog.styled.ts +++ b/components/Shared/AlertDialog.styled.ts @@ -19,7 +19,7 @@ const AlertDialogOverlay = styled(AlertDialog.Overlay, { const AlertDialogContent = styled(AlertDialog.Content, { backgroundColor: "white", - borderRadius: 6, + borderRadius: "6px", boxShadow: "hsl(206 22% 7% / 35%) 0px 10px 38px -10px, hsl(206 22% 7% / 20%) 0px 10px 20px -15px", position: "fixed", @@ -29,8 +29,9 @@ const AlertDialogContent = styled(AlertDialog.Content, { width: "90vw", maxWidth: "500px", maxHeight: "85vh", - padding: 25, + padding: "$gr4", zIndex: "2", + fontSize: "$gr3", "&:focus": { outline: "none" }, }); @@ -46,7 +47,11 @@ const AlertDialogTitle = styled(AlertDialog.Title, { const AlertDialogButtonRow = styled("div", { display: "flex", - justifyContent: "flex-end", + justifyContent: "space-between", + + "> button": { + margin: 0, + }, "& > *:not(:last-child)": { marginRight: "$gr3", diff --git a/components/Shared/AlertDialog.tsx b/components/Shared/AlertDialog.tsx index a92f4e58..516419eb 100644 --- a/components/Shared/AlertDialog.tsx +++ b/components/Shared/AlertDialog.tsx @@ -43,13 +43,13 @@ export default function SharedAlertDialog({ {children} {cancel && ( - )} - diff --git a/components/Shared/BouncingLoader.tsx b/components/Shared/BouncingLoader.tsx index 4859065b..97c70d1f 100644 --- a/components/Shared/BouncingLoader.tsx +++ b/components/Shared/BouncingLoader.tsx @@ -23,7 +23,7 @@ const bouncingLoader = keyframes({ const StyledBouncingLoader = styled("div", { display: "flex", - margin: "$gr2 auto", + margin: "$gr2 0", "& > div": { width: "$gr2", diff --git a/hooks/useGenerativeAISearchToggle.ts b/hooks/useGenerativeAISearchToggle.ts index 0c4f4445..59e7b480 100644 --- a/hooks/useGenerativeAISearchToggle.ts +++ b/hooks/useGenerativeAISearchToggle.ts @@ -5,7 +5,12 @@ import { UserContext } from "@/context/user-context"; import useLocalStorage from "@/hooks/useLocalStorage"; import { useRouter } from "next/router"; -const defaultModalState = { +export const defaultAIState = { + enabled: "false", + expires: undefined, +}; + +export const defaultModalState = { isOpen: false, title: "Use Generative AI", }; @@ -13,12 +18,13 @@ const defaultModalState = { export default function useGenerativeAISearchToggle() { const router = useRouter(); - const [ai, setAI] = useLocalStorage("ai", "false"); + const [ai, setAI] = useLocalStorage("ai", defaultAIState); const { user } = React.useContext(UserContext); const [dialog, setDialog] = useState(defaultModalState); - const isAIPreference = ai === "true"; + const expires = Date.now() + 1000 * 60 * 60; + const isAIPreference = ai.enabled === "true"; const isChecked = isAIPreference && user?.isLoggedIn; const loginUrl = `${DCAPI_ENDPOINT}/auth/login?goto=${goToLocation()}`; @@ -36,7 +42,7 @@ export default function useGenerativeAISearchToggle() { if (router.isReady) { const { query } = router; if (query.ai === "true") { - setAI("true"); + setAI({ enabled: "true", expires }); } } }, [router.asPath]); @@ -61,7 +67,10 @@ export default function useGenerativeAISearchToggle() { if (!user?.isLoggedIn) { setDialog({ ...dialog, isOpen: checked }); } else { - setAI(checked ? "true" : "false"); + setAI({ + enabled: checked ? "true" : "false", + expires: checked ? expires : undefined, + }); } } diff --git a/hooks/useLocalStorage.ts b/hooks/useLocalStorage.ts index 5488a5b8..a4f79b10 100644 --- a/hooks/useLocalStorage.ts +++ b/hooks/useLocalStorage.ts @@ -1,6 +1,6 @@ import { useCallback, useEffect, useState } from "react"; -function useLocalStorage(key: string, initialValue: string) { +function useLocalStorage(key: string, initialValue: any) { // Get the initial value from localStorage or use the provided initialValue const [storedValue, setStoredValue] = useState(() => { if (typeof window !== "undefined") { diff --git a/lib/chat-helpers.ts b/lib/chat-helpers.ts index 38df8da9..467d3f52 100644 --- a/lib/chat-helpers.ts +++ b/lib/chat-helpers.ts @@ -1,18 +1,17 @@ import axios, { AxiosError } from "axios"; import { DCAPI_CHAT_FEEDBACK } from "./constants/endpoints"; -import { Question } from "@/types/components/chat"; -import { v4 as uuidv4 } from "uuid"; const prepareQuestion = ( questionString: string, authToken: string, -): Question => { + conversationRef: string, +) => { return { auth: authToken, message: "chat", question: questionString, - ref: uuidv4(), + ref: conversationRef, }; }; diff --git a/pages/_app.tsx b/pages/_app.tsx index e2bba277..ef8f731d 100644 --- a/pages/_app.tsx +++ b/pages/_app.tsx @@ -16,6 +16,7 @@ import React from "react"; import { SearchProvider } from "@/context/search-context"; import { User } from "@/types/context/user"; import { UserProvider } from "@/context/user-context"; +import { defaultAIState } from "@/hooks/useGenerativeAISearchToggle"; import { defaultOpenGraphData } from "@/lib/open-graph"; import { getUser } from "@/lib/user-helpers"; import globalStyles from "@/styles/global"; @@ -37,8 +38,8 @@ function MyApp({ Component, pageProps }: MyAppProps) { const [mounted, setMounted] = React.useState(false); const [user, setUser] = React.useState(); - const [ai] = useLocalStorage("ai", "false"); - const isUsingAI = ai === "true"; + const [ai, setAI] = useLocalStorage("ai", defaultAIState); + const isUsingAI = ai?.enabled === "true"; React.useEffect(() => { async function getData() { @@ -47,6 +48,9 @@ function MyApp({ Component, pageProps }: MyAppProps) { setMounted(true); } getData(); + + // Check if AI is enabled and if it has expired + if (ai?.expires && ai.expires < Date.now()) setAI(defaultAIState); }, []); React.useEffect(() => { diff --git a/pages/search.tsx b/pages/search.tsx index 568886cb..b27958b7 100644 --- a/pages/search.tsx +++ b/pages/search.tsx @@ -254,10 +254,7 @@ const SearchPage: NextPage = () => { renderTabList={showStreamedResponse} /> - + diff --git a/types/components/chat.ts b/types/components/chat.ts index 053c7e5a..a1d77799 100644 --- a/types/components/chat.ts +++ b/types/components/chat.ts @@ -1,37 +1,90 @@ import { Work } from "@nulib/dcapi-types"; -export type QuestionRendered = { - question: string; +export type Ref = { ref: string; }; -export type Question = { - auth: string; - message: "chat"; - question: string; - ref: string; +export type AggregationResultMessage = { + type: "aggregation_result"; + message: { + buckets: [ + { + key: string; + doc_count: number; + }, + ]; + doc_count_error_upper_bound: number; + sum_other_doc_count: number; + }; }; -export type Answer = { - answer: string; - isComplete: boolean; - question?: string; - ref: string; - source_documents: Array; +export type AgentFinalMessage = { + type: "final"; + message: string; +}; + +export type LLMAnswerMessage = { + type: "answer"; + message: string; +}; + +export type LLMFinalMessage = { + type: "final_message"; +}; + +export type LLMTokenMessage = { + type: "token"; + message: string; +}; + +export type LLMStopMessage = { + type: "stop"; +}; + +export type SearchResultMessage = { + type: "search_result"; + message: Array; }; -export type StreamingMessage = { - answer?: string; - end?: { - reason: "stop" | "length" | "timeout" | "eos_token"; - ref: string; +export type StartMessage = { + type: "start"; + message: { + model: string; }; - question?: string; - ref: string; - source_documents?: Array; - token?: string; }; +export type ToolStartMessage = { + type: "tool_start"; + message: + | { + tool: "discover_fields"; + input: {}; + } + | { + tool: "search"; + input: { + query: string; + }; + } + | { + tool: "aggregate"; + input: { agg_field: string; term_field: string; term: string }; + }; +}; + +export type StreamingMessage = Ref & + ( + | AggregationResultMessage + | AgentFinalMessage + | LLMAnswerMessage + | LLMFinalMessage + | LLMTokenMessage + | LLMStopMessage + | SearchResultMessage + | StartMessage + | ToolStartMessage + ); + export type ChatConfig = { auth: string; endpoint: string;