Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add graphql query authentication #587

Merged
merged 3 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 42 additions & 41 deletions bd_api/apps/account/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,11 @@
from django.db.models.query import QuerySet
from django.utils import timezone

from bd_api.custom.graphql_jwt import ownership_required
from bd_api.custom.graphql_jwt import owner_required
from bd_api.custom.model import BaseModel
from bd_api.custom.storage import OverwriteStorage, upload_to, validate_image


def split_password(password: str) -> Tuple[str, str, str, str]:
"""Split a password into four parts: algorithm, iterations, salt, and hash"""
algorithm, iterations, salt, hash = password.split("$", 3)
return algorithm, iterations, salt, hash


def is_valid_encoded_password(password: str) -> bool:
"""Check if a password is valid"""
double_encoded = make_password(password)
try:
target_algorithm, target_iterations, _, _ = split_password(double_encoded)
algorithm, iterations, _, _ = split_password(password)
except ValueError:
return False
return algorithm == target_algorithm and iterations == target_iterations


class RegistrationToken(BaseModel):
token = models.CharField(max_length=255, unique=True, default=uuid4)
created_at = models.DateTimeField(auto_now_add=True)
Expand Down Expand Up @@ -296,7 +279,8 @@ class Account(BaseModel, AbstractBaseUser, PermissionsMixin):
]
graphql_filter_fields_blacklist = ["internal_subscription"]
graphql_nested_filter_fields_whitelist = ["email", "username"]
graphql_mutation_decorator = ownership_required
graphql_query_decorator = owner_required(allow_anonymous=False)
graphql_mutation_decorator = owner_required(allow_anonymous=True)

USERNAME_FIELD = "email"
REQUIRED_FIELDS = ["username", "first_name", "last_name"]
Expand Down Expand Up @@ -327,28 +311,6 @@ def pro_member_subscription(self):
sub = [s for s in sub if s.is_pro]
return sub[0] if sub else None

def __str__(self):
return self.email

def get_short_name(self):
return self.first_name

get_short_name.short_description = "nome"

def get_full_name(self):
if self.first_name and self.last_name:
return f"{self.first_name} {self.last_name}"
if self.first_name:
return self.first_name
return self.username

get_full_name.short_description = "nome completo"

def get_organization(self):
return ", ".join(self.organizations.all().values_list("name", flat=True))

get_organization.short_description = "organização"

@property
def pro_subscription(self) -> str:
"""BD Pro subscription role, one of bd_pro or bd_pro_empresas"""
Expand Down Expand Up @@ -381,6 +343,28 @@ def pro_subscription_status(self) -> str:
if self.pro_member_subscription:
return self.pro_member_subscription.stripe_subscription_status

def __str__(self):
return self.email

def get_short_name(self):
return self.first_name

get_short_name.short_description = "nome"

def get_full_name(self):
if self.first_name and self.last_name:
return f"{self.first_name} {self.last_name}"
if self.first_name:
return self.first_name
return self.username

get_full_name.short_description = "nome completo"

def get_organization(self):
return ", ".join(self.organizations.all().values_list("name", flat=True))

get_organization.short_description = "organização"

def save(self, *args, **kwargs) -> None:
# If self._password is set and check_password(self._password, self.password) is True, then
# just save the model without changing the password.
Expand Down Expand Up @@ -485,3 +469,20 @@ def stripe_subscription_created_at(self):
@property
def is_pro(self):
return "bd_pro" in self.subscription.plan.product.metadata.get("code", "")


def split_password(password: str) -> Tuple[str, str, str, str]:
"""Split a password into four parts: algorithm, iterations, salt, and hash"""
algorithm, iterations, salt, hash = password.split("$", 3)
return algorithm, iterations, salt, hash


def is_valid_encoded_password(password: str) -> bool:
"""Check if a password is valid"""
double_encoded = make_password(password)
try:
target_algorithm, target_iterations, _, _ = split_password(double_encoded)
algorithm, iterations, _, _ = split_password(password)
except ValueError:
return False
return algorithm == target_algorithm and iterations == target_iterations
110 changes: 66 additions & 44 deletions bd_api/custom/graphql_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from functools import partial
from typing import Iterable, Optional, get_type_hints

import graphql_jwt
from django.apps import apps
from django.core.exceptions import ValidationError
from django.db import models
Expand Down Expand Up @@ -40,26 +39,25 @@
)
from graphene_django.registry import get_global_registry
from graphene_file_upload.scalars import Upload
from graphql_jwt import ObtainJSONWebToken, Refresh, Verify

from bd_api.custom.graphql_base import CountableConnection, FileFieldScalar, PlainTextNode
from bd_api.custom.graphql_jwt import ObtainJSONWebTokenWithUser
from bd_api.custom.model import BaseModel


def build_schema(applications: list[str], extra_queries=[], extra_mutations=[]):
queries = [build_query_schema(app) for app in applications] + extra_queries
mutations = [build_mutation_schema(app) for app in applications] + extra_mutations

class Query(*queries):
pass

class Mutation(*mutations):
pass
Query = type("Query", tuple(queries), {})
Mutation = type("Mutation", tuple(mutations), {})

schema = Schema(query=Query, mutation=Mutation)
return schema


### Query
### Query #####################################################################


def build_query_schema(application_name: str):
Expand All @@ -68,38 +66,6 @@ def build_query_schema(application_name: str):


def build_query_objs(application_name: str):
def get_type(model, attr):
"""Get type of an attribute of a class"""
try:
func = getattr(model, attr)
func = getattr(func, "fget")
hint = get_type_hints(func)
name = hint.get("return")
return str(name)
except Exception:
return ""

def match_type(model, attr):
"""Match python types to graphene types"""
if get_type(model, attr).startswith("int"):
return Int
if get_type(model, attr).startswith("str"):
return String
if get_type(model, attr).startswith("list[int]"):
return partial(List, of_type=Int)
if get_type(model, attr).startswith("list[str]"):
return partial(List, of_type=String)
return GenericScalar

def build_custom_attrs(model, attrs):
for attr in dir(model):
attr_func = getattr(model, attr)
if isinstance(attr_func, property):
if attr not in model.graphql_fields_blacklist:
attr_type = match_type(model, attr)
attrs.update({attr: attr_type(source=attr, description=attr_func.__doc__)})
return attrs

queries = {}
models = apps.get_app_config(application_name).get_models()
models = [m for m in models if getattr(m, "graphql_visible", False)]
Expand All @@ -113,7 +79,7 @@ def build_custom_attrs(model, attrs):
resolve__id=lambda self, _: self.id,
)
attr = build_custom_attrs(model, attr)
node = type(f"{model_name}Node", (DjangoObjectType,), attr)
node = create_node_factory(model, attr)
queries.update({f"{model_name}Node": PlainTextNode.Field(node)})
queries.update({f"all_{model_name}": DjangoFilterConnectionField(node)})
return queries
Expand Down Expand Up @@ -218,16 +184,72 @@ def get_filter_fields(
return filter_fields


### Mutation
def build_custom_attrs(model, attrs):
def get_type(model, attr):
"""Get type of an attribute of a class"""
try:
func = getattr(model, attr)
func = getattr(func, "fget")
hint = get_type_hints(func)
name = hint.get("return")
return str(name)
except Exception:
return ""

def match_type(model, attr):
"""Match python types to graphene types"""
if get_type(model, attr).startswith("int"):
return Int
if get_type(model, attr).startswith("str"):
return String
if get_type(model, attr).startswith("list[int]"):
return partial(List, of_type=Int)
if get_type(model, attr).startswith("list[str]"):
return partial(List, of_type=String)
return GenericScalar

for attr in dir(model):
attr_func = getattr(model, attr)
if isinstance(attr_func, property):
if attr not in model.graphql_fields_blacklist:
attr_type = match_type(model, attr)
attrs.update({attr: attr_type(source=attr, description=attr_func.__doc__)})
return attrs


def create_node_factory(model: BaseModel, attr: dict):
"""Create graphql relay node"""

def get_queryset():
"""Create query endpoints with authorization"""

@model.graphql_query_decorator
def _get_queryset(cls, queryset, info):
return super(cls, cls).get_queryset(queryset, info)

return classmethod(_get_queryset)

return type(
f"{model.__name__}Node",
(DjangoObjectType,),
{
**attr,
"get_queryset": get_queryset(),
},
)


### Mutation ##################################################################


def build_mutation_schema(application_name: str):
base_mutations = build_mutation_objs(application_name)
base_mutations.update(
{
"token_auth": graphql_jwt.ObtainJSONWebToken.Field(),
"verify_token": graphql_jwt.Verify.Field(),
"refresh_token": graphql_jwt.Refresh.Field(),
"token_auth": ObtainJSONWebToken.Field(),
"auth_token": ObtainJSONWebTokenWithUser.Field(),
"verify_token": Verify.Field(),
"refresh_token": Refresh.Field(),
}
)
mutation = type("Mutation", (ObjectType,), base_mutations)
Expand Down
56 changes: 37 additions & 19 deletions bd_api/custom/graphql_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,27 @@
from functools import wraps
from re import findall

from graphene import Field, ObjectType, String
from graphql_jwt import exceptions
from graphql_jwt.compat import get_operation_name
from graphql_jwt.decorators import context
from graphql_jwt.relay import JSONWebTokenMutation
from graphql_jwt.settings import jwt_settings


class User(ObjectType):
id = String()
email = String()


class ObtainJSONWebTokenWithUser(JSONWebTokenMutation):
user = Field(User)

@classmethod
def resolve(cls, root, info, **kwargs):
return cls(user=info.context.user)


def allow_any(info, **kwargs):
"""Custom function to determine the non-authentication per-field

Expand Down Expand Up @@ -41,7 +56,7 @@ def wrapper(*args, **kwargs):
return wrapper


def ownership_required(f, exc=exceptions.PermissionDenied):
def owner_required(allow_anonymous=False, exc=exceptions.PermissionDenied):
"""Decorator to limit graphql queries and mutations

- Super users are allowed to edit all resources
Expand All @@ -58,23 +73,26 @@ def get_uid(context, exp=r"id:\s[\"]?(\d+)[\"]?"):
query = context.body.decode("utf-8").replace('\\"', "").lower()
except Exception:
query = str(context._post).replace('\\"', "").lower()

return [int(uid) for uid in findall(exp, query)]

@wraps(f)
@context(f)
def wrapper(context, *args, **kwargs):
if context.user.is_staff:
return f(*args, **kwargs)
if context.user.is_superuser:
return f(*args, **kwargs)
uid = get_uid(context)
if context.user.is_anonymous:
if not uid:
uid = [int(uid) for uid in findall(exp, query)]
return uid[0] if uid else None

def decorator(f):
@wraps(f)
@context(f)
def wrapper(context, *args, **kwargs):
if context.user.is_staff:
return f(*args, **kwargs)
if context.user.is_authenticated:
if context.user.id == uid[0]:
if context.user.is_superuser:
return f(*args, **kwargs)
raise exc

return wrapper
uid = get_uid(context)
if context.user.is_authenticated:
if context.user.id == uid:
return f(*args, **kwargs)
if context.user.is_anonymous:
if allow_anonymous and not uid:
return f(*args, **kwargs)
raise exc

return wrapper

return decorator
Loading