Skip to content

Commit

Permalink
v0.0.5
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsmoker committed May 7, 2024
1 parent a8afaa6 commit 07c795c
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 38 deletions.
8 changes: 6 additions & 2 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Create a new graph based on a user-defined schema.
### `query_graph`

```python
def query_graph(self, namespace: str, query: str) -> QueryGraphReturn
def query_graph(self, namespace: str, query: str, include_triples: bool = False, include_chunks: bool = False) -> QueryGraphReturn
```

Query the graph.
Expand All @@ -84,10 +84,12 @@ Query the graph.

- `namespace` (str): The namespace of the graph.
- `query` (str): The query to run.
- `include_triples` (bool): Include the triples used in the return.
- `include_chunks` (bool): Include the chunk ids and chunk text in the return.

#### Returns

- (`QueryGraphReturn`): The answer, triples, and Cypher query.
- (`QueryGraphReturn`): The answer, triples (optional), and chunks (optional).

## Schemas

Expand Down Expand Up @@ -148,6 +150,8 @@ class QueryGraphResponse(BaseResponse):

namespace: str
answer: str
include_triples: bool = False
include_chunks: bool = False
```

### `QueryGraphReturn`
Expand Down
11 changes: 11 additions & 0 deletions docs/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,15 @@ namespace = "harry-potter-2"
seed_questions_query_response = client.graph.query_graph(namespace, query)
print("Query Response:", query_response)

# Include the triples in the return
query = "Who is Harry friends with?"
namespace = "harry-potter"
schema_query_response = client.graph.query_graph(namespace, query, include_triples = True)
print("Query Response:", query_response)

# Include the chunk context in the return
query = "Who is Harry friends with?"
namespace = "harry-potter"
schema_query_response = client.graph.query_graph(namespace, query, include_chunks = True)
print("Query Response:", query_response)
```
6 changes: 1 addition & 5 deletions examples/create_graph_from_questions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,7 @@
],
"source": [
"# Add documents\n",
"documents_response = client.graph.add_documents(\n",
" namespace = namespace, \n",
" documents = documents\n",
")\n",
"\n",
"documents_response = client.graph.add_documents(namespace, documents)\n",
"print(documents_response)"
]
},
Expand Down
2 changes: 1 addition & 1 deletion src/whyhow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

from whyhow.client import AsyncWhyHow, WhyHow

__version__ = "v0.0.4"
__version__ = "v0.0.5"
__all__ = ["AsyncWhyHow", "WhyHow"]
28 changes: 19 additions & 9 deletions src/whyhow/apis/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
CreateSchemaGraphRequest,
QueryGraphRequest,
QueryGraphResponse,
QueryGraphReturn,
)


Expand Down Expand Up @@ -56,8 +55,9 @@ def add_documents(self, namespace: str, documents: list[str]) -> str:

if len(document_paths) > 3:
raise ValueError(
"""Too many documents
please limit uploads to 3 files during the beta."""
"""Too many documents.
please limit uploads to 3 files during the beta.
"""
)

files = [
Expand Down Expand Up @@ -142,7 +142,13 @@ def create_graph_from_schema(

return response.message

def query_graph(self, namespace: str, query: str) -> QueryGraphReturn:
def query_graph(
self,
namespace: str,
query: str,
include_triples: bool = False,
include_chunks: bool = False,
) -> QueryGraphResponse:
"""Query the graph.
Parameters
Expand All @@ -155,11 +161,15 @@ def query_graph(self, namespace: str, query: str) -> QueryGraphReturn:
Returns
-------
QueryGraphReturn
The answer, triples, and Cypher query.
QueryGraphResponse
The namespace, answer, triples, and chunks and Cypher query.
"""
request_body = QueryGraphRequest(query=query)
request_body = QueryGraphRequest(
query=query,
include_triples=include_triples,
include_chunks=include_chunks,
)

raw_response = self.client.post(
f"{self.prefix}/{namespace}/query",
Expand All @@ -170,6 +180,6 @@ def query_graph(self, namespace: str, query: str) -> QueryGraphReturn:

response = QueryGraphResponse.model_validate(raw_response.json())

retval = QueryGraphReturn(answer=response.answer)
# retval = QueryGraphReturn(answer=response.answer)

return retval
return response
25 changes: 25 additions & 0 deletions src/whyhow/schemas/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,41 @@ class QueryGraphRequest(BaseRequest):
"""Schema for the request body of the query graph endpoint."""

query: str
include_triples: bool = False
include_chunks: bool = False


class QueryGraphTripleResponse(BaseResponse):
"""Schema for the triples within the query graph response."""

head: str
relation: str
tail: str


class QueryGraphChunkResponse(BaseResponse):
"""Schema for the triples within the query graph response."""

head: str
relation: str
tail: str
chunk_ids: list[str]
chunk_texts: list[str]


class QueryGraphResponse(BaseResponse):
"""Schema for the response body of the query graph endpoint."""

namespace: str
answer: str
triples: list[QueryGraphTripleResponse] = []
chunks: list[QueryGraphChunkResponse] = []


class QueryGraphReturn(BaseReturn):
"""Schema for the return value of the query graph endpoint."""

namespace: str
answer: str
triples: list[QueryGraphTripleResponse] = []
chunks: list[QueryGraphChunkResponse] = []
21 changes: 12 additions & 9 deletions tests/apis/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@

from whyhow.client import WhyHow
from whyhow.schemas.common import Graph, Node, Relationship
from whyhow.schemas.graph import (
QueryGraphRequest,
QueryGraphResponse,
QueryGraphReturn,
)
from whyhow.schemas.graph import QueryGraphRequest, QueryGraphResponse

# Set fake environment variables
os.environ["WHYHOW_API_KEY"] = "fake_api_key"
Expand All @@ -33,16 +29,18 @@


class TestGraphAPIQuery:
"""Tests for the query_graph method."""
"""Tests for the query method."""

def test_query_graph(self, httpx_mock):
"""Test querying the graph."""
"""Test the query_graph method."""
client = WhyHow()
query = "What friends does Alice have?"

fake_response_body = QueryGraphResponse(
namespace="something",
answer="Alice knows Bob",
triples=[],
chunks=[],
)
httpx_mock.add_response(
method="POST",
Expand All @@ -54,7 +52,12 @@ def test_query_graph(self, httpx_mock):
query=query,
)

assert result == QueryGraphReturn(answer="Alice knows Bob")
assert result == QueryGraphResponse(
namespace="something",
answer="Alice knows Bob",
triples=[],
chunks=[],
)

actual_request = httpx_mock.get_requests()[0]
expected_request_body = QueryGraphRequest(query=query)
Expand All @@ -70,7 +73,7 @@ class TestGraphAPIAddDocuments:
"""Tests for the add_documents method."""

def test_errors(self, httpx_mock, tmp_path):
"""Test error handling."""
"""Test various error cases."""
client = WhyHow()

with pytest.raises(ValueError, match="No documents provided"):
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""This module contains pytest fixtures that are used across multiple files."""
"""Configuration for the tests."""

import pathlib

Expand All @@ -7,7 +7,7 @@

@pytest.fixture
def test_path():
"""Return the path of the test file."""
"""Return the path to the tests directory."""
return pathlib.Path(__file__).parent


Expand Down
51 changes: 44 additions & 7 deletions tests/schemas/test_common.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""This module contains tests for the common schemas."""
"""Tests for whyhow.schemas.common."""

import pytest

from whyhow.schemas.common import Entity, Graph, Node, Relationship, Triple


class TestGraph:
"""Tests for the Graph schema."""
"""Tests for the Graph class."""

def test_no_nodes(self):
"""Test creating a graph with no nodes."""
Expand Down Expand Up @@ -45,10 +45,10 @@ def test_3_nodes_1_rel(self):


class TestEntity:
"""Tests for the Entity schema."""
"""Tests for the Entity class."""

def test_overall(self):
"""Test creating an entity."""
"""Test creating an entity and converting it to a node."""
entity = Entity(
text="Alice", label="Person", properties={"foo": "bar"}
)
Expand All @@ -67,18 +67,55 @@ def test_overall(self):
assert entity.properties is not entity_reconstructed.properties

def test_missing_name(self):
"""Test creating an entity with missing name."""
"""Test creating an entity without a name property."""
node = Node(labels=["Person"], properties={})

with pytest.raises(ValueError, match="Node must have a name property"):
Entity.from_node(node)


class TestTriple:
"""Tests for the Triple schema."""
"""Tests for the Triple class."""

# def test_overall(self):
# triple = Triple(
# head="Alice",
# head_type="Person",
# relationship="KNOWS",
# tail="Bob",
# tail_type="Person",
# properties={"since": 1999},
# )

# assert triple.head == "Alice"
# assert triple.head_type == "Person"
# assert triple.relationship == "KNOWS"
# assert triple.tail == "Bob"
# assert triple.tail_type == "Person"

# rel = triple.to_relationship()

# assert rel.start_node.labels == ["Person"]
# assert rel.start_node.properties == {"name": "Alice"}
# assert rel.end_node.labels == ["Person"]
# assert rel.end_node.properties == {"name": "Bob"}
# assert rel.type == "KNOWS"
# assert rel.properties == {"since": 1999}

# triple_reconstructed = Triple.from_relationship(rel)

# assert triple.head == triple_reconstructed.head
# assert triple.head_type == triple_reconstructed.head_type
# assert triple.relationship == triple_reconstructed.relationship
# assert triple.tail == triple_reconstructed.tail
# assert triple.tail_type == triple_reconstructed.tail_type
# assert triple.properties == triple_reconstructed.properties

# # test properties copied
# assert triple.properties is not triple_reconstructed.properties

def test_missing_name(self):
"""Test creating a triple with missing name."""
"""Test creating a triple with a node missing a name property."""
rel = Relationship(
start_node=Node(labels=["Person"], properties={}),
end_node=Node(labels=["Person"], properties={"name": "Bob"}),
Expand Down
6 changes: 3 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ class TestWhyHow:
"""Tests for the WhyHow class."""

def test_constructor_missing_api_key(self, monkeypatch):
"""Test creating a WhyHow instance without an API key."""
"""Test that an error is raised if the API key is missing."""
monkeypatch.delenv("WHYHOW_API_KEY", raising=False)
with pytest.raises(ValueError, match="WHYHOW_API_KEY must be set"):
WhyHow()

def test_httpx_kwargs(self, monkeypatch):
"""Test passing httpx_kwargs to the constructor."""
"""Test that the client is initialized with the correct arguments."""
fake_httpx_client_inst = Mock(spec=Client)
fake_httpx_client_class = Mock(return_value=fake_httpx_client_inst)

Expand All @@ -42,7 +42,7 @@ def test_httpx_kwargs(self, monkeypatch):
assert client.httpx_client is fake_httpx_client_class.return_value

def test_base_url_twice(self):
"""Test setting base_url in httpx_kwargs."""
"""Test that an error is raised if base_url is set in httpx_kwargs."""
with pytest.raises(
ValueError, match="base_url cannot be set in httpx_kwargs."
):
Expand Down
6 changes: 6 additions & 0 deletions tests/test_dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Dummy test."""


def test():
"""Dummy test."""
assert True

0 comments on commit 07c795c

Please sign in to comment.