Skip to content

Commit

Permalink
Merge pull request #52 from mirumee/fix-45-proxy-errors
Browse files Browse the repository at this point in the history
Proxy errors and extensions in ProxySchema
  • Loading branch information
rafalp authored Mar 6, 2024
2 parents ff8698a + 92ea368 commit c243831
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 12 deletions.
2 changes: 2 additions & 0 deletions ariadne_graphql_proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from .narrow_graphql_query import narrow_graphql_query
from .proxy_resolver import ProxyResolver
from .proxy_root_value import ProxyRootValue
from .proxy_schema import ProxySchema
from .query_filter import QueryFilter, QueryFilterContext
from .remote_schema import get_remote_schema
Expand All @@ -47,6 +48,7 @@
__all__ = [
"ForeignKeyResolver",
"ProxyResolver",
"ProxyRootValue",
"ProxySchema",
"QueryFilter",
"QueryFilterContext",
Expand Down
30 changes: 30 additions & 0 deletions ariadne_graphql_proxy/proxy_root_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import List, Optional

from ariadne.types import BaseProxyRootValue, GraphQLResult


class ProxyRootValue(BaseProxyRootValue):
__slots__ = ("root_value", "errors", "extensions")

def __init__(
self,
root_value: Optional[dict] = None,
errors: Optional[List[dict]] = None,
extensions: Optional[dict] = None,
):
super().__init__(root_value)
self.errors = errors
self.extensions = extensions

def update_result(self, result: GraphQLResult) -> GraphQLResult:
success, data = super().update_result(result)

if self.errors:
data.setdefault("errors", [])
data["errors"] += self.errors

if self.extensions:
data.setdefault("extensions", {})
data["extensions"].update(self.extensions)

return success, data
76 changes: 66 additions & 10 deletions ariadne_graphql_proxy/proxy_schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from asyncio import gather
from functools import reduce
from inspect import isawaitable
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union

from ariadne.types import RootValue
from ariadne.types import BaseProxyRootValue, RootValue
from graphql import (
DocumentNode,
GraphQLInterfaceType,
Expand All @@ -17,6 +17,7 @@

from .copy import copy_schema
from .merge import merge_schemas
from .proxy_root_value import ProxyRootValue
from .query_filter import QueryFilter
from .remote_schema import get_remote_schema
from .standard_types import STANDARD_TYPES, add_missing_scalar_types
Expand All @@ -30,15 +31,24 @@


class ProxySchema:
def __init__(self, root_value: Optional[RootValue] = None):
def __init__(
self,
root_value: Optional[RootValue] = None,
proxy_root_value: Type[ProxyRootValue] = ProxyRootValue,
):
self.schemas: List[GraphQLSchema] = []
self.urls: List[Optional[str]] = []
self.headers: List[Optional[ProxyHeaders]] = []
self.proxy_errors: List[bool] = []
self.proxy_extensions: List[bool] = []
self.labels: List[str] = []
self.fields_map: Dict[str, Dict[str, Set[int]]] = {}
self.fields_types: Dict[str, Dict[str, str]] = {}
self.unions: Dict[str, List[str]] = {}
self.foreign_keys: Dict[str, Dict[str, List[str]]] = {}

self.proxy_root_value = proxy_root_value

self.schema: Optional[GraphQLSchema] = None
self.query_filter: Optional[QueryFilter] = None
self.root_value: Optional[RootValue] = root_value
Expand All @@ -54,12 +64,17 @@ def add_remote_schema(
exclude_directives: Optional[List[str]] = None,
exclude_directives_args: Optional[Dict[str, List[str]]] = None,
extra_fields: Optional[Dict[str, List[str]]] = None,
label: Optional[str] = None,
proxy_errors: bool = True,
proxy_extensions: bool = True,
) -> int:
if callable(headers):
remote_schema = get_remote_schema(url, headers(None))
else:
remote_schema = get_remote_schema(url, headers)

schema_id = len(self.schemas)

return self.add_schema(
remote_schema,
url,
Expand All @@ -70,6 +85,9 @@ def add_remote_schema(
exclude_directives=exclude_directives,
exclude_directives_args=exclude_directives_args,
extra_fields=extra_fields,
label=label or f"remote_{schema_id}",
proxy_errors=proxy_errors,
proxy_extensions=proxy_extensions,
)

def add_schema(
Expand All @@ -84,6 +102,9 @@ def add_schema(
exclude_directives: Optional[List[str]] = None,
exclude_directives_args: Optional[Dict[str, List[str]]] = None,
extra_fields: Optional[Dict[str, List[str]]] = None,
label: Optional[str] = None,
proxy_errors: bool = True,
proxy_extensions: bool = True,
) -> int:
if (
exclude_types
Expand All @@ -103,11 +124,15 @@ def add_schema(

schema.type_map = add_missing_scalar_types(schema.type_map)

schema_id = len(self.schemas)

self.schemas.append(schema)
self.urls.append(url)
self.headers.append(headers)
self.labels.append(label or f"schema_{schema_id}")
self.proxy_errors.append(proxy_errors)
self.proxy_extensions.append(proxy_extensions)

schema_id = len(self.schemas) - 1
for type_name, type_def in schema.type_map.items():
if type_name in STANDARD_TYPES:
continue
Expand Down Expand Up @@ -212,7 +237,7 @@ async def root_resolver(
operation_name: Optional[str],
variables: Optional[dict],
document: DocumentNode,
) -> Optional[dict]:
) -> Optional[Union[dict, BaseProxyRootValue]]:
if not self.query_filter:
raise RuntimeError(
"'get_final_schema' needs to be called to build final schema "
Expand Down Expand Up @@ -241,9 +266,13 @@ async def root_resolver(
if not queries:
return root_value

root_errors: List[dict] = []
root_extensions: dict = {}

subqueries_data = await gather(
*[
self.fetch_data(
schema_id,
context_value,
self.urls[schema_id],
self.headers[schema_id],
Expand All @@ -266,13 +295,32 @@ async def root_resolver(
]
)

for subquery_data in subqueries_data:
if subquery_data:
root_value.update(subquery_data)
for schema_id, subquery_data in subqueries_data:
label = self.labels[schema_id]
if isinstance(subquery_data.get("data"), dict):
root_value.update(subquery_data["data"])
if (
isinstance(subquery_data.get("errors"), list)
and self.proxy_errors[schema_id]
):
root_errors += self.clean_errors(label, subquery_data["errors"])
if (
isinstance(subquery_data.get("extensions"), dict)
and self.proxy_extensions[schema_id]
):
print("HERE")
root_extensions[label] = subquery_data["extensions"]

if root_errors or root_extensions:
return self.proxy_root_value(
root_value,
root_errors or None,
root_extensions or None,
)

return root_value or None

async def fetch_data(self, context, url, headers, json):
async def fetch_data(self, schema_id, context, url, headers, json):
async with AsyncClient() as client:
if callable(headers):
headers = headers(context)
Expand All @@ -284,4 +332,12 @@ async def fetch_data(self, context, url, headers, json):
)

query_data = r.json()
return query_data.get("data")
return (schema_id, query_data)

def clean_errors(self, label: str, errors: List[dict]) -> List[dict]:
clean_errors: List[dict] = []
for error in errors:
if isinstance(error, dict) and isinstance(error.get("path"), list):
error["path"].insert(0, label)
clean_errors.append(error)
return clean_errors
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Python Modules",
]
version = "0.2.0"
dependencies = ["graphql-core>=3.2.0,<3.3", "httpx", "ariadne"]
dependencies = [
"graphql-core>=3.2.0,<3.3",
"httpx",
"ariadne==0.23.0.b1",
]

[project.optional-dependencies]
test = [
Expand Down
128 changes: 128 additions & 0 deletions tests/test_proxy_root_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from ariadne_graphql_proxy import ProxyRootValue


def test_proxy_root_value_without_errors_or_extensions_skips_result_update():
result = False, {"data": "ok"}
root_value = ProxyRootValue()
assert root_value.update_result(result) == result


def test_proxy_root_value_with_errors_extends_result():
result = False, {"data": "ok"}
root_value = ProxyRootValue(errors=[{"message": "Test"}])
assert root_value.update_result(result) == (
False,
{
"data": "ok",
"errors": [
{
"message": "Test",
},
],
},
)


def test_proxy_root_value_with_extensions_extends_result():
result = False, {"data": "ok"}
root_value = ProxyRootValue(extensions={"score": "100"})
assert root_value.update_result(result) == (
False,
{
"data": "ok",
"extensions": {
"score": "100",
},
},
)


def test_proxy_root_value_with_errors_and_extensions_extends_result():
result = False, {"data": "ok"}
root_value = ProxyRootValue(
errors=[{"message": "Test"}],
extensions={"score": "100"},
)
assert root_value.update_result(result) == (
False,
{
"data": "ok",
"errors": [
{
"message": "Test",
},
],
"extensions": {
"score": "100",
},
},
)


def test_proxy_root_value_with_errors_updates_result():
result = False, {"data": "ok", "errors": [{"message": "Org"}]}

root_value = ProxyRootValue(errors=[{"message": "Test"}])
assert root_value.update_result(result) == (
False,
{
"data": "ok",
"errors": [
{
"message": "Org",
},
{
"message": "Test",
},
],
},
)


def test_proxy_root_value_with_extensions_updates_result():
result = False, {"data": "ok", "extensions": {"core": True}}
root_value = ProxyRootValue(extensions={"score": "100"})
assert root_value.update_result(result) == (
False,
{
"data": "ok",
"extensions": {
"core": True,
"score": "100",
},
},
)


def test_proxy_root_value_with_errors_and_extensions_updates_result():
result = (
False,
{
"data": "ok",
"errors": [{"message": "Org"}],
"extensions": {"core": True},
},
)

root_value = ProxyRootValue(
errors=[{"message": "Test"}],
extensions={"score": "100"},
)
assert root_value.update_result(result) == (
False,
{
"data": "ok",
"errors": [
{
"message": "Org",
},
{
"message": "Test",
},
],
"extensions": {
"core": True,
"score": "100",
},
},
)
Loading

0 comments on commit c243831

Please sign in to comment.