From e2b7dd1ad7d635b07d2056f2568cb18e8f0eeb65 Mon Sep 17 00:00:00 2001 From: Vinicius Date: Sun, 7 Apr 2024 19:43:39 -0300 Subject: [PATCH] feat: add graphql query authentication --- bd_api/apps/account/models.py | 79 ++++++++++++++-------------- bd_api/custom/graphql_auto.py | 96 ++++++++++++++++++++--------------- 2 files changed, 96 insertions(+), 79 deletions(-) diff --git a/bd_api/apps/account/models.py b/bd_api/apps/account/models.py index a3ed71e3..e31022c9 100644 --- a/bd_api/apps/account/models.py +++ b/bd_api/apps/account/models.py @@ -19,23 +19,6 @@ from bd_api.custom.storage import OverwriteStorage, upload_to, validate_image -def split_password(password: str) -> Tuple[str, str, str, str]: - """Split a password into four parts: algorithm, iterations, salt, and hash""" - algorithm, iterations, salt, hash = password.split("$", 3) - return algorithm, iterations, salt, hash - - -def is_valid_encoded_password(password: str) -> bool: - """Check if a password is valid""" - double_encoded = make_password(password) - try: - target_algorithm, target_iterations, _, _ = split_password(double_encoded) - algorithm, iterations, _, _ = split_password(password) - except ValueError: - return False - return algorithm == target_algorithm and iterations == target_iterations - - class RegistrationToken(BaseModel): token = models.CharField(max_length=255, unique=True, default=uuid4) created_at = models.DateTimeField(auto_now_add=True) @@ -296,6 +279,7 @@ class Account(BaseModel, AbstractBaseUser, PermissionsMixin): ] graphql_filter_fields_blacklist = ["internal_subscription"] graphql_nested_filter_fields_whitelist = ["email", "username"] + graphql_query_decorator = ownership_required graphql_mutation_decorator = ownership_required USERNAME_FIELD = "email" @@ -327,28 +311,6 @@ def pro_member_subscription(self): sub = [s for s in sub if s.is_pro] return sub[0] if sub else None - def __str__(self): - return self.email - - def get_short_name(self): - return self.first_name - - get_short_name.short_description = "nome" - - def get_full_name(self): - if self.first_name and self.last_name: - return f"{self.first_name} {self.last_name}" - if self.first_name: - return self.first_name - return self.username - - get_full_name.short_description = "nome completo" - - def get_organization(self): - return ", ".join(self.organizations.all().values_list("name", flat=True)) - - get_organization.short_description = "organização" - @property def pro_subscription(self) -> str: """BD Pro subscription role, one of bd_pro or bd_pro_empresas""" @@ -381,6 +343,28 @@ def pro_subscription_status(self) -> str: if self.pro_member_subscription: return self.pro_member_subscription.stripe_subscription_status + def __str__(self): + return self.email + + def get_short_name(self): + return self.first_name + + get_short_name.short_description = "nome" + + def get_full_name(self): + if self.first_name and self.last_name: + return f"{self.first_name} {self.last_name}" + if self.first_name: + return self.first_name + return self.username + + get_full_name.short_description = "nome completo" + + def get_organization(self): + return ", ".join(self.organizations.all().values_list("name", flat=True)) + + get_organization.short_description = "organização" + def save(self, *args, **kwargs) -> None: # If self._password is set and check_password(self._password, self.password) is True, then # just save the model without changing the password. @@ -485,3 +469,20 @@ def stripe_subscription_created_at(self): @property def is_pro(self): return "bd_pro" in self.subscription.plan.product.metadata.get("code", "") + + +def split_password(password: str) -> Tuple[str, str, str, str]: + """Split a password into four parts: algorithm, iterations, salt, and hash""" + algorithm, iterations, salt, hash = password.split("$", 3) + return algorithm, iterations, salt, hash + + +def is_valid_encoded_password(password: str) -> bool: + """Check if a password is valid""" + double_encoded = make_password(password) + try: + target_algorithm, target_iterations, _, _ = split_password(double_encoded) + algorithm, iterations, _, _ = split_password(password) + except ValueError: + return False + return algorithm == target_algorithm and iterations == target_iterations diff --git a/bd_api/custom/graphql_auto.py b/bd_api/custom/graphql_auto.py index f9f095fb..e2bd00d6 100644 --- a/bd_api/custom/graphql_auto.py +++ b/bd_api/custom/graphql_auto.py @@ -49,17 +49,14 @@ def build_schema(applications: list[str], extra_queries=[], extra_mutations=[]): queries = [build_query_schema(app) for app in applications] + extra_queries mutations = [build_mutation_schema(app) for app in applications] + extra_mutations - class Query(*queries): - pass - - class Mutation(*mutations): - pass + Query = type("Query", queries) + Mutation = type("Mutation", mutations) schema = Schema(query=Query, mutation=Mutation) return schema -### Query +### Query ##################################################################### def build_query_schema(application_name: str): @@ -68,38 +65,6 @@ def build_query_schema(application_name: str): def build_query_objs(application_name: str): - def get_type(model, attr): - """Get type of an attribute of a class""" - try: - func = getattr(model, attr) - func = getattr(func, "fget") - hint = get_type_hints(func) - name = hint.get("return") - return str(name) - except Exception: - return "" - - def match_type(model, attr): - """Match python types to graphene types""" - if get_type(model, attr).startswith("int"): - return Int - if get_type(model, attr).startswith("str"): - return String - if get_type(model, attr).startswith("list[int]"): - return partial(List, of_type=Int) - if get_type(model, attr).startswith("list[str]"): - return partial(List, of_type=String) - return GenericScalar - - def build_custom_attrs(model, attrs): - for attr in dir(model): - attr_func = getattr(model, attr) - if isinstance(attr_func, property): - if attr not in model.graphql_fields_blacklist: - attr_type = match_type(model, attr) - attrs.update({attr: attr_type(source=attr, description=attr_func.__doc__)}) - return attrs - queries = {} models = apps.get_app_config(application_name).get_models() models = [m for m in models if getattr(m, "graphql_visible", False)] @@ -113,7 +78,7 @@ def build_custom_attrs(model, attrs): resolve__id=lambda self, _: self.id, ) attr = build_custom_attrs(model, attr) - node = type(f"{model_name}Node", (DjangoObjectType,), attr) + node = create_node_factory(model, attr) queries.update({f"{model_name}Node": PlainTextNode.Field(node)}) queries.update({f"all_{model_name}": DjangoFilterConnectionField(node)}) return queries @@ -218,7 +183,58 @@ def get_filter_fields( return filter_fields -### Mutation +def build_custom_attrs(model, attrs): + def get_type(model, attr): + """Get type of an attribute of a class""" + try: + func = getattr(model, attr) + func = getattr(func, "fget") + hint = get_type_hints(func) + name = hint.get("return") + return str(name) + except Exception: + return "" + + def match_type(model, attr): + """Match python types to graphene types""" + if get_type(model, attr).startswith("int"): + return Int + if get_type(model, attr).startswith("str"): + return String + if get_type(model, attr).startswith("list[int]"): + return partial(List, of_type=Int) + if get_type(model, attr).startswith("list[str]"): + return partial(List, of_type=String) + return GenericScalar + + for attr in dir(model): + attr_func = getattr(model, attr) + if isinstance(attr_func, property): + if attr not in model.graphql_fields_blacklist: + attr_type = match_type(model, attr) + attrs.update({attr: attr_type(source=attr, description=attr_func.__doc__)}) + return attrs + + +def create_node_factory(model: BaseModel, attr: dict): + def get_queryset(): + @model.graphql_query_decorator + def _get_queryset(cls, queryset, info): + return super(cls, cls).get_queryset(queryset, info) + + return classmethod(_get_queryset) + + return type( + f"{model.__name__}Node", + (DjangoObjectType,), + { + **attr, + "get_queryset": get_queryset(), + }, + ) + + +### Mutation ################################################################## def build_mutation_schema(application_name: str):