diff --git a/.gitignore b/.gitignore index 47a82df0..1c86b9be 100644 --- a/.gitignore +++ b/.gitignore @@ -71,5 +71,8 @@ target/ *.sqlite3 .vscode +# Schema +*.gql + # mypy cache .mypy_cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 470a29eb..262e7608 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.7 + python: python3.8 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.2.0 @@ -12,7 +12,7 @@ repos: - id: trailing-whitespace exclude: README.md - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort name: isort (python) diff --git a/docs/filters.rst b/docs/filters.rst new file mode 100644 index 00000000..ac36803d --- /dev/null +++ b/docs/filters.rst @@ -0,0 +1,213 @@ +======= +Filters +======= + +Starting in graphene-sqlalchemy version 3, the SQLAlchemyConnectionField class implements filtering by default. The query utilizes a ``filter`` keyword to specify a filter class that inherits from ``graphene.InputObjectType``. + +Migrating from graphene-sqlalchemy-filter +--------------------------------------------- + +If like many of us, you have been using |graphene-sqlalchemy-filter|_ to implement filters and would like to use the in-built mechanism here, there are a couple key differences to note. Mainly, in an effort to simplify the generated schema, filter keywords are nested under their respective fields instead of concatenated. For example, the filter partial ``{usernameIn: ["moderator", "cool guy"]}`` would be represented as ``{username: {in: ["moderator", "cool guy"]}}``. + +.. |graphene-sqlalchemy-filter| replace:: ``graphene-sqlalchemy-filter`` +.. _graphene-sqlalchemy-filter: https://github.com/art1415926535/graphene-sqlalchemy-filter + +Further, some of the constructs found in libraries like `DGraph's DQL `_ have been implemented, so if you have created custom implementations for these features, you may want to take a look at the examples below. + + +Example model +------------- + +Take as example a Pet model similar to that in the sorting example. We will use variations on this arrangement for the following examples. + +.. code:: + + class Pet(Base): + __tablename__ = 'pets' + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + age = Column(Integer()) + + + class PetNode(SQLAlchemyObjectType): + class Meta: + model = Pet + + + class Query(graphene.ObjectType): + allPets = SQLAlchemyConnectionField(PetNode.connection) + + +Simple filter example +--------------------- + +Filters are defined at the object level through the ``BaseTypeFilter`` class. The ``BaseType`` encompasses both Graphene ``ObjectType``\ s and ``Interface``\ s. Each ``BaseTypeFilter`` instance may define fields via ``FieldFilter`` and relationships via ``RelationshipFilter``. Here's a basic example querying a single field on the Pet model: + +.. code:: + + allPets(filter: {name: {eq: "Fido"}}){ + edges { + node { + name + } + } + } + +This will return all pets with the name "Fido". + + +Custom filter types +------------------- + +If you'd like to implement custom behavior for filtering a field, you can do so by extending one of the base filter classes in ``graphene_sqlalchemy.filters``. For example, if you'd like to add a ``divisible_by`` keyword to filter the age attribute on the ``Pet`` model, you can do so as follows: + +.. code:: python + + class MathFilter(FloatFilter): + class Meta: + graphene_type = graphene.Float + + @classmethod + def divisible_by_filter(cls, query, field, val: int) -> bool: + return is_(field % val, 0) + + class PetType(SQLAlchemyObjectType): + ... + + age = ORMField(filter_type=MathFilter) + + class Query(graphene.ObjectType): + pets = SQLAlchemyConnectionField(PetType.connection) + + +Filtering over relationships with RelationshipFilter +---------------------------------------------------- + +When a filter class field refers to another object in a relationship, you may nest filters on relationship object attributes. This happens directly for 1:1 and m:1 relationships and through the ``contains`` and ``containsExactly`` keywords for 1:n and m:n relationships. + + +:1 relationships +^^^^^^^^^^^^^^^^ + +When an object or interface defines a singular relationship, relationship object attributes may be filtered directly like so: + +Take the following SQLAlchemy model definition as an example: + +.. code:: python + + class Pet + ... + person_id = Column(Integer(), ForeignKey("people.id")) + + class Person + ... + pets = relationship("Pet", backref="person") + + +Then, this query will return all pets whose person is named "Ada": + +.. code:: + + allPets(filter: { + person: {name: {eq: "Ada"}} + }) { + ... + } + + +:n relationships +^^^^^^^^^^^^^^^^ + +However, for plural relationships, relationship object attributes must be filtered through either ``contains`` or ``containsExactly``: + +Now, using a many-to-many model definition: + +.. code:: python + + people_pets_table = sqlalchemy.Table( + "people_pets", + Base.metadata, + Column("person_id", ForeignKey("people.id")), + Column("pet_id", ForeignKey("pets.id")), + ) + + class Pet + ... + + class Person + ... + pets = relationship("Pet", backref="people") + + +this query will return all pets which have a person named "Ben" in their ``people`` list. + +.. code:: + + allPets(filter: { + people: { + contains: [{name: {eq: "Ben"}}], + } + }) { + ... + } + + +and this one will return all pets which hvae a person list that contains exactly the people "Ada" and "Ben" and no fewer or people with other names. + +.. code:: + + allPets(filter: { + articles: { + containsExactly: [ + {name: {eq: "Ada"}}, + {name: {eq: "Ben"}}, + ], + } + }) { + ... + } + +And/Or Logic +------------ + +Filters can also be chained together logically using `and` and `or` keywords nested under `filter`. Clauses are passed directly to `sqlalchemy.and_` and `slqlalchemy.or_`, respectively. To return all pets named "Fido" or "Spot", use: + + +.. code:: + + allPets(filter: { + or: [ + {name: {eq: "Fido"}}, + {name: {eq: "Spot"}}, + ] + }) { + ... + } + +And to return all pets that are named "Fido" or are 5 years old and named "Spot", use: + +.. code:: + + allPets(filter: { + or: [ + {name: {eq: "Fido"}}, + { and: [ + {name: {eq: "Spot"}}, + {age: {eq: 5}} + } + ] + }) { + ... + } + + +Hybrid Property support +----------------------- + +Filtering over SQLAlchemy `hybrid properties `_ is fully supported. + + +Reporting feedback and bugs +--------------------------- + +Filtering is a new feature to graphene-sqlalchemy, so please `post an issue on Github `_ if you run into any problems or have ideas on how to improve the implementation. diff --git a/docs/index.rst b/docs/index.rst index b663752a..4245eba8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,6 +10,7 @@ Contents: inheritance relay tips + filters examples tutorial api diff --git a/examples/filters/README.md b/examples/filters/README.md new file mode 100644 index 00000000..a72e75de --- /dev/null +++ b/examples/filters/README.md @@ -0,0 +1,47 @@ +Example Filters Project +================================ + +This example highlights the ability to filter queries in graphene-sqlalchemy. + +The project contains two models, one named `Department` and another +named `Employee`. + +Getting started +--------------- + +First you'll need to get the source of the project. Do this by cloning the +whole Graphene-SQLAlchemy repository: + +```bash +# Get the example project code +git clone https://github.com/graphql-python/graphene-sqlalchemy.git +cd graphene-sqlalchemy/examples/filters +``` + +It is recommended to create a virtual environment +for this project. We'll do this using +[virtualenv](http://docs.python-guide.org/en/latest/dev/virtualenvs/) +to keep things simple, +but you may also find something like +[virtualenvwrapper](https://virtualenvwrapper.readthedocs.org/en/latest/) +to be useful: + +```bash +# Create a virtualenv in which we can install the dependencies +virtualenv env +source env/bin/activate +``` + +Install our dependencies: + +```bash +pip install -r requirements.txt +``` + +The following command will setup the database, and start the server: + +```bash +python app.py +``` + +Now head over to your favorite GraphQL client, POST to [http://127.0.0.1:5000/graphql](http://127.0.0.1:5000/graphql) and run some queries! diff --git a/examples/filters/__init__.py b/examples/filters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/filters/app.py b/examples/filters/app.py new file mode 100644 index 00000000..ab918da7 --- /dev/null +++ b/examples/filters/app.py @@ -0,0 +1,16 @@ +from database import init_db +from fastapi import FastAPI +from schema import schema +from starlette_graphene3 import GraphQLApp, make_playground_handler + + +def create_app() -> FastAPI: + init_db() + app = FastAPI() + + app.mount("/graphql", GraphQLApp(schema, on_get=make_playground_handler())) + + return app + + +app = create_app() diff --git a/examples/filters/database.py b/examples/filters/database.py new file mode 100644 index 00000000..8f6522f7 --- /dev/null +++ b/examples/filters/database.py @@ -0,0 +1,49 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +Base = declarative_base() +engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, echo=True +) +session_factory = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +from sqlalchemy.orm import scoped_session as scoped_session_factory + +scoped_session = scoped_session_factory(session_factory) + +Base.query = scoped_session.query_property() +Base.metadata.bind = engine + + +def init_db(): + from models import Person, Pet, Toy + + Base.metadata.create_all() + scoped_session.execute("PRAGMA foreign_keys=on") + db = scoped_session() + + person1 = Person(name="A") + person2 = Person(name="B") + + pet1 = Pet(name="Spot") + pet2 = Pet(name="Milo") + + toy1 = Toy(name="disc") + toy2 = Toy(name="ball") + + person1.pet = pet1 + person2.pet = pet2 + + pet1.toys.append(toy1) + pet2.toys.append(toy1) + pet2.toys.append(toy2) + + db.add(person1) + db.add(person2) + db.add(pet1) + db.add(pet2) + db.add(toy1) + db.add(toy2) + + db.commit() diff --git a/examples/filters/models.py b/examples/filters/models.py new file mode 100644 index 00000000..1b22956b --- /dev/null +++ b/examples/filters/models.py @@ -0,0 +1,34 @@ +import sqlalchemy +from database import Base +from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy.orm import relationship + + +class Pet(Base): + __tablename__ = "pets" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + age = Column(Integer()) + person_id = Column(Integer(), ForeignKey("people.id")) + + +class Person(Base): + __tablename__ = "people" + id = Column(Integer(), primary_key=True) + name = Column(String(100)) + pets = relationship("Pet", backref="person") + + +pets_toys_table = sqlalchemy.Table( + "pets_toys", + Base.metadata, + Column("pet_id", ForeignKey("pets.id")), + Column("toy_id", ForeignKey("toys.id")), +) + + +class Toy(Base): + __tablename__ = "toys" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + pets = relationship("Pet", secondary=pets_toys_table, backref="toys") diff --git a/examples/filters/requirements.txt b/examples/filters/requirements.txt new file mode 100644 index 00000000..b433ec59 --- /dev/null +++ b/examples/filters/requirements.txt @@ -0,0 +1,3 @@ +-e ../../ +fastapi +uvicorn diff --git a/examples/filters/run.sh b/examples/filters/run.sh new file mode 100755 index 00000000..ec365444 --- /dev/null +++ b/examples/filters/run.sh @@ -0,0 +1 @@ +uvicorn app:app --port 5000 diff --git a/examples/filters/schema.py b/examples/filters/schema.py new file mode 100644 index 00000000..2728cab7 --- /dev/null +++ b/examples/filters/schema.py @@ -0,0 +1,42 @@ +from models import Person as PersonModel +from models import Pet as PetModel +from models import Toy as ToyModel + +import graphene +from graphene import relay +from graphene_sqlalchemy import SQLAlchemyObjectType +from graphene_sqlalchemy.fields import SQLAlchemyConnectionField + + +class Pet(SQLAlchemyObjectType): + class Meta: + model = PetModel + name = "Pet" + interfaces = (relay.Node,) + batching = True + + +class Person(SQLAlchemyObjectType): + class Meta: + model = PersonModel + name = "Person" + interfaces = (relay.Node,) + batching = True + + +class Toy(SQLAlchemyObjectType): + class Meta: + model = ToyModel + name = "Toy" + interfaces = (relay.Node,) + batching = True + + +class Query(graphene.ObjectType): + node = relay.Node.Field() + pets = SQLAlchemyConnectionField(Pet.connection) + people = SQLAlchemyConnectionField(Person.connection) + toys = SQLAlchemyConnectionField(Toy.connection) + + +schema = graphene.Schema(query=Query) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 84c7886c..efcf3c6c 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -3,7 +3,7 @@ import typing import uuid from decimal import Decimal -from typing import Any, Optional, Union, cast +from typing import Any, Dict, Optional, TypeVar, Union, cast from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql @@ -21,7 +21,6 @@ from .batching import get_batch_resolver from .enums import enum_for_sa_enum -from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( @@ -237,6 +236,8 @@ def _convert_o2m_or_m2m_relationship( :param dict field_kwargs: :rtype: Field """ + from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory + child_type = obj_type._meta.registry.get_type_for_model( relationship_prop.mapper.entity ) @@ -332,8 +333,12 @@ def convert_sqlalchemy_type( # noqa type_arg: Any, column: Optional[Union[MapperProperty, hybrid_property]] = None, registry: Registry = None, + replace_type_vars: typing.Dict[str, Any] = None, **kwargs, ): + if replace_type_vars and type_arg in replace_type_vars: + return replace_type_vars[type_arg] + # No valid type found, raise an error raise TypeError( @@ -373,6 +378,11 @@ def convert_scalar_type(type_arg: Any, **kwargs): return type_arg +@convert_sqlalchemy_type.register(safe_isinstance(TypeVar)) +def convert_type_var(type_arg: Any, replace_type_vars: Dict[TypeVar, Any], **kwargs): + return replace_type_vars[type_arg] + + @convert_sqlalchemy_type.register(column_type_eq(str)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.String)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Text)) @@ -618,6 +628,7 @@ def convert_sqlalchemy_hybrid_property_union(type_arg: Any, **kwargs): # Just get the T out of the list of arguments by filtering out the NoneType nested_types = list(filter(lambda x: not type(None) == x, type_arg.__args__)) + # TODO redo this for , *args, **kwargs # Map the graphene types to the nested types. # We use convert_sqlalchemy_hybrid_property_type instead of the registry to account for ForwardRefs, Lists,... graphene_types = list(map(convert_sqlalchemy_type, nested_types)) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 6dbc134f..ef798852 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -5,13 +5,19 @@ from promise import Promise, is_thenable from sqlalchemy.orm.query import Query -from graphene import NonNull from graphene.relay import Connection, ConnectionField from graphene.relay.connection import connection_adapter, page_info_adapter from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver -from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, EnumValue, get_query, get_session +from .filters import BaseTypeFilter +from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + EnumValue, + get_nullable_type, + get_query, + get_session, +) if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession @@ -40,6 +46,7 @@ def type(self): def __init__(self, type_, *args, **kwargs): nullable_type = get_nullable_type(type_) + # Handle Sorting and Filtering if ( "sort" not in kwargs and nullable_type @@ -57,6 +64,19 @@ def __init__(self, type_, *args, **kwargs): ) elif "sort" in kwargs and kwargs["sort"] is None: del kwargs["sort"] + + if ( + "filter" not in kwargs + and nullable_type + and issubclass(nullable_type, Connection) + ): + # Only add filtering if a filter argument exists on the object type + filter_argument = nullable_type.Edge.node._type.get_filter_argument() + if filter_argument: + kwargs.setdefault("filter", filter_argument) + elif "filter" in kwargs and kwargs["filter"] is None: + del kwargs["filter"] + super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) @property @@ -64,7 +84,7 @@ def model(self): return get_nullable_type(self.type)._meta.node._meta.model @classmethod - def get_query(cls, model, info, sort=None, **args): + def get_query(cls, model, info, sort=None, filter=None, **args): query = get_query(model, info.context) if sort is not None: if not isinstance(sort, list): @@ -80,6 +100,12 @@ def get_query(cls, model, info, sort=None, **args): else: sort_args.append(item) query = query.order_by(*sort_args) + + if filter is not None: + assert isinstance(filter, dict) + filter_type: BaseTypeFilter = type(filter) + query, clauses = filter_type.execute_filters(query, filter) + query = query.filter(*clauses) return query @classmethod @@ -264,9 +290,3 @@ def unregisterConnectionFieldFactory(): ) global __connectionFactory __connectionFactory = UnsortedSQLAlchemyConnectionField - - -def get_nullable_type(_type): - if isinstance(_type, NonNull): - return _type.of_type - return _type diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py new file mode 100644 index 00000000..bb422724 --- /dev/null +++ b/graphene_sqlalchemy/filters.py @@ -0,0 +1,525 @@ +import re +from typing import Any, Dict, List, Tuple, Type, TypeVar, Union + +from graphql import Undefined +from sqlalchemy import and_, not_, or_ +from sqlalchemy.orm import Query, aliased # , selectinload + +import graphene +from graphene.types.inputobjecttype import ( + InputObjectTypeContainer, + InputObjectTypeOptions, +) +from graphene_sqlalchemy.utils import is_list + +BaseTypeFilterSelf = TypeVar( + "BaseTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer +) + + +class SQLAlchemyFilterInputField(graphene.InputField): + def __init__( + self, + type_, + model_attr, + name=None, + default_value=Undefined, + deprecation_reason=None, + description=None, + required=False, + _creation_counter=None, + **extra_args, + ): + super(SQLAlchemyFilterInputField, self).__init__( + type_, + name, + default_value, + deprecation_reason, + description, + required, + _creation_counter, + **extra_args, + ) + + self.model_attr = model_attr + + +def _get_functions_by_regex( + regex: str, subtract_regex: str, class_: Type +) -> List[Tuple[str, Dict[str, Any]]]: + function_regex = re.compile(regex) + + matching_functions = [] + + # Search the entire class for functions matching the filter regex + for fn in dir(class_): + func_attr = getattr(class_, fn) + # Check if attribute is a function + if callable(func_attr) and function_regex.match(fn): + # add function and attribute name to the list + matching_functions.append( + (re.sub(subtract_regex, "", fn), func_attr.__annotations__) + ) + return matching_functions + + +class BaseTypeFilter(graphene.InputObjectType): + @classmethod + def __init_subclass_with_meta__( + cls, filter_fields=None, model=None, _meta=None, **options + ): + from graphene_sqlalchemy.converter import convert_sqlalchemy_type + + # Init meta options class if it doesn't exist already + if not _meta: + _meta = InputObjectTypeOptions(cls) + + logic_functions = _get_functions_by_regex(".+_logic$", "_logic$", cls) + + new_filter_fields = {} + # Generate Graphene Fields from the filter functions based on type hints + for field_name, _annotations in logic_functions: + assert ( + "val" in _annotations + ), "Each filter method must have a value field with valid type annotations" + # If type is generic, replace with actual type of filter class + + replace_type_vars = {BaseTypeFilterSelf: cls} + field_type = convert_sqlalchemy_type( + _annotations.get("val", str), replace_type_vars=replace_type_vars + ) + new_filter_fields.update({field_name: graphene.InputField(field_type)}) + # Add all fields to the meta options. graphene.InputObjectType will take care of the rest + + if _meta.fields: + _meta.fields.update(filter_fields) + else: + _meta.fields = filter_fields + _meta.fields.update(new_filter_fields) + + _meta.model = model + + super(BaseTypeFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) + + @classmethod + def and_logic( + cls, + query, + filter_type: "BaseTypeFilter", + val: List[BaseTypeFilterSelf], + model_alias=None, + ): + # # Get the model to join on the Filter Query + # joined_model = filter_type._meta.model + # # Always alias the model + # joined_model_alias = aliased(joined_model) + clauses = [] + for value in val: + # # Join the aliased model onto the query + # query = query.join(model_field.of_type(joined_model_alias)) + + query, _clauses = filter_type.execute_filters( + query, value, model_alias=model_alias + ) # , model_alias=joined_model_alias) + clauses += _clauses + + return query, [and_(*clauses)] + + @classmethod + def or_logic( + cls, + query, + filter_type: "BaseTypeFilter", + val: List[BaseTypeFilterSelf], + model_alias=None, + ): + # # Get the model to join on the Filter Query + # joined_model = filter_type._meta.model + # # Always alias the model + # joined_model_alias = aliased(joined_model) + + clauses = [] + for value in val: + # # Join the aliased model onto the query + # query = query.join(model_field.of_type(joined_model_alias)) + + query, _clauses = filter_type.execute_filters( + query, value, model_alias=model_alias + ) # , model_alias=joined_model_alias) + clauses += _clauses + + return query, [or_(*clauses)] + + @classmethod + def execute_filters( + cls, query, filter_dict: Dict[str, Any], model_alias=None + ) -> Tuple[Query, List[Any]]: + model = cls._meta.model + if model_alias: + model = model_alias + + clauses = [] + + for field, field_filters in filter_dict.items(): + # Relationships are Dynamic, we need to resolve them fist + # Maybe we can cache these dynamics to improve efficiency + # Check with a profiler is required to determine necessity + input_field = cls._meta.fields[field] + if isinstance(input_field, graphene.Dynamic): + input_field = input_field.get_type() + field_filter_type = input_field.type + else: + field_filter_type = cls._meta.fields[field].type + # raise Exception + # TODO we need to save the relationship props in the meta fields array + # to conduct joins and alias the joins (in case there are duplicate joins: A->B A->C B->C) + if field == "and": + query, _clauses = cls.and_logic( + query, field_filter_type.of_type, field_filters, model_alias=model + ) + clauses.extend(_clauses) + elif field == "or": + query, _clauses = cls.or_logic( + query, field_filter_type.of_type, field_filters, model_alias=model + ) + clauses.extend(_clauses) + else: + # Get the model attr from the inputfield in case the field is aliased in graphql + model_field = getattr(model, input_field.model_attr or field) + if issubclass(field_filter_type, BaseTypeFilter): + # Get the model to join on the Filter Query + joined_model = field_filter_type._meta.model + # Always alias the model + joined_model_alias = aliased(joined_model) + # Join the aliased model onto the query + query = query.join(model_field.of_type(joined_model_alias)) + # Pass the joined query down to the next object type filter for processing + query, _clauses = field_filter_type.execute_filters( + query, field_filters, model_alias=joined_model_alias + ) + clauses.extend(_clauses) + if issubclass(field_filter_type, RelationshipFilter): + # TODO see above; not yet working + relationship_prop = field_filter_type._meta.model + # Always alias the model + # joined_model_alias = aliased(relationship_prop) + + # Join the aliased model onto the query + # query = query.join(model_field.of_type(joined_model_alias)) + # todo should we use selectinload here instead of join for large lists? + + query, _clauses = field_filter_type.execute_filters( + query, model, model_field, field_filters, relationship_prop + ) + clauses.extend(_clauses) + elif issubclass(field_filter_type, FieldFilter): + query, _clauses = field_filter_type.execute_filters( + query, model_field, field_filters + ) + clauses.extend(_clauses) + + return query, clauses + + +ScalarFilterInputType = TypeVar("ScalarFilterInputType") + + +class FieldFilterOptions(InputObjectTypeOptions): + graphene_type: Type = None + + +class FieldFilter(graphene.InputObjectType): + """Basic Filter for Scalars in Graphene. + We want this filter to use Dynamic fields so it provides the base + filtering methods ("eq, nEq") for different types of scalars. + The Dynamic fields will resolve to Meta.filtered_type""" + + @classmethod + def __init_subclass_with_meta__(cls, graphene_type=None, _meta=None, **options): + from .converter import convert_sqlalchemy_type + + # get all filter functions + + filter_functions = _get_functions_by_regex(".+_filter$", "_filter$", cls) + + # Init meta options class if it doesn't exist already + if not _meta: + _meta = FieldFilterOptions(cls) + + if not _meta.graphene_type: + _meta.graphene_type = graphene_type + + new_filter_fields = {} + # Generate Graphene Fields from the filter functions based on type hints + for field_name, _annotations in filter_functions: + assert ( + "val" in _annotations + ), "Each filter method must have a value field with valid type annotations" + # If type is generic, replace with actual type of filter class + replace_type_vars = {ScalarFilterInputType: _meta.graphene_type} + field_type = convert_sqlalchemy_type( + _annotations.get("val", str), replace_type_vars=replace_type_vars + ) + new_filter_fields.update({field_name: graphene.InputField(field_type)}) + + # Add all fields to the meta options. graphene.InputbjectType will take care of the rest + if _meta.fields: + _meta.fields.update(new_filter_fields) + else: + _meta.fields = new_filter_fields + + # Pass modified meta to the super class + super(FieldFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) + + # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method + @classmethod + def eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return field == val + + @classmethod + def n_eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return not_(field == val) + + @classmethod + def in_filter(cls, query, field, val: List[ScalarFilterInputType]): + return field.in_(val) + + @classmethod + def not_in_filter(cls, query, field, val: List[ScalarFilterInputType]): + return field.notin_(val) + + # TODO add like/ilike + + @classmethod + def execute_filters( + cls, query, field, filter_dict: Dict[str, any] + ) -> Tuple[Query, List[Any]]: + clauses = [] + for filt, val in filter_dict.items(): + clause = getattr(cls, filt + "_filter")(query, field, val) + if isinstance(clause, tuple): + query, clause = clause + clauses.append(clause) + + return query, clauses + + +class SQLEnumFilter(FieldFilter): + """Basic Filter for Scalars in Graphene. + We want this filter to use Dynamic fields so it provides the base + filtering methods ("eq, nEq") for different types of scalars. + The Dynamic fields will resolve to Meta.filtered_type""" + + class Meta: + graphene_type = graphene.Enum + + # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method + @classmethod + def eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return field == val.value + + @classmethod + def n_eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return not_(field == val.value) + + +class PyEnumFilter(FieldFilter): + """Basic Filter for Scalars in Graphene. + We want this filter to use Dynamic fields so it provides the base + filtering methods ("eq, nEq") for different types of scalars. + The Dynamic fields will resolve to Meta.filtered_type""" + + class Meta: + graphene_type = graphene.Enum + + # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method + @classmethod + def eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return field == val + + @classmethod + def n_eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return not_(field == val) + + +class StringFilter(FieldFilter): + class Meta: + graphene_type = graphene.String + + @classmethod + def like_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field.like(val) + + @classmethod + def ilike_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field.ilike(val) + + @classmethod + def notlike_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field.notlike(val) + + +class BooleanFilter(FieldFilter): + class Meta: + graphene_type = graphene.Boolean + + +class OrderedFilter(FieldFilter): + class Meta: + abstract = True + + @classmethod + def gt_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field > val + + @classmethod + def gte_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field >= val + + @classmethod + def lt_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field < val + + @classmethod + def lte_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field <= val + + +class NumberFilter(OrderedFilter): + """Intermediate Filter class since all Numbers are in an order relationship (support <, > etc)""" + + class Meta: + abstract = True + + +class FloatFilter(NumberFilter): + """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" + + class Meta: + graphene_type = graphene.Float + + +class IntFilter(NumberFilter): + class Meta: + graphene_type = graphene.Int + + +class DateFilter(OrderedFilter): + """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" + + class Meta: + graphene_type = graphene.Date + + +class IdFilter(FieldFilter): + class Meta: + graphene_type = graphene.ID + + +class RelationshipFilter(graphene.InputObjectType): + @classmethod + def __init_subclass_with_meta__( + cls, base_type_filter=None, model=None, _meta=None, **options + ): + if not base_type_filter: + raise Exception("Relationship Filters must be specific to an object type") + # Init meta options class if it doesn't exist already + if not _meta: + _meta = InputObjectTypeOptions(cls) + + # get all filter functions + filter_functions = _get_functions_by_regex(".+_filter$", "_filter$", cls) + + relationship_filters = {} + + # Generate Graphene Fields from the filter functions based on type hints + for field_name, _annotations in filter_functions: + assert ( + "val" in _annotations + ), "Each filter method must have a value field with valid type annotations" + # If type is generic, replace with actual type of filter class + if is_list(_annotations["val"]): + relationship_filters.update( + {field_name: graphene.InputField(graphene.List(base_type_filter))} + ) + else: + relationship_filters.update( + {field_name: graphene.InputField(base_type_filter)} + ) + + # Add all fields to the meta options. graphene.InputObjectType will take care of the rest + if _meta.fields: + _meta.fields.update(relationship_filters) + else: + _meta.fields = relationship_filters + + _meta.model = model + _meta.base_type_filter = base_type_filter + super(RelationshipFilter, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) + + @classmethod + def contains_filter( + cls, + query, + parent_model, + field, + relationship_prop, + val: List[ScalarFilterInputType], + ): + clauses = [] + for v in val: + # Always alias the model + joined_model_alias = aliased(relationship_prop) + + # Join the aliased model onto the query + query = query.join(field.of_type(joined_model_alias)).distinct() + # pass the alias so group can join group + query, _clauses = cls._meta.base_type_filter.execute_filters( + query, v, model_alias=joined_model_alias + ) + clauses.append(and_(*_clauses)) + return query, [or_(*clauses)] + + @classmethod + def contains_exactly_filter( + cls, + query, + parent_model, + field, + relationship_prop, + val: List[ScalarFilterInputType], + ): + raise NotImplementedError + + @classmethod + def execute_filters( + cls: Type[FieldFilter], + query, + parent_model, + field, + filter_dict: Dict, + relationship_prop, + ) -> Tuple[Query, List[Any]]: + query, clauses = (query, []) + + for filt, val in filter_dict.items(): + query, _clauses = getattr(cls, filt + "_filter")( + query, parent_model, field, relationship_prop, val + ) + clauses += _clauses + + return query, clauses diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 3c463013..b959d221 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,10 +1,15 @@ +import inspect from collections import defaultdict -from typing import List, Type +from typing import TYPE_CHECKING, List, Type from sqlalchemy.types import Enum as SQLAlchemyEnumType import graphene from graphene import Enum +from graphene.types.base import BaseType + +if TYPE_CHECKING: # pragma: no_cover + from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter class Registry(object): @@ -16,6 +21,30 @@ def __init__(self): self._registry_enums = {} self._registry_sort_enums = {} self._registry_unions = {} + self._registry_scalar_filters = {} + self._registry_base_type_filters = {} + self._registry_relationship_filters = {} + + self._init_base_filters() + + def _init_base_filters(self): + import graphene_sqlalchemy.filters as gsqa_filters + + from .filters import FieldFilter + + field_filter_classes = [ + filter_cls[1] + for filter_cls in inspect.getmembers(gsqa_filters, inspect.isclass) + if ( + filter_cls[1] is not FieldFilter + and FieldFilter in filter_cls[1].__mro__ + and getattr(filter_cls[1]._meta, "graphene_type", False) + ) + ] + for field_filter_class in field_filter_classes: + self.register_filter_for_scalar_type( + field_filter_class._meta.graphene_type, field_filter_class + ) def register(self, obj_type): from .types import SQLAlchemyBase @@ -99,6 +128,110 @@ def register_union_type( def get_union_for_object_types(self, obj_types: List[Type[graphene.ObjectType]]): return self._registry_unions.get(frozenset(obj_types)) + # Filter Scalar Fields of Object Types + def register_filter_for_scalar_type( + self, scalar_type: Type[graphene.Scalar], filter_obj: Type["FieldFilter"] + ): + from .filters import FieldFilter + + if not isinstance(scalar_type, type(graphene.Scalar)): + raise TypeError("Expected Scalar, but got: {!r}".format(scalar_type)) + + if not issubclass(filter_obj, FieldFilter): + raise TypeError("Expected ScalarFilter, but got: {!r}".format(filter_obj)) + self._registry_scalar_filters[scalar_type] = filter_obj + + def get_filter_for_sql_enum_type( + self, enum_type: Type[graphene.Enum] + ) -> Type["FieldFilter"]: + from .filters import SQLEnumFilter + + filter_type = self._registry_scalar_filters.get(enum_type) + if not filter_type: + filter_type = SQLEnumFilter.create_type( + f"Default{enum_type.__name__}EnumFilter", graphene_type=enum_type + ) + self._registry_scalar_filters[enum_type] = filter_type + return filter_type + + def get_filter_for_py_enum_type( + self, enum_type: Type[graphene.Enum] + ) -> Type["FieldFilter"]: + from .filters import PyEnumFilter + + filter_type = self._registry_scalar_filters.get(enum_type) + if not filter_type: + filter_type = PyEnumFilter.create_type( + f"Default{enum_type.__name__}EnumFilter", graphene_type=enum_type + ) + self._registry_scalar_filters[enum_type] = filter_type + return filter_type + + def get_filter_for_scalar_type( + self, scalar_type: Type[graphene.Scalar] + ) -> Type["FieldFilter"]: + from .filters import FieldFilter + + filter_type = self._registry_scalar_filters.get(scalar_type) + if not filter_type: + filter_type = FieldFilter.create_type( + f"Default{scalar_type.__name__}ScalarFilter", graphene_type=scalar_type + ) + self._registry_scalar_filters[scalar_type] = filter_type + + return filter_type + + # TODO register enums automatically + def register_filter_for_enum_type( + self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"] + ): + from .filters import FieldFilter + + if not issubclass(enum_type, graphene.Enum): + raise TypeError("Expected Enum, but got: {!r}".format(enum_type)) + + if not issubclass(filter_obj, FieldFilter): + raise TypeError("Expected FieldFilter, but got: {!r}".format(filter_obj)) + self._registry_scalar_filters[enum_type] = filter_obj + + # Filter Base Types + def register_filter_for_base_type( + self, + base_type: Type[BaseType], + filter_obj: Type["BaseTypeFilter"], + ): + from .filters import BaseTypeFilter + + if not issubclass(base_type, BaseType): + raise TypeError("Expected BaseType, but got: {!r}".format(base_type)) + + if not issubclass(filter_obj, BaseTypeFilter): + raise TypeError("Expected BaseTypeFilter, but got: {!r}".format(filter_obj)) + self._registry_base_type_filters[base_type] = filter_obj + + def get_filter_for_base_type(self, base_type: Type[BaseType]): + return self._registry_base_type_filters.get(base_type) + + # Filter Relationships between base types + def register_relationship_filter_for_base_type( + self, base_type: BaseType, filter_obj: Type["RelationshipFilter"] + ): + from .filters import RelationshipFilter + + if not isinstance(base_type, type(BaseType)): + raise TypeError("Expected BaseType, but got: {!r}".format(base_type)) + + if not issubclass(filter_obj, RelationshipFilter): + raise TypeError( + "Expected RelationshipFilter, but got: {!r}".format(filter_obj) + ) + self._registry_relationship_filters[base_type] = filter_obj + + def get_relationship_filter_for_base_type( + self, base_type: Type[BaseType] + ) -> "RelationshipFilter": + return self._registry_relationship_filters.get(base_type) + registry = None diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 89b357a4..2c749da7 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -2,6 +2,7 @@ import pytest_asyncio from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from typing_extensions import Literal import graphene from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 @@ -25,14 +26,23 @@ def convert_composite_class(composite, registry): return graphene.Field(graphene.Int) -@pytest.fixture(params=[False, True]) -def async_session(request): +# make a typed literal for session one is sync and one is async +SESSION_TYPE = Literal["sync", "session_factory"] + + +@pytest.fixture(params=["sync", "async"]) +def session_type(request) -> SESSION_TYPE: return request.param @pytest.fixture -def test_db_url(async_session: bool): - if async_session: +def async_session(session_type): + return session_type == "async" + + +@pytest.fixture +def test_db_url(session_type: SESSION_TYPE): + if session_type == "async": return "sqlite+aiosqlite://" else: return "sqlite://" @@ -40,8 +50,8 @@ def test_db_url(async_session: bool): @pytest.mark.asyncio @pytest_asyncio.fixture(scope="function") -async def session_factory(async_session: bool, test_db_url: str): - if async_session: +async def session_factory(session_type: SESSION_TYPE, test_db_url: str): + if session_type == "async": if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4: pytest.skip("Async Sessions only work in sql alchemy 1.4 and above") engine = create_async_engine(test_db_url) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index be07b896..8911b0a2 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -6,6 +6,7 @@ from decimal import Decimal from typing import List, Optional +# fmt: off from sqlalchemy import ( Column, Date, @@ -24,13 +25,16 @@ from sqlalchemy.sql.type_api import TypeEngine from graphene_sqlalchemy.tests.utils import wrap_select_func -from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, SQL_VERSION_HIGHER_EQUAL_THAN_2 +from graphene_sqlalchemy.utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + SQL_VERSION_HIGHER_EQUAL_THAN_2, +) # fmt: off if SQL_VERSION_HIGHER_EQUAL_THAN_2: - from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip + from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip else: - from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip + from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip # fmt: on PetKind = Enum("cat", "dog", name="pet_kind") @@ -64,6 +68,7 @@ class Pet(Base): pet_kind = Column(PetKind, nullable=False) hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False) reporter_id = Column(Integer(), ForeignKey("reporters.id")) + legs = Column(Integer(), default=4) class CompositeFullName(object): @@ -150,6 +155,27 @@ def hybrid_prop_list(self) -> List[int]: headlines = association_proxy("articles", "headline") +articles_tags_table = Table( + "articles_tags", + Base.metadata, + Column("article_id", ForeignKey("articles.id")), + Column("tag_id", ForeignKey("tags.id")), +) + + +class Image(Base): + __tablename__ = "images" + id = Column(Integer(), primary_key=True) + external_id = Column(Integer()) + description = Column(String(30)) + + +class Tag(Base): + __tablename__ = "tags" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + + class Article(Base): __tablename__ = "articles" id = Column(Integer(), primary_key=True) @@ -161,6 +187,13 @@ class Article(Base): ) recommended_reads = association_proxy("reporter", "articles") + # one-to-one relationship with image + image_id = Column(Integer(), ForeignKey("images.id"), unique=True) + image = relationship("Image", backref=backref("articles", uselist=False)) + + # many-to-many relationship with tags + tags = relationship("Tag", secondary=articles_tags_table, backref="articles") + class Reader(Base): __tablename__ = "readers" @@ -273,11 +306,20 @@ def hybrid_prop_deeply_nested_list_int(self) -> List[List[List[int]]]: ], ] - # Other SQLAlchemy Instances + # Other SQLAlchemy Instance @hybrid_property def hybrid_prop_first_shopping_cart_item(self) -> ShoppingCartItem: return ShoppingCartItem(id=1) + # Other SQLAlchemy Instance with expression + @hybrid_property + def hybrid_prop_first_shopping_cart_item_expression(self) -> ShoppingCartItem: + return ShoppingCartItem(id=1) + + @hybrid_prop_first_shopping_cart_item_expression.expression + def hybrid_prop_first_shopping_cart_item_expression(cls): + return ShoppingCartItem + # Other SQLAlchemy Instances @hybrid_property def hybrid_prop_shopping_cart_item_list(self) -> List[ShoppingCartItem]: diff --git a/graphene_sqlalchemy/tests/models_batching.py b/graphene_sqlalchemy/tests/models_batching.py index 5dde366f..e0f5d4bd 100644 --- a/graphene_sqlalchemy/tests/models_batching.py +++ b/graphene_sqlalchemy/tests/models_batching.py @@ -2,16 +2,7 @@ import enum -from sqlalchemy import ( - Column, - Date, - Enum, - ForeignKey, - Integer, - String, - Table, - func, -) +from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table, func from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import column_property, relationship diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 84069245..e62e07d2 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,13 +1,10 @@ import enum import sys -from typing import Dict, Tuple, Union +from typing import Dict, Tuple, TypeVar, Union -import graphene import pytest import sqlalchemy import sqlalchemy_utils as sqa_utils -from graphene.relay import Node -from graphene.types.structures import Structure from sqlalchemy import Column, func, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base @@ -15,15 +12,10 @@ from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property, composite -from .models import ( - Article, - CompositeFullName, - Pet, - Reporter, - ShoppingCart, - ShoppingCartItem, -) -from .utils import wrap_select_func +import graphene +from graphene.relay import Node +from graphene.types.structures import Structure + from ..converter import ( convert_sqlalchemy_association_proxy, convert_sqlalchemy_column, @@ -47,6 +39,7 @@ ShoppingCart, ShoppingCartItem, ) +from .utils import wrap_select_func def mock_resolver(): @@ -206,6 +199,17 @@ def hybrid_prop(self) -> "ShoppingCartItem": get_hybrid_property_type(hybrid_prop).type == ShoppingCartType +def test_converter_replace_type_var(): + + T = TypeVar("T") + + replace_type_vars = {T: graphene.String} + + field_type = convert_sqlalchemy_type(T, replace_type_vars=replace_type_vars) + + assert field_type == graphene.String + + @pytest.mark.skipif( sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" ) @@ -215,9 +219,9 @@ def prop_method() -> int | str: return "not allowed in gql schema" with pytest.raises( - ValueError, - match=r"Cannot convert hybrid_property Union to " - r"graphene.Union: the Union contains scalars. \.*", + ValueError, + match=r"Cannot convert hybrid_property Union to " + r"graphene.Union: the Union contains scalars. \.*", ): get_hybrid_property_type(prop_method) @@ -471,7 +475,9 @@ class TestEnum(enum.IntEnum): def test_should_columproperty_convert(): field = get_field_from_column( - column_property(wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1)) + column_property( + wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1) + ) ) assert field.type == graphene.Int @@ -888,8 +894,8 @@ class Meta: ) for ( - hybrid_prop_name, - hybrid_prop_expected_return_type, + hybrid_prop_name, + hybrid_prop_expected_return_type, ) in shopping_cart_item_expected_types.items(): hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name] @@ -900,7 +906,7 @@ class Meta: str(hybrid_prop_expected_return_type), ) assert ( - hybrid_prop_field.description is None + hybrid_prop_field.description is None ) # "doc" is ignored by hybrid property ################################################### @@ -925,6 +931,7 @@ class Meta: graphene.List(graphene.List(graphene.Int)) ), "hybrid_prop_first_shopping_cart_item": ShoppingCartItemType, + "hybrid_prop_first_shopping_cart_item_expression": ShoppingCartItemType, "hybrid_prop_shopping_cart_item_list": graphene.List(ShoppingCartItemType), # Self Referential List "hybrid_prop_self_referential": ShoppingCartType, @@ -947,8 +954,8 @@ class Meta: ) for ( - hybrid_prop_name, - hybrid_prop_expected_return_type, + hybrid_prop_name, + hybrid_prop_expected_return_type, ) in shopping_cart_expected_types.items(): hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name] @@ -959,5 +966,5 @@ class Meta: str(hybrid_prop_expected_return_type), ) assert ( - hybrid_prop_field.description is None + hybrid_prop_field.description is None ) # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py new file mode 100644 index 00000000..4acf89a8 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -0,0 +1,1201 @@ +import pytest +from sqlalchemy.sql.operators import is_ + +import graphene +from graphene import Connection, relay + +from ..fields import SQLAlchemyConnectionField +from ..filters import FloatFilter +from ..types import ORMField, SQLAlchemyObjectType +from .models import ( + Article, + Editor, + HairKind, + Image, + Pet, + Reader, + Reporter, + ShoppingCart, + ShoppingCartItem, + Tag, +) +from .utils import eventually_await_session, to_std_dicts + +# TODO test that generated schema is correct for all examples with: +# with open('schema.gql', 'w') as fp: +# fp.write(str(schema)) + + +def assert_and_raise_result(result, expected): + if result.errors: + for error in result.errors: + raise error + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +async def add_test_data(session): + reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") + session.add(reporter) + + pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT, legs=4) + pet.reporter = reporter + session.add(pet) + + pet = Pet(name="Snoopy", pet_kind="dog", hair_kind=HairKind.SHORT, legs=3) + pet.reporter = reporter + session.add(pet) + + reporter = Reporter(first_name="John", last_name="Woe", favorite_pet_kind="cat") + session.add(reporter) + + article = Article(headline="Hi!") + article.reporter = reporter + session.add(article) + + article = Article(headline="Hello!") + article.reporter = reporter + session.add(article) + + reporter = Reporter(first_name="Jane", last_name="Roe", favorite_pet_kind="dog") + session.add(reporter) + + pet = Pet(name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG) + pet.reporter = reporter + session.add(pet) + + editor = Editor(name="Jack") + session.add(editor) + + await eventually_await_session(session, "commit") + + +def create_schema(session): + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + name = "Article" + interfaces = (relay.Node,) + + class ImageType(SQLAlchemyObjectType): + class Meta: + model = Image + name = "Image" + interfaces = (relay.Node,) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + name = "Pet" + interfaces = (relay.Node,) + + class ReaderType(SQLAlchemyObjectType): + class Meta: + model = Reader + name = "Reader" + interfaces = (relay.Node,) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + + class TagType(SQLAlchemyObjectType): + class Meta: + model = Tag + name = "Tag" + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + node = relay.Node.Field() + articles = SQLAlchemyConnectionField(ArticleType.connection) + images = SQLAlchemyConnectionField(ImageType.connection) + readers = SQLAlchemyConnectionField(ReaderType.connection) + reporters = SQLAlchemyConnectionField(ReporterType.connection) + pets = SQLAlchemyConnectionField(PetType.connection) + tags = SQLAlchemyConnectionField(TagType.connection) + + return Query + + +# Test a simple example of filtering +@pytest.mark.asyncio +async def test_filter_simple(session): + await add_test_data(session) + + Query = create_schema(session) + + query = """ + query { + reporters (filter: {lastName: {eq: "Roe", like: "%oe"}}) { + edges { + node { + firstName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"firstName": "Jane"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +@pytest.mark.asyncio +async def test_filter_alias(session): + """ + Test aliasing of column names in the type + """ + await add_test_data(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + + lastNameAlias = ORMField(model_attr="last_name") + + class Query(graphene.ObjectType): + node = relay.Node.Field() + reporters = SQLAlchemyConnectionField(ReporterType.connection) + + query = """ + query { + reporters (filter: {lastNameAlias: {eq: "Roe", like: "%oe"}}) { + edges { + node { + firstName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"firstName": "Jane"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test a custom filter type +@pytest.mark.asyncio +async def test_filter_custom_type(session): + await add_test_data(session) + + class MathFilter(FloatFilter): + class Meta: + graphene_type = graphene.Float + + @classmethod + def divisible_by_filter(cls, query, field, val: int) -> bool: + return is_(field % val, 0) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + name = "Pet" + interfaces = (relay.Node,) + connection_class = Connection + + legs = ORMField(filter_type=MathFilter) + + class Query(graphene.ObjectType): + pets = SQLAlchemyConnectionField(PetType.connection) + + query = """ + query { + pets (filter: { + legs: {divisibleBy: 2} + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "pets": { + "edges": [{"node": {"name": "Garfield"}}, {"node": {"name": "Lassie"}}] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test filtering on enums +@pytest.mark.asyncio +async def test_filter_enum(session): + await add_test_data(session) + + Query = create_schema(session) + + # test sqlalchemy enum + query = """ + query { + reporters (filter: { + favoritePetKind: {eq: DOG} + } + ) { + edges { + node { + firstName + lastName + favoritePetKind + } + } + } + } + """ + expected = { + "reporters": { + "edges": [ + { + "node": { + "firstName": "Jane", + "lastName": "Roe", + "favoritePetKind": "DOG", + } + } + ] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test Python enum and sqlalchemy enum + query = """ + query { + pets (filter: { + and: [ + { hairKind: {eq: LONG} }, + { petKind: {eq: DOG} } + ]}) { + edges { + node { + name + } + } + } + } + """ + expected = { + "pets": {"edges": [{"node": {"name": "Lassie"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test a 1:1 relationship +@pytest.mark.asyncio +async def test_filter_relationship_one_to_one(session): + article = Article(headline="Hi!") + image = Image(external_id=1, description="A beautiful image.") + article.image = image + session.add(article) + session.add(image) + await eventually_await_session(session, "commit") + + Query = create_schema(session) + + query = """ + query { + articles (filter: { + image: {description: {eq: "A beautiful image."}} + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test a 1:n relationship +@pytest.mark.asyncio +async def test_filter_relationship_one_to_many(session): + await add_test_data(session) + Query = create_schema(session) + + # test contains + query = """ + query { + reporters (filter: { + articles: { + contains: [{headline: {eq: "Hi!"}}], + } + }) { + edges { + node { + lastName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"lastName": "Woe"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # TODO test containsExactly + # # test containsExactly + # query = """ + # query { + # reporters (filter: { + # articles: { + # containsExactly: [ + # {headline: {eq: "Hi!"}} + # {headline: {eq: "Hello!"}} + # ] + # } + # }) { + # edges { + # node { + # firstName + # lastName + # } + # } + # } + # } + # """ + # expected = { + # "reporters": {"edges": [{"node": {"firstName": "John", "lastName": "Woe"}}]} + # } + # schema = graphene.Schema(query=Query) + # result = await schema.execute_async(query, context_value={"session": session}) + # assert_and_raise_result(result, expected) + + +async def add_n2m_test_data(session): + # create objects + reader1 = Reader(name="Ada") + reader2 = Reader(name="Bip") + article1 = Article(headline="Article! Look!") + article2 = Article(headline="Woah! Another!") + tag1 = Tag(name="sensational") + tag2 = Tag(name="eye-grabbing") + image1 = Image(description="article 1") + image2 = Image(description="article 2") + + # set relationships + article1.tags = [tag1] + article2.tags = [tag1, tag2] + article1.image = image1 + article2.image = image2 + reader1.articles = [article1] + reader2.articles = [article1, article2] + + # save + session.add(image1) + session.add(image2) + session.add(tag1) + session.add(tag2) + session.add(article1) + session.add(article2) + session.add(reader1) + session.add(reader2) + await eventually_await_session(session, "commit") + + +# Test n:m relationship contains +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains(session): + await add_n2m_test_data(session) + Query = create_schema(session) + + # test contains 1 + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { in: ["sensational", "eye-grabbing"] } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Article! Look!"}}, + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test contains 2 + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test reverse + query = """ + query { + tags (filter: { + articles: { + contains: [ + { headline: { eq: "Article! Look!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": { + "edges": [ + {"node": {"name": "sensational"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains_with_and(session): + """ + This test is necessary to ensure we don't accidentally turn and-contains filter + into or-contains filters due to incorrect aliasing of the joined table. + """ + await add_n2m_test_data(session) + Query = create_schema(session) + + # test contains 1 + query = """ + query { + articles (filter: { + tags: { + contains: [{ + and: [ + { name: { in: ["sensational", "eye-grabbing"] } }, + { name: { eq: "eye-grabbing" } }, + ] + + } + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test contains 2 + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test reverse + query = """ + query { + tags (filter: { + articles: { + contains: [ + { headline: { eq: "Article! Look!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": { + "edges": [ + {"node": {"name": "sensational"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test n:m relationship containsExactly +@pytest.mark.xfail +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains_exactly(session): + raise NotImplementedError + await add_n2m_test_data(session) + Query = create_schema(session) + + # test containsExactly 1 + query = """ + query { + articles (filter: { + tags: { + containsExactly: [ + { name: { eq: "eye-grabbing" } }, + { name: { eq: "sensational" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": {"edges": [{"node": {"headline": "Woah! Another!"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test containsExactly 2 + query = """ + query { + articles (filter: { + tags: { + containsExactly: [ + { name: { eq: "sensational" } } + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": {"edges": [{"node": {"headline": "Article! Look!"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test reverse + query = """ + query { + tags (filter: { + articles: { + containsExactly: [ + { headline: { eq: "Article! Look!" } }, + { headline: { eq: "Woah! Another!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": {"edges": [{"node": {"name": "eye-grabbing"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test n:m relationship both contains and containsExactly +@pytest.mark.xfail +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains_and_contains_exactly(session): + raise NotImplementedError + await add_n2m_test_data(session) + Query = create_schema(session) + + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + containsExactly: [ + { name: { eq: "eye-grabbing" } }, + { name: { eq: "sensational" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": {"edges": [{"node": {"headline": "Woah! Another!"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test n:m nested relationship +# TODO add containsExactly +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_nested(session): + await add_n2m_test_data(session) + Query = create_schema(session) + + # test readers->articles relationship + query = """ + query { + readers (filter: { + articles: { + contains: [ + { headline: { eq: "Woah! Another!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "readers": {"edges": [{"node": {"name": "Bip"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test nested readers->articles->tags + query = """ + query { + readers (filter: { + articles: { + contains: [ + { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + } + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "readers": {"edges": [{"node": {"name": "Bip"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test nested reverse + query = """ + query { + tags (filter: { + articles: { + contains: [ + { + readers: { + contains: [ + { name: { eq: "Ada" } }, + ] + } + } + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": {"edges": [{"node": {"name": "sensational"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test filter on both levels of nesting + query = """ + query { + readers (filter: { + articles: { + contains: [ + { headline: { eq: "Woah! Another!" } }, + { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + } + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "readers": {"edges": [{"node": {"name": "Bip"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test connecting filters with "and" +@pytest.mark.asyncio +async def test_filter_logic_and(session): + await add_test_data(session) + + Query = create_schema(session) + + query = """ + query { + reporters (filter: { + and: [ + { firstName: { eq: "John" } }, + { favoritePetKind: { eq: CAT } }, + ] + }) { + edges { + node { + lastName + } + } + } + } + """ + expected = { + "reporters": { + "edges": [{"node": {"lastName": "Doe"}}, {"node": {"lastName": "Woe"}}] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test connecting filters with "or" +@pytest.mark.asyncio +async def test_filter_logic_or(session): + await add_test_data(session) + Query = create_schema(session) + + query = """ + query { + reporters (filter: { + or: [ + { lastName: { eq: "Woe" } }, + { favoritePetKind: { eq: DOG } }, + ] + }) { + edges { + node { + firstName + lastName + favoritePetKind + } + } + } + } + """ + expected = { + "reporters": { + "edges": [ + { + "node": { + "firstName": "John", + "lastName": "Woe", + "favoritePetKind": "CAT", + } + }, + { + "node": { + "firstName": "Jane", + "lastName": "Roe", + "favoritePetKind": "DOG", + } + }, + ] + } + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test connecting filters with "and" and "or" together +@pytest.mark.asyncio +async def test_filter_logic_and_or(session): + await add_test_data(session) + Query = create_schema(session) + + query = """ + query { + reporters (filter: { + and: [ + { firstName: { eq: "John" } }, + { + or: [ + { lastName: { eq: "Doe" } }, + # TODO get enums working for filters + # { favoritePetKind: { eq: "cat" } }, + ] + } + ] + }) { + edges { + node { + firstName + } + } + } + } + """ + expected = { + "reporters": { + "edges": [ + {"node": {"firstName": "John"}}, + # {"node": {"firstName": "Jane"}}, + ], + } + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +async def add_hybrid_prop_test_data(session): + cart = ShoppingCart() + session.add(cart) + await eventually_await_session(session, "commit") + + +def create_hybrid_prop_schema(session): + class ShoppingCartItemType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + name = "ShoppingCartItem" + interfaces = (relay.Node,) + connection_class = Connection + + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCart + name = "ShoppingCart" + interfaces = (relay.Node,) + connection_class = Connection + + class Query(graphene.ObjectType): + node = relay.Node.Field() + items = SQLAlchemyConnectionField(ShoppingCartItemType.connection) + carts = SQLAlchemyConnectionField(ShoppingCartType.connection) + + return Query + + +# Test filtering over and returning hybrid_property +@pytest.mark.asyncio +async def test_filter_hybrid_property(session): + await add_hybrid_prop_test_data(session) + Query = create_hybrid_prop_schema(session) + + # test hybrid_prop_int + query = """ + query { + carts (filter: {hybridPropInt: {eq: 42}}) { + edges { + node { + hybridPropInt + } + } + } + } + """ + expected = { + "carts": { + "edges": [ + {"node": {"hybridPropInt": 42}}, + ] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test hybrid_prop_float + query = """ + query { + carts (filter: {hybridPropFloat: {gt: 42}}) { + edges { + node { + hybridPropFloat + } + } + } + } + """ + expected = { + "carts": { + "edges": [ + {"node": {"hybridPropFloat": 42.3}}, + ] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test hybrid_prop different model without expression + query = """ + query { + carts { + edges { + node { + hybridPropFirstShoppingCartItem { + id + } + } + } + } + } + """ + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert len(result["carts"]["edges"]) == 1 + + # test hybrid_prop different model with expression + query = """ + query { + carts { + edges { + node { + hybridPropFirstShoppingCartItemExpression { + id + } + } + } + } + } + """ + + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert len(result["carts"]["edges"]) == 1 + + # test hybrid_prop list of models + query = """ + query { + carts { + edges { + node { + hybridPropShoppingCartItemList { + id + } + } + } + } + } + """ + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert len(result["carts"]["edges"]) == 1 + assert ( + len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2 + ) + + +# Test edge cases to improve test coverage +@pytest.mark.asyncio +async def test_filter_edge_cases(session): + await add_test_data(session) + + # test disabling filtering + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + name = "Article" + interfaces = (relay.Node,) + connection_class = Connection + + class Query(graphene.ObjectType): + node = relay.Node.Field() + articles = SQLAlchemyConnectionField(ArticleType.connection, filter=None) + + schema = graphene.Schema(query=Query) + assert not hasattr(schema, "ArticleTypeFilter") + + +# Test additional filter types to improve test coverage +@pytest.mark.asyncio +async def test_additional_filters(session): + await add_test_data(session) + Query = create_schema(session) + + # test n_eq and not_in filters + query = """ + query { + reporters (filter: {firstName: {nEq: "Jane"}, lastName: {notIn: "Doe"}}) { + edges { + node { + lastName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"lastName": "Woe"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test gt, lt, gte, and lte filters + query = """ + query { + pets (filter: {legs: {gt: 2, lt: 4, gte: 3, lte: 3}}) { + edges { + node { + name + } + } + } + } + """ + expected = { + "pets": {"edges": [{"node": {"name": "Snoopy"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index f8f1ff8c..bb530f2c 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -41,6 +41,8 @@ class Meta: "HAIR_KIND_DESC", "REPORTER_ID_ASC", "REPORTER_ID_DESC", + "LEGS_ASC", + "LEGS_DESC", ] assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" @@ -95,6 +97,8 @@ class Meta: "PET_KIND_DESC", "HAIR_KIND_ASC", "HAIR_KIND_DESC", + "LEGS_ASC", + "LEGS_DESC", ] @@ -135,6 +139,8 @@ class Meta: "HAIR_KIND_DESC", "REPORTER_ID_ASC", "REPORTER_ID_DESC", + "LEGS_ASC", + "LEGS_DESC", ] assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" @@ -149,7 +155,7 @@ def test_sort_argument_with_excluded_fields_in_object_type(): class PetType(SQLAlchemyObjectType): class Meta: model = Pet - exclude_fields = ["hair_kind", "reporter_id"] + exclude_fields = ["hair_kind", "reporter_id", "legs"] sort_arg = PetType.sort_argument() sort_enum = sort_arg.type._of_type @@ -238,6 +244,8 @@ def get_symbol_name(column_name, sort_asc=True): "HairKindDown", "ReporterIdUp", "ReporterIdDown", + "LegsUp", + "LegsDown", ] assert sort_arg.default_value == ["IdUp"] diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index dac5b15f..18d06eef 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,6 +1,10 @@ +import inspect +import logging +import warnings from collections import OrderedDict +from functools import partial from inspect import isawaitable -from typing import Any +from typing import Any, Optional, Type, Union import sqlalchemy from sqlalchemy.ext.associationproxy import AssociationProxy @@ -8,11 +12,13 @@ from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty from sqlalchemy.orm.exc import NoResultFound -from graphene import Field +import graphene +from graphene import Dynamic, Field, InputField from graphene.relay import Connection, Node from graphene.types.base import BaseType from graphene.types.interface import Interface, InterfaceOptions from graphene.types.objecttype import ObjectType, ObjectTypeOptions +from graphene.types.unmountedtype import UnmountedType from graphene.types.utils import yank_fields_from_attrs from graphene.utils.orderedtype import OrderedType @@ -28,10 +34,12 @@ sort_argument_for_object_type, sort_enum_for_object_type, ) +from .filters import BaseTypeFilter, RelationshipFilter, SQLAlchemyFilterInputField from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + get_nullable_type, get_query, get_session, is_mapped_class, @@ -41,6 +49,8 @@ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession +logger = logging.getLogger(__name__) + class ORMField(OrderedType): def __init__( @@ -51,8 +61,10 @@ def __init__( description=None, deprecation_reason=None, batching=None, + create_filter=None, + filter_type: Optional[Type] = None, _creation_counter=None, - **field_kwargs + **field_kwargs, ): """ Use this to override fields automatically generated by SQLAlchemyObjectType. @@ -89,6 +101,12 @@ class Meta: Same behavior as in graphene.Field. Defaults to None. :param bool batching: Toggle SQL batching. Defaults to None, that is `SQLAlchemyObjectType.meta.batching`. + :param bool create_filter: + Create a filter for this field. Defaults to True. + :param Type filter_type: + Override for the filter of this field with a custom filter type. + Default behavior is to get a matching filter type for this field from the registry. + Create_filter needs to be true :param int _creation_counter: Same behavior as in graphene.Field. """ @@ -100,6 +118,8 @@ class Meta: "required": required, "description": description, "deprecation_reason": deprecation_reason, + "create_filter": create_filter, + "filter_type": filter_type, "batching": batching, } common_kwargs = { @@ -109,6 +129,139 @@ class Meta: self.kwargs.update(common_kwargs) +def get_or_create_relationship_filter( + base_type: Type[BaseType], registry: Registry +) -> Type[RelationshipFilter]: + relationship_filter = registry.get_relationship_filter_for_base_type(base_type) + + if not relationship_filter: + try: + base_type_filter = registry.get_filter_for_base_type(base_type) + relationship_filter = RelationshipFilter.create_type( + f"{base_type.__name__}RelationshipFilter", + base_type_filter=base_type_filter, + model=base_type._meta.model, + ) + registry.register_relationship_filter_for_base_type( + base_type, relationship_filter + ) + except Exception as e: + print("e") + raise e + + return relationship_filter + + +def filter_field_from_field( + field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], + type_, + registry: Registry, + model_attr: Any, + model_attr_name: str, +) -> Optional[graphene.InputField]: + # Field might be a SQLAlchemyObjectType, due to hybrid properties + if issubclass(type_, SQLAlchemyObjectType): + filter_class = registry.get_filter_for_base_type(type_) + # Enum Special Case + elif issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty): + column = model_attr.columns[0] + model_enum_type: Optional[sqlalchemy.types.Enum] = getattr(column, "type", None) + if not getattr(model_enum_type, "enum_class", None): + filter_class = registry.get_filter_for_sql_enum_type(type_) + else: + filter_class = registry.get_filter_for_py_enum_type(type_) + else: + filter_class = registry.get_filter_for_scalar_type(type_) + if not filter_class: + warnings.warn( + f"No compatible filters found for {field.type} with db name {model_attr_name}. Skipping field." + ) + return None + return SQLAlchemyFilterInputField(filter_class, model_attr_name) + + +def resolve_dynamic_relationship_filter( + field: graphene.Dynamic, registry: Registry, model_attr_name: str +) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: + # Resolve Dynamic Type + type_ = get_nullable_type(field.get_type()) + from graphene_sqlalchemy import SQLAlchemyConnectionField + + # Connections always result in list filters + if isinstance(type_, SQLAlchemyConnectionField): + inner_type = get_nullable_type(type_.type.Edge.node._type) + reg_res = get_or_create_relationship_filter(inner_type, registry) + # Field relationships can either be a list or a single object + elif isinstance(type_, Field): + if isinstance(type_.type, graphene.List): + inner_type = get_nullable_type(type_.type.of_type) + reg_res = get_or_create_relationship_filter(inner_type, registry) + else: + reg_res = registry.get_filter_for_base_type(type_.type) + else: + # Other dynamic type constellation are not yet supported, + # please open an issue with reproduction if you need them + reg_res = None + + if not reg_res: + warnings.warn( + f"No compatible filters found for {field} with db name {model_attr_name}. Skipping field." + ) + return None + + return SQLAlchemyFilterInputField(reg_res, model_attr_name) + + +def filter_field_from_type_field( + field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], + registry: Registry, + filter_type: Optional[Type], + model_attr: Any, + model_attr_name: str, +) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: + # If a custom filter type was set for this field, use it here + if filter_type: + return SQLAlchemyFilterInputField(filter_type, model_attr_name) + elif issubclass(type(field), graphene.Scalar): + filter_class = registry.get_filter_for_scalar_type(type(field)) + return SQLAlchemyFilterInputField(filter_class, model_attr_name) + # If the generated field is Dynamic, it is always a relationship + # (due to graphene-sqlalchemy's conversion mechanism). + elif isinstance(field, graphene.Dynamic): + return Dynamic( + partial( + resolve_dynamic_relationship_filter, field, registry, model_attr_name + ) + ) + # Unsupported but theoretically possible cases, please drop us an issue with reproduction if you need them + elif isinstance(field, graphene.List) or isinstance(field._type, graphene.List): + # Pure lists are not yet supported + pass + elif isinstance(field._type, graphene.Dynamic): + # Fields with nested dynamic Dynamic are not yet supported + pass + # Order matters, this comes last as field._type == list also matches Field + elif isinstance(field, graphene.Field): + if inspect.isfunction(field._type) or isinstance(field._type, partial): + return Dynamic( + lambda: filter_field_from_field( + field, + get_nullable_type(field.type), + registry, + model_attr, + model_attr_name, + ) + ) + else: + return filter_field_from_field( + field, + get_nullable_type(field.type), + registry, + model_attr, + model_attr_name, + ) + + def get_polymorphic_on(model): """ Check whether this model is a polymorphic type, and if so return the name @@ -121,13 +274,14 @@ def get_polymorphic_on(model): return polymorphic_on.name -def construct_fields( +def construct_fields_and_filters( obj_type, model, registry, only_fields, exclude_fields, batching, + create_filters, connection_field_factory, ): """ @@ -143,6 +297,7 @@ def construct_fields( :param tuple[string] only_fields: :param tuple[string] exclude_fields: :param bool batching: + :param bool create_filters: Enable filter generation for this type :param function|None connection_field_factory: :rtype: OrderedDict[str, graphene.Field] """ @@ -201,7 +356,12 @@ def construct_fields( # Build all the field dictionary fields = OrderedDict() + filters = OrderedDict() for orm_field_name, orm_field in orm_fields.items(): + filtering_enabled_for_field = orm_field.kwargs.pop( + "create_filter", create_filters + ) + filter_type = orm_field.kwargs.pop("filter_type", None) attr_name = orm_field.kwargs.pop("model_attr") attr = all_model_attrs[attr_name] resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver( @@ -220,7 +380,7 @@ def construct_fields( connection_field_factory, batching_, orm_field_name, - **orm_field.kwargs + **orm_field.kwargs, ) elif isinstance(attr, CompositeProperty): if attr_name != orm_field_name or orm_field.kwargs: @@ -241,15 +401,21 @@ def construct_fields( connection_field_factory, batching, resolver, - **orm_field.kwargs + **orm_field.kwargs, ) else: raise Exception("Property type is not supported") # Should never happen registry.register_orm_field(obj_type, orm_field_name, attr) fields[orm_field_name] = field + if filtering_enabled_for_field and not isinstance(attr, AssociationProxy): + # we don't support filtering on association proxies yet. + # Support will be patched in a future release of graphene-sqlalchemy + filters[orm_field_name] = filter_field_from_type_field( + field, registry, filter_type, attr, attr_name + ) - return fields + return fields, filters class SQLAlchemyBase(BaseType): @@ -274,7 +440,7 @@ def __init_subclass_with_meta__( batching=False, connection_field_factory=None, _meta=None, - **options + **options, ): # We always want to bypass this hook unless we're defining a concrete # `SQLAlchemyObjectType` or `SQLAlchemyInterface`. @@ -301,16 +467,19 @@ def __init_subclass_with_meta__( "The options 'only_fields' and 'exclude_fields' cannot be both set on the same type." ) + fields, filters = construct_fields_and_filters( + obj_type=cls, + model=model, + registry=registry, + only_fields=only_fields, + exclude_fields=exclude_fields, + batching=batching, + create_filters=True, + connection_field_factory=connection_field_factory, + ) + sqla_fields = yank_fields_from_attrs( - construct_fields( - obj_type=cls, - model=model, - registry=registry, - only_fields=only_fields, - exclude_fields=exclude_fields, - batching=batching, - connection_field_factory=connection_field_factory, - ), + fields, _as=Field, sort=False, ) @@ -342,6 +511,19 @@ def __init_subclass_with_meta__( else: _meta.fields = sqla_fields + # Save Generated filter class in Meta Class + if not _meta.filter_class: + # Map graphene fields to filters + # TODO we might need to pass the ORMFields containing the SQLAlchemy models + # to the scalar filters here (to generate expressions from the model) + + filter_fields = yank_fields_from_attrs(filters, _as=InputField, sort=False) + + _meta.filter_class = BaseTypeFilter.create_type( + f"{cls.__name__}Filter", filter_fields=filter_fields, model=model + ) + registry.register_filter_for_base_type(cls, _meta.filter_class) + _meta.connection = connection _meta.id = id or "id" @@ -401,6 +583,12 @@ def resolve_id(self, info): def enum_for_field(cls, field_name): return enum_for_field(cls, field_name) + @classmethod + def get_filter_argument(cls): + if cls._meta.filter_class: + return graphene.Argument(cls._meta.filter_class) + return None + sort_enum = classmethod(sort_enum_for_object_type) sort_argument = classmethod(sort_argument_for_object_type) @@ -411,6 +599,7 @@ class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): registry = None # type: sqlalchemy.Registry connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] id = None # type: str + filter_class: Type[BaseTypeFilter] = None class SQLAlchemyObjectType(SQLAlchemyBase, ObjectType): @@ -447,6 +636,7 @@ class SQLAlchemyInterfaceOptions(InterfaceOptions): registry = None # type: sqlalchemy.Registry connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] id = None # type: str + filter_class: Type[BaseTypeFilter] = None class SQLAlchemyInterface(SQLAlchemyBase, Interface): diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index bb9386e8..3ba14865 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,4 +1,5 @@ import re +import typing import warnings from collections import OrderedDict from functools import _c3_mro @@ -10,6 +11,14 @@ from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError +from graphene import NonNull + + +def get_nullable_type(_type): + if isinstance(_type, NonNull): + return _type.of_type + return _type + def is_sqlalchemy_version_less_than(version_string): """Check the installed SQLAlchemy version""" @@ -259,6 +268,10 @@ def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: pass +def is_list(x): + return getattr(x, "__origin__", None) in [list, typing.List] + + class DummyImport: """The dummy module returns 'object' for a query for any member"""