diff --git a/docs/api.md b/docs/api.md index c2f2fed..ce41f77 100644 --- a/docs/api.md +++ b/docs/api.md @@ -75,7 +75,7 @@ Create a new graph based on a user-defined schema. ### `query_graph` ```python -def query_graph(self, namespace: str, query: str) -> QueryGraphReturn +def query_graph(self, namespace: str, query: str, include_triples: bool = False, include_chunks: bool = False) -> QueryGraphReturn ``` Query the graph. @@ -84,10 +84,12 @@ Query the graph. - `namespace` (str): The namespace of the graph. - `query` (str): The query to run. +- `include_triples` (bool): Include the triples used in the return. +- `include_chunks` (bool): Include the chunk ids and chunk text in the return. #### Returns -- (`QueryGraphReturn`): The answer, triples, and Cypher query. +- (`QueryGraphReturn`): The answer, triples (optional), and chunks (optional). ## Schemas @@ -148,6 +150,8 @@ class QueryGraphResponse(BaseResponse): namespace: str answer: str + include_triples: bool = False + include_chunks: bool = False ``` ### `QueryGraphReturn` diff --git a/docs/tutorial.md b/docs/tutorial.md index 43dd399..75849b5 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -160,4 +160,15 @@ namespace = "harry-potter-2" seed_questions_query_response = client.graph.query_graph(namespace, query) print("Query Response:", query_response) +# Include the triples in the return +query = "Who is Harry friends with?" +namespace = "harry-potter" +schema_query_response = client.graph.query_graph(namespace, query, include_triples = True) +print("Query Response:", query_response) + +# Include the chunk context in the return +query = "Who is Harry friends with?" +namespace = "harry-potter" +schema_query_response = client.graph.query_graph(namespace, query, include_chunks = True) +print("Query Response:", query_response) ``` diff --git a/examples/create_graph_from_questions.ipynb b/examples/create_graph_from_questions.ipynb index e1f3e84..8ef44df 100644 --- a/examples/create_graph_from_questions.ipynb +++ b/examples/create_graph_from_questions.ipynb @@ -89,11 +89,7 @@ ], "source": [ "# Add documents\n", - "documents_response = client.graph.add_documents(\n", - " namespace = namespace, \n", - " documents = documents\n", - ")\n", - "\n", + "documents_response = client.graph.add_documents(namespace, documents)\n", "print(documents_response)" ] }, diff --git a/src/whyhow/__init__.py b/src/whyhow/__init__.py index 4626c25..d2f2ba7 100644 --- a/src/whyhow/__init__.py +++ b/src/whyhow/__init__.py @@ -2,5 +2,5 @@ from whyhow.client import AsyncWhyHow, WhyHow -__version__ = "v0.0.4" +__version__ = "v0.0.5" __all__ = ["AsyncWhyHow", "WhyHow"] diff --git a/src/whyhow/apis/graph.py b/src/whyhow/apis/graph.py index 273eeaa..fb75336 100644 --- a/src/whyhow/apis/graph.py +++ b/src/whyhow/apis/graph.py @@ -13,7 +13,6 @@ CreateSchemaGraphRequest, QueryGraphRequest, QueryGraphResponse, - QueryGraphReturn, ) @@ -56,8 +55,9 @@ def add_documents(self, namespace: str, documents: list[str]) -> str: if len(document_paths) > 3: raise ValueError( - """Too many documents - please limit uploads to 3 files during the beta.""" + """Too many documents. + please limit uploads to 3 files during the beta. + """ ) files = [ @@ -142,7 +142,13 @@ def create_graph_from_schema( return response.message - def query_graph(self, namespace: str, query: str) -> QueryGraphReturn: + def query_graph( + self, + namespace: str, + query: str, + include_triples: bool = False, + include_chunks: bool = False, + ) -> QueryGraphResponse: """Query the graph. Parameters @@ -155,11 +161,15 @@ def query_graph(self, namespace: str, query: str) -> QueryGraphReturn: Returns ------- - QueryGraphReturn - The answer, triples, and Cypher query. + QueryGraphResponse + The namespace, answer, triples, and chunks and Cypher query. """ - request_body = QueryGraphRequest(query=query) + request_body = QueryGraphRequest( + query=query, + include_triples=include_triples, + include_chunks=include_chunks, + ) raw_response = self.client.post( f"{self.prefix}/{namespace}/query", @@ -170,6 +180,6 @@ def query_graph(self, namespace: str, query: str) -> QueryGraphReturn: response = QueryGraphResponse.model_validate(raw_response.json()) - retval = QueryGraphReturn(answer=response.answer) + # retval = QueryGraphReturn(answer=response.answer) - return retval + return response diff --git a/src/whyhow/schemas/graph.py b/src/whyhow/schemas/graph.py index 40bd648..46f7eeb 100644 --- a/src/whyhow/schemas/graph.py +++ b/src/whyhow/schemas/graph.py @@ -49,6 +49,26 @@ class QueryGraphRequest(BaseRequest): """Schema for the request body of the query graph endpoint.""" query: str + include_triples: bool = False + include_chunks: bool = False + + +class QueryGraphTripleResponse(BaseResponse): + """Schema for the triples within the query graph response.""" + + head: str + relation: str + tail: str + + +class QueryGraphChunkResponse(BaseResponse): + """Schema for the triples within the query graph response.""" + + head: str + relation: str + tail: str + chunk_ids: list[str] + chunk_texts: list[str] class QueryGraphResponse(BaseResponse): @@ -56,9 +76,14 @@ class QueryGraphResponse(BaseResponse): namespace: str answer: str + triples: list[QueryGraphTripleResponse] = [] + chunks: list[QueryGraphChunkResponse] = [] class QueryGraphReturn(BaseReturn): """Schema for the return value of the query graph endpoint.""" + namespace: str answer: str + triples: list[QueryGraphTripleResponse] = [] + chunks: list[QueryGraphChunkResponse] = [] diff --git a/tests/apis/test_graph.py b/tests/apis/test_graph.py index 0af2594..b6630fc 100644 --- a/tests/apis/test_graph.py +++ b/tests/apis/test_graph.py @@ -6,11 +6,7 @@ from whyhow.client import WhyHow from whyhow.schemas.common import Graph, Node, Relationship -from whyhow.schemas.graph import ( - QueryGraphRequest, - QueryGraphResponse, - QueryGraphReturn, -) +from whyhow.schemas.graph import QueryGraphRequest, QueryGraphResponse # Set fake environment variables os.environ["WHYHOW_API_KEY"] = "fake_api_key" @@ -33,16 +29,18 @@ class TestGraphAPIQuery: - """Tests for the query_graph method.""" + """Tests for the query method.""" def test_query_graph(self, httpx_mock): - """Test querying the graph.""" + """Test the query_graph method.""" client = WhyHow() query = "What friends does Alice have?" fake_response_body = QueryGraphResponse( namespace="something", answer="Alice knows Bob", + triples=[], + chunks=[], ) httpx_mock.add_response( method="POST", @@ -54,7 +52,12 @@ def test_query_graph(self, httpx_mock): query=query, ) - assert result == QueryGraphReturn(answer="Alice knows Bob") + assert result == QueryGraphResponse( + namespace="something", + answer="Alice knows Bob", + triples=[], + chunks=[], + ) actual_request = httpx_mock.get_requests()[0] expected_request_body = QueryGraphRequest(query=query) @@ -70,7 +73,7 @@ class TestGraphAPIAddDocuments: """Tests for the add_documents method.""" def test_errors(self, httpx_mock, tmp_path): - """Test error handling.""" + """Test various error cases.""" client = WhyHow() with pytest.raises(ValueError, match="No documents provided"): diff --git a/tests/conftest.py b/tests/conftest.py index cbb332e..f132012 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -"""This module contains pytest fixtures that are used across multiple files.""" +"""Configuration for the tests.""" import pathlib @@ -7,7 +7,7 @@ @pytest.fixture def test_path(): - """Return the path of the test file.""" + """Return the path to the tests directory.""" return pathlib.Path(__file__).parent diff --git a/tests/schemas/test_common.py b/tests/schemas/test_common.py index 91e5300..5bcafe5 100644 --- a/tests/schemas/test_common.py +++ b/tests/schemas/test_common.py @@ -1,4 +1,4 @@ -"""This module contains tests for the common schemas.""" +"""Tests for whyhow.schemas.common.""" import pytest @@ -6,7 +6,7 @@ class TestGraph: - """Tests for the Graph schema.""" + """Tests for the Graph class.""" def test_no_nodes(self): """Test creating a graph with no nodes.""" @@ -45,10 +45,10 @@ def test_3_nodes_1_rel(self): class TestEntity: - """Tests for the Entity schema.""" + """Tests for the Entity class.""" def test_overall(self): - """Test creating an entity.""" + """Test creating an entity and converting it to a node.""" entity = Entity( text="Alice", label="Person", properties={"foo": "bar"} ) @@ -67,7 +67,7 @@ def test_overall(self): assert entity.properties is not entity_reconstructed.properties def test_missing_name(self): - """Test creating an entity with missing name.""" + """Test creating an entity without a name property.""" node = Node(labels=["Person"], properties={}) with pytest.raises(ValueError, match="Node must have a name property"): @@ -75,10 +75,47 @@ def test_missing_name(self): class TestTriple: - """Tests for the Triple schema.""" + """Tests for the Triple class.""" + + # def test_overall(self): + # triple = Triple( + # head="Alice", + # head_type="Person", + # relationship="KNOWS", + # tail="Bob", + # tail_type="Person", + # properties={"since": 1999}, + # ) + + # assert triple.head == "Alice" + # assert triple.head_type == "Person" + # assert triple.relationship == "KNOWS" + # assert triple.tail == "Bob" + # assert triple.tail_type == "Person" + + # rel = triple.to_relationship() + + # assert rel.start_node.labels == ["Person"] + # assert rel.start_node.properties == {"name": "Alice"} + # assert rel.end_node.labels == ["Person"] + # assert rel.end_node.properties == {"name": "Bob"} + # assert rel.type == "KNOWS" + # assert rel.properties == {"since": 1999} + + # triple_reconstructed = Triple.from_relationship(rel) + + # assert triple.head == triple_reconstructed.head + # assert triple.head_type == triple_reconstructed.head_type + # assert triple.relationship == triple_reconstructed.relationship + # assert triple.tail == triple_reconstructed.tail + # assert triple.tail_type == triple_reconstructed.tail_type + # assert triple.properties == triple_reconstructed.properties + + # # test properties copied + # assert triple.properties is not triple_reconstructed.properties def test_missing_name(self): - """Test creating a triple with missing name.""" + """Test creating a triple with a node missing a name property.""" rel = Relationship( start_node=Node(labels=["Person"], properties={}), end_node=Node(labels=["Person"], properties={"name": "Bob"}), diff --git a/tests/test_client.py b/tests/test_client.py index 41568b3..5c89025 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,13 +12,13 @@ class TestWhyHow: """Tests for the WhyHow class.""" def test_constructor_missing_api_key(self, monkeypatch): - """Test creating a WhyHow instance without an API key.""" + """Test that an error is raised if the API key is missing.""" monkeypatch.delenv("WHYHOW_API_KEY", raising=False) with pytest.raises(ValueError, match="WHYHOW_API_KEY must be set"): WhyHow() def test_httpx_kwargs(self, monkeypatch): - """Test passing httpx_kwargs to the constructor.""" + """Test that the client is initialized with the correct arguments.""" fake_httpx_client_inst = Mock(spec=Client) fake_httpx_client_class = Mock(return_value=fake_httpx_client_inst) @@ -42,7 +42,7 @@ def test_httpx_kwargs(self, monkeypatch): assert client.httpx_client is fake_httpx_client_class.return_value def test_base_url_twice(self): - """Test setting base_url in httpx_kwargs.""" + """Test that an error is raised if base_url is set in httpx_kwargs.""" with pytest.raises( ValueError, match="base_url cannot be set in httpx_kwargs." ): diff --git a/tests/test_dummy.py b/tests/test_dummy.py new file mode 100644 index 0000000..def84ad --- /dev/null +++ b/tests/test_dummy.py @@ -0,0 +1,6 @@ +"""Dummy test.""" + + +def test(): + """Dummy test.""" + assert True