From 8750bed58de29bc6d512d92b88787a7c982e0777 Mon Sep 17 00:00:00 2001 From: Vinicius Date: Sat, 6 Apr 2024 16:12:18 -0300 Subject: [PATCH 1/3] feat: add anyone required decorator --- bd_api/custom/graphql_jwt.py | 10 ++++++++++ bd_api/custom/model.py | 5 +++++ 2 files changed, 15 insertions(+) diff --git a/bd_api/custom/graphql_jwt.py b/bd_api/custom/graphql_jwt.py index 25318885..fbe21d01 100644 --- a/bd_api/custom/graphql_jwt.py +++ b/bd_api/custom/graphql_jwt.py @@ -31,6 +31,16 @@ def allow_any(info, **kwargs): return False +def anyone_required(f): + """Decorator to open graphql queries and mutations""" + + @wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + + return wrapper + + def ownership_required(f, exc=exceptions.PermissionDenied): """Custom decorator to limit graphql account mutations diff --git a/bd_api/custom/model.py b/bd_api/custom/model.py index a66f3e09..1e98d29e 100644 --- a/bd_api/custom/model.py +++ b/bd_api/custom/model.py @@ -6,6 +6,8 @@ from django.urls import reverse from graphql_jwt.decorators import staff_member_required +from bd_api.custom.graphql_jwt import anyone_required + default_blacklist_fields = [ "created_at", "updated_at", @@ -15,6 +17,7 @@ "order", "_field_status", ] +default_query_decorator = anyone_required default_mutation_decorator = staff_member_required @@ -25,6 +28,7 @@ class BaseModel(models.Model): Attributes: - graphql_visible: show or hide the model in the documentation - graphql_fields_black_list: list of fields to hide in mutations + - graphql_query_decorator: authentication decorator for queries - graphql_mutation_decorator: authentication decorator for mutations """ @@ -40,6 +44,7 @@ class Meta: graphql_nested_filter_fields_whitelist: List[str] = [] graphql_nested_filter_fields_blacklist: List[str] = [] + graphql_query_decorator: Callable = default_query_decorator graphql_mutation_decorator: Callable = default_mutation_decorator @classmethod From 212fc9fef1578765f73b3e2d1fc2ae5d4c4b6778 Mon Sep 17 00:00:00 2001 From: Vinicius Date: Sat, 6 Apr 2024 16:16:20 -0300 Subject: [PATCH 2/3] feat: update ownership required decorator --- bd_api/custom/graphql_jwt.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/bd_api/custom/graphql_jwt.py b/bd_api/custom/graphql_jwt.py index fbe21d01..bf8529d6 100644 --- a/bd_api/custom/graphql_jwt.py +++ b/bd_api/custom/graphql_jwt.py @@ -42,11 +42,12 @@ def wrapper(*args, **kwargs): def ownership_required(f, exc=exceptions.PermissionDenied): - """Custom decorator to limit graphql account mutations + """Decorator to limit graphql queries and mutations - - Superusers are allowed to edit accounts - - Anonymous users are allowed to create accounts - - Authenticated users are allowed to edit their own account + - Super users are allowed to edit all resources + - Staff users are allowed to edit all resources + - Anonymous users are allowed to create resources + - Authenticated users are allowed to edit their own resources References: - https://django-graphql-jwt.domake.io/decorators.html @@ -63,6 +64,8 @@ def get_uid(context, exp=r"id:\s[\"]?(\d+)[\"]?"): @wraps(f) @context(f) def wrapper(context, *args, **kwargs): + if context.user.is_staff: + return f(*args, **kwargs) if context.user.is_superuser: return f(*args, **kwargs) uid = get_uid(context) @@ -72,7 +75,6 @@ def wrapper(context, *args, **kwargs): if context.user.is_authenticated: if context.user.id == uid[0]: return f(*args, **kwargs) - raise exc return wrapper From 823cba252e6c836a106e5ab3ddbf46618801e24f Mon Sep 17 00:00:00 2001 From: Vinicius Date: Sat, 6 Apr 2024 17:23:25 -0300 Subject: [PATCH 3/3] chore: organize graphql auto generator --- bd_api/custom/graphql_auto.py | 614 +++++++++++++++++----------------- 1 file changed, 310 insertions(+), 304 deletions(-) diff --git a/bd_api/custom/graphql_auto.py b/bd_api/custom/graphql_auto.py index 43bec8ac..f9f095fb 100644 --- a/bd_api/custom/graphql_auto.py +++ b/bd_api/custom/graphql_auto.py @@ -45,167 +45,205 @@ from bd_api.custom.model import BaseModel -@convert_django_field.register(models.FileField) -def convert_file_to_url(field, registry=None): - return FileFieldScalar(description=field.help_text, required=not field.null) +def build_schema(applications: list[str], extra_queries=[], extra_mutations=[]): + queries = [build_query_schema(app) for app in applications] + extra_queries + mutations = [build_mutation_schema(app) for app in applications] + extra_mutations + class Query(*queries): + pass -def fields_for_form(form, only_fields, exclude_fields): - fields = OrderedDict() - for name, field in form.fields.items(): - is_excluded = name in exclude_fields - is_not_in_only = only_fields and name not in only_fields - if is_excluded or is_not_in_only: - continue - if isinstance(field, forms_fields.FileField): - fields[name] = Upload(description=field.help_text) - else: - fields[name] = convert_form_field(field) - return fields + class Mutation(*mutations): + pass + + schema = Schema(query=Query, mutation=Mutation) + return schema -def generate_form_fields(model: BaseModel): - whitelist_field_types = ( - models.DateField, - models.DateTimeField, - models.SlugField, +### Query + + +def build_query_schema(application_name: str): + query = type("Query", (ObjectType,), build_query_objs(application_name)) + return query + + +def build_query_objs(application_name: str): + def get_type(model, attr): + """Get type of an attribute of a class""" + try: + func = getattr(model, attr) + func = getattr(func, "fget") + hint = get_type_hints(func) + name = hint.get("return") + return str(name) + except Exception: + return "" + + def match_type(model, attr): + """Match python types to graphene types""" + if get_type(model, attr).startswith("int"): + return Int + if get_type(model, attr).startswith("str"): + return String + if get_type(model, attr).startswith("list[int]"): + return partial(List, of_type=Int) + if get_type(model, attr).startswith("list[str]"): + return partial(List, of_type=String) + return GenericScalar + + def build_custom_attrs(model, attrs): + for attr in dir(model): + attr_func = getattr(model, attr) + if isinstance(attr_func, property): + if attr not in model.graphql_fields_blacklist: + attr_type = match_type(model, attr) + attrs.update({attr: attr_type(source=attr, description=attr_func.__doc__)}) + return attrs + + queries = {} + models = apps.get_app_config(application_name).get_models() + models = [m for m in models if getattr(m, "graphql_visible", False)] + + for model in models: + model_name = model.__name__ + meta = create_model_object_meta(model) + attr = dict( + Meta=meta, + _id=UUID(name="_id"), + resolve__id=lambda self, _: self.id, + ) + attr = build_custom_attrs(model, attr) + node = type(f"{model_name}Node", (DjangoObjectType,), attr) + queries.update({f"{model_name}Node": PlainTextNode.Field(node)}) + queries.update({f"all_{model_name}": DjangoFilterConnectionField(node)}) + return queries + + +def create_model_object_meta(model: BaseModel): + return type( + "Meta", + (object,), + dict( + model=(model), + interfaces=((PlainTextNode,)), + connection_class=CountableConnection, + filter_fields=(generate_filter_fields(model)), + ), + ) + + +def generate_filter_fields(model: BaseModel): + exempted_field_names = ("_field_status",) + exempted_field_types = ( + models.ImageField, + models.JSONField, + ) + string_field_types = ( models.CharField, - models.URLField, - models.ForeignKey, models.TextField, - models.BooleanField, + ) + comparable_field_types = ( models.BigIntegerField, models.IntegerField, + models.FloatField, + models.DecimalField, + models.DateTimeField, + models.DateField, + models.TimeField, + models.DurationField, + ) + foreign_key_field_types = ( + models.ForeignKey, models.OneToOneField, + models.OneToOneRel, models.ManyToManyField, - models.ImageField, - models.UUIDField, + models.ManyToManyRel, + models.ManyToOneRel, ) - fields = [] - for field in model._meta.get_fields(): - if isinstance(field, whitelist_field_types): - if field.name not in model.graphql_fields_blacklist: - fields.append(field.name) - return fields + def get_filter_fields( + model: models.Model, processed_models: Optional[Iterable[models.Model]] = None + ): + if processed_models is None: + processed_models = [] + if len(processed_models) > 10: + return {}, processed_models -def generate_form(model: BaseModel): - return modelform_factory(model, form=CustomModelForm, fields=generate_form_fields(model)) + if not issubclass(model, BaseModel): + model_fields = model._meta.get_fields() + if issubclass(model, BaseModel) and not processed_models: + model_fields = model.get_graphql_filter_fields_whitelist() + if issubclass(model, BaseModel) and len(processed_models): + model_fields = model.get_graphql_nested_filter_fields_whitelist() + processed_models.append(model) + filter_fields = {"id": ["exact", "isnull", "in"]} -class CustomModelForm(ModelForm): - def __init__(self, *args, **kwargs) -> None: - data = args[0] if args else kwargs.get("data") - # Store raw data, so we can verify whether the user has filled None or hasn't filled - # anything - self.__raw_data = deepcopy(data) - super().__init__(*args, **kwargs) - # Store which fields are required, so we can validate them later - self.__required_fields = set() - for field_name, field in self.fields.items(): - if field.required: - self.__required_fields.add(field_name) - field.required = False + foreign_models = [] + for field in model_fields: + if isinstance(field, foreign_key_field_types): + foreign_models.append(field.related_model) - def _clean_fields(self): - id_provided: bool = self.__raw_data.get("id") is not None - for name, bf in self._bound_items(): - field = bf.field - value = bf.initial - if name in self.__raw_data: - value = self.__raw_data[name] - if value is None and name in self.__required_fields and not id_provided: - self.add_error( - name, - ValidationError(field.error_messages["required"], code="required"), - ) + for field in model_fields: + if ( + False + or "djstripe" in field.name + or field.name in exempted_field_names + or model.__module__.startswith("django") + or isinstance(field, exempted_field_types) + ): continue - try: - if isinstance(field, forms_fields.FileField): - value = field.clean(value, bf.initial) - else: - value = field.clean(value) - self.cleaned_data[name] = value - if hasattr(self, "clean_%s" % name): - value = getattr(self, "clean_%s" % name)() - self.cleaned_data[name] = value - except ValidationError as e: - self.add_error(name, e) + filter_fields[field.name] = ["exact", "isnull", "in"] + if isinstance(field, string_field_types): + filter_fields[field.name] += ["icontains", "istartswith", "iendswith"] + if isinstance(field, comparable_field_types): + filter_fields[field.name] += ["lt", "lte", "gt", "gte", "range"] + if isinstance(field, foreign_key_field_types): + filter_fields[field.name] = [] + if field.related_model in processed_models: + continue + related_model_filter_fields, _ = get_filter_fields( + field.related_model, + processed_models, + ) + for ( + related_model_field_name, + related_model_field_filter, + ) in related_model_filter_fields.items(): + name = f"{field.name}__{related_model_field_name}" + filter_fields[name] = related_model_field_filter + return filter_fields, processed_models + filter_fields, _ = get_filter_fields(model) + return filter_fields -class CreateUpdateMutation(DjangoModelFormMutation): - class Meta: - abstract = True - @classmethod - def __init_subclass_with_meta__( - cls, - form_class=None, - model=None, - return_field_name=None, - only_fields=(), - exclude_fields=(), - **options, - ): - if not form_class: - raise Exception("form_class is required for DjangoModelFormMutation") +### Mutation - if not model: - model = form_class._meta.model - if not model: - raise Exception("model is required for DjangoModelFormMutation") +def build_mutation_schema(application_name: str): + base_mutations = build_mutation_objs(application_name) + base_mutations.update( + { + "token_auth": graphql_jwt.ObtainJSONWebToken.Field(), + "verify_token": graphql_jwt.Verify.Field(), + "refresh_token": graphql_jwt.Refresh.Field(), + } + ) + mutation = type("Mutation", (ObjectType,), base_mutations) + return mutation - form = form_class() - input_fields = fields_for_form(form, only_fields, exclude_fields) - if "id" not in exclude_fields: - input_fields["id"] = ID() - registry = get_global_registry() - model_type = registry.get_type_for_model(model) - if not model_type: - raise Exception("No type registered for model: {}".format(model.__name__)) - - if not return_field_name: - model_name = model.__name__ - return_field_name = model_name[:1].lower() + model_name[1:] - - output_fields = OrderedDict() - output_fields[return_field_name] = Field(model_type) - - _meta = DjangoModelDjangoFormMutationOptions(cls) - _meta.form_class = form_class - _meta.model = model - _meta.return_field_name = return_field_name - _meta.fields = yank_fields_from_attrs(output_fields, _as=Field) - - input_fields = yank_fields_from_attrs(input_fields, _as=InputField) - super(DjangoModelFormMutation, cls).__init_subclass_with_meta__( - _meta=_meta, input_fields=input_fields, **options - ) - - @classmethod - def get_form_kwargs(cls, root, info, **input): - # Get file data - file_fields = [ - field - for field in cls._meta.form_class.base_fields - if isinstance(cls._meta.form_class.base_fields[field], forms_fields.FileField) - ] - file_data = {} - if file_fields: - for field in file_fields: - if field in input: - file_data[field] = input[field] - - kwargs = {"data": input, "files": file_data} - - pk = input.pop("id", None) - if pk: - instance = cls._meta.model._default_manager.get(pk=pk) - kwargs["instance"] = instance +def build_mutation_objs(application_name: str): + mutations = {} + models = apps.get_app_config(application_name).get_models() + models = [m for m in models if getattr(m, "graphql_visible", False)] - return kwargs + for model in models: + model_name = model.__name__ + mutations.update({f"Delete{model_name}": delete_mutation_factory(model).Field()}) + mutations.update({f"CreateUpdate{model_name}": create_mutation_factory(model).Field()}) + return mutations def create_mutation_factory(model: BaseModel): @@ -274,196 +312,164 @@ def _mutate(cls, root, info, id): ) -def generate_filter_fields(model: BaseModel): - exempted_field_names = ("_field_status",) - exempted_field_types = ( - models.ImageField, - models.JSONField, - ) - string_field_types = ( - models.CharField, - models.TextField, - ) - comparable_field_types = ( - models.BigIntegerField, - models.IntegerField, - models.FloatField, - models.DecimalField, - models.DateTimeField, - models.DateField, - models.TimeField, - models.DurationField, - ) - foreign_key_field_types = ( - models.ForeignKey, - models.OneToOneField, - models.OneToOneRel, - models.ManyToManyField, - models.ManyToManyRel, - models.ManyToOneRel, - ) - - def get_filter_fields( - model: models.Model, processed_models: Optional[Iterable[models.Model]] = None - ): - if processed_models is None: - processed_models = [] - if len(processed_models) > 10: - return {}, processed_models - - if not issubclass(model, BaseModel): - model_fields = model._meta.get_fields() - if issubclass(model, BaseModel) and not processed_models: - model_fields = model.get_graphql_filter_fields_whitelist() - if issubclass(model, BaseModel) and len(processed_models): - model_fields = model.get_graphql_nested_filter_fields_whitelist() +def generate_form(model: BaseModel): + return modelform_factory(model, form=CustomModelForm, fields=generate_form_fields(model)) - processed_models.append(model) - filter_fields = {"id": ["exact", "isnull", "in"]} - foreign_models = [] - for field in model_fields: - if isinstance(field, foreign_key_field_types): - foreign_models.append(field.related_model) +class CustomModelForm(ModelForm): + def __init__(self, *args, **kwargs) -> None: + data = args[0] if args else kwargs.get("data") + # Store raw data, so we can verify whether the user has filled None or hasn't filled + # anything + self.__raw_data = deepcopy(data) + super().__init__(*args, **kwargs) + # Store which fields are required, so we can validate them later + self.__required_fields = set() + for field_name, field in self.fields.items(): + if field.required: + self.__required_fields.add(field_name) + field.required = False - for field in model_fields: - if ( - False - or "djstripe" in field.name - or field.name in exempted_field_names - or model.__module__.startswith("django") - or isinstance(field, exempted_field_types) - ): - continue - filter_fields[field.name] = ["exact", "isnull", "in"] - if isinstance(field, string_field_types): - filter_fields[field.name] += ["icontains", "istartswith", "iendswith"] - if isinstance(field, comparable_field_types): - filter_fields[field.name] += ["lt", "lte", "gt", "gte", "range"] - if isinstance(field, foreign_key_field_types): - filter_fields[field.name] = [] - if field.related_model in processed_models: - continue - related_model_filter_fields, _ = get_filter_fields( - field.related_model, - processed_models, + def _clean_fields(self): + id_provided: bool = self.__raw_data.get("id") is not None + for name, bf in self._bound_items(): + field = bf.field + value = bf.initial + if name in self.__raw_data: + value = self.__raw_data[name] + if value is None and name in self.__required_fields and not id_provided: + self.add_error( + name, + ValidationError(field.error_messages["required"], code="required"), ) - for ( - related_model_field_name, - related_model_field_filter, - ) in related_model_filter_fields.items(): - name = f"{field.name}__{related_model_field_name}" - filter_fields[name] = related_model_field_filter - return filter_fields, processed_models + continue + try: + if isinstance(field, forms_fields.FileField): + value = field.clean(value, bf.initial) + else: + value = field.clean(value) + self.cleaned_data[name] = value + if hasattr(self, "clean_%s" % name): + value = getattr(self, "clean_%s" % name)() + self.cleaned_data[name] = value + except ValidationError as e: + self.add_error(name, e) - filter_fields, _ = get_filter_fields(model) - return filter_fields +class CreateUpdateMutation(DjangoModelFormMutation): + class Meta: + abstract = True -def create_model_object_meta(model: BaseModel): - return type( - "Meta", - (object,), - dict( - model=(model), - interfaces=((PlainTextNode,)), - connection_class=CountableConnection, - filter_fields=(generate_filter_fields(model)), - ), - ) + @classmethod + def __init_subclass_with_meta__( + cls, + form_class=None, + model=None, + return_field_name=None, + only_fields=(), + exclude_fields=(), + **options, + ): + if not form_class: + raise Exception("form_class is required for DjangoModelFormMutation") + if not model: + model = form_class._meta.model -def build_query_objs(application_name: str): - def get_type(model, attr): - """Get type of an attribute of a class""" - try: - func = getattr(model, attr) - func = getattr(func, "fget") - hint = get_type_hints(func) - name = hint.get("return") - return str(name) - except Exception: - return "" + if not model: + raise Exception("model is required for DjangoModelFormMutation") - def match_type(model, attr): - """Match python types to graphene types""" - if get_type(model, attr).startswith("int"): - return Int - if get_type(model, attr).startswith("str"): - return String - if get_type(model, attr).startswith("list[int]"): - return partial(List, of_type=Int) - if get_type(model, attr).startswith("list[str]"): - return partial(List, of_type=String) - return GenericScalar + form = form_class() + input_fields = fields_for_form(form, only_fields, exclude_fields) + if "id" not in exclude_fields: + input_fields["id"] = ID() - def build_custom_attrs(model, attrs): - for attr in dir(model): - attr_func = getattr(model, attr) - if isinstance(attr_func, property): - if attr not in model.graphql_fields_blacklist: - attr_type = match_type(model, attr) - attrs.update({attr: attr_type(source=attr, description=attr_func.__doc__)}) - return attrs + registry = get_global_registry() + model_type = registry.get_type_for_model(model) + if not model_type: + raise Exception("No type registered for model: {}".format(model.__name__)) - queries = {} - models = apps.get_app_config(application_name).get_models() - models = [m for m in models if getattr(m, "graphql_visible", False)] + if not return_field_name: + model_name = model.__name__ + return_field_name = model_name[:1].lower() + model_name[1:] - for model in models: - model_name = model.__name__ - meta = create_model_object_meta(model) - attr = dict( - Meta=meta, - _id=UUID(name="_id"), - resolve__id=lambda self, _: self.id, - ) - attr = build_custom_attrs(model, attr) - node = type(f"{model_name}Node", (DjangoObjectType,), attr) - queries.update({f"{model_name}Node": PlainTextNode.Field(node)}) - queries.update({f"all_{model_name}": DjangoFilterConnectionField(node)}) - return queries + output_fields = OrderedDict() + output_fields[return_field_name] = Field(model_type) + _meta = DjangoModelDjangoFormMutationOptions(cls) + _meta.form_class = form_class + _meta.model = model + _meta.return_field_name = return_field_name + _meta.fields = yank_fields_from_attrs(output_fields, _as=Field) -def build_mutation_objs(application_name: str): - mutations = {} - models = apps.get_app_config(application_name).get_models() - models = [m for m in models if getattr(m, "graphql_visible", False)] + input_fields = yank_fields_from_attrs(input_fields, _as=InputField) + super(DjangoModelFormMutation, cls).__init_subclass_with_meta__( + _meta=_meta, input_fields=input_fields, **options + ) - for model in models: - model_name = model.__name__ - mutations.update({f"Delete{model_name}": delete_mutation_factory(model).Field()}) - mutations.update({f"CreateUpdate{model_name}": create_mutation_factory(model).Field()}) - return mutations + @classmethod + def get_form_kwargs(cls, root, info, **input): + # Get file data + file_fields = [ + field + for field in cls._meta.form_class.base_fields + if isinstance(cls._meta.form_class.base_fields[field], forms_fields.FileField) + ] + file_data = {} + if file_fields: + for field in file_fields: + if field in input: + file_data[field] = input[field] + kwargs = {"data": input, "files": file_data} -def build_query_schema(application_name: str): - query = type("Query", (ObjectType,), build_query_objs(application_name)) - return query + pk = input.pop("id", None) + if pk: + instance = cls._meta.model._default_manager.get(pk=pk) + kwargs["instance"] = instance + return kwargs -def build_mutation_schema(application_name: str): - base_mutations = build_mutation_objs(application_name) - base_mutations.update( - { - "token_auth": graphql_jwt.ObtainJSONWebToken.Field(), - "verify_token": graphql_jwt.Verify.Field(), - "refresh_token": graphql_jwt.Refresh.Field(), - } - ) - mutation = type("Mutation", (ObjectType,), base_mutations) - return mutation +def fields_for_form(form, only_fields, exclude_fields): + fields = OrderedDict() + for name, field in form.fields.items(): + is_excluded = name in exclude_fields + is_not_in_only = only_fields and name not in only_fields + if is_excluded or is_not_in_only: + continue + if isinstance(field, forms_fields.FileField): + fields[name] = Upload(description=field.help_text) + else: + fields[name] = convert_form_field(field) + return fields -def build_schema(applications: list[str], extra_queries=[], extra_mutations=[]): - queries = [build_query_schema(app) for app in applications] + extra_queries - mutations = [build_mutation_schema(app) for app in applications] + extra_mutations - class Query(*queries): - pass +def generate_form_fields(model: BaseModel): + whitelist_field_types = ( + models.DateField, + models.DateTimeField, + models.SlugField, + models.CharField, + models.URLField, + models.ForeignKey, + models.TextField, + models.BooleanField, + models.BigIntegerField, + models.IntegerField, + models.OneToOneField, + models.ManyToManyField, + models.ImageField, + models.UUIDField, + ) + fields = [] + for field in model._meta.get_fields(): + if isinstance(field, whitelist_field_types): + if field.name not in model.graphql_fields_blacklist: + fields.append(field.name) + return fields - class Mutation(*mutations): - pass - schema = Schema(query=Query, mutation=Mutation) - return schema +@convert_django_field.register(models.FileField) +def convert_file_to_url(field, registry=None): + return FileFieldScalar(description=field.help_text, required=not field.null)