Skip to content

Commit

Permalink
Add strict json mode (#2917)
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves authored Oct 25, 2024
1 parent d7a30b0 commit 4a47e9a
Show file tree
Hide file tree
Showing 11 changed files with 291 additions and 6 deletions.
1 change: 1 addition & 0 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ def stream_chat_message_objects(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
),
prompt_config=prompt_config,
llm=(
Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/llm/answering/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ class AnswerStyleConfig(BaseModel):
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
# right now, only used by the simple chat API
structured_response_format: dict | None = None

@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
Expand Down
19 changes: 16 additions & 3 deletions backend/danswer/llm/chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def _completion(
tools: list[dict] | None,
tool_choice: ToolChoiceOptions | None,
stream: bool,
structured_response_format: dict | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
if isinstance(prompt, list):
prompt = [
Expand Down Expand Up @@ -313,6 +314,11 @@ def _completion(
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
**({"parallel_tool_calls": False} if tools else {}),
**(
{"response_format": structured_response_format}
if structured_response_format
else {}
),
**self._model_kwargs,
)
except Exception as e:
Expand All @@ -336,12 +342,16 @@ def _invoke_implementation(
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()

response = cast(
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
litellm.ModelResponse,
self._completion(
prompt, tools, tool_choice, False, structured_response_format
),
)
choice = response.choices[0]
if hasattr(choice, "message"):
Expand All @@ -354,18 +364,21 @@ def _stream_implementation(
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()

if DISABLE_LITELLM_STREAMING:
yield self.invoke(prompt)
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
return

output = None
response = cast(
litellm.CustomStreamWrapper,
self._completion(prompt, tools, tool_choice, True),
self._completion(
prompt, tools, tool_choice, True, structured_response_format
),
)
try:
for part in response:
Expand Down
2 changes: 2 additions & 0 deletions backend/danswer/llm/custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _invoke_implementation(
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
return self._execute(prompt)

Expand All @@ -88,5 +89,6 @@ def _stream_implementation(
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
yield self._execute(prompt)
12 changes: 10 additions & 2 deletions backend/danswer/llm/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,22 @@ def invoke(
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(prompt, tools, tool_choice)
return self._invoke_implementation(
prompt, tools, tool_choice, structured_response_format
)

@abc.abstractmethod
def _invoke_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
raise NotImplementedError

Expand All @@ -108,17 +112,21 @@ def stream(
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._stream_implementation(prompt, tools, tool_choice)
return self._stream_implementation(
prompt, tools, tool_choice, structured_response_format
)

@abc.abstractmethod
def _stream_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError
4 changes: 4 additions & 0 deletions backend/danswer/server/query_and_chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ class CreateChatMessageRequest(ChunkContext):
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False

# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None

@model_validator(mode="after")
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
if self.search_doc_ids is None and self.retrieval_options is None:
Expand Down
4 changes: 3 additions & 1 deletion backend/ee/danswer/server/query_and_chat/chat_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def handle_simplified_chat_message(
chunks_above=0,
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
)

packets = stream_chat_message_objects(
Expand All @@ -202,7 +203,7 @@ def handle_send_message_simple_with_history(
raise HTTPException(status_code=400, detail="Messages cannot be zero length")

# This is a sanity check to make sure the chat history is valid
# It must start with a user message and alternate between user and assistant
# It must start with a user message and alternate beteen user and assistant
expected_role = MessageType.USER
for msg in req.messages:
if not msg.message:
Expand Down Expand Up @@ -296,6 +297,7 @@ def handle_send_message_simple_with_history(
chunks_above=0,
chunks_below=0,
full_doc=req.full_doc,
structured_response_format=req.structured_response_format,
)

packets = stream_chat_message_objects(
Expand Down
6 changes: 6 additions & 0 deletions backend/ee/danswer/server/query_and_chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
query_override: str | None = None
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None = None
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None


class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
Expand All @@ -60,6 +63,9 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
skip_rerank: bool | None = None
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None = None
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None


class SimpleDoc(BaseModel):
Expand Down
148 changes: 148 additions & 0 deletions backend/scripts/add_connector_creation_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from typing import Any
from typing import Dict

import requests

API_SERVER_URL = "http://localhost:3000" # Adjust this to your Danswer server URL
HEADERS = {"Content-Type": "application/json"}
API_KEY = "danswer-api-key" # API key here, if auth is enabled


def create_connector(
name: str,
source: str,
input_type: str,
connector_specific_config: Dict[str, Any],
is_public: bool = True,
groups: list[int] | None = None,
) -> Dict[str, Any]:
connector_update_request = {
"name": name,
"source": source,
"input_type": input_type,
"connector_specific_config": connector_specific_config,
"is_public": is_public,
"groups": groups or [],
}

response = requests.post(
url=f"{API_SERVER_URL}/api/manage/admin/connector",
json=connector_update_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()


def create_credential(
name: str,
source: str,
credential_json: Dict[str, Any],
is_public: bool = True,
groups: list[int] | None = None,
) -> Dict[str, Any]:
credential_request = {
"name": name,
"source": source,
"credential_json": credential_json,
"admin_public": is_public,
"groups": groups or [],
}

response = requests.post(
url=f"{API_SERVER_URL}/api/manage/credential",
json=credential_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()


def create_cc_pair(
connector_id: int,
credential_id: int,
name: str,
access_type: str = "public",
groups: list[int] | None = None,
) -> Dict[str, Any]:
cc_pair_request = {
"name": name,
"access_type": access_type,
"groups": groups or [],
}

response = requests.put(
url=f"{API_SERVER_URL}/api/manage/connector/{connector_id}/credential/{credential_id}",
json=cc_pair_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()


def main() -> None:
# Create a Web connector
web_connector = create_connector(
name="Example Web Connector",
source="web",
input_type="load_state",
connector_specific_config={
"base_url": "https://example.com",
"web_connector_type": "recursive",
},
)
print(f"Created Web Connector: {web_connector}")

# Create a credential for the Web connector
web_credential = create_credential(
name="Example Web Credential",
source="web",
credential_json={}, # Web connectors typically don't need credentials
is_public=True,
)
print(f"Created Web Credential: {web_credential}")

# Create CC pair for Web connector
web_cc_pair = create_cc_pair(
connector_id=web_connector["id"],
credential_id=web_credential["id"],
name="Example Web CC Pair",
access_type="public",
)
print(f"Created Web CC Pair: {web_cc_pair}")

# Create a GitHub connector
github_connector = create_connector(
name="Example GitHub Connector",
source="github",
input_type="poll",
connector_specific_config={
"repo_owner": "example-owner",
"repo_name": "example-repo",
"include_prs": True,
"include_issues": True,
},
)
print(f"Created GitHub Connector: {github_connector}")

# Create a credential for the GitHub connector
github_credential = create_credential(
name="Example GitHub Credential",
source="github",
credential_json={"github_access_token": "your_github_access_token_here"},
is_public=True,
)
print(f"Created GitHub Credential: {github_credential}")

# Create CC pair for GitHub connector
github_cc_pair = create_cc_pair(
connector_id=github_connector["id"],
credential_id=github_credential["id"],
name="Example GitHub CC Pair",
access_type="public",
)
print(f"Created GitHub CC Pair: {github_cc_pair}")


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions backend/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from danswer.db.engine import get_session_context_manager
from danswer.db.search_settings import get_current_search_settings
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture


Expand Down Expand Up @@ -44,3 +46,11 @@ def vespa_client(db_session: Session) -> vespa_fixture:
@pytest.fixture
def reset() -> None:
reset_all()


@pytest.fixture
def new_admin_user() -> DATestUser | None:
try:
return UserManager.create(name="admin_user")
except Exception:
return None
Loading

0 comments on commit 4a47e9a

Please sign in to comment.