diff --git a/.gitignore b/.gitignore
index 47a82df0..1c86b9be 100644
--- a/.gitignore
+++ b/.gitignore
@@ -71,5 +71,8 @@ target/
*.sqlite3
.vscode
+# Schema
+*.gql
+
# mypy cache
.mypy_cache/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 470a29eb..262e7608 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,5 +1,5 @@
default_language_version:
- python: python3.7
+ python: python3.8
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
@@ -12,7 +12,7 @@ repos:
- id: trailing-whitespace
exclude: README.md
- repo: https://github.com/pycqa/isort
- rev: 5.10.1
+ rev: 5.12.0
hooks:
- id: isort
name: isort (python)
diff --git a/docs/filters.rst b/docs/filters.rst
new file mode 100644
index 00000000..ac36803d
--- /dev/null
+++ b/docs/filters.rst
@@ -0,0 +1,213 @@
+=======
+Filters
+=======
+
+Starting in graphene-sqlalchemy version 3, the SQLAlchemyConnectionField class implements filtering by default. The query utilizes a ``filter`` keyword to specify a filter class that inherits from ``graphene.InputObjectType``.
+
+Migrating from graphene-sqlalchemy-filter
+---------------------------------------------
+
+If like many of us, you have been using |graphene-sqlalchemy-filter|_ to implement filters and would like to use the in-built mechanism here, there are a couple key differences to note. Mainly, in an effort to simplify the generated schema, filter keywords are nested under their respective fields instead of concatenated. For example, the filter partial ``{usernameIn: ["moderator", "cool guy"]}`` would be represented as ``{username: {in: ["moderator", "cool guy"]}}``.
+
+.. |graphene-sqlalchemy-filter| replace:: ``graphene-sqlalchemy-filter``
+.. _graphene-sqlalchemy-filter: https://github.com/art1415926535/graphene-sqlalchemy-filter
+
+Further, some of the constructs found in libraries like `DGraph's DQL `_ have been implemented, so if you have created custom implementations for these features, you may want to take a look at the examples below.
+
+
+Example model
+-------------
+
+Take as example a Pet model similar to that in the sorting example. We will use variations on this arrangement for the following examples.
+
+.. code::
+
+ class Pet(Base):
+ __tablename__ = 'pets'
+ id = Column(Integer(), primary_key=True)
+ name = Column(String(30))
+ age = Column(Integer())
+
+
+ class PetNode(SQLAlchemyObjectType):
+ class Meta:
+ model = Pet
+
+
+ class Query(graphene.ObjectType):
+ allPets = SQLAlchemyConnectionField(PetNode.connection)
+
+
+Simple filter example
+---------------------
+
+Filters are defined at the object level through the ``BaseTypeFilter`` class. The ``BaseType`` encompasses both Graphene ``ObjectType``\ s and ``Interface``\ s. Each ``BaseTypeFilter`` instance may define fields via ``FieldFilter`` and relationships via ``RelationshipFilter``. Here's a basic example querying a single field on the Pet model:
+
+.. code::
+
+ allPets(filter: {name: {eq: "Fido"}}){
+ edges {
+ node {
+ name
+ }
+ }
+ }
+
+This will return all pets with the name "Fido".
+
+
+Custom filter types
+-------------------
+
+If you'd like to implement custom behavior for filtering a field, you can do so by extending one of the base filter classes in ``graphene_sqlalchemy.filters``. For example, if you'd like to add a ``divisible_by`` keyword to filter the age attribute on the ``Pet`` model, you can do so as follows:
+
+.. code:: python
+
+ class MathFilter(FloatFilter):
+ class Meta:
+ graphene_type = graphene.Float
+
+ @classmethod
+ def divisible_by_filter(cls, query, field, val: int) -> bool:
+ return is_(field % val, 0)
+
+ class PetType(SQLAlchemyObjectType):
+ ...
+
+ age = ORMField(filter_type=MathFilter)
+
+ class Query(graphene.ObjectType):
+ pets = SQLAlchemyConnectionField(PetType.connection)
+
+
+Filtering over relationships with RelationshipFilter
+----------------------------------------------------
+
+When a filter class field refers to another object in a relationship, you may nest filters on relationship object attributes. This happens directly for 1:1 and m:1 relationships and through the ``contains`` and ``containsExactly`` keywords for 1:n and m:n relationships.
+
+
+:1 relationships
+^^^^^^^^^^^^^^^^
+
+When an object or interface defines a singular relationship, relationship object attributes may be filtered directly like so:
+
+Take the following SQLAlchemy model definition as an example:
+
+.. code:: python
+
+ class Pet
+ ...
+ person_id = Column(Integer(), ForeignKey("people.id"))
+
+ class Person
+ ...
+ pets = relationship("Pet", backref="person")
+
+
+Then, this query will return all pets whose person is named "Ada":
+
+.. code::
+
+ allPets(filter: {
+ person: {name: {eq: "Ada"}}
+ }) {
+ ...
+ }
+
+
+:n relationships
+^^^^^^^^^^^^^^^^
+
+However, for plural relationships, relationship object attributes must be filtered through either ``contains`` or ``containsExactly``:
+
+Now, using a many-to-many model definition:
+
+.. code:: python
+
+ people_pets_table = sqlalchemy.Table(
+ "people_pets",
+ Base.metadata,
+ Column("person_id", ForeignKey("people.id")),
+ Column("pet_id", ForeignKey("pets.id")),
+ )
+
+ class Pet
+ ...
+
+ class Person
+ ...
+ pets = relationship("Pet", backref="people")
+
+
+this query will return all pets which have a person named "Ben" in their ``people`` list.
+
+.. code::
+
+ allPets(filter: {
+ people: {
+ contains: [{name: {eq: "Ben"}}],
+ }
+ }) {
+ ...
+ }
+
+
+and this one will return all pets which hvae a person list that contains exactly the people "Ada" and "Ben" and no fewer or people with other names.
+
+.. code::
+
+ allPets(filter: {
+ articles: {
+ containsExactly: [
+ {name: {eq: "Ada"}},
+ {name: {eq: "Ben"}},
+ ],
+ }
+ }) {
+ ...
+ }
+
+And/Or Logic
+------------
+
+Filters can also be chained together logically using `and` and `or` keywords nested under `filter`. Clauses are passed directly to `sqlalchemy.and_` and `slqlalchemy.or_`, respectively. To return all pets named "Fido" or "Spot", use:
+
+
+.. code::
+
+ allPets(filter: {
+ or: [
+ {name: {eq: "Fido"}},
+ {name: {eq: "Spot"}},
+ ]
+ }) {
+ ...
+ }
+
+And to return all pets that are named "Fido" or are 5 years old and named "Spot", use:
+
+.. code::
+
+ allPets(filter: {
+ or: [
+ {name: {eq: "Fido"}},
+ { and: [
+ {name: {eq: "Spot"}},
+ {age: {eq: 5}}
+ }
+ ]
+ }) {
+ ...
+ }
+
+
+Hybrid Property support
+-----------------------
+
+Filtering over SQLAlchemy `hybrid properties `_ is fully supported.
+
+
+Reporting feedback and bugs
+---------------------------
+
+Filtering is a new feature to graphene-sqlalchemy, so please `post an issue on Github `_ if you run into any problems or have ideas on how to improve the implementation.
diff --git a/docs/index.rst b/docs/index.rst
index b663752a..4245eba8 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -10,6 +10,7 @@ Contents:
inheritance
relay
tips
+ filters
examples
tutorial
api
diff --git a/examples/filters/README.md b/examples/filters/README.md
new file mode 100644
index 00000000..a72e75de
--- /dev/null
+++ b/examples/filters/README.md
@@ -0,0 +1,47 @@
+Example Filters Project
+================================
+
+This example highlights the ability to filter queries in graphene-sqlalchemy.
+
+The project contains two models, one named `Department` and another
+named `Employee`.
+
+Getting started
+---------------
+
+First you'll need to get the source of the project. Do this by cloning the
+whole Graphene-SQLAlchemy repository:
+
+```bash
+# Get the example project code
+git clone https://github.com/graphql-python/graphene-sqlalchemy.git
+cd graphene-sqlalchemy/examples/filters
+```
+
+It is recommended to create a virtual environment
+for this project. We'll do this using
+[virtualenv](http://docs.python-guide.org/en/latest/dev/virtualenvs/)
+to keep things simple,
+but you may also find something like
+[virtualenvwrapper](https://virtualenvwrapper.readthedocs.org/en/latest/)
+to be useful:
+
+```bash
+# Create a virtualenv in which we can install the dependencies
+virtualenv env
+source env/bin/activate
+```
+
+Install our dependencies:
+
+```bash
+pip install -r requirements.txt
+```
+
+The following command will setup the database, and start the server:
+
+```bash
+python app.py
+```
+
+Now head over to your favorite GraphQL client, POST to [http://127.0.0.1:5000/graphql](http://127.0.0.1:5000/graphql) and run some queries!
diff --git a/examples/filters/__init__.py b/examples/filters/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/examples/filters/app.py b/examples/filters/app.py
new file mode 100644
index 00000000..ab918da7
--- /dev/null
+++ b/examples/filters/app.py
@@ -0,0 +1,16 @@
+from database import init_db
+from fastapi import FastAPI
+from schema import schema
+from starlette_graphene3 import GraphQLApp, make_playground_handler
+
+
+def create_app() -> FastAPI:
+ init_db()
+ app = FastAPI()
+
+ app.mount("/graphql", GraphQLApp(schema, on_get=make_playground_handler()))
+
+ return app
+
+
+app = create_app()
diff --git a/examples/filters/database.py b/examples/filters/database.py
new file mode 100644
index 00000000..8f6522f7
--- /dev/null
+++ b/examples/filters/database.py
@@ -0,0 +1,49 @@
+from sqlalchemy import create_engine
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import sessionmaker
+
+Base = declarative_base()
+engine = create_engine(
+ "sqlite://", connect_args={"check_same_thread": False}, echo=True
+)
+session_factory = sessionmaker(autocommit=False, autoflush=False, bind=engine)
+
+from sqlalchemy.orm import scoped_session as scoped_session_factory
+
+scoped_session = scoped_session_factory(session_factory)
+
+Base.query = scoped_session.query_property()
+Base.metadata.bind = engine
+
+
+def init_db():
+ from models import Person, Pet, Toy
+
+ Base.metadata.create_all()
+ scoped_session.execute("PRAGMA foreign_keys=on")
+ db = scoped_session()
+
+ person1 = Person(name="A")
+ person2 = Person(name="B")
+
+ pet1 = Pet(name="Spot")
+ pet2 = Pet(name="Milo")
+
+ toy1 = Toy(name="disc")
+ toy2 = Toy(name="ball")
+
+ person1.pet = pet1
+ person2.pet = pet2
+
+ pet1.toys.append(toy1)
+ pet2.toys.append(toy1)
+ pet2.toys.append(toy2)
+
+ db.add(person1)
+ db.add(person2)
+ db.add(pet1)
+ db.add(pet2)
+ db.add(toy1)
+ db.add(toy2)
+
+ db.commit()
diff --git a/examples/filters/models.py b/examples/filters/models.py
new file mode 100644
index 00000000..1b22956b
--- /dev/null
+++ b/examples/filters/models.py
@@ -0,0 +1,34 @@
+import sqlalchemy
+from database import Base
+from sqlalchemy import Column, ForeignKey, Integer, String
+from sqlalchemy.orm import relationship
+
+
+class Pet(Base):
+ __tablename__ = "pets"
+ id = Column(Integer(), primary_key=True)
+ name = Column(String(30))
+ age = Column(Integer())
+ person_id = Column(Integer(), ForeignKey("people.id"))
+
+
+class Person(Base):
+ __tablename__ = "people"
+ id = Column(Integer(), primary_key=True)
+ name = Column(String(100))
+ pets = relationship("Pet", backref="person")
+
+
+pets_toys_table = sqlalchemy.Table(
+ "pets_toys",
+ Base.metadata,
+ Column("pet_id", ForeignKey("pets.id")),
+ Column("toy_id", ForeignKey("toys.id")),
+)
+
+
+class Toy(Base):
+ __tablename__ = "toys"
+ id = Column(Integer(), primary_key=True)
+ name = Column(String(30))
+ pets = relationship("Pet", secondary=pets_toys_table, backref="toys")
diff --git a/examples/filters/requirements.txt b/examples/filters/requirements.txt
new file mode 100644
index 00000000..b433ec59
--- /dev/null
+++ b/examples/filters/requirements.txt
@@ -0,0 +1,3 @@
+-e ../../
+fastapi
+uvicorn
diff --git a/examples/filters/run.sh b/examples/filters/run.sh
new file mode 100755
index 00000000..ec365444
--- /dev/null
+++ b/examples/filters/run.sh
@@ -0,0 +1 @@
+uvicorn app:app --port 5000
diff --git a/examples/filters/schema.py b/examples/filters/schema.py
new file mode 100644
index 00000000..2728cab7
--- /dev/null
+++ b/examples/filters/schema.py
@@ -0,0 +1,42 @@
+from models import Person as PersonModel
+from models import Pet as PetModel
+from models import Toy as ToyModel
+
+import graphene
+from graphene import relay
+from graphene_sqlalchemy import SQLAlchemyObjectType
+from graphene_sqlalchemy.fields import SQLAlchemyConnectionField
+
+
+class Pet(SQLAlchemyObjectType):
+ class Meta:
+ model = PetModel
+ name = "Pet"
+ interfaces = (relay.Node,)
+ batching = True
+
+
+class Person(SQLAlchemyObjectType):
+ class Meta:
+ model = PersonModel
+ name = "Person"
+ interfaces = (relay.Node,)
+ batching = True
+
+
+class Toy(SQLAlchemyObjectType):
+ class Meta:
+ model = ToyModel
+ name = "Toy"
+ interfaces = (relay.Node,)
+ batching = True
+
+
+class Query(graphene.ObjectType):
+ node = relay.Node.Field()
+ pets = SQLAlchemyConnectionField(Pet.connection)
+ people = SQLAlchemyConnectionField(Person.connection)
+ toys = SQLAlchemyConnectionField(Toy.connection)
+
+
+schema = graphene.Schema(query=Query)
diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py
index 84c7886c..efcf3c6c 100644
--- a/graphene_sqlalchemy/converter.py
+++ b/graphene_sqlalchemy/converter.py
@@ -3,7 +3,7 @@
import typing
import uuid
from decimal import Decimal
-from typing import Any, Optional, Union, cast
+from typing import Any, Dict, Optional, TypeVar, Union, cast
from sqlalchemy import types as sqa_types
from sqlalchemy.dialects import postgresql
@@ -21,7 +21,6 @@
from .batching import get_batch_resolver
from .enums import enum_for_sa_enum
-from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory
from .registry import Registry, get_global_registry
from .resolvers import get_attr_resolver, get_custom_resolver
from .utils import (
@@ -237,6 +236,8 @@ def _convert_o2m_or_m2m_relationship(
:param dict field_kwargs:
:rtype: Field
"""
+ from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory
+
child_type = obj_type._meta.registry.get_type_for_model(
relationship_prop.mapper.entity
)
@@ -332,8 +333,12 @@ def convert_sqlalchemy_type( # noqa
type_arg: Any,
column: Optional[Union[MapperProperty, hybrid_property]] = None,
registry: Registry = None,
+ replace_type_vars: typing.Dict[str, Any] = None,
**kwargs,
):
+ if replace_type_vars and type_arg in replace_type_vars:
+ return replace_type_vars[type_arg]
+
# No valid type found, raise an error
raise TypeError(
@@ -373,6 +378,11 @@ def convert_scalar_type(type_arg: Any, **kwargs):
return type_arg
+@convert_sqlalchemy_type.register(safe_isinstance(TypeVar))
+def convert_type_var(type_arg: Any, replace_type_vars: Dict[TypeVar, Any], **kwargs):
+ return replace_type_vars[type_arg]
+
+
@convert_sqlalchemy_type.register(column_type_eq(str))
@convert_sqlalchemy_type.register(column_type_eq(sqa_types.String))
@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Text))
@@ -618,6 +628,7 @@ def convert_sqlalchemy_hybrid_property_union(type_arg: Any, **kwargs):
# Just get the T out of the list of arguments by filtering out the NoneType
nested_types = list(filter(lambda x: not type(None) == x, type_arg.__args__))
+ # TODO redo this for , *args, **kwargs
# Map the graphene types to the nested types.
# We use convert_sqlalchemy_hybrid_property_type instead of the registry to account for ForwardRefs, Lists,...
graphene_types = list(map(convert_sqlalchemy_type, nested_types))
diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py
index 6dbc134f..ef798852 100644
--- a/graphene_sqlalchemy/fields.py
+++ b/graphene_sqlalchemy/fields.py
@@ -5,13 +5,19 @@
from promise import Promise, is_thenable
from sqlalchemy.orm.query import Query
-from graphene import NonNull
from graphene.relay import Connection, ConnectionField
from graphene.relay.connection import connection_adapter, page_info_adapter
from graphql_relay import connection_from_array_slice
from .batching import get_batch_resolver
-from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, EnumValue, get_query, get_session
+from .filters import BaseTypeFilter
+from .utils import (
+ SQL_VERSION_HIGHER_EQUAL_THAN_1_4,
+ EnumValue,
+ get_nullable_type,
+ get_query,
+ get_session,
+)
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
from sqlalchemy.ext.asyncio import AsyncSession
@@ -40,6 +46,7 @@ def type(self):
def __init__(self, type_, *args, **kwargs):
nullable_type = get_nullable_type(type_)
+ # Handle Sorting and Filtering
if (
"sort" not in kwargs
and nullable_type
@@ -57,6 +64,19 @@ def __init__(self, type_, *args, **kwargs):
)
elif "sort" in kwargs and kwargs["sort"] is None:
del kwargs["sort"]
+
+ if (
+ "filter" not in kwargs
+ and nullable_type
+ and issubclass(nullable_type, Connection)
+ ):
+ # Only add filtering if a filter argument exists on the object type
+ filter_argument = nullable_type.Edge.node._type.get_filter_argument()
+ if filter_argument:
+ kwargs.setdefault("filter", filter_argument)
+ elif "filter" in kwargs and kwargs["filter"] is None:
+ del kwargs["filter"]
+
super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)
@property
@@ -64,7 +84,7 @@ def model(self):
return get_nullable_type(self.type)._meta.node._meta.model
@classmethod
- def get_query(cls, model, info, sort=None, **args):
+ def get_query(cls, model, info, sort=None, filter=None, **args):
query = get_query(model, info.context)
if sort is not None:
if not isinstance(sort, list):
@@ -80,6 +100,12 @@ def get_query(cls, model, info, sort=None, **args):
else:
sort_args.append(item)
query = query.order_by(*sort_args)
+
+ if filter is not None:
+ assert isinstance(filter, dict)
+ filter_type: BaseTypeFilter = type(filter)
+ query, clauses = filter_type.execute_filters(query, filter)
+ query = query.filter(*clauses)
return query
@classmethod
@@ -264,9 +290,3 @@ def unregisterConnectionFieldFactory():
)
global __connectionFactory
__connectionFactory = UnsortedSQLAlchemyConnectionField
-
-
-def get_nullable_type(_type):
- if isinstance(_type, NonNull):
- return _type.of_type
- return _type
diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py
new file mode 100644
index 00000000..bb422724
--- /dev/null
+++ b/graphene_sqlalchemy/filters.py
@@ -0,0 +1,525 @@
+import re
+from typing import Any, Dict, List, Tuple, Type, TypeVar, Union
+
+from graphql import Undefined
+from sqlalchemy import and_, not_, or_
+from sqlalchemy.orm import Query, aliased # , selectinload
+
+import graphene
+from graphene.types.inputobjecttype import (
+ InputObjectTypeContainer,
+ InputObjectTypeOptions,
+)
+from graphene_sqlalchemy.utils import is_list
+
+BaseTypeFilterSelf = TypeVar(
+ "BaseTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer
+)
+
+
+class SQLAlchemyFilterInputField(graphene.InputField):
+ def __init__(
+ self,
+ type_,
+ model_attr,
+ name=None,
+ default_value=Undefined,
+ deprecation_reason=None,
+ description=None,
+ required=False,
+ _creation_counter=None,
+ **extra_args,
+ ):
+ super(SQLAlchemyFilterInputField, self).__init__(
+ type_,
+ name,
+ default_value,
+ deprecation_reason,
+ description,
+ required,
+ _creation_counter,
+ **extra_args,
+ )
+
+ self.model_attr = model_attr
+
+
+def _get_functions_by_regex(
+ regex: str, subtract_regex: str, class_: Type
+) -> List[Tuple[str, Dict[str, Any]]]:
+ function_regex = re.compile(regex)
+
+ matching_functions = []
+
+ # Search the entire class for functions matching the filter regex
+ for fn in dir(class_):
+ func_attr = getattr(class_, fn)
+ # Check if attribute is a function
+ if callable(func_attr) and function_regex.match(fn):
+ # add function and attribute name to the list
+ matching_functions.append(
+ (re.sub(subtract_regex, "", fn), func_attr.__annotations__)
+ )
+ return matching_functions
+
+
+class BaseTypeFilter(graphene.InputObjectType):
+ @classmethod
+ def __init_subclass_with_meta__(
+ cls, filter_fields=None, model=None, _meta=None, **options
+ ):
+ from graphene_sqlalchemy.converter import convert_sqlalchemy_type
+
+ # Init meta options class if it doesn't exist already
+ if not _meta:
+ _meta = InputObjectTypeOptions(cls)
+
+ logic_functions = _get_functions_by_regex(".+_logic$", "_logic$", cls)
+
+ new_filter_fields = {}
+ # Generate Graphene Fields from the filter functions based on type hints
+ for field_name, _annotations in logic_functions:
+ assert (
+ "val" in _annotations
+ ), "Each filter method must have a value field with valid type annotations"
+ # If type is generic, replace with actual type of filter class
+
+ replace_type_vars = {BaseTypeFilterSelf: cls}
+ field_type = convert_sqlalchemy_type(
+ _annotations.get("val", str), replace_type_vars=replace_type_vars
+ )
+ new_filter_fields.update({field_name: graphene.InputField(field_type)})
+ # Add all fields to the meta options. graphene.InputObjectType will take care of the rest
+
+ if _meta.fields:
+ _meta.fields.update(filter_fields)
+ else:
+ _meta.fields = filter_fields
+ _meta.fields.update(new_filter_fields)
+
+ _meta.model = model
+
+ super(BaseTypeFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options)
+
+ @classmethod
+ def and_logic(
+ cls,
+ query,
+ filter_type: "BaseTypeFilter",
+ val: List[BaseTypeFilterSelf],
+ model_alias=None,
+ ):
+ # # Get the model to join on the Filter Query
+ # joined_model = filter_type._meta.model
+ # # Always alias the model
+ # joined_model_alias = aliased(joined_model)
+ clauses = []
+ for value in val:
+ # # Join the aliased model onto the query
+ # query = query.join(model_field.of_type(joined_model_alias))
+
+ query, _clauses = filter_type.execute_filters(
+ query, value, model_alias=model_alias
+ ) # , model_alias=joined_model_alias)
+ clauses += _clauses
+
+ return query, [and_(*clauses)]
+
+ @classmethod
+ def or_logic(
+ cls,
+ query,
+ filter_type: "BaseTypeFilter",
+ val: List[BaseTypeFilterSelf],
+ model_alias=None,
+ ):
+ # # Get the model to join on the Filter Query
+ # joined_model = filter_type._meta.model
+ # # Always alias the model
+ # joined_model_alias = aliased(joined_model)
+
+ clauses = []
+ for value in val:
+ # # Join the aliased model onto the query
+ # query = query.join(model_field.of_type(joined_model_alias))
+
+ query, _clauses = filter_type.execute_filters(
+ query, value, model_alias=model_alias
+ ) # , model_alias=joined_model_alias)
+ clauses += _clauses
+
+ return query, [or_(*clauses)]
+
+ @classmethod
+ def execute_filters(
+ cls, query, filter_dict: Dict[str, Any], model_alias=None
+ ) -> Tuple[Query, List[Any]]:
+ model = cls._meta.model
+ if model_alias:
+ model = model_alias
+
+ clauses = []
+
+ for field, field_filters in filter_dict.items():
+ # Relationships are Dynamic, we need to resolve them fist
+ # Maybe we can cache these dynamics to improve efficiency
+ # Check with a profiler is required to determine necessity
+ input_field = cls._meta.fields[field]
+ if isinstance(input_field, graphene.Dynamic):
+ input_field = input_field.get_type()
+ field_filter_type = input_field.type
+ else:
+ field_filter_type = cls._meta.fields[field].type
+ # raise Exception
+ # TODO we need to save the relationship props in the meta fields array
+ # to conduct joins and alias the joins (in case there are duplicate joins: A->B A->C B->C)
+ if field == "and":
+ query, _clauses = cls.and_logic(
+ query, field_filter_type.of_type, field_filters, model_alias=model
+ )
+ clauses.extend(_clauses)
+ elif field == "or":
+ query, _clauses = cls.or_logic(
+ query, field_filter_type.of_type, field_filters, model_alias=model
+ )
+ clauses.extend(_clauses)
+ else:
+ # Get the model attr from the inputfield in case the field is aliased in graphql
+ model_field = getattr(model, input_field.model_attr or field)
+ if issubclass(field_filter_type, BaseTypeFilter):
+ # Get the model to join on the Filter Query
+ joined_model = field_filter_type._meta.model
+ # Always alias the model
+ joined_model_alias = aliased(joined_model)
+ # Join the aliased model onto the query
+ query = query.join(model_field.of_type(joined_model_alias))
+ # Pass the joined query down to the next object type filter for processing
+ query, _clauses = field_filter_type.execute_filters(
+ query, field_filters, model_alias=joined_model_alias
+ )
+ clauses.extend(_clauses)
+ if issubclass(field_filter_type, RelationshipFilter):
+ # TODO see above; not yet working
+ relationship_prop = field_filter_type._meta.model
+ # Always alias the model
+ # joined_model_alias = aliased(relationship_prop)
+
+ # Join the aliased model onto the query
+ # query = query.join(model_field.of_type(joined_model_alias))
+ # todo should we use selectinload here instead of join for large lists?
+
+ query, _clauses = field_filter_type.execute_filters(
+ query, model, model_field, field_filters, relationship_prop
+ )
+ clauses.extend(_clauses)
+ elif issubclass(field_filter_type, FieldFilter):
+ query, _clauses = field_filter_type.execute_filters(
+ query, model_field, field_filters
+ )
+ clauses.extend(_clauses)
+
+ return query, clauses
+
+
+ScalarFilterInputType = TypeVar("ScalarFilterInputType")
+
+
+class FieldFilterOptions(InputObjectTypeOptions):
+ graphene_type: Type = None
+
+
+class FieldFilter(graphene.InputObjectType):
+ """Basic Filter for Scalars in Graphene.
+ We want this filter to use Dynamic fields so it provides the base
+ filtering methods ("eq, nEq") for different types of scalars.
+ The Dynamic fields will resolve to Meta.filtered_type"""
+
+ @classmethod
+ def __init_subclass_with_meta__(cls, graphene_type=None, _meta=None, **options):
+ from .converter import convert_sqlalchemy_type
+
+ # get all filter functions
+
+ filter_functions = _get_functions_by_regex(".+_filter$", "_filter$", cls)
+
+ # Init meta options class if it doesn't exist already
+ if not _meta:
+ _meta = FieldFilterOptions(cls)
+
+ if not _meta.graphene_type:
+ _meta.graphene_type = graphene_type
+
+ new_filter_fields = {}
+ # Generate Graphene Fields from the filter functions based on type hints
+ for field_name, _annotations in filter_functions:
+ assert (
+ "val" in _annotations
+ ), "Each filter method must have a value field with valid type annotations"
+ # If type is generic, replace with actual type of filter class
+ replace_type_vars = {ScalarFilterInputType: _meta.graphene_type}
+ field_type = convert_sqlalchemy_type(
+ _annotations.get("val", str), replace_type_vars=replace_type_vars
+ )
+ new_filter_fields.update({field_name: graphene.InputField(field_type)})
+
+ # Add all fields to the meta options. graphene.InputbjectType will take care of the rest
+ if _meta.fields:
+ _meta.fields.update(new_filter_fields)
+ else:
+ _meta.fields = new_filter_fields
+
+ # Pass modified meta to the super class
+ super(FieldFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options)
+
+ # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method
+ @classmethod
+ def eq_filter(
+ cls, query, field, val: ScalarFilterInputType
+ ) -> Union[Tuple[Query, Any], Any]:
+ return field == val
+
+ @classmethod
+ def n_eq_filter(
+ cls, query, field, val: ScalarFilterInputType
+ ) -> Union[Tuple[Query, Any], Any]:
+ return not_(field == val)
+
+ @classmethod
+ def in_filter(cls, query, field, val: List[ScalarFilterInputType]):
+ return field.in_(val)
+
+ @classmethod
+ def not_in_filter(cls, query, field, val: List[ScalarFilterInputType]):
+ return field.notin_(val)
+
+ # TODO add like/ilike
+
+ @classmethod
+ def execute_filters(
+ cls, query, field, filter_dict: Dict[str, any]
+ ) -> Tuple[Query, List[Any]]:
+ clauses = []
+ for filt, val in filter_dict.items():
+ clause = getattr(cls, filt + "_filter")(query, field, val)
+ if isinstance(clause, tuple):
+ query, clause = clause
+ clauses.append(clause)
+
+ return query, clauses
+
+
+class SQLEnumFilter(FieldFilter):
+ """Basic Filter for Scalars in Graphene.
+ We want this filter to use Dynamic fields so it provides the base
+ filtering methods ("eq, nEq") for different types of scalars.
+ The Dynamic fields will resolve to Meta.filtered_type"""
+
+ class Meta:
+ graphene_type = graphene.Enum
+
+ # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method
+ @classmethod
+ def eq_filter(
+ cls, query, field, val: ScalarFilterInputType
+ ) -> Union[Tuple[Query, Any], Any]:
+ return field == val.value
+
+ @classmethod
+ def n_eq_filter(
+ cls, query, field, val: ScalarFilterInputType
+ ) -> Union[Tuple[Query, Any], Any]:
+ return not_(field == val.value)
+
+
+class PyEnumFilter(FieldFilter):
+ """Basic Filter for Scalars in Graphene.
+ We want this filter to use Dynamic fields so it provides the base
+ filtering methods ("eq, nEq") for different types of scalars.
+ The Dynamic fields will resolve to Meta.filtered_type"""
+
+ class Meta:
+ graphene_type = graphene.Enum
+
+ # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method
+ @classmethod
+ def eq_filter(
+ cls, query, field, val: ScalarFilterInputType
+ ) -> Union[Tuple[Query, Any], Any]:
+ return field == val
+
+ @classmethod
+ def n_eq_filter(
+ cls, query, field, val: ScalarFilterInputType
+ ) -> Union[Tuple[Query, Any], Any]:
+ return not_(field == val)
+
+
+class StringFilter(FieldFilter):
+ class Meta:
+ graphene_type = graphene.String
+
+ @classmethod
+ def like_filter(cls, query, field, val: ScalarFilterInputType) -> bool:
+ return field.like(val)
+
+ @classmethod
+ def ilike_filter(cls, query, field, val: ScalarFilterInputType) -> bool:
+ return field.ilike(val)
+
+ @classmethod
+ def notlike_filter(cls, query, field, val: ScalarFilterInputType) -> bool:
+ return field.notlike(val)
+
+
+class BooleanFilter(FieldFilter):
+ class Meta:
+ graphene_type = graphene.Boolean
+
+
+class OrderedFilter(FieldFilter):
+ class Meta:
+ abstract = True
+
+ @classmethod
+ def gt_filter(cls, query, field, val: ScalarFilterInputType) -> bool:
+ return field > val
+
+ @classmethod
+ def gte_filter(cls, query, field, val: ScalarFilterInputType) -> bool:
+ return field >= val
+
+ @classmethod
+ def lt_filter(cls, query, field, val: ScalarFilterInputType) -> bool:
+ return field < val
+
+ @classmethod
+ def lte_filter(cls, query, field, val: ScalarFilterInputType) -> bool:
+ return field <= val
+
+
+class NumberFilter(OrderedFilter):
+ """Intermediate Filter class since all Numbers are in an order relationship (support <, > etc)"""
+
+ class Meta:
+ abstract = True
+
+
+class FloatFilter(NumberFilter):
+ """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes"""
+
+ class Meta:
+ graphene_type = graphene.Float
+
+
+class IntFilter(NumberFilter):
+ class Meta:
+ graphene_type = graphene.Int
+
+
+class DateFilter(OrderedFilter):
+ """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes"""
+
+ class Meta:
+ graphene_type = graphene.Date
+
+
+class IdFilter(FieldFilter):
+ class Meta:
+ graphene_type = graphene.ID
+
+
+class RelationshipFilter(graphene.InputObjectType):
+ @classmethod
+ def __init_subclass_with_meta__(
+ cls, base_type_filter=None, model=None, _meta=None, **options
+ ):
+ if not base_type_filter:
+ raise Exception("Relationship Filters must be specific to an object type")
+ # Init meta options class if it doesn't exist already
+ if not _meta:
+ _meta = InputObjectTypeOptions(cls)
+
+ # get all filter functions
+ filter_functions = _get_functions_by_regex(".+_filter$", "_filter$", cls)
+
+ relationship_filters = {}
+
+ # Generate Graphene Fields from the filter functions based on type hints
+ for field_name, _annotations in filter_functions:
+ assert (
+ "val" in _annotations
+ ), "Each filter method must have a value field with valid type annotations"
+ # If type is generic, replace with actual type of filter class
+ if is_list(_annotations["val"]):
+ relationship_filters.update(
+ {field_name: graphene.InputField(graphene.List(base_type_filter))}
+ )
+ else:
+ relationship_filters.update(
+ {field_name: graphene.InputField(base_type_filter)}
+ )
+
+ # Add all fields to the meta options. graphene.InputObjectType will take care of the rest
+ if _meta.fields:
+ _meta.fields.update(relationship_filters)
+ else:
+ _meta.fields = relationship_filters
+
+ _meta.model = model
+ _meta.base_type_filter = base_type_filter
+ super(RelationshipFilter, cls).__init_subclass_with_meta__(
+ _meta=_meta, **options
+ )
+
+ @classmethod
+ def contains_filter(
+ cls,
+ query,
+ parent_model,
+ field,
+ relationship_prop,
+ val: List[ScalarFilterInputType],
+ ):
+ clauses = []
+ for v in val:
+ # Always alias the model
+ joined_model_alias = aliased(relationship_prop)
+
+ # Join the aliased model onto the query
+ query = query.join(field.of_type(joined_model_alias)).distinct()
+ # pass the alias so group can join group
+ query, _clauses = cls._meta.base_type_filter.execute_filters(
+ query, v, model_alias=joined_model_alias
+ )
+ clauses.append(and_(*_clauses))
+ return query, [or_(*clauses)]
+
+ @classmethod
+ def contains_exactly_filter(
+ cls,
+ query,
+ parent_model,
+ field,
+ relationship_prop,
+ val: List[ScalarFilterInputType],
+ ):
+ raise NotImplementedError
+
+ @classmethod
+ def execute_filters(
+ cls: Type[FieldFilter],
+ query,
+ parent_model,
+ field,
+ filter_dict: Dict,
+ relationship_prop,
+ ) -> Tuple[Query, List[Any]]:
+ query, clauses = (query, [])
+
+ for filt, val in filter_dict.items():
+ query, _clauses = getattr(cls, filt + "_filter")(
+ query, parent_model, field, relationship_prop, val
+ )
+ clauses += _clauses
+
+ return query, clauses
diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py
index 3c463013..b959d221 100644
--- a/graphene_sqlalchemy/registry.py
+++ b/graphene_sqlalchemy/registry.py
@@ -1,10 +1,15 @@
+import inspect
from collections import defaultdict
-from typing import List, Type
+from typing import TYPE_CHECKING, List, Type
from sqlalchemy.types import Enum as SQLAlchemyEnumType
import graphene
from graphene import Enum
+from graphene.types.base import BaseType
+
+if TYPE_CHECKING: # pragma: no_cover
+ from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter
class Registry(object):
@@ -16,6 +21,30 @@ def __init__(self):
self._registry_enums = {}
self._registry_sort_enums = {}
self._registry_unions = {}
+ self._registry_scalar_filters = {}
+ self._registry_base_type_filters = {}
+ self._registry_relationship_filters = {}
+
+ self._init_base_filters()
+
+ def _init_base_filters(self):
+ import graphene_sqlalchemy.filters as gsqa_filters
+
+ from .filters import FieldFilter
+
+ field_filter_classes = [
+ filter_cls[1]
+ for filter_cls in inspect.getmembers(gsqa_filters, inspect.isclass)
+ if (
+ filter_cls[1] is not FieldFilter
+ and FieldFilter in filter_cls[1].__mro__
+ and getattr(filter_cls[1]._meta, "graphene_type", False)
+ )
+ ]
+ for field_filter_class in field_filter_classes:
+ self.register_filter_for_scalar_type(
+ field_filter_class._meta.graphene_type, field_filter_class
+ )
def register(self, obj_type):
from .types import SQLAlchemyBase
@@ -99,6 +128,110 @@ def register_union_type(
def get_union_for_object_types(self, obj_types: List[Type[graphene.ObjectType]]):
return self._registry_unions.get(frozenset(obj_types))
+ # Filter Scalar Fields of Object Types
+ def register_filter_for_scalar_type(
+ self, scalar_type: Type[graphene.Scalar], filter_obj: Type["FieldFilter"]
+ ):
+ from .filters import FieldFilter
+
+ if not isinstance(scalar_type, type(graphene.Scalar)):
+ raise TypeError("Expected Scalar, but got: {!r}".format(scalar_type))
+
+ if not issubclass(filter_obj, FieldFilter):
+ raise TypeError("Expected ScalarFilter, but got: {!r}".format(filter_obj))
+ self._registry_scalar_filters[scalar_type] = filter_obj
+
+ def get_filter_for_sql_enum_type(
+ self, enum_type: Type[graphene.Enum]
+ ) -> Type["FieldFilter"]:
+ from .filters import SQLEnumFilter
+
+ filter_type = self._registry_scalar_filters.get(enum_type)
+ if not filter_type:
+ filter_type = SQLEnumFilter.create_type(
+ f"Default{enum_type.__name__}EnumFilter", graphene_type=enum_type
+ )
+ self._registry_scalar_filters[enum_type] = filter_type
+ return filter_type
+
+ def get_filter_for_py_enum_type(
+ self, enum_type: Type[graphene.Enum]
+ ) -> Type["FieldFilter"]:
+ from .filters import PyEnumFilter
+
+ filter_type = self._registry_scalar_filters.get(enum_type)
+ if not filter_type:
+ filter_type = PyEnumFilter.create_type(
+ f"Default{enum_type.__name__}EnumFilter", graphene_type=enum_type
+ )
+ self._registry_scalar_filters[enum_type] = filter_type
+ return filter_type
+
+ def get_filter_for_scalar_type(
+ self, scalar_type: Type[graphene.Scalar]
+ ) -> Type["FieldFilter"]:
+ from .filters import FieldFilter
+
+ filter_type = self._registry_scalar_filters.get(scalar_type)
+ if not filter_type:
+ filter_type = FieldFilter.create_type(
+ f"Default{scalar_type.__name__}ScalarFilter", graphene_type=scalar_type
+ )
+ self._registry_scalar_filters[scalar_type] = filter_type
+
+ return filter_type
+
+ # TODO register enums automatically
+ def register_filter_for_enum_type(
+ self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"]
+ ):
+ from .filters import FieldFilter
+
+ if not issubclass(enum_type, graphene.Enum):
+ raise TypeError("Expected Enum, but got: {!r}".format(enum_type))
+
+ if not issubclass(filter_obj, FieldFilter):
+ raise TypeError("Expected FieldFilter, but got: {!r}".format(filter_obj))
+ self._registry_scalar_filters[enum_type] = filter_obj
+
+ # Filter Base Types
+ def register_filter_for_base_type(
+ self,
+ base_type: Type[BaseType],
+ filter_obj: Type["BaseTypeFilter"],
+ ):
+ from .filters import BaseTypeFilter
+
+ if not issubclass(base_type, BaseType):
+ raise TypeError("Expected BaseType, but got: {!r}".format(base_type))
+
+ if not issubclass(filter_obj, BaseTypeFilter):
+ raise TypeError("Expected BaseTypeFilter, but got: {!r}".format(filter_obj))
+ self._registry_base_type_filters[base_type] = filter_obj
+
+ def get_filter_for_base_type(self, base_type: Type[BaseType]):
+ return self._registry_base_type_filters.get(base_type)
+
+ # Filter Relationships between base types
+ def register_relationship_filter_for_base_type(
+ self, base_type: BaseType, filter_obj: Type["RelationshipFilter"]
+ ):
+ from .filters import RelationshipFilter
+
+ if not isinstance(base_type, type(BaseType)):
+ raise TypeError("Expected BaseType, but got: {!r}".format(base_type))
+
+ if not issubclass(filter_obj, RelationshipFilter):
+ raise TypeError(
+ "Expected RelationshipFilter, but got: {!r}".format(filter_obj)
+ )
+ self._registry_relationship_filters[base_type] = filter_obj
+
+ def get_relationship_filter_for_base_type(
+ self, base_type: Type[BaseType]
+ ) -> "RelationshipFilter":
+ return self._registry_relationship_filters.get(base_type)
+
registry = None
diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py
index 89b357a4..2c749da7 100644
--- a/graphene_sqlalchemy/tests/conftest.py
+++ b/graphene_sqlalchemy/tests/conftest.py
@@ -2,6 +2,7 @@
import pytest_asyncio
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
+from typing_extensions import Literal
import graphene
from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4
@@ -25,14 +26,23 @@ def convert_composite_class(composite, registry):
return graphene.Field(graphene.Int)
-@pytest.fixture(params=[False, True])
-def async_session(request):
+# make a typed literal for session one is sync and one is async
+SESSION_TYPE = Literal["sync", "session_factory"]
+
+
+@pytest.fixture(params=["sync", "async"])
+def session_type(request) -> SESSION_TYPE:
return request.param
@pytest.fixture
-def test_db_url(async_session: bool):
- if async_session:
+def async_session(session_type):
+ return session_type == "async"
+
+
+@pytest.fixture
+def test_db_url(session_type: SESSION_TYPE):
+ if session_type == "async":
return "sqlite+aiosqlite://"
else:
return "sqlite://"
@@ -40,8 +50,8 @@ def test_db_url(async_session: bool):
@pytest.mark.asyncio
@pytest_asyncio.fixture(scope="function")
-async def session_factory(async_session: bool, test_db_url: str):
- if async_session:
+async def session_factory(session_type: SESSION_TYPE, test_db_url: str):
+ if session_type == "async":
if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
pytest.skip("Async Sessions only work in sql alchemy 1.4 and above")
engine = create_async_engine(test_db_url)
diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py
index be07b896..8911b0a2 100644
--- a/graphene_sqlalchemy/tests/models.py
+++ b/graphene_sqlalchemy/tests/models.py
@@ -6,6 +6,7 @@
from decimal import Decimal
from typing import List, Optional
+# fmt: off
from sqlalchemy import (
Column,
Date,
@@ -24,13 +25,16 @@
from sqlalchemy.sql.type_api import TypeEngine
from graphene_sqlalchemy.tests.utils import wrap_select_func
-from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, SQL_VERSION_HIGHER_EQUAL_THAN_2
+from graphene_sqlalchemy.utils import (
+ SQL_VERSION_HIGHER_EQUAL_THAN_1_4,
+ SQL_VERSION_HIGHER_EQUAL_THAN_2,
+)
# fmt: off
if SQL_VERSION_HIGHER_EQUAL_THAN_2:
- from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip
+ from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip
else:
- from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip
+ from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip
# fmt: on
PetKind = Enum("cat", "dog", name="pet_kind")
@@ -64,6 +68,7 @@ class Pet(Base):
pet_kind = Column(PetKind, nullable=False)
hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False)
reporter_id = Column(Integer(), ForeignKey("reporters.id"))
+ legs = Column(Integer(), default=4)
class CompositeFullName(object):
@@ -150,6 +155,27 @@ def hybrid_prop_list(self) -> List[int]:
headlines = association_proxy("articles", "headline")
+articles_tags_table = Table(
+ "articles_tags",
+ Base.metadata,
+ Column("article_id", ForeignKey("articles.id")),
+ Column("tag_id", ForeignKey("tags.id")),
+)
+
+
+class Image(Base):
+ __tablename__ = "images"
+ id = Column(Integer(), primary_key=True)
+ external_id = Column(Integer())
+ description = Column(String(30))
+
+
+class Tag(Base):
+ __tablename__ = "tags"
+ id = Column(Integer(), primary_key=True)
+ name = Column(String(30))
+
+
class Article(Base):
__tablename__ = "articles"
id = Column(Integer(), primary_key=True)
@@ -161,6 +187,13 @@ class Article(Base):
)
recommended_reads = association_proxy("reporter", "articles")
+ # one-to-one relationship with image
+ image_id = Column(Integer(), ForeignKey("images.id"), unique=True)
+ image = relationship("Image", backref=backref("articles", uselist=False))
+
+ # many-to-many relationship with tags
+ tags = relationship("Tag", secondary=articles_tags_table, backref="articles")
+
class Reader(Base):
__tablename__ = "readers"
@@ -273,11 +306,20 @@ def hybrid_prop_deeply_nested_list_int(self) -> List[List[List[int]]]:
],
]
- # Other SQLAlchemy Instances
+ # Other SQLAlchemy Instance
@hybrid_property
def hybrid_prop_first_shopping_cart_item(self) -> ShoppingCartItem:
return ShoppingCartItem(id=1)
+ # Other SQLAlchemy Instance with expression
+ @hybrid_property
+ def hybrid_prop_first_shopping_cart_item_expression(self) -> ShoppingCartItem:
+ return ShoppingCartItem(id=1)
+
+ @hybrid_prop_first_shopping_cart_item_expression.expression
+ def hybrid_prop_first_shopping_cart_item_expression(cls):
+ return ShoppingCartItem
+
# Other SQLAlchemy Instances
@hybrid_property
def hybrid_prop_shopping_cart_item_list(self) -> List[ShoppingCartItem]:
diff --git a/graphene_sqlalchemy/tests/models_batching.py b/graphene_sqlalchemy/tests/models_batching.py
index 5dde366f..e0f5d4bd 100644
--- a/graphene_sqlalchemy/tests/models_batching.py
+++ b/graphene_sqlalchemy/tests/models_batching.py
@@ -2,16 +2,7 @@
import enum
-from sqlalchemy import (
- Column,
- Date,
- Enum,
- ForeignKey,
- Integer,
- String,
- Table,
- func,
-)
+from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import column_property, relationship
diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py
index 84069245..e62e07d2 100644
--- a/graphene_sqlalchemy/tests/test_converter.py
+++ b/graphene_sqlalchemy/tests/test_converter.py
@@ -1,13 +1,10 @@
import enum
import sys
-from typing import Dict, Tuple, Union
+from typing import Dict, Tuple, TypeVar, Union
-import graphene
import pytest
import sqlalchemy
import sqlalchemy_utils as sqa_utils
-from graphene.relay import Node
-from graphene.types.structures import Structure
from sqlalchemy import Column, func, types
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.declarative import declarative_base
@@ -15,15 +12,10 @@
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import column_property, composite
-from .models import (
- Article,
- CompositeFullName,
- Pet,
- Reporter,
- ShoppingCart,
- ShoppingCartItem,
-)
-from .utils import wrap_select_func
+import graphene
+from graphene.relay import Node
+from graphene.types.structures import Structure
+
from ..converter import (
convert_sqlalchemy_association_proxy,
convert_sqlalchemy_column,
@@ -47,6 +39,7 @@
ShoppingCart,
ShoppingCartItem,
)
+from .utils import wrap_select_func
def mock_resolver():
@@ -206,6 +199,17 @@ def hybrid_prop(self) -> "ShoppingCartItem":
get_hybrid_property_type(hybrid_prop).type == ShoppingCartType
+def test_converter_replace_type_var():
+
+ T = TypeVar("T")
+
+ replace_type_vars = {T: graphene.String}
+
+ field_type = convert_sqlalchemy_type(T, replace_type_vars=replace_type_vars)
+
+ assert field_type == graphene.String
+
+
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10"
)
@@ -215,9 +219,9 @@ def prop_method() -> int | str:
return "not allowed in gql schema"
with pytest.raises(
- ValueError,
- match=r"Cannot convert hybrid_property Union to "
- r"graphene.Union: the Union contains scalars. \.*",
+ ValueError,
+ match=r"Cannot convert hybrid_property Union to "
+ r"graphene.Union: the Union contains scalars. \.*",
):
get_hybrid_property_type(prop_method)
@@ -471,7 +475,9 @@ class TestEnum(enum.IntEnum):
def test_should_columproperty_convert():
field = get_field_from_column(
- column_property(wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1))
+ column_property(
+ wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1)
+ )
)
assert field.type == graphene.Int
@@ -888,8 +894,8 @@ class Meta:
)
for (
- hybrid_prop_name,
- hybrid_prop_expected_return_type,
+ hybrid_prop_name,
+ hybrid_prop_expected_return_type,
) in shopping_cart_item_expected_types.items():
hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name]
@@ -900,7 +906,7 @@ class Meta:
str(hybrid_prop_expected_return_type),
)
assert (
- hybrid_prop_field.description is None
+ hybrid_prop_field.description is None
) # "doc" is ignored by hybrid property
###################################################
@@ -925,6 +931,7 @@ class Meta:
graphene.List(graphene.List(graphene.Int))
),
"hybrid_prop_first_shopping_cart_item": ShoppingCartItemType,
+ "hybrid_prop_first_shopping_cart_item_expression": ShoppingCartItemType,
"hybrid_prop_shopping_cart_item_list": graphene.List(ShoppingCartItemType),
# Self Referential List
"hybrid_prop_self_referential": ShoppingCartType,
@@ -947,8 +954,8 @@ class Meta:
)
for (
- hybrid_prop_name,
- hybrid_prop_expected_return_type,
+ hybrid_prop_name,
+ hybrid_prop_expected_return_type,
) in shopping_cart_expected_types.items():
hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name]
@@ -959,5 +966,5 @@ class Meta:
str(hybrid_prop_expected_return_type),
)
assert (
- hybrid_prop_field.description is None
+ hybrid_prop_field.description is None
) # "doc" is ignored by hybrid property
diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py
new file mode 100644
index 00000000..4acf89a8
--- /dev/null
+++ b/graphene_sqlalchemy/tests/test_filters.py
@@ -0,0 +1,1201 @@
+import pytest
+from sqlalchemy.sql.operators import is_
+
+import graphene
+from graphene import Connection, relay
+
+from ..fields import SQLAlchemyConnectionField
+from ..filters import FloatFilter
+from ..types import ORMField, SQLAlchemyObjectType
+from .models import (
+ Article,
+ Editor,
+ HairKind,
+ Image,
+ Pet,
+ Reader,
+ Reporter,
+ ShoppingCart,
+ ShoppingCartItem,
+ Tag,
+)
+from .utils import eventually_await_session, to_std_dicts
+
+# TODO test that generated schema is correct for all examples with:
+# with open('schema.gql', 'w') as fp:
+# fp.write(str(schema))
+
+
+def assert_and_raise_result(result, expected):
+ if result.errors:
+ for error in result.errors:
+ raise error
+ assert not result.errors
+ result = to_std_dicts(result.data)
+ assert result == expected
+
+
+async def add_test_data(session):
+ reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat")
+ session.add(reporter)
+
+ pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT, legs=4)
+ pet.reporter = reporter
+ session.add(pet)
+
+ pet = Pet(name="Snoopy", pet_kind="dog", hair_kind=HairKind.SHORT, legs=3)
+ pet.reporter = reporter
+ session.add(pet)
+
+ reporter = Reporter(first_name="John", last_name="Woe", favorite_pet_kind="cat")
+ session.add(reporter)
+
+ article = Article(headline="Hi!")
+ article.reporter = reporter
+ session.add(article)
+
+ article = Article(headline="Hello!")
+ article.reporter = reporter
+ session.add(article)
+
+ reporter = Reporter(first_name="Jane", last_name="Roe", favorite_pet_kind="dog")
+ session.add(reporter)
+
+ pet = Pet(name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG)
+ pet.reporter = reporter
+ session.add(pet)
+
+ editor = Editor(name="Jack")
+ session.add(editor)
+
+ await eventually_await_session(session, "commit")
+
+
+def create_schema(session):
+ class ArticleType(SQLAlchemyObjectType):
+ class Meta:
+ model = Article
+ name = "Article"
+ interfaces = (relay.Node,)
+
+ class ImageType(SQLAlchemyObjectType):
+ class Meta:
+ model = Image
+ name = "Image"
+ interfaces = (relay.Node,)
+
+ class PetType(SQLAlchemyObjectType):
+ class Meta:
+ model = Pet
+ name = "Pet"
+ interfaces = (relay.Node,)
+
+ class ReaderType(SQLAlchemyObjectType):
+ class Meta:
+ model = Reader
+ name = "Reader"
+ interfaces = (relay.Node,)
+
+ class ReporterType(SQLAlchemyObjectType):
+ class Meta:
+ model = Reporter
+ name = "Reporter"
+ interfaces = (relay.Node,)
+
+ class TagType(SQLAlchemyObjectType):
+ class Meta:
+ model = Tag
+ name = "Tag"
+ interfaces = (relay.Node,)
+
+ class Query(graphene.ObjectType):
+ node = relay.Node.Field()
+ articles = SQLAlchemyConnectionField(ArticleType.connection)
+ images = SQLAlchemyConnectionField(ImageType.connection)
+ readers = SQLAlchemyConnectionField(ReaderType.connection)
+ reporters = SQLAlchemyConnectionField(ReporterType.connection)
+ pets = SQLAlchemyConnectionField(PetType.connection)
+ tags = SQLAlchemyConnectionField(TagType.connection)
+
+ return Query
+
+
+# Test a simple example of filtering
+@pytest.mark.asyncio
+async def test_filter_simple(session):
+ await add_test_data(session)
+
+ Query = create_schema(session)
+
+ query = """
+ query {
+ reporters (filter: {lastName: {eq: "Roe", like: "%oe"}}) {
+ edges {
+ node {
+ firstName
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "reporters": {"edges": [{"node": {"firstName": "Jane"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+
+@pytest.mark.asyncio
+async def test_filter_alias(session):
+ """
+ Test aliasing of column names in the type
+ """
+ await add_test_data(session)
+
+ class ReporterType(SQLAlchemyObjectType):
+ class Meta:
+ model = Reporter
+ name = "Reporter"
+ interfaces = (relay.Node,)
+
+ lastNameAlias = ORMField(model_attr="last_name")
+
+ class Query(graphene.ObjectType):
+ node = relay.Node.Field()
+ reporters = SQLAlchemyConnectionField(ReporterType.connection)
+
+ query = """
+ query {
+ reporters (filter: {lastNameAlias: {eq: "Roe", like: "%oe"}}) {
+ edges {
+ node {
+ firstName
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "reporters": {"edges": [{"node": {"firstName": "Jane"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+
+# Test a custom filter type
+@pytest.mark.asyncio
+async def test_filter_custom_type(session):
+ await add_test_data(session)
+
+ class MathFilter(FloatFilter):
+ class Meta:
+ graphene_type = graphene.Float
+
+ @classmethod
+ def divisible_by_filter(cls, query, field, val: int) -> bool:
+ return is_(field % val, 0)
+
+ class PetType(SQLAlchemyObjectType):
+ class Meta:
+ model = Pet
+ name = "Pet"
+ interfaces = (relay.Node,)
+ connection_class = Connection
+
+ legs = ORMField(filter_type=MathFilter)
+
+ class Query(graphene.ObjectType):
+ pets = SQLAlchemyConnectionField(PetType.connection)
+
+ query = """
+ query {
+ pets (filter: {
+ legs: {divisibleBy: 2}
+ }) {
+ edges {
+ node {
+ name
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "pets": {
+ "edges": [{"node": {"name": "Garfield"}}, {"node": {"name": "Lassie"}}]
+ },
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+
+# Test filtering on enums
+@pytest.mark.asyncio
+async def test_filter_enum(session):
+ await add_test_data(session)
+
+ Query = create_schema(session)
+
+ # test sqlalchemy enum
+ query = """
+ query {
+ reporters (filter: {
+ favoritePetKind: {eq: DOG}
+ }
+ ) {
+ edges {
+ node {
+ firstName
+ lastName
+ favoritePetKind
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "reporters": {
+ "edges": [
+ {
+ "node": {
+ "firstName": "Jane",
+ "lastName": "Roe",
+ "favoritePetKind": "DOG",
+ }
+ }
+ ]
+ },
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # test Python enum and sqlalchemy enum
+ query = """
+ query {
+ pets (filter: {
+ and: [
+ { hairKind: {eq: LONG} },
+ { petKind: {eq: DOG} }
+ ]}) {
+ edges {
+ node {
+ name
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "pets": {"edges": [{"node": {"name": "Lassie"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+
+# Test a 1:1 relationship
+@pytest.mark.asyncio
+async def test_filter_relationship_one_to_one(session):
+ article = Article(headline="Hi!")
+ image = Image(external_id=1, description="A beautiful image.")
+ article.image = image
+ session.add(article)
+ session.add(image)
+ await eventually_await_session(session, "commit")
+
+ Query = create_schema(session)
+
+ query = """
+ query {
+ articles (filter: {
+ image: {description: {eq: "A beautiful image."}}
+ }) {
+ edges {
+ node {
+ headline
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "articles": {"edges": [{"node": {"headline": "Hi!"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+
+# Test a 1:n relationship
+@pytest.mark.asyncio
+async def test_filter_relationship_one_to_many(session):
+ await add_test_data(session)
+ Query = create_schema(session)
+
+ # test contains
+ query = """
+ query {
+ reporters (filter: {
+ articles: {
+ contains: [{headline: {eq: "Hi!"}}],
+ }
+ }) {
+ edges {
+ node {
+ lastName
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "reporters": {"edges": [{"node": {"lastName": "Woe"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # TODO test containsExactly
+ # # test containsExactly
+ # query = """
+ # query {
+ # reporters (filter: {
+ # articles: {
+ # containsExactly: [
+ # {headline: {eq: "Hi!"}}
+ # {headline: {eq: "Hello!"}}
+ # ]
+ # }
+ # }) {
+ # edges {
+ # node {
+ # firstName
+ # lastName
+ # }
+ # }
+ # }
+ # }
+ # """
+ # expected = {
+ # "reporters": {"edges": [{"node": {"firstName": "John", "lastName": "Woe"}}]}
+ # }
+ # schema = graphene.Schema(query=Query)
+ # result = await schema.execute_async(query, context_value={"session": session})
+ # assert_and_raise_result(result, expected)
+
+
+async def add_n2m_test_data(session):
+ # create objects
+ reader1 = Reader(name="Ada")
+ reader2 = Reader(name="Bip")
+ article1 = Article(headline="Article! Look!")
+ article2 = Article(headline="Woah! Another!")
+ tag1 = Tag(name="sensational")
+ tag2 = Tag(name="eye-grabbing")
+ image1 = Image(description="article 1")
+ image2 = Image(description="article 2")
+
+ # set relationships
+ article1.tags = [tag1]
+ article2.tags = [tag1, tag2]
+ article1.image = image1
+ article2.image = image2
+ reader1.articles = [article1]
+ reader2.articles = [article1, article2]
+
+ # save
+ session.add(image1)
+ session.add(image2)
+ session.add(tag1)
+ session.add(tag2)
+ session.add(article1)
+ session.add(article2)
+ session.add(reader1)
+ session.add(reader2)
+ await eventually_await_session(session, "commit")
+
+
+# Test n:m relationship contains
+@pytest.mark.asyncio
+async def test_filter_relationship_many_to_many_contains(session):
+ await add_n2m_test_data(session)
+ Query = create_schema(session)
+
+ # test contains 1
+ query = """
+ query {
+ articles (filter: {
+ tags: {
+ contains: [
+ { name: { in: ["sensational", "eye-grabbing"] } },
+ ]
+ }
+ }) {
+ edges {
+ node {
+ headline
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "articles": {
+ "edges": [
+ {"node": {"headline": "Article! Look!"}},
+ {"node": {"headline": "Woah! Another!"}},
+ ],
+ },
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # test contains 2
+ query = """
+ query {
+ articles (filter: {
+ tags: {
+ contains: [
+ { name: { eq: "eye-grabbing" } },
+ ]
+ }
+ }) {
+ edges {
+ node {
+ headline
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "articles": {
+ "edges": [
+ {"node": {"headline": "Woah! Another!"}},
+ ],
+ },
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # test reverse
+ query = """
+ query {
+ tags (filter: {
+ articles: {
+ contains: [
+ { headline: { eq: "Article! Look!" } },
+ ]
+ }
+ }) {
+ edges {
+ node {
+ name
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "tags": {
+ "edges": [
+ {"node": {"name": "sensational"}},
+ ],
+ },
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+
+@pytest.mark.asyncio
+async def test_filter_relationship_many_to_many_contains_with_and(session):
+ """
+ This test is necessary to ensure we don't accidentally turn and-contains filter
+ into or-contains filters due to incorrect aliasing of the joined table.
+ """
+ await add_n2m_test_data(session)
+ Query = create_schema(session)
+
+ # test contains 1
+ query = """
+ query {
+ articles (filter: {
+ tags: {
+ contains: [{
+ and: [
+ { name: { in: ["sensational", "eye-grabbing"] } },
+ { name: { eq: "eye-grabbing" } },
+ ]
+
+ }
+ ]
+ }
+ }) {
+ edges {
+ node {
+ headline
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "articles": {
+ "edges": [
+ {"node": {"headline": "Woah! Another!"}},
+ ],
+ },
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # test contains 2
+ query = """
+ query {
+ articles (filter: {
+ tags: {
+ contains: [
+ { name: { eq: "eye-grabbing" } },
+ ]
+ }
+ }) {
+ edges {
+ node {
+ headline
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "articles": {
+ "edges": [
+ {"node": {"headline": "Woah! Another!"}},
+ ],
+ },
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # test reverse
+ query = """
+ query {
+ tags (filter: {
+ articles: {
+ contains: [
+ { headline: { eq: "Article! Look!" } },
+ ]
+ }
+ }) {
+ edges {
+ node {
+ name
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "tags": {
+ "edges": [
+ {"node": {"name": "sensational"}},
+ ],
+ },
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+
+# Test n:m relationship containsExactly
+@pytest.mark.xfail
+@pytest.mark.asyncio
+async def test_filter_relationship_many_to_many_contains_exactly(session):
+ raise NotImplementedError
+ await add_n2m_test_data(session)
+ Query = create_schema(session)
+
+ # test containsExactly 1
+ query = """
+ query {
+ articles (filter: {
+ tags: {
+ containsExactly: [
+ { name: { eq: "eye-grabbing" } },
+ { name: { eq: "sensational" } },
+ ]
+ }
+ }) {
+ edges {
+ node {
+ headline
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "articles": {"edges": [{"node": {"headline": "Woah! Another!"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # test containsExactly 2
+ query = """
+ query {
+ articles (filter: {
+ tags: {
+ containsExactly: [
+ { name: { eq: "sensational" } }
+ ]
+ }
+ }) {
+ edges {
+ node {
+ headline
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "articles": {"edges": [{"node": {"headline": "Article! Look!"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # test reverse
+ query = """
+ query {
+ tags (filter: {
+ articles: {
+ containsExactly: [
+ { headline: { eq: "Article! Look!" } },
+ { headline: { eq: "Woah! Another!" } },
+ ]
+ }
+ }) {
+ edges {
+ node {
+ name
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "tags": {"edges": [{"node": {"name": "eye-grabbing"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+
+# Test n:m relationship both contains and containsExactly
+@pytest.mark.xfail
+@pytest.mark.asyncio
+async def test_filter_relationship_many_to_many_contains_and_contains_exactly(session):
+ raise NotImplementedError
+ await add_n2m_test_data(session)
+ Query = create_schema(session)
+
+ query = """
+ query {
+ articles (filter: {
+ tags: {
+ contains: [
+ { name: { eq: "eye-grabbing" } },
+ ]
+ containsExactly: [
+ { name: { eq: "eye-grabbing" } },
+ { name: { eq: "sensational" } },
+ ]
+ }
+ }) {
+ edges {
+ node {
+ headline
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "articles": {"edges": [{"node": {"headline": "Woah! Another!"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+
+# Test n:m nested relationship
+# TODO add containsExactly
+@pytest.mark.asyncio
+async def test_filter_relationship_many_to_many_nested(session):
+ await add_n2m_test_data(session)
+ Query = create_schema(session)
+
+ # test readers->articles relationship
+ query = """
+ query {
+ readers (filter: {
+ articles: {
+ contains: [
+ { headline: { eq: "Woah! Another!" } },
+ ]
+ }
+ }) {
+ edges {
+ node {
+ name
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "readers": {"edges": [{"node": {"name": "Bip"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # test nested readers->articles->tags
+ query = """
+ query {
+ readers (filter: {
+ articles: {
+ contains: [
+ {
+ tags: {
+ contains: [
+ { name: { eq: "eye-grabbing" } },
+ ]
+ }
+ }
+ ]
+ }
+ }) {
+ edges {
+ node {
+ name
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "readers": {"edges": [{"node": {"name": "Bip"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # test nested reverse
+ query = """
+ query {
+ tags (filter: {
+ articles: {
+ contains: [
+ {
+ readers: {
+ contains: [
+ { name: { eq: "Ada" } },
+ ]
+ }
+ }
+ ]
+ }
+ }) {
+ edges {
+ node {
+ name
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "tags": {"edges": [{"node": {"name": "sensational"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # test filter on both levels of nesting
+ query = """
+ query {
+ readers (filter: {
+ articles: {
+ contains: [
+ { headline: { eq: "Woah! Another!" } },
+ {
+ tags: {
+ contains: [
+ { name: { eq: "eye-grabbing" } },
+ ]
+ }
+ }
+ ]
+ }
+ }) {
+ edges {
+ node {
+ name
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "readers": {"edges": [{"node": {"name": "Bip"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+
+# Test connecting filters with "and"
+@pytest.mark.asyncio
+async def test_filter_logic_and(session):
+ await add_test_data(session)
+
+ Query = create_schema(session)
+
+ query = """
+ query {
+ reporters (filter: {
+ and: [
+ { firstName: { eq: "John" } },
+ { favoritePetKind: { eq: CAT } },
+ ]
+ }) {
+ edges {
+ node {
+ lastName
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "reporters": {
+ "edges": [{"node": {"lastName": "Doe"}}, {"node": {"lastName": "Woe"}}]
+ },
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+
+# Test connecting filters with "or"
+@pytest.mark.asyncio
+async def test_filter_logic_or(session):
+ await add_test_data(session)
+ Query = create_schema(session)
+
+ query = """
+ query {
+ reporters (filter: {
+ or: [
+ { lastName: { eq: "Woe" } },
+ { favoritePetKind: { eq: DOG } },
+ ]
+ }) {
+ edges {
+ node {
+ firstName
+ lastName
+ favoritePetKind
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "reporters": {
+ "edges": [
+ {
+ "node": {
+ "firstName": "John",
+ "lastName": "Woe",
+ "favoritePetKind": "CAT",
+ }
+ },
+ {
+ "node": {
+ "firstName": "Jane",
+ "lastName": "Roe",
+ "favoritePetKind": "DOG",
+ }
+ },
+ ]
+ }
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+
+# Test connecting filters with "and" and "or" together
+@pytest.mark.asyncio
+async def test_filter_logic_and_or(session):
+ await add_test_data(session)
+ Query = create_schema(session)
+
+ query = """
+ query {
+ reporters (filter: {
+ and: [
+ { firstName: { eq: "John" } },
+ {
+ or: [
+ { lastName: { eq: "Doe" } },
+ # TODO get enums working for filters
+ # { favoritePetKind: { eq: "cat" } },
+ ]
+ }
+ ]
+ }) {
+ edges {
+ node {
+ firstName
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "reporters": {
+ "edges": [
+ {"node": {"firstName": "John"}},
+ # {"node": {"firstName": "Jane"}},
+ ],
+ }
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+
+async def add_hybrid_prop_test_data(session):
+ cart = ShoppingCart()
+ session.add(cart)
+ await eventually_await_session(session, "commit")
+
+
+def create_hybrid_prop_schema(session):
+ class ShoppingCartItemType(SQLAlchemyObjectType):
+ class Meta:
+ model = ShoppingCartItem
+ name = "ShoppingCartItem"
+ interfaces = (relay.Node,)
+ connection_class = Connection
+
+ class ShoppingCartType(SQLAlchemyObjectType):
+ class Meta:
+ model = ShoppingCart
+ name = "ShoppingCart"
+ interfaces = (relay.Node,)
+ connection_class = Connection
+
+ class Query(graphene.ObjectType):
+ node = relay.Node.Field()
+ items = SQLAlchemyConnectionField(ShoppingCartItemType.connection)
+ carts = SQLAlchemyConnectionField(ShoppingCartType.connection)
+
+ return Query
+
+
+# Test filtering over and returning hybrid_property
+@pytest.mark.asyncio
+async def test_filter_hybrid_property(session):
+ await add_hybrid_prop_test_data(session)
+ Query = create_hybrid_prop_schema(session)
+
+ # test hybrid_prop_int
+ query = """
+ query {
+ carts (filter: {hybridPropInt: {eq: 42}}) {
+ edges {
+ node {
+ hybridPropInt
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "carts": {
+ "edges": [
+ {"node": {"hybridPropInt": 42}},
+ ]
+ },
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # test hybrid_prop_float
+ query = """
+ query {
+ carts (filter: {hybridPropFloat: {gt: 42}}) {
+ edges {
+ node {
+ hybridPropFloat
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "carts": {
+ "edges": [
+ {"node": {"hybridPropFloat": 42.3}},
+ ]
+ },
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # test hybrid_prop different model without expression
+ query = """
+ query {
+ carts {
+ edges {
+ node {
+ hybridPropFirstShoppingCartItem {
+ id
+ }
+ }
+ }
+ }
+ }
+ """
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert not result.errors
+ result = to_std_dicts(result.data)
+ assert len(result["carts"]["edges"]) == 1
+
+ # test hybrid_prop different model with expression
+ query = """
+ query {
+ carts {
+ edges {
+ node {
+ hybridPropFirstShoppingCartItemExpression {
+ id
+ }
+ }
+ }
+ }
+ }
+ """
+
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert not result.errors
+ result = to_std_dicts(result.data)
+ assert len(result["carts"]["edges"]) == 1
+
+ # test hybrid_prop list of models
+ query = """
+ query {
+ carts {
+ edges {
+ node {
+ hybridPropShoppingCartItemList {
+ id
+ }
+ }
+ }
+ }
+ }
+ """
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert not result.errors
+ result = to_std_dicts(result.data)
+ assert len(result["carts"]["edges"]) == 1
+ assert (
+ len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2
+ )
+
+
+# Test edge cases to improve test coverage
+@pytest.mark.asyncio
+async def test_filter_edge_cases(session):
+ await add_test_data(session)
+
+ # test disabling filtering
+ class ArticleType(SQLAlchemyObjectType):
+ class Meta:
+ model = Article
+ name = "Article"
+ interfaces = (relay.Node,)
+ connection_class = Connection
+
+ class Query(graphene.ObjectType):
+ node = relay.Node.Field()
+ articles = SQLAlchemyConnectionField(ArticleType.connection, filter=None)
+
+ schema = graphene.Schema(query=Query)
+ assert not hasattr(schema, "ArticleTypeFilter")
+
+
+# Test additional filter types to improve test coverage
+@pytest.mark.asyncio
+async def test_additional_filters(session):
+ await add_test_data(session)
+ Query = create_schema(session)
+
+ # test n_eq and not_in filters
+ query = """
+ query {
+ reporters (filter: {firstName: {nEq: "Jane"}, lastName: {notIn: "Doe"}}) {
+ edges {
+ node {
+ lastName
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "reporters": {"edges": [{"node": {"lastName": "Woe"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
+
+ # test gt, lt, gte, and lte filters
+ query = """
+ query {
+ pets (filter: {legs: {gt: 2, lt: 4, gte: 3, lte: 3}}) {
+ edges {
+ node {
+ name
+ }
+ }
+ }
+ }
+ """
+ expected = {
+ "pets": {"edges": [{"node": {"name": "Snoopy"}}]},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
+ assert_and_raise_result(result, expected)
diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py
index f8f1ff8c..bb530f2c 100644
--- a/graphene_sqlalchemy/tests/test_sort_enums.py
+++ b/graphene_sqlalchemy/tests/test_sort_enums.py
@@ -41,6 +41,8 @@ class Meta:
"HAIR_KIND_DESC",
"REPORTER_ID_ASC",
"REPORTER_ID_DESC",
+ "LEGS_ASC",
+ "LEGS_DESC",
]
assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC"
assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC"
@@ -95,6 +97,8 @@ class Meta:
"PET_KIND_DESC",
"HAIR_KIND_ASC",
"HAIR_KIND_DESC",
+ "LEGS_ASC",
+ "LEGS_DESC",
]
@@ -135,6 +139,8 @@ class Meta:
"HAIR_KIND_DESC",
"REPORTER_ID_ASC",
"REPORTER_ID_DESC",
+ "LEGS_ASC",
+ "LEGS_DESC",
]
assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC"
assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC"
@@ -149,7 +155,7 @@ def test_sort_argument_with_excluded_fields_in_object_type():
class PetType(SQLAlchemyObjectType):
class Meta:
model = Pet
- exclude_fields = ["hair_kind", "reporter_id"]
+ exclude_fields = ["hair_kind", "reporter_id", "legs"]
sort_arg = PetType.sort_argument()
sort_enum = sort_arg.type._of_type
@@ -238,6 +244,8 @@ def get_symbol_name(column_name, sort_asc=True):
"HairKindDown",
"ReporterIdUp",
"ReporterIdDown",
+ "LegsUp",
+ "LegsDown",
]
assert sort_arg.default_value == ["IdUp"]
diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py
index dac5b15f..18d06eef 100644
--- a/graphene_sqlalchemy/types.py
+++ b/graphene_sqlalchemy/types.py
@@ -1,6 +1,10 @@
+import inspect
+import logging
+import warnings
from collections import OrderedDict
+from functools import partial
from inspect import isawaitable
-from typing import Any
+from typing import Any, Optional, Type, Union
import sqlalchemy
from sqlalchemy.ext.associationproxy import AssociationProxy
@@ -8,11 +12,13 @@
from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty
from sqlalchemy.orm.exc import NoResultFound
-from graphene import Field
+import graphene
+from graphene import Dynamic, Field, InputField
from graphene.relay import Connection, Node
from graphene.types.base import BaseType
from graphene.types.interface import Interface, InterfaceOptions
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
+from graphene.types.unmountedtype import UnmountedType
from graphene.types.utils import yank_fields_from_attrs
from graphene.utils.orderedtype import OrderedType
@@ -28,10 +34,12 @@
sort_argument_for_object_type,
sort_enum_for_object_type,
)
+from .filters import BaseTypeFilter, RelationshipFilter, SQLAlchemyFilterInputField
from .registry import Registry, get_global_registry
from .resolvers import get_attr_resolver, get_custom_resolver
from .utils import (
SQL_VERSION_HIGHER_EQUAL_THAN_1_4,
+ get_nullable_type,
get_query,
get_session,
is_mapped_class,
@@ -41,6 +49,8 @@
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
from sqlalchemy.ext.asyncio import AsyncSession
+logger = logging.getLogger(__name__)
+
class ORMField(OrderedType):
def __init__(
@@ -51,8 +61,10 @@ def __init__(
description=None,
deprecation_reason=None,
batching=None,
+ create_filter=None,
+ filter_type: Optional[Type] = None,
_creation_counter=None,
- **field_kwargs
+ **field_kwargs,
):
"""
Use this to override fields automatically generated by SQLAlchemyObjectType.
@@ -89,6 +101,12 @@ class Meta:
Same behavior as in graphene.Field. Defaults to None.
:param bool batching:
Toggle SQL batching. Defaults to None, that is `SQLAlchemyObjectType.meta.batching`.
+ :param bool create_filter:
+ Create a filter for this field. Defaults to True.
+ :param Type filter_type:
+ Override for the filter of this field with a custom filter type.
+ Default behavior is to get a matching filter type for this field from the registry.
+ Create_filter needs to be true
:param int _creation_counter:
Same behavior as in graphene.Field.
"""
@@ -100,6 +118,8 @@ class Meta:
"required": required,
"description": description,
"deprecation_reason": deprecation_reason,
+ "create_filter": create_filter,
+ "filter_type": filter_type,
"batching": batching,
}
common_kwargs = {
@@ -109,6 +129,139 @@ class Meta:
self.kwargs.update(common_kwargs)
+def get_or_create_relationship_filter(
+ base_type: Type[BaseType], registry: Registry
+) -> Type[RelationshipFilter]:
+ relationship_filter = registry.get_relationship_filter_for_base_type(base_type)
+
+ if not relationship_filter:
+ try:
+ base_type_filter = registry.get_filter_for_base_type(base_type)
+ relationship_filter = RelationshipFilter.create_type(
+ f"{base_type.__name__}RelationshipFilter",
+ base_type_filter=base_type_filter,
+ model=base_type._meta.model,
+ )
+ registry.register_relationship_filter_for_base_type(
+ base_type, relationship_filter
+ )
+ except Exception as e:
+ print("e")
+ raise e
+
+ return relationship_filter
+
+
+def filter_field_from_field(
+ field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]],
+ type_,
+ registry: Registry,
+ model_attr: Any,
+ model_attr_name: str,
+) -> Optional[graphene.InputField]:
+ # Field might be a SQLAlchemyObjectType, due to hybrid properties
+ if issubclass(type_, SQLAlchemyObjectType):
+ filter_class = registry.get_filter_for_base_type(type_)
+ # Enum Special Case
+ elif issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty):
+ column = model_attr.columns[0]
+ model_enum_type: Optional[sqlalchemy.types.Enum] = getattr(column, "type", None)
+ if not getattr(model_enum_type, "enum_class", None):
+ filter_class = registry.get_filter_for_sql_enum_type(type_)
+ else:
+ filter_class = registry.get_filter_for_py_enum_type(type_)
+ else:
+ filter_class = registry.get_filter_for_scalar_type(type_)
+ if not filter_class:
+ warnings.warn(
+ f"No compatible filters found for {field.type} with db name {model_attr_name}. Skipping field."
+ )
+ return None
+ return SQLAlchemyFilterInputField(filter_class, model_attr_name)
+
+
+def resolve_dynamic_relationship_filter(
+ field: graphene.Dynamic, registry: Registry, model_attr_name: str
+) -> Optional[Union[graphene.InputField, graphene.Dynamic]]:
+ # Resolve Dynamic Type
+ type_ = get_nullable_type(field.get_type())
+ from graphene_sqlalchemy import SQLAlchemyConnectionField
+
+ # Connections always result in list filters
+ if isinstance(type_, SQLAlchemyConnectionField):
+ inner_type = get_nullable_type(type_.type.Edge.node._type)
+ reg_res = get_or_create_relationship_filter(inner_type, registry)
+ # Field relationships can either be a list or a single object
+ elif isinstance(type_, Field):
+ if isinstance(type_.type, graphene.List):
+ inner_type = get_nullable_type(type_.type.of_type)
+ reg_res = get_or_create_relationship_filter(inner_type, registry)
+ else:
+ reg_res = registry.get_filter_for_base_type(type_.type)
+ else:
+ # Other dynamic type constellation are not yet supported,
+ # please open an issue with reproduction if you need them
+ reg_res = None
+
+ if not reg_res:
+ warnings.warn(
+ f"No compatible filters found for {field} with db name {model_attr_name}. Skipping field."
+ )
+ return None
+
+ return SQLAlchemyFilterInputField(reg_res, model_attr_name)
+
+
+def filter_field_from_type_field(
+ field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]],
+ registry: Registry,
+ filter_type: Optional[Type],
+ model_attr: Any,
+ model_attr_name: str,
+) -> Optional[Union[graphene.InputField, graphene.Dynamic]]:
+ # If a custom filter type was set for this field, use it here
+ if filter_type:
+ return SQLAlchemyFilterInputField(filter_type, model_attr_name)
+ elif issubclass(type(field), graphene.Scalar):
+ filter_class = registry.get_filter_for_scalar_type(type(field))
+ return SQLAlchemyFilterInputField(filter_class, model_attr_name)
+ # If the generated field is Dynamic, it is always a relationship
+ # (due to graphene-sqlalchemy's conversion mechanism).
+ elif isinstance(field, graphene.Dynamic):
+ return Dynamic(
+ partial(
+ resolve_dynamic_relationship_filter, field, registry, model_attr_name
+ )
+ )
+ # Unsupported but theoretically possible cases, please drop us an issue with reproduction if you need them
+ elif isinstance(field, graphene.List) or isinstance(field._type, graphene.List):
+ # Pure lists are not yet supported
+ pass
+ elif isinstance(field._type, graphene.Dynamic):
+ # Fields with nested dynamic Dynamic are not yet supported
+ pass
+ # Order matters, this comes last as field._type == list also matches Field
+ elif isinstance(field, graphene.Field):
+ if inspect.isfunction(field._type) or isinstance(field._type, partial):
+ return Dynamic(
+ lambda: filter_field_from_field(
+ field,
+ get_nullable_type(field.type),
+ registry,
+ model_attr,
+ model_attr_name,
+ )
+ )
+ else:
+ return filter_field_from_field(
+ field,
+ get_nullable_type(field.type),
+ registry,
+ model_attr,
+ model_attr_name,
+ )
+
+
def get_polymorphic_on(model):
"""
Check whether this model is a polymorphic type, and if so return the name
@@ -121,13 +274,14 @@ def get_polymorphic_on(model):
return polymorphic_on.name
-def construct_fields(
+def construct_fields_and_filters(
obj_type,
model,
registry,
only_fields,
exclude_fields,
batching,
+ create_filters,
connection_field_factory,
):
"""
@@ -143,6 +297,7 @@ def construct_fields(
:param tuple[string] only_fields:
:param tuple[string] exclude_fields:
:param bool batching:
+ :param bool create_filters: Enable filter generation for this type
:param function|None connection_field_factory:
:rtype: OrderedDict[str, graphene.Field]
"""
@@ -201,7 +356,12 @@ def construct_fields(
# Build all the field dictionary
fields = OrderedDict()
+ filters = OrderedDict()
for orm_field_name, orm_field in orm_fields.items():
+ filtering_enabled_for_field = orm_field.kwargs.pop(
+ "create_filter", create_filters
+ )
+ filter_type = orm_field.kwargs.pop("filter_type", None)
attr_name = orm_field.kwargs.pop("model_attr")
attr = all_model_attrs[attr_name]
resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver(
@@ -220,7 +380,7 @@ def construct_fields(
connection_field_factory,
batching_,
orm_field_name,
- **orm_field.kwargs
+ **orm_field.kwargs,
)
elif isinstance(attr, CompositeProperty):
if attr_name != orm_field_name or orm_field.kwargs:
@@ -241,15 +401,21 @@ def construct_fields(
connection_field_factory,
batching,
resolver,
- **orm_field.kwargs
+ **orm_field.kwargs,
)
else:
raise Exception("Property type is not supported") # Should never happen
registry.register_orm_field(obj_type, orm_field_name, attr)
fields[orm_field_name] = field
+ if filtering_enabled_for_field and not isinstance(attr, AssociationProxy):
+ # we don't support filtering on association proxies yet.
+ # Support will be patched in a future release of graphene-sqlalchemy
+ filters[orm_field_name] = filter_field_from_type_field(
+ field, registry, filter_type, attr, attr_name
+ )
- return fields
+ return fields, filters
class SQLAlchemyBase(BaseType):
@@ -274,7 +440,7 @@ def __init_subclass_with_meta__(
batching=False,
connection_field_factory=None,
_meta=None,
- **options
+ **options,
):
# We always want to bypass this hook unless we're defining a concrete
# `SQLAlchemyObjectType` or `SQLAlchemyInterface`.
@@ -301,16 +467,19 @@ def __init_subclass_with_meta__(
"The options 'only_fields' and 'exclude_fields' cannot be both set on the same type."
)
+ fields, filters = construct_fields_and_filters(
+ obj_type=cls,
+ model=model,
+ registry=registry,
+ only_fields=only_fields,
+ exclude_fields=exclude_fields,
+ batching=batching,
+ create_filters=True,
+ connection_field_factory=connection_field_factory,
+ )
+
sqla_fields = yank_fields_from_attrs(
- construct_fields(
- obj_type=cls,
- model=model,
- registry=registry,
- only_fields=only_fields,
- exclude_fields=exclude_fields,
- batching=batching,
- connection_field_factory=connection_field_factory,
- ),
+ fields,
_as=Field,
sort=False,
)
@@ -342,6 +511,19 @@ def __init_subclass_with_meta__(
else:
_meta.fields = sqla_fields
+ # Save Generated filter class in Meta Class
+ if not _meta.filter_class:
+ # Map graphene fields to filters
+ # TODO we might need to pass the ORMFields containing the SQLAlchemy models
+ # to the scalar filters here (to generate expressions from the model)
+
+ filter_fields = yank_fields_from_attrs(filters, _as=InputField, sort=False)
+
+ _meta.filter_class = BaseTypeFilter.create_type(
+ f"{cls.__name__}Filter", filter_fields=filter_fields, model=model
+ )
+ registry.register_filter_for_base_type(cls, _meta.filter_class)
+
_meta.connection = connection
_meta.id = id or "id"
@@ -401,6 +583,12 @@ def resolve_id(self, info):
def enum_for_field(cls, field_name):
return enum_for_field(cls, field_name)
+ @classmethod
+ def get_filter_argument(cls):
+ if cls._meta.filter_class:
+ return graphene.Argument(cls._meta.filter_class)
+ return None
+
sort_enum = classmethod(sort_enum_for_object_type)
sort_argument = classmethod(sort_argument_for_object_type)
@@ -411,6 +599,7 @@ class SQLAlchemyObjectTypeOptions(ObjectTypeOptions):
registry = None # type: sqlalchemy.Registry
connection = None # type: sqlalchemy.Type[sqlalchemy.Connection]
id = None # type: str
+ filter_class: Type[BaseTypeFilter] = None
class SQLAlchemyObjectType(SQLAlchemyBase, ObjectType):
@@ -447,6 +636,7 @@ class SQLAlchemyInterfaceOptions(InterfaceOptions):
registry = None # type: sqlalchemy.Registry
connection = None # type: sqlalchemy.Type[sqlalchemy.Connection]
id = None # type: str
+ filter_class: Type[BaseTypeFilter] = None
class SQLAlchemyInterface(SQLAlchemyBase, Interface):
diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py
index bb9386e8..3ba14865 100644
--- a/graphene_sqlalchemy/utils.py
+++ b/graphene_sqlalchemy/utils.py
@@ -1,4 +1,5 @@
import re
+import typing
import warnings
from collections import OrderedDict
from functools import _c3_mro
@@ -10,6 +11,14 @@
from sqlalchemy.orm import class_mapper, object_mapper
from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError
+from graphene import NonNull
+
+
+def get_nullable_type(_type):
+ if isinstance(_type, NonNull):
+ return _type.of_type
+ return _type
+
def is_sqlalchemy_version_less_than(version_string):
"""Check the installed SQLAlchemy version"""
@@ -259,6 +268,10 @@ def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]:
pass
+def is_list(x):
+ return getattr(x, "__origin__", None) in [list, typing.List]
+
+
class DummyImport:
"""The dummy module returns 'object' for a query for any member"""