From c0d0e0a94c4546da2db8d7ae0333c575d42e82a1 Mon Sep 17 00:00:00 2001 From: Paul Schweizer Date: Fri, 26 Feb 2021 18:40:31 +0100 Subject: [PATCH 1/7] Generic solution for filtering over relationships --- graphene_sqlalchemy_filter/filters.py | 386 +++++++++++++++++--------- tests/graphql_objects.py | 33 ++- tests/models.py | 20 ++ tests/test_filter_set.py | 16 +- tests/test_integration.py | 85 +++++- tests/test_query.py | 37 ++- 6 files changed, 441 insertions(+), 136 deletions(-) diff --git a/graphene_sqlalchemy_filter/filters.py b/graphene_sqlalchemy_filter/filters.py index 87f4906..509529d 100644 --- a/graphene_sqlalchemy_filter/filters.py +++ b/graphene_sqlalchemy_filter/filters.py @@ -21,7 +21,9 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import aliased from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.query import Query +from sqlalchemy.orm.relationships import RelationshipProperty from sqlalchemy.sql import sqltypes @@ -495,10 +497,10 @@ def _aliases_from_query(cls, query: Query) -> 'Dict[str, _MapperEntity]': @classmethod def _generate_default_filters( - cls, model, field_filters: 'Dict[str, Union[Iterable[str], Any]]' - ) -> dict: + cls, model, + field_filters: 'Dict[str, Union[Iterable[str], Any]]') -> dict: """ - Generate GraphQL fields from SQLAlchemy model columns. + Generate GraphQL fields from SQLAlchemy model columns and relationships. Args: model: SQLAlchemy model. @@ -510,47 +512,16 @@ def _generate_default_filters( """ graphql_filters = {} - filters_map = cls.ALLOWED_FILTERS - model_fields = cls._get_model_fields_data(model, field_filters.keys()) - - for field_name, field_object in model_fields.items(): - column_type = field_object['type'] - - expressions = field_filters[field_name] - if expressions == cls.ALL: - if column_type is None: - raise ValueError( - 'Unsupported field type for automatic filter binding' - ) - - type_class = column_type.__class__ - try: - expressions = filters_map[type_class].copy() - except KeyError: - for type_, exprs in filters_map.items(): - if issubclass(type_class, type_): - expressions = exprs.copy() - break - else: - raise KeyError( - 'Unsupported column type. ' - 'Hint: use EXTRA_ALLOWED_FILTERS.' - ) - - if field_object['nullable']: - expressions.append(cls.IS_NULL) - - field_type = cls._get_gql_type_from_sqla_type( - column_type, field_object['column'] - ) - - fields = cls._generate_filter_fields( - expressions, field_name, field_type, field_object['nullable'] - ) - for name, field in fields.items(): - graphql_filters[name] = get_field_as( - field, graphene.InputField - ) + fields = cls._recursively_generate_filters( + model=model, + fields={}, + parent_field=None, + parent=None, + field_filters=field_filters, + parents=[] + ) + for name, field in fields.items(): + graphql_filters[name] = get_field_as(field, graphene.InputField) return graphql_filters @@ -606,7 +577,7 @@ def _get_model_fields_data(cls, model, only_fields: 'Iterable[str]'): 'nullable': True, } - elif isinstance(descr, InstrumentedAttribute): + elif isinstance(descr.property, ColumnProperty): attr = descr.property name = attr.key if name not in only_fields: @@ -619,6 +590,16 @@ def _get_model_fields_data(cls, model, only_fields: 'Iterable[str]'): 'nullable': column.nullable, } + elif isinstance(descr.property, RelationshipProperty): + attr = descr.property + name = attr.key + if name not in only_fields: + continue + model_fields[name] = { + 'column': attr, + 'type': 'relationship' + } + return model_fields @classmethod @@ -666,6 +647,107 @@ def _generate_filter_fields( return filters + @classmethod + def _recursively_generate_filters( + cls, + model: 'sqlalchemy.ext.declarative.api.DeclarativeMeta', + fields: dict, + parent_field: 'Union[String, None]', + parent: 'Union[dict, None]', + field_filters: dict, + parents: 'List[String]' + ) -> dict: + """ + Recursively geenerate filters from the given fields. + + Args: + model: The model. + fields: Holds the resulting filters. + parent_field: The parent field name. + parent: Parent for the resulting fields. + field_filters: The fields to turn into filters. + parents: List of parent fields to create unique type names. + + Returns: + Dict. + + """ + model_fields = cls._get_model_fields_data(model, field_filters.keys()) + filters_map = cls.ALLOWED_FILTERS + + for field_name, field_object in model_fields.items(): + column_type = field_object['type'] + if column_type == 'relationship': + rel_model = field_object['column'].mapper.class_ + fields[field_name] = {} + cls._recursively_generate_filters( + model=rel_model, + fields=fields[field_name], + parent_field=field_name, + parent=fields, + field_filters=field_filters[field_name], + parents=parents + [field_name] + ) + continue + + expressions = field_filters[field_name] + if expressions == cls.ALL: + if column_type is None: + raise ValueError( + 'Unsupported field type for automatic filter binding' + ) + + type_class = column_type.__class__ + try: + expressions = filters_map[type_class].copy() + except KeyError: + for type_, exprs in filters_map.items(): + if issubclass(type_class, type_): + expressions = exprs.copy() + break + else: + raise KeyError( + 'Unsupported column type. ' + 'Hint: use EXTRA_ALLOWED_FILTERS.' + ) + + if field_object['nullable']: + expressions.append(cls.IS_NULL) + + field_type = cls._get_gql_type_from_sqla_type( + column_type, field_object['column'] + ) + + fields.update( + cls._generate_filter_fields( + expressions, field_name, field_type, field_object['nullable'] + ) + ) + + # Create a graphene type from this level + if parent is not None: + filters = {} + for name, field in fields.items(): + filters[name] = get_field_as(field, graphene.InputField) + + for op in [cls.AND, cls.OR, cls.NOT]: + doc = cls.DESCRIPTIONS.get(op) + graphql_name = cls.GRAPHQL_EXPRESSION_NAMES[op] + + # Need to concatenate to keep the types unique + OperatorAutoType = type( + "_".join(parents + [op]), (graphene.InputObjectType, ), filters) + filters[graphql_name] = graphene.InputField( + cls.FILTER_OBJECT_TYPES[op](OperatorAutoType, False, doc), description=doc + ) + + AutoType = type( + "_".join(parents), (graphene.InputObjectType, ), filters) + parent[parent_field] = graphene.InputField( + AutoType, description=parent_field) + + return fields + @classmethod def filter( cls, info: ResolveInfo, query: Query, filters: 'FilterType' @@ -683,7 +765,6 @@ def filter( """ context = info.context - if isinstance(context, dict): context[cls._filter_aliases] = {} elif '__dict__' in context.__dir__(): @@ -697,11 +778,25 @@ def filter( ).format(type(context)) warnings.warn(msg, RuntimeWarning) - query, sqla_filters = cls._translate_many_filter(info, query, filters) - if sqla_filters is not None: - query = query.filter(*sqla_filters) - - return query + result = { + "__root__": {} + } + query = cls._translate_many_filter( + info=info, + query=query, + filters=filters, + join_by=and_, + parent="__root__", + result=result["__root__"], + parent_result=result, + model=cls.model, + parent_attr=None + ) + print(result) + if not isinstance(result["__root__"], dict): + return query.filter(result["__root__"]) + else: + return query @classmethod @lru_cache(maxsize=500) @@ -736,105 +831,132 @@ def _split_graphql_field(cls, graphql_field: str) -> 'Tuple[str, str]': raise KeyError('Operator not found "{}"'.format(graphql_field)) - @classmethod - def _translate_filter( - cls, info: ResolveInfo, query: Query, key: str, value: 'Any' - ) -> 'Tuple[Query, Any]': - """ - Translate GraphQL to SQLAlchemy filters. - - Args: - info: GraphQL resolve info. - query: SQLAlchemy query. - key: Filter key: model field, 'or', 'and', 'not', custom filter. - value: Filter value. - - Returns: - SQLAlchemy clause. - - """ - if key in cls._custom_filters: - filter_name = key + '_filter' - with warnings.catch_warnings(): - warnings.simplefilter('ignore', SAWarning) - clause = getattr(cls, filter_name)(info, query, value) - if isinstance(clause, tuple): - query, clause = clause - - return query, clause - - if key == cls.GRAPHQL_EXPRESSION_NAMES[cls.AND]: - return cls._translate_many_filter(info, query, value, and_) - - if key == cls.GRAPHQL_EXPRESSION_NAMES[cls.OR]: - return cls._translate_many_filter(info, query, value, or_) - - if key == cls.GRAPHQL_EXPRESSION_NAMES[cls.NOT]: - return cls._translate_many_filter( - info, query, value, lambda *x: not_(and_(*x)) - ) - - field, expression = cls._split_graphql_field(key) - filter_function = cls.FILTER_FUNCTIONS[expression] - - try: - model_field = getattr(cls.model, field) - except AttributeError: - raise KeyError('Field not found: ' + field) - - model_field_type = getattr(model_field, 'type', None) - is_enum = isinstance(model_field_type, sqltypes.Enum) - if is_enum and model_field_type.enum_class: - if isinstance(value, list): - value = [model_field_type.enum_class(v) for v in value] - else: - value = model_field_type.enum_class(value) - - clause = filter_function(model_field, value) - return query, clause - @classmethod def _translate_many_filter( cls, info: ResolveInfo, query: Query, filters: 'Union[List[FilterType], FilterType]', - join_by: 'Callable' = None, - ) -> 'Tuple[Query, Any]': + join_by: 'Callable', + parent: 'String', + result: dict, + parent_result: dict, + model: 'sqlalchemy.ext.declarative.api.DeclarativeMeta', + parent_attr: 'Union[sqlalchemy.orm.attributes.InstrumentedAttribute, None]' + ) -> 'Query': """ - Translate several filters. + Recursively translate filters. Args: info: GraphQL resolve info. query: SQLAlchemy query. filters: GraphQL filters. join_by: Join translated filters. + parent: The parent key. + result: The result dict. + parent_result: Parent for the result. + model: The model to translate. + parent_attr: The parent attribute (for relationships). Returns: - SQLAlchemy clause. + Query. """ - result = [] - - # Filters from 'and', 'or', 'not'. - if isinstance(filters, list): - for f in filters: - query, local_filters = cls._translate_many_filter( - info, query, f, and_ - ) - if local_filters is not None: - result.append(local_filters) + join_by_map = { + "and": and_, + "or": or_, + "not": lambda *x: not_(and_(*x)) + } + for key, value in filters.items(): + + if key in cls._custom_filters: + filter_name = key + '_filter' + with warnings.catch_warnings(): + warnings.simplefilter('ignore', SAWarning) + clause = getattr(cls, filter_name)(info, query, value) + if isinstance(clause, tuple): + query, clause = clause + result[key] = clause + continue - else: - for k, v in filters.items(): - query, r = cls._translate_filter(info, query, k, v) - if r is not None: - result.append(r) + if key in ("and", "or", "not"): + result.setdefault(key, {}) + if key in ("and", "or"): + for op_key, op_filters in enumerate(value): + result[key].setdefault(op_key, {}) + query = cls._translate_many_filter( + info=info, + query=query, + filters=op_filters, + join_by=and_, + parent=op_key, + result=result[key][op_key], + parent_result=result[key], + model=model, + parent_attr=parent_attr + ) + result[key] = join_by_map[key]( + *[v for v in result[key].values() if not isinstance(v, dict)]) + else: + query = cls._translate_many_filter( + info=info, + query=query, + filters=value, + join_by=join_by_map[key], + parent=key, + result=result[key], + parent_result=result, + model=model, + parent_attr=parent_attr + ) + continue - if not result: - return query, None + field, expression = cls._split_graphql_field(key) + filter_function = cls.FILTER_FUNCTIONS[expression] + + if isinstance(value, dict) and expression != "range": + result.setdefault(key, {}) + inspected = inspection.inspect(model) + for relationship in inspected.relationships: + if relationship.key == key: + query = cls._translate_many_filter( + info=info, + query=query, + filters=value, + join_by=and_, + parent=key, + result=result[key], + parent_result=result, + model=relationship.mapper.class_, + parent_attr=getattr(model, key) + ) + break + else: + try: + model_field = getattr(model, field) + except AttributeError: + raise KeyError('Field not found: ' + field) + + model_field_type = getattr(model_field, 'type', None) + is_enum = isinstance(model_field_type, sqltypes.Enum) + if is_enum and model_field_type.enum_class: + if isinstance(value, list): + value = [model_field_type.enum_class(v) for v in value] + else: + value = model_field_type.enum_class(value) - if join_by is None: - return query, result + result[key] = filter_function(getattr(model, field), value) - return query, join_by(*result) + # Join this level of filters and store under the parent key + if not result: + return query + if parent_attr is None: + parent_result[parent] = join_by( + *[v for v in result.values() if not isinstance(v, dict)]) + elif parent_attr.property.uselist: + parent_result[parent] = parent_attr.any(join_by( + *[v for v in result.values() if not isinstance(v, dict)])) + else: + parent_result[parent] = parent_attr.has(join_by( + *[v for v in result.values() if not isinstance(v, dict)])) + return query diff --git a/tests/graphql_objects.py b/tests/graphql_objects.py index 6e79039..09e34e8 100644 --- a/tests/graphql_objects.py +++ b/tests/graphql_objects.py @@ -11,7 +11,7 @@ from tests import gqls_version # This module -from .models import Article, Author, Group, Membership, User +from .models import Article, Assignment, Author, Group, Membership, Task, User class BaseFilter(FilterSet): @@ -35,6 +35,13 @@ class Meta: 'balance': ['eq', 'ne', 'gt', 'lt', 'range', 'is_null'], 'is_active': ['eq', 'ne'], 'username_hybrid_property': ['eq', 'ne', 'in'], + 'assignments': { + 'task': { + 'name': ['eq'], + 'id': ['eq'] + }, + 'active': ['eq'] + } } @@ -191,11 +198,35 @@ class Meta: node = ArticleNode +class TaskNode(SQLAlchemyObjectType): + class Meta: + model = Task + interfaces = (graphene.relay.Node,) + + +class TaskConnection(Connection): + class Meta: + node = TaskNode + + +class AssignmentNode(SQLAlchemyObjectType): + class Meta: + model = Assignment + interfaces = (graphene.relay.Node,) + + +class AssignmentConnection(Connection): + class Meta: + node = AssignmentNode + + class Query(graphene.ObjectType): field = MyFilterableConnectionField(UserConnection) all_groups = MyFilterableConnectionField(GroupConnection) all_authors = MyFilterableConnectionField(AuthorConnection) all_articles = MyFilterableConnectionField(ArticleConnection) + tasks = MyFilterableConnectionField(TaskConnection) + assignments = MyFilterableConnectionField(AssignmentConnection) schema = graphene.Schema(query=Query) diff --git a/tests/models.py b/tests/models.py index b249a49..929dfcc 100644 --- a/tests/models.py +++ b/tests/models.py @@ -46,6 +46,7 @@ class User(Base): username = Column(String(50), nullable=False, unique=True, index=True) balance = Column(Integer, default=None) is_active = Column(Boolean, default=True) + assignments = relationship('Assignment', back_populates='user') if gqls_version >= (2, 2, 0): status = Column(Enum(StatusEnum), default=StatusEnum.offline) @@ -106,3 +107,22 @@ class Article(Base): ) author = relationship('Author', back_populates='articles') + + +class Task(Base): + __tablename__ = 'task' + + id = Column(Integer, primary_key=True) + name = Column(String(32)) + assignments = relationship('Assignment', back_populates='task') + + +class Assignment(Base): + __tablename__ = 'assignment' + + task_id = Column(Integer, ForeignKey('task.id'), primary_key=True) + task = relationship('Task', back_populates='assignments') + user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True) + user = relationship('User', back_populates='assignments') + + active = Column(Boolean) diff --git a/tests/test_filter_set.py b/tests/test_filter_set.py index ab07144..734c599 100644 --- a/tests/test_filter_set.py +++ b/tests/test_filter_set.py @@ -33,10 +33,11 @@ def test_default_filter_field_types(): for model_field, operators in USER_FILTER_FIELDS.items(): for op in operators: field = model_field + if op not in UserFilter.GRAPHQL_EXPRESSION_NAMES: + continue graphql_op = UserFilter.GRAPHQL_EXPRESSION_NAMES[op] if graphql_op: field += filters.DELIMITER + graphql_op - assert field in filter_fields, 'Field not found: ' + field assert isinstance(filter_fields[field], graphene.InputField) del filter_fields[field] @@ -320,3 +321,16 @@ class TestFilter(F): class Meta: model = models.User fields = {'id': [...]} + + +def test_generate_relationship_filter_field_names_concatenate_parents(): + filter_fields = deepcopy(UserFilter._meta.fields) + + assert filter_fields["assignments"].type._meta.name == "assignments" + assert getattr(filter_fields["assignments"].type, "and").type.of_type.of_type._meta.name == "assignments_and" + assert getattr(filter_fields["assignments"].type, "or").type.of_type.of_type._meta.name == "assignments_or" + assert getattr(filter_fields["assignments"].type, "not").type._meta.name == "assignments_not" + assert filter_fields["assignments"].type.task.type._meta.name == "assignments_task" + assert getattr(filter_fields["assignments"].type.task.type, "and").type.of_type.of_type._meta.name == "assignments_task_and" + assert getattr(filter_fields["assignments"].type.task.type, "or").type.of_type.of_type._meta.name == "assignments_task_or" + assert getattr(filter_fields["assignments"].type.task.type, "not").type._meta.name == "assignments_task_not" diff --git a/tests/test_integration.py b/tests/test_integration.py index 5b673d5..04b18e8 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -6,7 +6,7 @@ graphene_sqlalchemy_version_lt_2_1_2, ) from tests.graphql_objects import schema -from tests.models import Group, Membership, User +from tests.models import Group, Membership, User, Task, Assignment from tests.utils import SQLAlchemyQueryCounter @@ -82,6 +82,43 @@ def add_users_to_new_groups(session, users): return groups +def add_tasks(session): + tasks = [ + Task(name='Write code'), + Task(name='Write documentation'), + Task(name='Make breakfast'), + ] + session.bulk_save_objects(tasks, return_defaults=True) + session.flush() + + return tasks + + +def assign_users_to_tasks(session, users): + tasks = add_tasks(session) + + assignments = [ + Assignment( + user_id=users[0].id, + task_id=tasks[0].id, + active=True + ), + Assignment( + user_id=users[0].id, + task_id=tasks[1].id, + active=False + ), + Assignment( + user_id=users[1].id, + task_id=tasks[2].id, + active=True + ) + ] + session.bulk_save_objects(assignments, return_defaults=True) + session.flush() + return assignments + + def test_response_without_filters(session): add_users(session) session.commit() @@ -322,3 +359,49 @@ def test_nested_response_with_recursive_model(session): assert len(group_0_sub_groups_edges) == 2 sub_group_name = group_0_sub_groups_edges[0]['node']['name'] assert sub_group_name == 'group_2' + + +def test_relationship_filtering(session): + users = add_users(session) + assign_users_to_tasks(session, users) + session.commit() + + request_string = """{ + field(filters: { + assignments: { + and: [ + { + task: { + name: "Write code", + } + }, + { + active: true + } + ] + } + }){ + edges{ + node{ + username + assignments{ + edges{ + node{ + active + task { + name + } + } + } + } + } + } + } + }""" + + execution_result = schema.execute( + request_string, context={'session': session} + ) + assert len(execution_result.data["field"]["edges"]) == 1 # Only user_1 matches + assert execution_result.data["field"]["edges"][0]["node"]["username"] == "user_1" + assert len(execution_result.data["field"]["edges"][0]["node"]["assignments"]["edges"]) == 2 diff --git a/tests/test_query.py b/tests/test_query.py index 1b405d2..111db0c 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -54,7 +54,6 @@ def test_enum(info_and_user_query): where_clause = query.whereclause ok = '"user".status = :status_1' assert str(where_clause) == ok - assert where_clause.right.effective_value == models.StatusEnum.online @@ -194,3 +193,39 @@ def test_complex_filters(info_and_user_query): str_query = str(query) assert str_query.lower().count('join') == 4, str_query + + +def test_complex_relationship_filters(info_and_user_query): + info, user_query = info_and_user_query + + filters = { + 'not': {'is_active': True}, + 'or': [ + {'is_admin': False}, + { + 'assignments': { + 'or': [ + { + 'task': {'name': 'Write code'} + }, + {'active': True} + ] + } + } + ] + } + query = UserFilter.filter(info, user_query, filters) + + ok = ( + '"user".is_active != true AND ("user".username != :username_1 OR (EXISTS (SELECT 1' + ' FROM "user", assignment' + ' WHERE "user".user_id = assignment.user_id AND ((EXISTS (SELECT 1' + ' FROM assignment' + ' WHERE "user".user_id = assignment.user_id AND (EXISTS (SELECT 1' + ' FROM task' + ' WHERE task.id = assignment.task_id AND task.name = :name_1)))) OR (EXISTS (SELECT 1' + ' FROM assignment' + ' WHERE "user".user_id = assignment.user_id AND assignment.active = true))))))' + ) + where_clause = str(query.whereclause).replace('\n', '') + assert where_clause == ok From 20e0a697bb507f3a0d42fab9bfe7fbf7a5db0a4f Mon Sep 17 00:00:00 2001 From: Paul Schweizer Date: Fri, 26 Feb 2021 18:58:55 +0100 Subject: [PATCH 2/7] Remove print --- graphene_sqlalchemy_filter/filters.py | 1 - 1 file changed, 1 deletion(-) diff --git a/graphene_sqlalchemy_filter/filters.py b/graphene_sqlalchemy_filter/filters.py index 509529d..7a6c0a0 100644 --- a/graphene_sqlalchemy_filter/filters.py +++ b/graphene_sqlalchemy_filter/filters.py @@ -792,7 +792,6 @@ def filter( model=cls.model, parent_attr=None ) - print(result) if not isinstance(result["__root__"], dict): return query.filter(result["__root__"]) else: From 3b12d84cce2844de60e22f4c2df02a8f0aae3be2 Mon Sep 17 00:00:00 2001 From: Paul Schweizer Date: Mon, 1 Mar 2021 13:13:12 +0100 Subject: [PATCH 3/7] Fixing linter issues --- graphene_sqlalchemy_filter/filters.py | 84 ++++++++++++++++----------- tests/graphql_objects.py | 10 +--- tests/test_filter_set.py | 45 +++++++++++--- tests/test_integration.py | 27 +++------ tests/test_query.py | 33 +++++------ 5 files changed, 114 insertions(+), 85 deletions(-) diff --git a/graphene_sqlalchemy_filter/filters.py b/graphene_sqlalchemy_filter/filters.py index 7a6c0a0..eb7e90c 100644 --- a/graphene_sqlalchemy_filter/filters.py +++ b/graphene_sqlalchemy_filter/filters.py @@ -20,11 +20,11 @@ from sqlalchemy.exc import SAWarning from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import aliased -from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.query import Query from sqlalchemy.orm.relationships import RelationshipProperty from sqlalchemy.sql import sqltypes +import sqlalchemy.ext.declarative.api MYPY = False @@ -497,10 +497,10 @@ def _aliases_from_query(cls, query: Query) -> 'Dict[str, _MapperEntity]': @classmethod def _generate_default_filters( - cls, model, - field_filters: 'Dict[str, Union[Iterable[str], Any]]') -> dict: + cls, model, field_filters: 'Dict[str, Union[Iterable[str], Any]]' + ) -> dict: """ - Generate GraphQL fields from SQLAlchemy model columns and relationships. + Generate GraphQL fields from SQLAlchemy model columns & relationships. Args: model: SQLAlchemy model. @@ -518,7 +518,7 @@ def _generate_default_filters( parent_field=None, parent=None, field_filters=field_filters, - parents=[] + parents=[], ) for name, field in fields.items(): graphql_filters[name] = get_field_as(field, graphene.InputField) @@ -595,10 +595,7 @@ def _get_model_fields_data(cls, model, only_fields: 'Iterable[str]'): name = attr.key if name not in only_fields: continue - model_fields[name] = { - 'column': attr, - 'type': 'relationship' - } + model_fields[name] = {'column': attr, 'type': 'relationship'} return model_fields @@ -652,10 +649,10 @@ def _recursively_generate_filters( cls, model: 'sqlalchemy.ext.declarative.api.DeclarativeMeta', fields: dict, - parent_field: 'Union[String, None]', + parent_field: 'Union[str, None]', parent: 'Union[dict, None]', field_filters: dict, - parents: 'List[String]' + parents: 'List[str]', ) -> dict: """ Recursively geenerate filters from the given fields. @@ -686,7 +683,7 @@ def _recursively_generate_filters( parent_field=field_name, parent=fields, field_filters=field_filters[field_name], - parents=parents + [field_name] + parents=parents + [field_name], ) continue @@ -720,7 +717,10 @@ def _recursively_generate_filters( fields.update( cls._generate_filter_fields( - expressions, field_name, field_type, field_object['nullable'] + expressions, + field_name, + field_type, + field_object['nullable'], ) ) @@ -736,15 +736,21 @@ def _recursively_generate_filters( # Need to concatenate to keep the types unique OperatorAutoType = type( - "_".join(parents + [op]), (graphene.InputObjectType, ), filters) + "_".join(parents + [op]), + (graphene.InputObjectType,), + filters, + ) filters[graphql_name] = graphene.InputField( - cls.FILTER_OBJECT_TYPES[op](OperatorAutoType, False, doc), description=doc + cls.FILTER_OBJECT_TYPES[op](OperatorAutoType, False, doc), + description=doc, ) AutoType = type( - "_".join(parents), (graphene.InputObjectType, ), filters) + "_".join(parents), (graphene.InputObjectType,), filters + ) parent[parent_field] = graphene.InputField( - AutoType, description=parent_field) + AutoType, description=parent_field + ) return fields @@ -778,9 +784,7 @@ def filter( ).format(type(context)) warnings.warn(msg, RuntimeWarning) - result = { - "__root__": {} - } + result = {"__root__": {}} query = cls._translate_many_filter( info=info, query=query, @@ -790,7 +794,7 @@ def filter( result=result["__root__"], parent_result=result, model=cls.model, - parent_attr=None + parent_attr=None, ) if not isinstance(result["__root__"], dict): return query.filter(result["__root__"]) @@ -837,11 +841,13 @@ def _translate_many_filter( query: Query, filters: 'Union[List[FilterType], FilterType]', join_by: 'Callable', - parent: 'String', + parent: str, result: dict, parent_result: dict, model: 'sqlalchemy.ext.declarative.api.DeclarativeMeta', - parent_attr: 'Union[sqlalchemy.orm.attributes.InstrumentedAttribute, None]' + parent_attr: ( + 'Union[' 'sqlalchemy.orm.attributes.InstrumentedAttribute, None]' + ), ) -> 'Query': """ Recursively translate filters. @@ -864,7 +870,7 @@ def _translate_many_filter( join_by_map = { "and": and_, "or": or_, - "not": lambda *x: not_(and_(*x)) + "not": lambda *x: not_(and_(*x)), } for key, value in filters.items(): @@ -892,10 +898,15 @@ def _translate_many_filter( result=result[key][op_key], parent_result=result[key], model=model, - parent_attr=parent_attr + parent_attr=parent_attr, ) result[key] = join_by_map[key]( - *[v for v in result[key].values() if not isinstance(v, dict)]) + *[ + v + for v in result[key].values() + if not isinstance(v, dict) + ] + ) else: query = cls._translate_many_filter( info=info, @@ -906,7 +917,7 @@ def _translate_many_filter( result=result[key], parent_result=result, model=model, - parent_attr=parent_attr + parent_attr=parent_attr, ) continue @@ -927,7 +938,7 @@ def _translate_many_filter( result=result[key], parent_result=result, model=relationship.mapper.class_, - parent_attr=getattr(model, key) + parent_attr=getattr(model, key), ) break else: @@ -951,11 +962,18 @@ def _translate_many_filter( return query if parent_attr is None: parent_result[parent] = join_by( - *[v for v in result.values() if not isinstance(v, dict)]) + *[v for v in result.values() if not isinstance(v, dict)] + ) elif parent_attr.property.uselist: - parent_result[parent] = parent_attr.any(join_by( - *[v for v in result.values() if not isinstance(v, dict)])) + parent_result[parent] = parent_attr.any( + join_by( + *[v for v in result.values() if not isinstance(v, dict)] + ) + ) else: - parent_result[parent] = parent_attr.has(join_by( - *[v for v in result.values() if not isinstance(v, dict)])) + parent_result[parent] = parent_attr.has( + join_by( + *[v for v in result.values() if not isinstance(v, dict)] + ) + ) return query diff --git a/tests/graphql_objects.py b/tests/graphql_objects.py index 09e34e8..316fc4d 100644 --- a/tests/graphql_objects.py +++ b/tests/graphql_objects.py @@ -35,13 +35,7 @@ class Meta: 'balance': ['eq', 'ne', 'gt', 'lt', 'range', 'is_null'], 'is_active': ['eq', 'ne'], 'username_hybrid_property': ['eq', 'ne', 'in'], - 'assignments': { - 'task': { - 'name': ['eq'], - 'id': ['eq'] - }, - 'active': ['eq'] - } + 'assignments': {'task': {'name': ['eq'], 'id': ['eq']}, 'active': ['eq']}, } @@ -226,7 +220,7 @@ class Query(graphene.ObjectType): all_authors = MyFilterableConnectionField(AuthorConnection) all_articles = MyFilterableConnectionField(ArticleConnection) tasks = MyFilterableConnectionField(TaskConnection) - assignments = MyFilterableConnectionField(AssignmentConnection) + assignments = MyFilterableConnectionField(AssignmentConnection) schema = graphene.Schema(query=Query) diff --git a/tests/test_filter_set.py b/tests/test_filter_set.py index 734c599..66459f6 100644 --- a/tests/test_filter_set.py +++ b/tests/test_filter_set.py @@ -327,10 +327,41 @@ def test_generate_relationship_filter_field_names_concatenate_parents(): filter_fields = deepcopy(UserFilter._meta.fields) assert filter_fields["assignments"].type._meta.name == "assignments" - assert getattr(filter_fields["assignments"].type, "and").type.of_type.of_type._meta.name == "assignments_and" - assert getattr(filter_fields["assignments"].type, "or").type.of_type.of_type._meta.name == "assignments_or" - assert getattr(filter_fields["assignments"].type, "not").type._meta.name == "assignments_not" - assert filter_fields["assignments"].type.task.type._meta.name == "assignments_task" - assert getattr(filter_fields["assignments"].type.task.type, "and").type.of_type.of_type._meta.name == "assignments_task_and" - assert getattr(filter_fields["assignments"].type.task.type, "or").type.of_type.of_type._meta.name == "assignments_task_or" - assert getattr(filter_fields["assignments"].type.task.type, "not").type._meta.name == "assignments_task_not" + assert ( + getattr( + filter_fields["assignments"].type, "and" + ).type.of_type.of_type._meta.name + == "assignments_and" + ) + assert ( + getattr( + filter_fields["assignments"].type, "or" + ).type.of_type.of_type._meta.name + == "assignments_or" + ) + assert ( + getattr(filter_fields["assignments"].type, "not").type._meta.name + == "assignments_not" + ) + assert ( + filter_fields["assignments"].type.task.type._meta.name + == "assignments_task" + ) + assert ( + getattr( + filter_fields["assignments"].type.task.type, "and" + ).type.of_type.of_type._meta.name + == "assignments_task_and" + ) + assert ( + getattr( + filter_fields["assignments"].type.task.type, "or" + ).type.of_type.of_type._meta.name + == "assignments_task_or" + ) + assert ( + getattr( + filter_fields["assignments"].type.task.type, "not" + ).type._meta.name + == "assignments_task_not" + ) diff --git a/tests/test_integration.py b/tests/test_integration.py index 04b18e8..e6bc081 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -6,7 +6,7 @@ graphene_sqlalchemy_version_lt_2_1_2, ) from tests.graphql_objects import schema -from tests.models import Group, Membership, User, Task, Assignment +from tests.models import Assignment, Group, Membership, Task, User from tests.utils import SQLAlchemyQueryCounter @@ -98,21 +98,9 @@ def assign_users_to_tasks(session, users): tasks = add_tasks(session) assignments = [ - Assignment( - user_id=users[0].id, - task_id=tasks[0].id, - active=True - ), - Assignment( - user_id=users[0].id, - task_id=tasks[1].id, - active=False - ), - Assignment( - user_id=users[1].id, - task_id=tasks[2].id, - active=True - ) + Assignment(user_id=users[0].id, task_id=tasks[0].id, active=True), + Assignment(user_id=users[0].id, task_id=tasks[1].id, active=False), + Assignment(user_id=users[1].id, task_id=tasks[2].id, active=True), ] session.bulk_save_objects(assignments, return_defaults=True) session.flush() @@ -402,6 +390,7 @@ def test_relationship_filtering(session): execution_result = schema.execute( request_string, context={'session': session} ) - assert len(execution_result.data["field"]["edges"]) == 1 # Only user_1 matches - assert execution_result.data["field"]["edges"][0]["node"]["username"] == "user_1" - assert len(execution_result.data["field"]["edges"][0]["node"]["assignments"]["edges"]) == 2 + edges = execution_result.data["field"]["edges"] + assert len(execution_result.data["field"]["edges"]) == 1 + assert edges[0]["node"]["username"] == "user_1" + assert len(edges[0]["node"]["assignments"]["edges"]) == 2 diff --git a/tests/test_query.py b/tests/test_query.py index 111db0c..580007f 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -199,33 +199,30 @@ def test_complex_relationship_filters(info_and_user_query): info, user_query = info_and_user_query filters = { - 'not': {'is_active': True}, - 'or': [ - {'is_admin': False}, - { - 'assignments': { - 'or': [ - { - 'task': {'name': 'Write code'} - }, - {'active': True} - ] - } + 'not': {'is_active': True}, + 'or': [ + {'is_admin': False}, + { + 'assignments': { + 'or': [{'task': {'name': 'Write code'}}, {'active': True}] } - ] - } + }, + ], + } query = UserFilter.filter(info, user_query, filters) ok = ( - '"user".is_active != true AND ("user".username != :username_1 OR (EXISTS (SELECT 1' + '"user".is_active != true AND ' + '("user".username != :username_1 OR (EXISTS (SELECT 1' ' FROM "user", assignment' ' WHERE "user".user_id = assignment.user_id AND ((EXISTS (SELECT 1' ' FROM assignment' ' WHERE "user".user_id = assignment.user_id AND (EXISTS (SELECT 1' ' FROM task' - ' WHERE task.id = assignment.task_id AND task.name = :name_1)))) OR (EXISTS (SELECT 1' - ' FROM assignment' - ' WHERE "user".user_id = assignment.user_id AND assignment.active = true))))))' + ' WHERE task.id = assignment.task_id AND task.name = :name_1)))) OR ' + '(EXISTS (SELECT 1 FROM assignment WHERE ' + '"user".user_id = assignment.user_id AND ' + 'assignment.active = true))))))' ) where_clause = str(query.whereclause).replace('\n', '') assert where_clause == ok From 283b88c8b84ffa1e8d9cdeba5a9c329c090d6768 Mon Sep 17 00:00:00 2001 From: Paul Schweizer Date: Mon, 1 Mar 2021 14:44:20 +0100 Subject: [PATCH 4/7] Adding relationships filtering to readme --- README.md | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index c6fe987..52f4e95 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,12 @@ class UserFilter(FilterSet): fields = { 'username': ['eq', 'ne', 'in', 'ilike'], 'is_active': [...], # shortcut! + 'assignments': { # (nested) relationships are supported + 'task': { + 'name': ['eq'], + }, + 'active': ['eq'] + } } @staticmethod @@ -45,7 +51,12 @@ Now, we're going to create query. or: [ {isAdmin: true}, {usernameIn: ["moderator", "cool guy"]} - ] + ], + assignments: { + task: { + name: "Write code" + } + } } ){ edges { @@ -65,10 +76,10 @@ Now, we're going to create query. FilterSet class must inherit `graphene_sqlalchemy_filter.FilterSet` or your subclass of this class. -There are three types of filters: - 1. [automatically generated filters](#automatically-generated-filters) - 1. [simple filters](#simple-filters) - 1. [filters that require join](#filters-that-require-join) +There are three types of filters: + 1. [automatically generated filters](#automatically-generated-filters) + 1. [simple filters](#simple-filters) + 1. [filters that require join](#filters-that-require-join) ## Automatically generated filters @@ -77,14 +88,20 @@ class UserFilter(FilterSet): class Meta: model = User fields = { - 'username': ['eq', 'ne', 'in', 'ilike'], - 'is_active': [...], # shortcut! + 'username': ['eq', 'ne', 'in', 'ilike'], + 'is_active': [...], # shortcut! + 'assignments': { # (nested) relationships are supported + 'task': { + 'name': ['eq'], + }, + 'active': ['eq'] + } } ``` Metaclass must contain the sqlalchemy model and fields. -Automatically generated filters must be specified by `fields` variable. -Key - field name of sqlalchemy model, value - list of expressions (or shortcut). +Automatically generated filters must be specified by `fields` variable. +Key - field name of sqlalchemy model, value - list of expressions (or shortcut). For relationship fields, the value is another dictionary defining the filters for the related model. Shortcut (default: `[...]`) will add all the allowed filters for this type of sqlalchemy field (does not work with hybrid property). @@ -141,7 +158,7 @@ class UserFilter(FilterSet): @classmethod def is_moderator_filter(cls, info, query, value): membership = cls.aliased(query, Membership, name='is_moderator') - + query = query.outerjoin( membership, and_( @@ -207,7 +224,7 @@ class UserFilter(FilterSet): class Meta: model = User fields = {'is_active': [...]} - + class CustomField(FilterableConnectionField): @@ -431,7 +448,7 @@ class MyString(types.String): class BaseFilter(FilterSet): # You can override all allowed filters # ALLOWED_FILTERS = {types.Integer: ['eq']} - + # Or add new column type EXTRA_ALLOWED_FILTERS = {MyString: ['eq']} From e7aaa3911741441adae74661a0ef783a36d42e36 Mon Sep 17 00:00:00 2001 From: Paul Schweizer Date: Fri, 12 Mar 2021 10:25:09 +0100 Subject: [PATCH 5/7] Support for associationproxy Also fixing a problem with the schema type names --- graphene_sqlalchemy_filter/filters.py | 125 ++++++++++++++++++++++---- tests/graphql_objects.py | 7 ++ tests/models.py | 12 +++ tests/test_associationproxy.py | 24 +++++ tests/test_filter_set.py | 16 ++-- 5 files changed, 157 insertions(+), 27 deletions(-) create mode 100644 tests/test_associationproxy.py diff --git a/graphene_sqlalchemy_filter/filters.py b/graphene_sqlalchemy_filter/filters.py index eb7e90c..28f0218 100644 --- a/graphene_sqlalchemy_filter/filters.py +++ b/graphene_sqlalchemy_filter/filters.py @@ -18,6 +18,10 @@ from sqlalchemy import and_, cast, inspection, not_, or_, types from sqlalchemy.dialects import postgresql from sqlalchemy.exc import SAWarning +from sqlalchemy.ext.associationproxy import ( + AssociationProxy, + AssociationProxyInstance, +) from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import aliased from sqlalchemy.orm.properties import ColumnProperty @@ -518,7 +522,7 @@ def _generate_default_filters( parent_field=None, parent=None, field_filters=field_filters, - parents=[], + parents=[model.__tablename__], ) for name, field in fields.items(): graphql_filters[name] = get_field_as(field, graphene.InputField) @@ -577,6 +581,41 @@ def _get_model_fields_data(cls, model, only_fields: 'Iterable[str]'): 'nullable': True, } + elif isinstance(descr, AssociationProxy): + name = None + for attr in dir(model): + field = getattr(model, attr) + if ( + isinstance(field, AssociationProxyInstance) + # We have to get the actual field name + and not attr.startswith("_AssociationProxy") + and field.value_attr + == descr.for_class(model).value_attr + and field.target_class + == descr.for_class(model).target_class + ): + name = attr + break + if name not in only_fields: + continue + proxy_class = descr.for_class(model).target_class + proxy_attr = descr.for_class(model).value_attr + column = getattr(proxy_class, proxy_attr) + if hasattr(column, "mapper"): + related_model = column.mapper.class_ + model_fields[name] = { + 'column': descr.for_class(model), + 'related_model': related_model, + 'type': 'relationship', + } + else: + print(column.nullable) + model_fields[name] = { + 'column': column, + 'type': column.type, + 'nullable': column.nullable, + } + elif isinstance(descr.property, ColumnProperty): attr = descr.property name = attr.key @@ -595,7 +634,11 @@ def _get_model_fields_data(cls, model, only_fields: 'Iterable[str]'): name = attr.key if name not in only_fields: continue - model_fields[name] = {'column': attr, 'type': 'relationship'} + model_fields[name] = { + 'column': attr, + 'related_model': attr.mapper.class_, + 'type': 'relationship', + } return model_fields @@ -675,10 +718,9 @@ def _recursively_generate_filters( for field_name, field_object in model_fields.items(): column_type = field_object['type'] if column_type == 'relationship': - rel_model = field_object['column'].mapper.class_ fields[field_name] = {} cls._recursively_generate_filters( - model=rel_model, + model=field_object['related_model'], fields=fields[field_name], parent_field=field_name, parent=fields, @@ -848,6 +890,7 @@ def _translate_many_filter( parent_attr: ( 'Union[' 'sqlalchemy.orm.attributes.InstrumentedAttribute, None]' ), + expression_type: str = None, ) -> 'Query': """ Recursively translate filters. @@ -862,6 +905,7 @@ def _translate_many_filter( parent_result: Parent for the result. model: The model to translate. parent_attr: The parent attribute (for relationships). + expression_type: Either 'any' or 'has' for relationships. Returns: Query. @@ -899,6 +943,7 @@ def _translate_many_filter( parent_result=result[key], model=model, parent_attr=parent_attr, + expression_type=expression_type ) result[key] = join_by_map[key]( *[ @@ -918,6 +963,7 @@ def _translate_many_filter( parent_result=result, model=model, parent_attr=parent_attr, + expression_type=expression_type ) continue @@ -927,20 +973,61 @@ def _translate_many_filter( if isinstance(value, dict) and expression != "range": result.setdefault(key, {}) inspected = inspection.inspect(model) - for relationship in inspected.relationships: - if relationship.key == key: - query = cls._translate_many_filter( - info=info, - query=query, - filters=value, - join_by=and_, - parent=key, - result=result[key], - parent_result=result, - model=relationship.mapper.class_, - parent_attr=getattr(model, key), - ) - break + for descr in inspected.all_orm_descriptors: + if isinstance( + getattr(descr, "property", None), RelationshipProperty + ): + if descr.key == key: + expression_type_ = "has" + if getattr(model, key).property.uselist: + expression_type_ = "any" + query = cls._translate_many_filter( + info=info, + query=query, + filters=value, + join_by=and_, + parent=key, + result=result[key], + parent_result=result, + model=descr.mapper.class_, + parent_attr=getattr(model, key), + expression_type=expression_type_, + ) + break + elif isinstance(descr, AssociationProxy): + parent_attr_ = getattr(model, key) + if ( + isinstance(parent_attr_, AssociationProxyInstance) + and parent_attr_.value_attr + == descr.for_class(model).value_attr + and parent_attr_.target_class + == descr.for_class(model).target_class + ): + proxy_class = descr.for_class(model).target_class + proxy_attr = descr.for_class(model).value_attr + related_model = getattr( + proxy_class, proxy_attr + ).mapper.class_ + + expression_type_ = "has" + if getattr( + model, descr.target_collection + ).property.uselist: + expression_type_ = "any" + + query = cls._translate_many_filter( + info=info, + query=query, + filters=value, + join_by=and_, + parent=key, + result=result[key], + parent_result=result, + model=related_model, + parent_attr=getattr(model, key), + expression_type=expression_type_, + ) + break else: try: model_field = getattr(model, field) @@ -964,7 +1051,7 @@ def _translate_many_filter( parent_result[parent] = join_by( *[v for v in result.values() if not isinstance(v, dict)] ) - elif parent_attr.property.uselist: + elif expression_type == "any": parent_result[parent] = parent_attr.any( join_by( *[v for v in result.values() if not isinstance(v, dict)] diff --git a/tests/graphql_objects.py b/tests/graphql_objects.py index 316fc4d..90c5d02 100644 --- a/tests/graphql_objects.py +++ b/tests/graphql_objects.py @@ -122,6 +122,12 @@ class Meta: } +class TaskFilter(FilterSet): + class Meta: + model = Task + fields = {'users': {'username': [...]}, 'status_name': [...]} + + class MyFilterableConnectionField(FilterableConnectionField): filters = { User: UserFilter(), @@ -129,6 +135,7 @@ class MyFilterableConnectionField(FilterableConnectionField): Group: GroupFilter(), Article: ArticleFilter(), Author: AuthorFilter(), + Task: TaskFilter(), } diff --git a/tests/models.py b/tests/models.py index 929dfcc..165ccd4 100644 --- a/tests/models.py +++ b/tests/models.py @@ -12,6 +12,7 @@ String, func, ) +from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import backref, relationship @@ -115,6 +116,10 @@ class Task(Base): id = Column(Integer, primary_key=True) name = Column(String(32)) assignments = relationship('Assignment', back_populates='task') + users = association_proxy('assignments', 'user') + status_id = Column(Integer, ForeignKey('status.id')) + status = relationship('Status') + status_name = association_proxy('status', 'name') class Assignment(Base): @@ -126,3 +131,10 @@ class Assignment(Base): user = relationship('User', back_populates='assignments') active = Column(Boolean) + + +class Status(Base): + __tablename__ = 'status' + + id = Column(Integer, primary_key=True) + name = Column(String(32)) diff --git a/tests/test_associationproxy.py b/tests/test_associationproxy.py new file mode 100644 index 0000000..efbfc69 --- /dev/null +++ b/tests/test_associationproxy.py @@ -0,0 +1,24 @@ +# Database +from sqlalchemy.orm import Query + +# Project +from tests import models +from tests.graphql_objects import TaskFilter + + +def test_sql_query(info): + filters = {'users': {'username': 'user name'}, 'status_name': 'done'} + task_query = Query(models.Task) + query = TaskFilter.filter(info, task_query, filters) + where_clause = str(query.whereclause) + ok = ( + '(EXISTS (SELECT 1' + ' FROM task, assignment' + ' WHERE task.id = assignment.task_id AND (EXISTS (SELECT 1' + ' FROM "user"' + ' WHERE "user".user_id = assignment.user_id AND ' + '"user".username = :username_1)))) AND (EXISTS (SELECT 1' + ' FROM status, task' + ' WHERE status.id = task.status_id AND status.name = :name_1))' + ) + assert where_clause.replace('\n', '') == ok diff --git a/tests/test_filter_set.py b/tests/test_filter_set.py index 66459f6..12d5621 100644 --- a/tests/test_filter_set.py +++ b/tests/test_filter_set.py @@ -326,42 +326,42 @@ class Meta: def test_generate_relationship_filter_field_names_concatenate_parents(): filter_fields = deepcopy(UserFilter._meta.fields) - assert filter_fields["assignments"].type._meta.name == "assignments" + assert filter_fields["assignments"].type._meta.name == "user_assignments" assert ( getattr( filter_fields["assignments"].type, "and" ).type.of_type.of_type._meta.name - == "assignments_and" + == "user_assignments_and" ) assert ( getattr( filter_fields["assignments"].type, "or" ).type.of_type.of_type._meta.name - == "assignments_or" + == "user_assignments_or" ) assert ( getattr(filter_fields["assignments"].type, "not").type._meta.name - == "assignments_not" + == "user_assignments_not" ) assert ( filter_fields["assignments"].type.task.type._meta.name - == "assignments_task" + == "user_assignments_task" ) assert ( getattr( filter_fields["assignments"].type.task.type, "and" ).type.of_type.of_type._meta.name - == "assignments_task_and" + == "user_assignments_task_and" ) assert ( getattr( filter_fields["assignments"].type.task.type, "or" ).type.of_type.of_type._meta.name - == "assignments_task_or" + == "user_assignments_task_or" ) assert ( getattr( filter_fields["assignments"].type.task.type, "not" ).type._meta.name - == "assignments_task_not" + == "user_assignments_task_not" ) From 6939fe6fde17c5a8a748d6acea6138e2df2c1e70 Mon Sep 17 00:00:00 2001 From: Paul Schweizer Date: Mon, 31 May 2021 12:30:52 +0200 Subject: [PATCH 6/7] Appendix number for duplicated auto type names --- graphene_sqlalchemy_filter/filters.py | 23 ++++++++++++++++++----- setup.py | 2 +- tests/conftest.py | 6 +++--- tests/graphql_objects.py | 13 ++++++++++++- tests/models.py | 2 +- tests/test_associationproxy.py | 6 +++--- tests/test_query.py | 18 +++++++++--------- 7 files changed, 47 insertions(+), 23 deletions(-) diff --git a/graphene_sqlalchemy_filter/filters.py b/graphene_sqlalchemy_filter/filters.py index 28f0218..14df78a 100644 --- a/graphene_sqlalchemy_filter/filters.py +++ b/graphene_sqlalchemy_filter/filters.py @@ -265,6 +265,8 @@ class FilterSet(graphene.InputObjectType): NOT: 'Negation of filters.', } + AUTO_TYPE_NAMES = {} + class Meta: abstract = True @@ -698,7 +700,7 @@ def _recursively_generate_filters( parents: 'List[str]', ) -> dict: """ - Recursively geenerate filters from the given fields. + Recursively generate filters from the given fields. Args: model: The model. @@ -772,13 +774,24 @@ def _recursively_generate_filters( for name, field in fields.items(): filters[name] = get_field_as(field, graphene.InputField) + # Make sure the name of the type does not repeat globally + auto_type_name = "_".join(parents) + if auto_type_name in FilterSet.AUTO_TYPE_NAMES: + FilterSet.AUTO_TYPE_NAMES[auto_type_name] += 1 + auto_type_name = ( + f"{auto_type_name}" + f"{FilterSet.AUTO_TYPE_NAMES[auto_type_name]}" + ) + else: + FilterSet.AUTO_TYPE_NAMES[auto_type_name] = 0 + for op in [cls.AND, cls.OR, cls.NOT]: doc = cls.DESCRIPTIONS.get(op) graphql_name = cls.GRAPHQL_EXPRESSION_NAMES[op] # Need to concatenate to keep the types unique OperatorAutoType = type( - "_".join(parents + [op]), + f"{auto_type_name}_{op}", (graphene.InputObjectType,), filters, ) @@ -788,7 +801,7 @@ def _recursively_generate_filters( ) AutoType = type( - "_".join(parents), (graphene.InputObjectType,), filters + auto_type_name, (graphene.InputObjectType,), filters ) parent[parent_field] = graphene.InputField( AutoType, description=parent_field @@ -943,7 +956,7 @@ def _translate_many_filter( parent_result=result[key], model=model, parent_attr=parent_attr, - expression_type=expression_type + expression_type=expression_type, ) result[key] = join_by_map[key]( *[ @@ -963,7 +976,7 @@ def _translate_many_filter( parent_result=result, model=model, parent_attr=parent_attr, - expression_type=expression_type + expression_type=expression_type, ) continue diff --git a/setup.py b/setup.py index 930e7de..9daed6d 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ requirements = [ 'graphene-sqlalchemy>=2.1.0,<3', - 'SQLAlchemy<2', + 'SQLAlchemy==1.3.23', ] diff --git a/tests/conftest.py b/tests/conftest.py index 5252a21..82ddd15 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ from tests.models import Base -@pytest.yield_fixture(scope="function") +@pytest.fixture(scope="function") def session(): db = create_engine('sqlite://') # in-memory connection = db.engine.connect() @@ -30,7 +30,7 @@ def session(): session.remove() -@pytest.yield_fixture(scope="function") +@pytest.fixture(scope="function") def info(): db = create_engine('sqlite://') # in-memory connection = db.engine.connect() @@ -47,7 +47,7 @@ def info(): session.remove() -@pytest.yield_fixture(scope="function") +@pytest.fixture(scope="function") def info_and_user_query(): db = create_engine('sqlite://') # in-memory connection = db.engine.connect() diff --git a/tests/graphql_objects.py b/tests/graphql_objects.py index 90c5d02..26a1959 100644 --- a/tests/graphql_objects.py +++ b/tests/graphql_objects.py @@ -125,7 +125,17 @@ class Meta: class TaskFilter(FilterSet): class Meta: model = Task - fields = {'users': {'username': [...]}, 'status_name': [...]} + fields = { + 'users': {'username': [...]}, + 'status_name': [...], + 'assignments': {'task': {'name': [...]}}, + } + + +class AssignmentFilter(FilterSet): + class Meta: + model = Assignment + fields = {'task': {'assignments': {'active': [...]}}} class MyFilterableConnectionField(FilterableConnectionField): @@ -136,6 +146,7 @@ class MyFilterableConnectionField(FilterableConnectionField): Article: ArticleFilter(), Author: AuthorFilter(), Task: TaskFilter(), + Assignment: AssignmentFilter(), } diff --git a/tests/models.py b/tests/models.py index 165ccd4..d4827fa 100644 --- a/tests/models.py +++ b/tests/models.py @@ -123,7 +123,7 @@ class Task(Base): class Assignment(Base): - __tablename__ = 'assignment' + __tablename__ = 'task_assignments' task_id = Column(Integer, ForeignKey('task.id'), primary_key=True) task = relationship('Task', back_populates='assignments') diff --git a/tests/test_associationproxy.py b/tests/test_associationproxy.py index efbfc69..c58d37c 100644 --- a/tests/test_associationproxy.py +++ b/tests/test_associationproxy.py @@ -13,10 +13,10 @@ def test_sql_query(info): where_clause = str(query.whereclause) ok = ( '(EXISTS (SELECT 1' - ' FROM task, assignment' - ' WHERE task.id = assignment.task_id AND (EXISTS (SELECT 1' + ' FROM task, task_assignments' + ' WHERE task.id = task_assignments.task_id AND (EXISTS (SELECT 1' ' FROM "user"' - ' WHERE "user".user_id = assignment.user_id AND ' + ' WHERE "user".user_id = task_assignments.user_id AND ' '"user".username = :username_1)))) AND (EXISTS (SELECT 1' ' FROM status, task' ' WHERE status.id = task.status_id AND status.name = :name_1))' diff --git a/tests/test_query.py b/tests/test_query.py index 580007f..e21c633 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -214,15 +214,15 @@ def test_complex_relationship_filters(info_and_user_query): ok = ( '"user".is_active != true AND ' '("user".username != :username_1 OR (EXISTS (SELECT 1' - ' FROM "user", assignment' - ' WHERE "user".user_id = assignment.user_id AND ((EXISTS (SELECT 1' - ' FROM assignment' - ' WHERE "user".user_id = assignment.user_id AND (EXISTS (SELECT 1' - ' FROM task' - ' WHERE task.id = assignment.task_id AND task.name = :name_1)))) OR ' - '(EXISTS (SELECT 1 FROM assignment WHERE ' - '"user".user_id = assignment.user_id AND ' - 'assignment.active = true))))))' + ' FROM "user", task_assignments' + ' WHERE "user".user_id = task_assignments.user_id AND ((EXISTS' + ' (SELECT 1 FROM task_assignments' + ' WHERE "user".user_id = task_assignments.user_id AND (EXISTS' + ' (SELECT 1 FROM task' + ' WHERE task.id = task_assignments.task_id AND task.name = :name_1))))' + ' OR (EXISTS (SELECT 1 FROM task_assignments WHERE ' + '"user".user_id = task_assignments.user_id AND ' + 'task_assignments.active = true))))))' ) where_clause = str(query.whereclause).replace('\n', '') assert where_clause == ok From c513a34edff93f374ec40a5c2895362bb58dfa2b Mon Sep 17 00:00:00 2001 From: Paul Schweizer Date: Thu, 3 Jun 2021 10:49:48 +0200 Subject: [PATCH 7/7] Adjust tests for sqlalchemy 1.4+ --- graphene_sqlalchemy_filter/filters.py | 5 ++--- setup.py | 2 +- tests/test_postgres.py | 8 ++++---- tests/test_query.py | 12 +++++++++--- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/graphene_sqlalchemy_filter/filters.py b/graphene_sqlalchemy_filter/filters.py index 1e03082..3d72c35 100644 --- a/graphene_sqlalchemy_filter/filters.py +++ b/graphene_sqlalchemy_filter/filters.py @@ -28,7 +28,6 @@ from sqlalchemy.orm.query import Query from sqlalchemy.orm.relationships import RelationshipProperty from sqlalchemy.sql import sqltypes -import sqlalchemy.ext.declarative.api MYPY = False @@ -504,7 +503,7 @@ def _aliases_from_query(cls, query: Query) -> 'Dict[str, _MapperEntity]': else: aliases = { (join_entity._target, join_entity.name): join_entity.entity - for join_entity in query._compile_state()._join_entities + for join_entity in query._compile_state()._join_entities } return aliases @@ -909,7 +908,7 @@ def _translate_many_filter( parent_result: dict, model: 'sqlalchemy.ext.declarative.api.DeclarativeMeta', parent_attr: ( - 'Union[' 'sqlalchemy.orm.attributes.InstrumentedAttribute, None]' + 'Union[sqlalchemy.orm.attributes.InstrumentedAttribute, None]' ), expression_type: str = None, ) -> 'Query': diff --git a/setup.py b/setup.py index 9daed6d..930e7de 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ requirements = [ 'graphene-sqlalchemy>=2.1.0,<3', - 'SQLAlchemy==1.3.23', + 'SQLAlchemy<2', ] diff --git a/tests/test_postgres.py b/tests/test_postgres.py index 8e0c936..96b1b55 100644 --- a/tests/test_postgres.py +++ b/tests/test_postgres.py @@ -34,7 +34,7 @@ def test_eq(info): ok = ( 'SELECT post.id, post.tags \n' 'FROM post \n' - 'WHERE post.tags = CAST(%(param_1)s AS VARCHAR(10)[])' + 'WHERE post.tags = CAST(%(param_1)s::VARCHAR(10)[] AS VARCHAR(10)[])' ) assert sql == ok @@ -48,7 +48,7 @@ def test_contained_by(info): ok = ( 'SELECT post.id, post.tags \n' 'FROM post \n' - 'WHERE post.tags <@ CAST(%(param_1)s AS VARCHAR(10)[])' + 'WHERE post.tags <@ CAST(%(param_1)s::VARCHAR(10)[] AS VARCHAR(10)[])' ) assert sql == ok @@ -62,7 +62,7 @@ def test_contains(info): ok = ( 'SELECT post.id, post.tags \n' 'FROM post \n' - 'WHERE post.tags @> CAST(%(param_1)s AS VARCHAR(10)[])' + 'WHERE post.tags @> CAST(%(param_1)s::VARCHAR(10)[] AS VARCHAR(10)[])' ) assert sql == ok @@ -76,6 +76,6 @@ def test_overlap(info): ok = ( 'SELECT post.id, post.tags \n' 'FROM post \n' - 'WHERE post.tags && CAST(%(param_1)s AS VARCHAR(10)[])' + 'WHERE post.tags && CAST(%(param_1)s::VARCHAR(10)[] AS VARCHAR(10)[])' ) assert sql == ok diff --git a/tests/test_query.py b/tests/test_query.py index e21c633..781e3e2 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -4,6 +4,9 @@ # GraphQL from graphene_sqlalchemy.utils import EnumValue +# Database +from sqlalchemy.sql import text + # Project from graphene_sqlalchemy_filter import FilterSet from tests import gqls_version, models @@ -14,13 +17,16 @@ def test_sort(info): filters = None sort = 'username desc' query = Query.field.get_query( - models.User, info, sort=EnumValue('username', sort), filters=filters + models.User, + info, + sort=EnumValue('username', text(sort)), + filters=filters, ) where_clause = query.whereclause assert where_clause is None - assert str(query._order_by[0]) == sort + assert str(query._order_by_clauses[0]) == sort def test_empty_filters_query(info_and_user_query): @@ -186,7 +192,7 @@ def test_complex_filters(info_and_user_query): ' OR "user".is_active != true' ' AND is_moderator.id IS NOT NULL' ' OR of_group.name = :name_1' - ' AND "user".username NOT IN (:username_3))' + ' AND ("user".username NOT IN ([POSTCOMPILE_username_3])))' ) where_clause = str(query.whereclause) assert where_clause == ok