diff --git a/backend/alembic/versions/5809c0787398_add_chat_sessions.py b/backend/alembic/versions/5809c0787398_add_chat_sessions.py new file mode 100644 index 00000000000..1c5d9e5402a --- /dev/null +++ b/backend/alembic/versions/5809c0787398_add_chat_sessions.py @@ -0,0 +1,85 @@ +"""Add Chat Sessions + +Revision ID: 5809c0787398 +Revises: d929f0c1c6af +Create Date: 2023-09-04 15:29:44.002164 + +""" +import fastapi_users_db_sqlalchemy +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "5809c0787398" +down_revision = "d929f0c1c6af" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "chat_session", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=True, + ), + sa.Column("description", sa.Text(), nullable=False), + sa.Column("deleted", sa.Boolean(), nullable=False), + sa.Column( + "time_updated", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "time_created", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "chat_message", + sa.Column("chat_session_id", sa.Integer(), nullable=False), + sa.Column("message_number", sa.Integer(), nullable=False), + sa.Column("edit_number", sa.Integer(), nullable=False), + sa.Column("parent_edit_number", sa.Integer(), nullable=True), + sa.Column("latest", sa.Boolean(), nullable=False), + sa.Column("message", sa.Text(), nullable=False), + sa.Column( + "message_type", + sa.Enum( + "SYSTEM", + "USER", + "ASSISTANT", + "DANSWER", + name="messagetype", + native_enum=False, + ), + nullable=False, + ), + sa.Column( + "time_sent", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["chat_session_id"], + ["chat_session.id"], + ), + sa.PrimaryKeyConstraint("chat_session_id", "message_number", "edit_number"), + ) + + +def downgrade() -> None: + op.drop_table("chat_message") + op.drop_table("chat_session") diff --git a/backend/alembic/versions/d929f0c1c6af_feedback_feature.py b/backend/alembic/versions/d929f0c1c6af_feedback_feature.py index 985880e4025..e2f4e6ff575 100644 --- a/backend/alembic/versions/d929f0c1c6af_feedback_feature.py +++ b/backend/alembic/versions/d929f0c1c6af_feedback_feature.py @@ -24,13 +24,13 @@ def upgrade() -> None: sa.Column("query", sa.String(), nullable=False), sa.Column( "selected_search_flow", - sa.Enum("KEYWORD", "SEMANTIC", name="searchtype"), + sa.Enum("KEYWORD", "SEMANTIC", name="searchtype", native_enum=False), nullable=True, ), sa.Column("llm_answer", sa.String(), nullable=True), sa.Column( "feedback", - sa.Enum("LIKE", "DISLIKE", name="qafeedbacktype"), + sa.Enum("LIKE", "DISLIKE", name="qafeedbacktype", native_enum=False), nullable=True, ), sa.Column( @@ -65,6 +65,7 @@ def upgrade() -> None: "HIDE", "UNHIDE", name="searchfeedbacktype", + native_enum=False, ), nullable=True, ), diff --git a/backend/danswer/chat/__init__.py b/backend/danswer/chat/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/danswer/chat/chat_llm.py b/backend/danswer/chat/chat_llm.py new file mode 100644 index 00000000000..e496afe15a6 --- /dev/null +++ b/backend/danswer/chat/chat_llm.py @@ -0,0 +1,27 @@ +from collections.abc import Iterator + +from langchain.schema.messages import AIMessage +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage +from langchain.schema.messages import SystemMessage + +from danswer.configs.constants import MessageType +from danswer.db.models import ChatMessage +from danswer.llm.build import get_default_llm + + +def llm_chat_answer(previous_messages: list[ChatMessage]) -> Iterator[str]: + prompt: list[BaseMessage] = [] + for msg in previous_messages: + content = msg.message + if msg.message_type == MessageType.SYSTEM: + prompt.append(SystemMessage(content=content)) + if msg.message_type == MessageType.ASSISTANT: + prompt.append(AIMessage(content=content)) + if ( + msg.message_type == MessageType.USER + or msg.message_type == MessageType.DANSWER # consider using FunctionMessage + ): + prompt.append(HumanMessage(content=content)) + + return get_default_llm().stream(prompt) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 1a5f36824d3..eaadb241f74 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -144,6 +144,7 @@ QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "10") # 10 seconds # Include additional document/chunk metadata in prompt to GenerativeAI INCLUDE_METADATA = False +HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false" ##### diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 2038272e293..f4c82be43e8 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -79,3 +79,11 @@ class SearchFeedbackType(str, Enum): REJECT = "reject" # down-boost this document for all future queries HIDE = "hide" # mark this document as untrusted, hide from LLM UNHIDE = "unhide" + + +class MessageType(str, Enum): + # Using OpenAI standards, Langchain equivalent shown in comment + SYSTEM = "system" # SystemMessage + USER = "user" # HumanMessage + ASSISTANT = "assistant" # AIMessage + DANSWER = "danswer" # FunctionMessage diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py new file mode 100644 index 00000000000..13230929be9 --- /dev/null +++ b/backend/danswer/db/chat.py @@ -0,0 +1,247 @@ +from uuid import UUID + +from sqlalchemy import and_ +from sqlalchemy import delete +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy.exc import NoResultFound +from sqlalchemy.orm import selectinload +from sqlalchemy.orm import Session + +from danswer.configs.app_configs import HARD_DELETE_CHATS +from danswer.configs.constants import MessageType +from danswer.db.models import ChatMessage +from danswer.db.models import ChatSession + + +def fetch_chat_sessions_by_user( + user_id: UUID | None, + deleted: bool | None, + db_session: Session, +) -> list[ChatSession]: + stmt = select(ChatSession).where(ChatSession.user_id == user_id) + + if deleted is not None: + stmt = stmt.where(ChatSession.deleted == deleted) + + result = db_session.execute(stmt) + chat_sessions = result.scalars().all() + + return list(chat_sessions) + + +def fetch_chat_messages_by_session( + chat_session_id: int, db_session: Session +) -> list[ChatMessage]: + stmt = ( + select(ChatMessage) + .where(ChatMessage.chat_session_id == chat_session_id) + .order_by(ChatMessage.message_number.asc(), ChatMessage.edit_number.asc()) + ) + result = db_session.execute(stmt).scalars().all() + return list(result) + + +def fetch_chat_message( + chat_session_id: int, message_number: int, edit_number: int, db_session: Session +) -> ChatMessage: + stmt = ( + select(ChatMessage) + .where( + (ChatMessage.chat_session_id == chat_session_id) + & (ChatMessage.message_number == message_number) + & (ChatMessage.edit_number == edit_number) + ) + .options(selectinload(ChatMessage.chat_session)) + ) + + chat_message = db_session.execute(stmt).scalar_one_or_none() + + if not chat_message: + raise ValueError("Invalid Chat Message specified") + + return chat_message + + +def fetch_chat_session_by_id(chat_session_id: int, db_session: Session) -> ChatSession: + stmt = select(ChatSession).where(ChatSession.id == chat_session_id) + result = db_session.execute(stmt) + chat_session = result.scalar_one_or_none() + + if not chat_session: + raise ValueError("Invalid Chat Session ID provided") + + return chat_session + + +def verify_parent_exists( + chat_session_id: int, + message_number: int, + parent_edit_number: int | None, + db_session: Session, +) -> ChatMessage: + stmt = select(ChatMessage).where( + (ChatMessage.chat_session_id == chat_session_id) + & (ChatMessage.message_number == message_number - 1) + & (ChatMessage.edit_number == parent_edit_number) + ) + + result = db_session.execute(stmt) + + try: + return result.scalar_one() + except NoResultFound: + raise ValueError("Invalid message, parent message not found") + + +def create_chat_session( + description: str, user_id: UUID | None, db_session: Session +) -> ChatSession: + chat_session = ChatSession( + user_id=user_id, + description=description, + ) + + db_session.add(chat_session) + db_session.commit() + + return chat_session + + +def update_chat_session( + user_id: UUID | None, chat_session_id: int, description: str, db_session: Session +) -> ChatSession: + chat_session = fetch_chat_session_by_id(chat_session_id, db_session) + + if chat_session.deleted: + raise ValueError("Trying to rename a deleted chat session") + + if user_id != chat_session.user_id: + raise ValueError("User trying to update chat of another user.") + + chat_session.description = description + + db_session.commit() + + return chat_session + + +def delete_chat_session( + user_id: UUID | None, + chat_session_id: int, + db_session: Session, + hard_delete: bool = HARD_DELETE_CHATS, +) -> None: + chat_session = fetch_chat_session_by_id(chat_session_id, db_session) + + if user_id != chat_session.user_id: + raise ValueError("User trying to delete chat of another user.") + + if hard_delete: + stmt_messages = delete(ChatMessage).where( + ChatMessage.chat_session_id == chat_session_id + ) + db_session.execute(stmt_messages) + + stmt = delete(ChatSession).where(ChatSession.id == chat_session_id) + db_session.execute(stmt) + + else: + chat_session.deleted = True + + db_session.commit() + + +def _set_latest_chat_message_no_commit( + chat_session_id: int, + message_number: int, + parent_edit_number: int | None, + edit_number: int, + db_session: Session, +) -> None: + if message_number != 0 and parent_edit_number is None: + raise ValueError( + "Only initial message in a chat is allowed to not have a parent" + ) + + db_session.query(ChatMessage).filter( + and_( + ChatMessage.chat_session_id == chat_session_id, + ChatMessage.message_number == message_number, + ChatMessage.parent_edit_number == parent_edit_number, + ) + ).update({ChatMessage.latest: False}) + + db_session.query(ChatMessage).filter( + and_( + ChatMessage.chat_session_id == chat_session_id, + ChatMessage.message_number == message_number, + ChatMessage.edit_number == edit_number, + ) + ).update({ChatMessage.latest: True}) + + +def create_new_chat_message( + chat_session_id: int, + message_number: int, + message: str, + parent_edit_number: int | None, + message_type: MessageType, + db_session: Session, +) -> ChatMessage: + """Creates a new chat message and sets it to the latest message of its parent message""" + # Get the count of existing edits at the provided message number + latest_edit_number = ( + db_session.query(func.max(ChatMessage.edit_number)) + .filter_by( + chat_session_id=chat_session_id, + message_number=message_number, + ) + .scalar() + ) + + # The new message is a new edit at the provided message number + new_edit_number = latest_edit_number + 1 if latest_edit_number is not None else 0 + + # Create a new message and set it to be the latest for its parent message + new_chat_message = ChatMessage( + chat_session_id=chat_session_id, + message_number=message_number, + parent_edit_number=parent_edit_number, + edit_number=new_edit_number, + message=message, + message_type=message_type, + ) + + db_session.add(new_chat_message) + + # Set the previous latest message of the same parent, as no longer the latest + _set_latest_chat_message_no_commit( + chat_session_id=chat_session_id, + message_number=message_number, + parent_edit_number=parent_edit_number, + edit_number=new_edit_number, + db_session=db_session, + ) + + db_session.commit() + + return new_chat_message + + +def set_latest_chat_message( + chat_session_id: int, + message_number: int, + parent_edit_number: int | None, + edit_number: int, + db_session: Session, +) -> None: + _set_latest_chat_message_no_commit( + chat_session_id=chat_session_id, + message_number=message_number, + parent_edit_number=parent_edit_number, + edit_number=edit_number, + db_session=db_session, + ) + + db_session.commit() diff --git a/backend/danswer/db/feedback.py b/backend/danswer/db/feedback.py index 3cfcf3a2201..2ae5fc21265 100644 --- a/backend/danswer/db/feedback.py +++ b/backend/danswer/db/feedback.py @@ -21,7 +21,7 @@ def fetch_query_event_by_id(query_id: int, db_session: Session) -> QueryEvent: query_event = result.scalar_one_or_none() if not query_event: - raise ValueError("Invalid Query Event provided for updating") + raise ValueError("Invalid Query Event ID Provided") return query_event @@ -32,7 +32,7 @@ def fetch_docs_by_id(doc_id: str, db_session: Session) -> DbDocument: doc = result.scalar_one_or_none() if not doc: - raise ValueError("Invalid Document provided for updating") + raise ValueError("Invalid Document ID Provided") return doc diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 7d845b5d786..f88b4417a1f 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -13,6 +13,7 @@ from sqlalchemy import Enum from sqlalchemy import ForeignKey from sqlalchemy import func +from sqlalchemy import Index from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import Text @@ -25,6 +26,7 @@ from danswer.auth.schemas import UserRole from danswer.configs.constants import DEFAULT_BOOST from danswer.configs.constants import DocumentSource +from danswer.configs.constants import MessageType from danswer.configs.constants import QAFeedbackType from danswer.configs.constants import SearchFeedbackType from danswer.connectors.models import InputType @@ -52,7 +54,7 @@ class Base(DeclarativeBase): class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): # even an almost empty token from keycloak will not fit the default 1024 bytes - access_token: Mapped[str] = mapped_column(Text(), nullable=False) # type: ignore + access_token: Mapped[str] = mapped_column(Text, nullable=False) # type: ignore class User(SQLAlchemyBaseUserTableUUID, Base): @@ -68,6 +70,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base): query_events: Mapped[List["QueryEvent"]] = relationship( "QueryEvent", back_populates="user" ) + chat_sessions: Mapped[List["ChatSession"]] = relationship( + "ChatSession", back_populates="user" + ) class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): @@ -193,7 +198,7 @@ class IndexAttempt(Base): status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus)) num_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0) error_msg: Mapped[str | None] = mapped_column( - String(), default=None + Text, default=None ) # only filled if status = "failed" time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), @@ -217,6 +222,15 @@ class IndexAttempt(Base): "Credential", back_populates="index_attempts" ) + __table_args__ = ( + Index( + "ix_index_attempt_latest_for_connector_credential_pair", + "connector_id", + "credential_id", + "time_created", + ), + ) + def __repr__(self) -> str: return ( f" JSONResponse: def get_application() -> FastAPI: application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1") application.include_router(backend_router) + application.include_router(chat_router) application.include_router(event_processing_router) application.include_router(admin_router) application.include_router(user_router) diff --git a/backend/danswer/secondary_llm_flows/chat_helpers.py b/backend/danswer/secondary_llm_flows/chat_helpers.py new file mode 100644 index 00000000000..7925a3fa6a9 --- /dev/null +++ b/backend/danswer/secondary_llm_flows/chat_helpers.py @@ -0,0 +1,19 @@ +from danswer.llm.build import get_default_llm +from danswer.llm.utils import dict_based_prompt_to_langchain_prompt + + +def get_chat_name_messages(user_query: str) -> list[dict[str, str]]: + messages = [ + { + "role": "system", + "content": "Give a short name for this chat session based on the user's first message.", + }, + {"role": "user", "content": user_query}, + ] + return messages + + +def get_new_chat_name(user_query: str) -> str: + messages = get_chat_name_messages(user_query) + filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) + return get_default_llm().invoke(filled_llm_prompt) diff --git a/backend/danswer/server/chat_backend.py b/backend/danswer/server/chat_backend.py new file mode 100644 index 00000000000..b74d083de10 --- /dev/null +++ b/backend/danswer/server/chat_backend.py @@ -0,0 +1,373 @@ +from collections.abc import Iterator +from dataclasses import asdict + +from fastapi import APIRouter +from fastapi import Depends +from fastapi.responses import StreamingResponse +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.chat.chat_llm import llm_chat_answer +from danswer.configs.constants import MessageType +from danswer.db.chat import create_chat_session +from danswer.db.chat import create_new_chat_message +from danswer.db.chat import delete_chat_session +from danswer.db.chat import fetch_chat_message +from danswer.db.chat import fetch_chat_messages_by_session +from danswer.db.chat import fetch_chat_session_by_id +from danswer.db.chat import fetch_chat_sessions_by_user +from danswer.db.chat import set_latest_chat_message +from danswer.db.chat import update_chat_session +from danswer.db.chat import verify_parent_exists +from danswer.db.engine import get_session +from danswer.db.models import ChatMessage +from danswer.db.models import User +from danswer.direct_qa.interfaces import DanswerAnswerPiece +from danswer.secondary_llm_flows.chat_helpers import get_new_chat_name +from danswer.server.models import ChatMessageDetail +from danswer.server.models import ChatMessageIdentifier +from danswer.server.models import ChatRenameRequest +from danswer.server.models import ChatSessionDetailResponse +from danswer.server.models import ChatSessionIdsResponse +from danswer.server.models import CreateChatID +from danswer.server.models import CreateChatRequest +from danswer.server.models import RenameChatSessionResponse +from danswer.server.utils import get_json_line +from danswer.utils.logger import setup_logger +from danswer.utils.timing import log_generator_function_time + + +logger = setup_logger() + +router = APIRouter(prefix="/chat") + + +@router.get("/get-user-chat-sessions") +def get_user_chat_sessions( + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> ChatSessionIdsResponse: + user_id = user.id if user is not None else None + + # Don't included deleted chats, even if soft delete only + chat_sessions = fetch_chat_sessions_by_user( + user_id=user_id, deleted=False, db_session=db_session + ) + + return ChatSessionIdsResponse(sessions=[chat.id for chat in chat_sessions]) + + +@router.get("/get-chat-session/{session_id}") +def get_chat_session_messages( + session_id: int, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> ChatSessionDetailResponse: + user_id = user.id if user is not None else None + + try: + session = fetch_chat_session_by_id(session_id, db_session) + except ValueError: + raise ValueError("Chat Session has been deleted") + + if session.deleted: + raise ValueError("Chat Session has been deleted") + + if user_id != session.user_id: + if user is None: + raise PermissionError( + "The No-Auth User is trying to read a different user's chat" + ) + raise PermissionError( + f"User {user.email} is trying to read a different user's chat" + ) + + session_messages = fetch_chat_messages_by_session( + chat_session_id=session_id, db_session=db_session + ) + + return ChatSessionDetailResponse( + chat_session_id=session_id, + description=session.description, + messages=[ + ChatMessageDetail( + message_number=msg.message_number, + edit_number=msg.edit_number, + parent_edit_number=msg.parent_edit_number, + latest=msg.latest, + message=msg.message, + message_type=msg.message_type, + time_sent=msg.time_sent, + ) + for msg in session_messages + ], + ) + + +@router.post("/create-chat-session") +def create_new_chat_session( + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> CreateChatID: + user_id = user.id if user is not None else None + + new_chat_session = create_chat_session( + "", user_id, db_session # Leave the naming till later to prevent delay + ) + + return CreateChatID(chat_session_id=new_chat_session.id) + + +@router.put("/rename-chat-session") +def rename_chat_session( + rename: ChatRenameRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> RenameChatSessionResponse: + name = rename.name + message = rename.first_message + user_id = user.id if user is not None else None + + if not name and not message: + raise ValueError("Can't assign a name for the chat without context") + + new_name = name or get_new_chat_name(str(message)) + + update_chat_session(user_id, rename.chat_session_id, new_name, db_session) + + return RenameChatSessionResponse(new_name=new_name) + + +@router.delete("/delete-chat-session/{session_id}") +def delete_chat_session_by_id( + session_id: int, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + user_id = user.id if user is not None else None + delete_chat_session(user_id, session_id, db_session) + + +def _create_chat_chain( + chat_session_id: int, + db_session: Session, + stop_after: int | None = None, +) -> list[ChatMessage]: + mainline_messages: list[ChatMessage] = [] + all_chat_messages = fetch_chat_messages_by_session(chat_session_id, db_session) + target_message_num = 0 + target_parent_edit_num = None + + # Chat messages must be ordered by message_number + # (fetch_chat_messages_by_session ensures this so no resorting here necessary) + for msg in all_chat_messages: + if ( + msg.message_number != target_message_num + or msg.parent_edit_number != target_parent_edit_num + or not msg.latest + ): + continue + + target_parent_edit_num = msg.edit_number + target_message_num += 1 + + mainline_messages.append(msg) + + if stop_after is not None and target_message_num > stop_after: + break + + if not mainline_messages: + raise RuntimeError("Could not trace chat message history") + + return mainline_messages + + +@router.post("/send-message") +def handle_new_chat_message( + chat_message: CreateChatRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> StreamingResponse: + """This endpoint is both used for sending new messages and for sending edited messages. + To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path + have already been set as latest""" + chat_session_id = chat_message.chat_session_id + message_number = chat_message.message_number + message_content = chat_message.message + parent_edit_number = chat_message.parent_edit_number + user_id = user.id if user is not None else None + + chat_session = fetch_chat_session_by_id(chat_session_id, db_session) + + if chat_session.deleted: + raise ValueError("Cannot send messages to a deleted chat session") + + if chat_session.user_id != user_id: + if user is None: + raise PermissionError( + "The No-Auth User trying to interact with a different user's chat" + ) + raise PermissionError( + f"User {user.email} trying to interact with a different user's chat" + ) + + if message_number != 0: + if parent_edit_number is None: + raise ValueError("Message must have a valid parent message") + + verify_parent_exists( + chat_session_id=chat_session_id, + message_number=message_number, + parent_edit_number=parent_edit_number, + db_session=db_session, + ) + else: + if parent_edit_number is not None: + raise ValueError("Initial message in session cannot have parent") + + # Create new message at the right place in the tree and label it latest for its parent + new_message = create_new_chat_message( + chat_session_id=chat_session_id, + message_number=message_number, + parent_edit_number=parent_edit_number, + message=message_content, + message_type=MessageType.USER, + db_session=db_session, + ) + + mainline_messages = _create_chat_chain( + chat_session_id, + db_session, + ) + + if mainline_messages[-1].message != message_content: + raise RuntimeError( + "The new message was not on the mainline. " + "Be sure to update latests before calling this." + ) + + @log_generator_function_time() + def stream_chat_tokens() -> Iterator[str]: + tokens = llm_chat_answer(mainline_messages) + llm_output = "" + for token in tokens: + llm_output += token + yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token))) + + create_new_chat_message( + chat_session_id=chat_session_id, + message_number=message_number + 1, + parent_edit_number=new_message.edit_number, + message=llm_output, + message_type=MessageType.ASSISTANT, + db_session=db_session, + ) + + return StreamingResponse(stream_chat_tokens(), media_type="application/json") + + +@router.post("/regenerate-from-parent") +def regenerate_message_given_parent( + parent_message: ChatMessageIdentifier, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> StreamingResponse: + """Regenerate an LLM response given a particular parent message + The parent message is set as latest and a new LLM response is set as + the latest following message""" + chat_session_id = parent_message.chat_session_id + message_number = parent_message.message_number + edit_number = parent_message.edit_number + user_id = user.id if user is not None else None + + chat_message = fetch_chat_message( + chat_session_id=chat_session_id, + message_number=message_number, + edit_number=edit_number, + db_session=db_session, + ) + + chat_session = chat_message.chat_session + + if chat_session.deleted: + raise ValueError("Chat session has been deleted") + + if chat_session.user_id != user_id: + if user is None: + raise PermissionError( + "The No-Auth User trying to regenerate chat messages of another user" + ) + raise PermissionError( + f"User {user.email} trying to regenerate chat messages of another user" + ) + + set_latest_chat_message( + chat_session_id, + message_number, + chat_message.parent_edit_number, + edit_number, + db_session, + ) + + # The parent message, now set as latest, may have follow on messages + # Don't want to include those in the context to LLM + mainline_messages = _create_chat_chain( + chat_session_id, db_session, stop_after=message_number + ) + + @log_generator_function_time() + def stream_regenerate_tokens() -> Iterator[str]: + tokens = llm_chat_answer(mainline_messages) + llm_output = "" + for token in tokens: + llm_output += token + yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token))) + + create_new_chat_message( + chat_session_id=chat_session_id, + message_number=message_number + 1, + parent_edit_number=edit_number, + message=llm_output, + message_type=MessageType.ASSISTANT, + db_session=db_session, + ) + + return StreamingResponse(stream_regenerate_tokens(), media_type="application/json") + + +@router.put("/set-message-as-latest") +def set_message_as_latest( + message_identifier: ChatMessageIdentifier, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + user_id = user.id if user is not None else None + + chat_message = fetch_chat_message( + chat_session_id=message_identifier.chat_session_id, + message_number=message_identifier.message_number, + edit_number=message_identifier.edit_number, + db_session=db_session, + ) + + chat_session = chat_message.chat_session + + if chat_session.deleted: + raise ValueError("Chat session has been deleted") + + if chat_session.user_id != user_id: + if user is None: + raise PermissionError( + "The No-Auth User trying to update chat messages of another user" + ) + raise PermissionError( + f"User {user.email} trying to update chat messages of another user" + ) + + set_latest_chat_message( + chat_session_id=chat_message.chat_session_id, + message_number=chat_message.message_number, + parent_edit_number=chat_message.parent_edit_number, + edit_number=chat_message.edit_number, + db_session=db_session, + ) diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 292fba396b5..36930ffa74f 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -11,6 +11,7 @@ from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX from danswer.configs.constants import DocumentSource +from danswer.configs.constants import MessageType from danswer.configs.constants import QAFeedbackType from danswer.configs.constants import SearchFeedbackType from danswer.connectors.models import InputType @@ -130,6 +131,10 @@ class SearchDoc(BaseModel): score: float | None +class CreateChatID(BaseModel): + chat_session_id: int + + class QuestionRequest(BaseModel): query: str collection: str @@ -151,6 +156,49 @@ class SearchFeedbackRequest(BaseModel): search_feedback: SearchFeedbackType +class CreateChatRequest(BaseModel): + chat_session_id: int + message_number: int + parent_edit_number: int | None + message: str + + +class ChatMessageIdentifier(BaseModel): + chat_session_id: int + message_number: int + edit_number: int + + +class ChatRenameRequest(BaseModel): + chat_session_id: int + name: str | None + first_message: str | None + + +class RenameChatSessionResponse(BaseModel): + new_name: str # This is only really useful if the name is generated + + +class ChatSessionIdsResponse(BaseModel): + sessions: list[int] + + +class ChatMessageDetail(BaseModel): + message_number: int + edit_number: int + parent_edit_number: int | None + latest: bool + message: str + message_type: MessageType + time_sent: datetime + + +class ChatSessionDetailResponse(BaseModel): + chat_session_id: int + description: str + messages: list[ChatMessageDetail] + + class QueryValidationResponse(BaseModel): reasoning: str answerable: bool diff --git a/backend/danswer/utils/timing.py b/backend/danswer/utils/timing.py index e1ed5e14dd7..c92d91c0160 100644 --- a/backend/danswer/utils/timing.py +++ b/backend/danswer/utils/timing.py @@ -1,6 +1,7 @@ import time from collections.abc import Callable from collections.abc import Generator +from collections.abc import Iterator from typing import Any from typing import cast from typing import TypeVar @@ -10,7 +11,7 @@ logger = setup_logger() F = TypeVar("F", bound=Callable) -FG = TypeVar("FG", bound=Callable[..., Generator]) +FG = TypeVar("FG", bound=Callable[..., Generator | Iterator]) def log_function_time(