Skip to content

Commit

Permalink
feat: support for custom grain (#55)
Browse files Browse the repository at this point in the history
* feat: add functionality to deprecate attributes

This commit adds a new `BaseModel.DEPRECATED` key that can be added to
the dataclasses `metadata` dict to indicate that this field is
deprecated.

* feat: new `DeprecatedMixin` for whole classes

This commit introduces a new `DeprecatedMixin` class that can be added
to any model class to mark that class as deprecated. It will throw a
warning if the user tries to instantiate the class.

* feat: mark `TimeGranularity` as deprecated

Since we introduced custom grains, the old `queryable_granularities` is
deprecated in favor of the new `queryable_time_granilarities`, which is
just a simple list of strings. This commit marks the `TimeGranularity`
class and all granularity fields that return it as deprecated.

* fix: change `OrderByGroupBy` to use `str` as grain

This commit changes the `OrderByGroupBy` class to use `str` instead of
the deprecated `TimeGranularity` enum as its input grain. We can do this
without a deprecation because we haven't released the SDK since the
order by refactor, so we can just change it.

* docs: changelog entry

* fix: preload models in `__init__`

We need to call `BaseModel._register_subclasses` otherwise models will
fail to use `camelCase` and raise deprecation warnings. That is done in
`dbtsl.models.__init__`. If the user never explicitly imports that, this
won't get called, and they might get an error.

This fixes that by adding an explicit call to it on the library init.

* fix: catch deprecation warnings in GQL client

We are raising deprecation warnings from the GQL client when we
instantiate the models. To avoid the warning spam, we filter those
warnings out. They should only be display if the user uses any
deprecated class, not us.
  • Loading branch information
serramatutu authored Oct 21, 2024
1 parent 9896a1c commit 7ab6312
Show file tree
Hide file tree
Showing 17 changed files with 173 additions and 25 deletions.
3 changes: 3 additions & 0 deletions .changes/unreleased/Deprecations-20241017-163158.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Deprecations
body: Deprecate `TimeGranularity` enum and all other fields that used it
time: 2024-10-17T16:31:58.091095+02:00
3 changes: 3 additions & 0 deletions .changes/unreleased/Features-20241017-163057.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Features
body: Add support for custom time granularity
time: 2024-10-17T16:30:57.023867+02:00
3 changes: 3 additions & 0 deletions .changes/unreleased/Under the Hood-20241017-163037.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Under the Hood
body: Add new mechanisms to deprecate fields and classes
time: 2024-10-17T16:30:37.793294+02:00
4 changes: 3 additions & 1 deletion .changie.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ kindFormat: '### {{.Kind}}'
changeFormat: '* {{.Body}}'
kinds:
- label: Breaking Changes
auto: major
auto: minor
- label: Deprecations
auto: minor
- label: Features
auto: minor
- label: Fixes
Expand Down
2 changes: 2 additions & 0 deletions dbtsl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pyright: reportUnusedImport=false
try:
from dbtsl.client.sync import SyncSemanticLayerClient

Expand All @@ -13,6 +14,7 @@ def err_factory(*args, **kwargs) -> None: # noqa: D103

SemanticLayerClient = err_factory

import dbtsl.models # noqa: F401
from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric

__all__ = ["SemanticLayerClient", "OrderByMetric", "OrderByGroupBy"]
2 changes: 1 addition & 1 deletion dbtsl/api/adbc/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _serialize_val(cls, val: Any) -> str:
if isinstance(val, OrderByGroupBy):
d = f'Dimension("{val.name}")'
if val.grain:
grain_str = val.grain.name.lower()
grain_str = val.grain.lower()
d += f'.grain("{grain_str}")'
if val.descending:
d += ".descending(True)"
Expand Down
11 changes: 7 additions & 4 deletions dbtsl/api/graphql/client/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import warnings
from abc import abstractmethod
from typing import Any, Dict, Generic, Optional, Protocol, TypeVar, Union

Expand Down Expand Up @@ -102,10 +103,12 @@ def __getattr__(self, attr: str) -> Any:
if op is None:
raise AttributeError()

return functools.partial(
self._run,
op=op,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
return functools.partial(
self._run,
op=op,
)


TClient = TypeVar("TClient", bound=BaseGraphQLClient, covariant=True)
Expand Down
4 changes: 1 addition & 3 deletions dbtsl/api/shared/query_params.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from dataclasses import dataclass
from typing import List, Optional, TypedDict, Union

from dbtsl.models.time import TimeGranularity


@dataclass(frozen=True)
class OrderByMetric:
Expand All @@ -20,7 +18,7 @@ class OrderByGroupBy:
"""

name: str
grain: Optional[TimeGranularity]
grain: Optional[str]
descending: bool = False


Expand Down
2 changes: 2 additions & 0 deletions dbtsl/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pyright: reportUnusedImport=false
try:
from dbtsl.client.asyncio import AsyncSemanticLayerClient
except ImportError:
Expand All @@ -11,6 +12,7 @@ def err_factory(*args, **kwargs) -> None: # noqa: D103

AsyncSemanticLayerClient = err_factory

import dbtsl.models # noqa: F401

__all__ = [
"AsyncSemanticLayerClient",
Expand Down
2 changes: 1 addition & 1 deletion dbtsl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# Only importing this so it registers aliases
_ = QueryResult

BaseModel._apply_aliases()
BaseModel._register_subclasses()

__all__ = [
"AggregationType",
Expand Down
58 changes: 54 additions & 4 deletions dbtsl/models/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import inspect
import warnings
from dataclasses import dataclass, fields, is_dataclass
from dataclasses import field as dc_field
from functools import cache
from types import MappingProxyType
from typing import Any, List, Set, Type, Union
from typing import Any, ClassVar, Dict, List, Set, Type, Union
from typing import get_args as get_type_args
from typing import get_origin as get_type_origin

Expand All @@ -25,19 +26,68 @@ class BaseModel(DataClassDictMixin):
Adds some functionality like automatically creating camelCase aliases.
"""

DEPRECATED: ClassVar[str] = "dbtsl_deprecated"

# Mapping of "subclass.field" to "deprecation reason"
_deprecated_fields: ClassVar[Dict[str, str]] = dict()

@staticmethod
def _get_deprecation_key(class_name: str, field_name: str) -> str:
return f"{class_name}.{field_name}"

@classmethod
def _warn_if_deprecated(cls, field_name: str) -> None:
key = BaseModel._get_deprecation_key(cls.__name__, field_name)
reason = BaseModel._deprecated_fields.get(key)
if reason is not None:
warnings.warn(reason, DeprecationWarning)

class Config(BaseConfig): # noqa: D106
lazy_compilation = True

@classmethod
def _apply_aliases(cls) -> None:
"""Apply camelCase aliases to all subclasses."""
def _register_subclasses(cls) -> None:
"""Process fields of all subclasses.
This will:
- Apply camelCase aliases
- Pre-populate the _deprecated_fields dict with the deprecated fields
"""
for subclass in cls.__subclasses__():
assert is_dataclass(subclass), "Subclass of BaseModel must be dataclass"

for field in fields(subclass):
camel_name = snake_case_to_camel_case(field.name)
if field.name != camel_name:
field.metadata = MappingProxyType(field_options(alias=camel_name))
opts = field_options(alias=camel_name)
if field.metadata is not None:
opts = {**opts, **field.metadata}
field.metadata = MappingProxyType(opts)

if cls.DEPRECATED in field.metadata:
reason = field.metadata[cls.DEPRECATED]
key = BaseModel._get_deprecation_key(subclass.__name__, field.name)
cls._deprecated_fields[key] = reason

def __getattribute__(self, name: str) -> Any: # noqa: D105
v = object.__getattribute__(self, name)
if not name.startswith("__") and not callable(v):
self._warn_if_deprecated(name)

return v


class DeprecatedMixin:
"""Add this to any deprecated model."""

@classmethod
def _deprecation_message(cls) -> str:
"""The deprecation message that will get displayed."""
return f"{cls.__name__} is deprecated"

def __init__(self, *args, **kwargs) -> None: # noqa: D107
warnings.warn(self._deprecation_message(), DeprecationWarning)
super(DeprecatedMixin, self).__init__()


@dataclass(frozen=True, eq=True)
Expand Down
13 changes: 11 additions & 2 deletions dbtsl/models/dimension.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional

Expand All @@ -13,6 +13,12 @@ class DimensionType(str, Enum):
TIME = "TIME"


QUERYABLE_GRANULARITIES_DEPRECATION = (
"Since the introduction of custom time granularities, `Dimension.queryable_granularities` is deprecated. "
"Use `queryable_time_granularities` instead."
)


@dataclass(frozen=True)
class Dimension(BaseModel, GraphQLFragmentMixin):
"""A metric dimension."""
Expand All @@ -24,4 +30,7 @@ class Dimension(BaseModel, GraphQLFragmentMixin):
label: Optional[str]
is_partition: bool
expr: Optional[str]
queryable_granularities: List[TimeGranularity]
queryable_granularities: List[TimeGranularity] = field(
metadata={BaseModel.DEPRECATED: QUERYABLE_GRANULARITIES_DEPRECATION}
)
queryable_time_granularities: List[str]
13 changes: 11 additions & 2 deletions dbtsl/models/metric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional

Expand All @@ -19,6 +19,12 @@ class MetricType(str, Enum):
CONVERSION = "CONVERSION"


QUERYABLE_GRANULARITIES_DEPRECATION = (
"Since the introduction of custom time granularities, `Metric.queryable_granularities` is deprecated. "
"Use `queryable_time_granularities` instead."
)


@dataclass(frozen=True)
class Metric(BaseModel, GraphQLFragmentMixin):
"""A metric."""
Expand All @@ -29,6 +35,9 @@ class Metric(BaseModel, GraphQLFragmentMixin):
dimensions: List[Dimension]
measures: List[Measure]
entities: List[Entity]
queryable_granularities: List[TimeGranularity]
queryable_granularities: List[TimeGranularity] = field(
metadata={BaseModel.DEPRECATED: QUERYABLE_GRANULARITIES_DEPRECATION}
)
queryable_time_granularities: List[str]
label: str
requires_metric_time: bool
10 changes: 9 additions & 1 deletion dbtsl/models/saved_query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from dataclasses import field as dc_field
from enum import Enum
from typing import List, Optional

Expand Down Expand Up @@ -37,12 +38,19 @@ class SavedQueryMetricParam(BaseModel, GraphQLFragmentMixin):
name: str


GRAIN_DEPRECATION = (
"Since the introduction of custom time granularities, `SavedQueryGroupByParam.grain` is deprecated. "
"Use `time_granularity` instead."
)


@dataclass(frozen=True)
class SavedQueryGroupByParam(BaseModel, GraphQLFragmentMixin):
"""The groupBy param of a saved query."""

name: str
grain: Optional[TimeGranularity]
grain: Optional[TimeGranularity] = dc_field(metadata={BaseModel.DEPRECATED: GRAIN_DEPRECATION})
time_granularity: Optional[str]
date_part: Optional[DatePart]


Expand Down
14 changes: 13 additions & 1 deletion dbtsl/models/time.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from enum import Enum

from typing_extensions import override

class TimeGranularity(str, Enum):
from dbtsl.models.base import DeprecatedMixin


class TimeGranularity(str, DeprecatedMixin, Enum):
"""A time granularity."""

@override
@classmethod
def _deprecation_message(cls) -> str:
return (
"Since the introduction of custom time granularity, the `TimeGranularity` enum is deprecated. "
"Please just use strings to represent time grains."
)

NANOSECOND = "NANOSECOND"
MICROSECOND = "MICROSECOND"
MILLISECOND = "MILLISECOND"
Expand Down
5 changes: 2 additions & 3 deletions tests/api/adbc/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dbtsl.api.adbc.protocol import ADBCProtocol
from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric
from dbtsl.models.time import TimeGranularity


def test_serialize_val_basic_values() -> None:
Expand All @@ -23,11 +22,11 @@ def test_serialize_val_OrderByGroupBy() -> None:
== 'Dimension("m").descending(True)'
)
assert (
ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=TimeGranularity.DAY, descending=False))
ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain="day", descending=False))
== 'Dimension("m").grain("day")'
)
assert (
ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=TimeGranularity.WEEK, descending=True))
ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain="week", descending=True))
== 'Dimension("m").grain("week").descending(True)'
)

Expand Down
49 changes: 47 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import warnings
from dataclasses import dataclass
from dataclasses import field as dc_field
from typing import List

import pytest
from mashumaro.codecs.basic import decode
from typing_extensions import override

from dbtsl.api.graphql.util import normalize_query
from dbtsl.api.shared.query_params import (
Expand All @@ -14,7 +17,7 @@
validate_order_by,
validate_query_parameters,
)
from dbtsl.models.base import BaseModel, GraphQLFragmentMixin
from dbtsl.models.base import BaseModel, DeprecatedMixin, GraphQLFragmentMixin
from dbtsl.models.base import snake_case_to_camel_case as stc


Expand All @@ -31,7 +34,7 @@ def test_base_model_auto_alias() -> None:
class SubModel(BaseModel):
hello_world: str

BaseModel._apply_aliases()
BaseModel._register_subclasses()

data = {
"helloWorld": "asdf",
Expand Down Expand Up @@ -89,6 +92,48 @@ class B(BaseModel, GraphQLFragmentMixin):
assert b_fragments[1] == a_fragment


def test_DeprecatedMixin() -> None:
msg = "i am deprecated :("

class MyDeprecatedClass(DeprecatedMixin):
@override
@classmethod
def _deprecation_message(cls) -> str:
return msg

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

_ = MyDeprecatedClass()
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert msg == str(w[0].message)


def test_attr_deprecation_warning() -> None:
msg = "i am deprecated :("

@dataclass(frozen=True)
class MyClassWithDeprecatedField(BaseModel):
its_fine: bool = True
oh_no: bool = dc_field(default=False, metadata={BaseModel.DEPRECATED: msg})

BaseModel._register_subclasses()

m = MyClassWithDeprecatedField()

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

_ = m.its_fine
assert len(w) == 0

_ = m.oh_no
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert msg == str(w[0].message)


def test_validate_order_by_params_passthrough_OrderByMetric() -> None:
i = OrderByMetric(name="asdf", descending=True)
r = validate_order_by([], [], i)
Expand Down

0 comments on commit 7ab6312

Please sign in to comment.