Skip to content

Commit

Permalink
refactor: add automatic camelCase mashumaro aliases for `snake_case…
Browse files Browse the repository at this point in the history
…` attrs (#26)

This commit adds automatic mashumaro camel case aliases for all fields
in our models. This will make them more scalable and less error-prone.
  • Loading branch information
serramatutu authored Jun 28, 2024
1 parent 34782ef commit 2062850
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 17 deletions.
3 changes: 3 additions & 0 deletions .changes/unreleased/Under the Hood-20240628-164812.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Under the Hood
body: Changed how field aliases to make it easier to define new models
time: 2024-06-28T16:48:12.280623+02:00
3 changes: 3 additions & 0 deletions dbtsl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
generated code from our GraphQL schema.
"""

from .base import BaseModel
from .dimension import Dimension, DimensionType
from .measure import AggregationType, Measure
from .metric import Metric, MetricType

BaseModel._apply_aliases()

__all__ = [
"Dimension",
"DimensionType",
Expand Down
32 changes: 32 additions & 0 deletions dbtsl/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dataclasses import fields, is_dataclass
from types import MappingProxyType

from mashumaro import DataClassDictMixin, field_options
from mashumaro.config import BaseConfig


def snake_case_to_camel_case(s: str) -> str:
"""Convert a snake_case_string into a camelCaseString."""
tokens = s.split("_")
return tokens[0] + "".join(t.title() for t in tokens[1:])


class BaseModel(DataClassDictMixin):
"""Base class for all serializable models.
Adds some functionality like automatically creating camelCase aliases.
"""

class Config(BaseConfig): # noqa: D106
lazy_compilation = True

@classmethod
def _apply_aliases(cls) -> None:
"""Apply camelCase aliases to all subclasses."""
for subclass in cls.__subclasses__():
assert is_dataclass(subclass), "Subclass of BaseModel must be dataclass"

for field in fields(subclass):
camel_name = snake_case_to_camel_case(field.name)
if field.name != camel_name:
field.metadata = MappingProxyType(field_options(alias=camel_name))
4 changes: 2 additions & 2 deletions dbtsl/models/dimension.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from enum import Enum

from mashumaro import DataClassDictMixin
from dbtsl.models.base import BaseModel


class DimensionType(str, Enum):
Expand All @@ -12,7 +12,7 @@ class DimensionType(str, Enum):


@dataclass(frozen=True)
class Dimension(DataClassDictMixin):
class Dimension(BaseModel):
"""A metric dimension."""

name: str
Expand Down
8 changes: 4 additions & 4 deletions dbtsl/models/measure.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum
from typing import Optional

from mashumaro import DataClassDictMixin, field_options
from dbtsl.models.base import BaseModel


class AggregationType(str, Enum):
Expand All @@ -20,10 +20,10 @@ class AggregationType(str, Enum):


@dataclass(frozen=True)
class Measure(DataClassDictMixin):
class Measure(BaseModel):
"""A measure."""

name: str
agg_time_dimension: Optional[str] = field(metadata=field_options(alias="aggTimeDimension"))
agg_time_dimension: Optional[str]
agg: AggregationType
expr: str
7 changes: 2 additions & 5 deletions dbtsl/models/metric.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from dataclasses import dataclass
from enum import Enum

from mashumaro import DataClassDictMixin

# TODO @serramatutu: replace this file with codegen from GraphQL API
# See: https://strawberry.rocks/docs/codegen/query-codegen
from dbtsl.models.base import BaseModel


class MetricType(str, Enum):
Expand All @@ -28,7 +25,7 @@ def missing(cls, _: str) -> "MetricType":


@dataclass(frozen=True)
class Metric(DataClassDictMixin):
class Metric(BaseModel):
"""A metric."""

name: str
Expand Down
13 changes: 7 additions & 6 deletions dbtsl/models/query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import base64
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from typing import NewType, Optional

import pyarrow as pa
from mashumaro import DataClassDictMixin, field_options

from dbtsl.models.base import BaseModel

QueryId = NewType("QueryId", str)

Expand All @@ -21,15 +22,15 @@ class QueryStatus(str, Enum):


@dataclass(frozen=True)
class QueryResult(DataClassDictMixin):
class QueryResult(BaseModel):
"""A query result containing its status, SQL and error/results."""

query_id: QueryId = field(metadata=field_options(alias="queryId"))
query_id: QueryId
status: QueryStatus
sql: Optional[str]
error: Optional[str]
total_pages: Optional[int] = field(metadata=field_options(alias="totalPages"))
arrow_result: Optional[str] = field(metadata=field_options(alias="arrowResult"))
total_pages: Optional[int]
arrow_result: Optional[str]

@cached_property
def result_table(self) -> pa.Table:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dataclasses import dataclass

from mashumaro.codecs.basic import decode

from dbtsl.models.base import BaseModel
from dbtsl.models.base import snake_case_to_camel_case as stc


def test_snake_case_to_camel_case() -> None:
assert stc("hello") == "hello"
assert stc("hello_world") == "helloWorld"
assert stc("Hello_world") == "HelloWorld"
assert stc("hello world") == "hello world"
assert stc("helloWorld") == "helloWorld"


def test_base_model_auto_alias() -> None:
@dataclass
class SubModel(BaseModel):
hello_world: str

BaseModel._apply_aliases()

data = {
"helloWorld": "asdf",
}

model = SubModel.from_dict(data)
assert model.hello_world == "asdf"

codec_model = decode(data, SubModel)
assert codec_model.hello_world == "asdf"

0 comments on commit 2062850

Please sign in to comment.