diff --git a/.changes/unreleased/Breaking Changes-20241017-144053.yaml b/.changes/unreleased/Breaking Changes-20241017-144053.yaml new file mode 100644 index 0000000..727e3ee --- /dev/null +++ b/.changes/unreleased/Breaking Changes-20241017-144053.yaml @@ -0,0 +1,3 @@ +kind: Breaking Changes +body: '`Dimension` and `Metric`''s `queryable_granularities` can now contain strings that correspond to custom grains. Elements can still be `TimeGranularity`, though' +time: 2024-10-17T14:40:53.023434+02:00 diff --git a/.changes/unreleased/Features-20241017-143959.yaml b/.changes/unreleased/Features-20241017-143959.yaml new file mode 100644 index 0000000..819e22d --- /dev/null +++ b/.changes/unreleased/Features-20241017-143959.yaml @@ -0,0 +1,3 @@ +kind: Features +body: Add support for custom time granularities +time: 2024-10-17T14:39:59.367812+02:00 diff --git a/dbtsl/api/adbc/protocol.py b/dbtsl/api/adbc/protocol.py index c49d176..2fbacb0 100644 --- a/dbtsl/api/adbc/protocol.py +++ b/dbtsl/api/adbc/protocol.py @@ -9,6 +9,7 @@ QueryParameters, validate_query_parameters, ) +from dbtsl.models.time import TimeGranularity class ADBCProtocol: @@ -32,7 +33,7 @@ def _serialize_val(cls, val: Any) -> str: if isinstance(val, OrderByGroupBy): d = f'Dimension("{val.name}")' if val.grain: - grain_str = val.grain.name.lower() + grain_str = val.grain.name.lower() if isinstance(val.grain, TimeGranularity) else val.grain.lower() d += f'.grain("{grain_str}")' if val.descending: d += ".descending(True)" diff --git a/dbtsl/api/shared/query_params.py b/dbtsl/api/shared/query_params.py index 083065b..a18759c 100644 --- a/dbtsl/api/shared/query_params.py +++ b/dbtsl/api/shared/query_params.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import List, Optional, TypedDict, Union -from dbtsl.models.time import TimeGranularity +from dbtsl.models.time import Grain @dataclass(frozen=True) @@ -20,7 +20,7 @@ class OrderByGroupBy: """ name: str - grain: Optional[TimeGranularity] + grain: Optional[Grain] descending: bool = False diff --git a/dbtsl/models/__init__.py b/dbtsl/models/__init__.py index 7884bb2..88f5d18 100644 --- a/dbtsl/models/__init__.py +++ b/dbtsl/models/__init__.py @@ -20,7 +20,7 @@ SavedQueryQueryParams, SavedQueryWhereParam, ) -from .time import DatePart, TimeGranularity +from .time import DatePart, Grain, TimeGranularity # Only importing this so it registers aliases _ = QueryResult @@ -37,6 +37,7 @@ "Export", "ExportConfig", "ExportDestinationType", + "Grain", "Measure", "Metric", "MetricType", diff --git a/dbtsl/models/base.py b/dbtsl/models/base.py index ac044a3..cd1953d 100644 --- a/dbtsl/models/base.py +++ b/dbtsl/models/base.py @@ -56,6 +56,15 @@ def gql_model_name(cls) -> str: """The model's name in the GraphQL schema. Defaults to same as class name.""" return cls.__name__ + @classmethod + def extra_gql_fields(cls) -> List[str]: + """Any extra GraphQL fields that the mixin requires. + + This can be paired with `dataclasses.InitVar` to create fields that are queried via GraphQL + and passed into `__post_init__`, but are not a part of the final dataclass object. + """ + return [] + # NOTE: this will overflow the stack if we add any circular dependencies in our GraphQL schema, like # Metric -> Dimension -> Metric -> Dimension ... # @@ -85,9 +94,13 @@ def gql_fragments(cls) -> List[GraphQLFragment]: gql_model_name = cls.gql_model_name() fragment_name = f"fragment{cls.__name__}" + # NOTE: for some reason MyPy and pyright freak out when calling `extra_gql_fields` from + # `cls` after this assertion, even though it's guaranteed the method exists. If we proxy + # it through this `cls_ref` variable then they stop complaining + ref = cls assert is_dataclass(cls), "Subclass of GraphQLFragmentMixin must be dataclass" - query_elements: List[str] = [] + query_elements: List[str] = ref.extra_gql_fields() dependencies: Set[GraphQLFragment] = set() for field in fields(cls): frag_or_field = GraphQLFragmentMixin._get_fragments_for_field(field.type, field.name) diff --git a/dbtsl/models/dimension.py b/dbtsl/models/dimension.py index 05e44fa..9962dde 100644 --- a/dbtsl/models/dimension.py +++ b/dbtsl/models/dimension.py @@ -1,9 +1,11 @@ -from dataclasses import dataclass +from dataclasses import InitVar, dataclass from enum import Enum from typing import List, Optional +from typing_extensions import override + from dbtsl.models.base import BaseModel, GraphQLFragmentMixin -from dbtsl.models.time import TimeGranularity +from dbtsl.models.time import Grain class DimensionType(str, Enum): @@ -24,4 +26,25 @@ class Dimension(BaseModel, GraphQLFragmentMixin): label: Optional[str] is_partition: bool expr: Optional[str] - queryable_granularities: List[TimeGranularity] + queryable_granularities: List[Grain] + + queryable_time_granilarities: InitVar[List[str]] + + @override + @classmethod + def extra_gql_fields(cls) -> List[str]: + return ["queryableTimeGranularities"] + + def __post_init__(self, queryable_time_granilarities: List[str]) -> None: + """Initialize queryable_granularities from queryable_time_granilarities. + + In GraphQL, the standard time granularities are in `queryableGranularities` + but the custom time granularities are in `queryableTimeGranularities`. + + Here' we're setting `queryable_time_granilarities` as an `InitVar`, and + making `queryable_granularities` contain both standard and non standard + `Grain`s. + + This method is what merges both of them. + """ + self.queryable_granularities.extend(queryable_time_granilarities) diff --git a/dbtsl/models/metric.py b/dbtsl/models/metric.py index 034d4df..4e508a5 100644 --- a/dbtsl/models/metric.py +++ b/dbtsl/models/metric.py @@ -1,12 +1,14 @@ -from dataclasses import dataclass +from dataclasses import InitVar, dataclass from enum import Enum from typing import List, Optional +from typing_extensions import override + from dbtsl.models.base import BaseModel, GraphQLFragmentMixin from dbtsl.models.dimension import Dimension from dbtsl.models.entity import Entity from dbtsl.models.measure import Measure -from dbtsl.models.time import TimeGranularity +from dbtsl.models.time import Grain class MetricType(str, Enum): @@ -29,6 +31,27 @@ class Metric(BaseModel, GraphQLFragmentMixin): dimensions: List[Dimension] measures: List[Measure] entities: List[Entity] - queryable_granularities: List[TimeGranularity] + queryable_granularities: List[Grain] label: str requires_metric_time: bool + + queryable_time_granilarities: InitVar[List[str]] + + @override + @classmethod + def extra_gql_fields(cls) -> List[str]: + return ["queryableTimeGranularities"] + + def __post_init__(self, queryable_time_granilarities: List[str]) -> None: + """Initialize queryable_granularities from queryable_time_granilarities. + + In GraphQL, the standard time granularities are in `queryableGranularities` + but the custom time granularities are in `queryableTimeGranularities`. + + Here' we're setting `queryable_time_granilarities` as an `InitVar`, and + making `queryable_granularities` contain both standard and non standard + `Grain`s. + + This method is what merges both of them. + """ + self.queryable_granularities.extend(queryable_time_granilarities) diff --git a/dbtsl/models/saved_query.py b/dbtsl/models/saved_query.py index ce0068d..cbbcb82 100644 --- a/dbtsl/models/saved_query.py +++ b/dbtsl/models/saved_query.py @@ -3,7 +3,7 @@ from typing import List, Optional from dbtsl.models.base import BaseModel, GraphQLFragmentMixin -from dbtsl.models.time import DatePart, TimeGranularity +from dbtsl.models.time import DatePart, Grain class ExportDestinationType(str, Enum): @@ -42,7 +42,7 @@ class SavedQueryGroupByParam(BaseModel, GraphQLFragmentMixin): """The groupBy param of a saved query.""" name: str - grain: Optional[TimeGranularity] + grain: Optional[Grain] date_part: Optional[DatePart] diff --git a/dbtsl/models/time.py b/dbtsl/models/time.py index 81631de..20a332d 100644 --- a/dbtsl/models/time.py +++ b/dbtsl/models/time.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Union class TimeGranularity(str, Enum): @@ -17,6 +18,10 @@ class TimeGranularity(str, Enum): YEAR = "YEAR" +Grain = Union[TimeGranularity, str] +"""Either a standard TimeGranularity or a custom grain.""" + + class DatePart(str, Enum): """Date part.""" diff --git a/tests/api/adbc/test_protocol.py b/tests/api/adbc/test_protocol.py index fec3e74..66d2424 100644 --- a/tests/api/adbc/test_protocol.py +++ b/tests/api/adbc/test_protocol.py @@ -30,6 +30,10 @@ def test_serialize_val_OrderByGroupBy() -> None: ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=TimeGranularity.WEEK, descending=True)) == 'Dimension("m").grain("week").descending(True)' ) + assert ( + ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain="custom_grain")) + == 'Dimension("m").grain("custom_grain")' + ) def test_serialize_query_params_metrics() -> None: diff --git a/tests/test_models.py b/tests/test_models.py index 50f1957..fcc2c24 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,6 +3,7 @@ import pytest from mashumaro.codecs.basic import decode +from typing_extensions import override from dbtsl.api.graphql.util import normalize_query from dbtsl.api.shared.query_params import ( @@ -16,6 +17,9 @@ ) from dbtsl.models.base import BaseModel, GraphQLFragmentMixin from dbtsl.models.base import snake_case_to_camel_case as stc +from dbtsl.models.dimension import Dimension, DimensionType +from dbtsl.models.metric import Metric, MetricType +from dbtsl.models.time import TimeGranularity def test_snake_case_to_camel_case() -> None: @@ -56,6 +60,11 @@ class B(BaseModel, GraphQLFragmentMixin): a: A many_a: List[A] + @override + @classmethod + def extra_gql_fields(cls) -> List[str]: + return ["myExtraGqlField"] + a_fragments = A.gql_fragments() assert len(a_fragments) == 1 a_fragment = a_fragments[0] @@ -74,6 +83,7 @@ class B(BaseModel, GraphQLFragmentMixin): b_expect = normalize_query(""" fragment fragmentB on B { + myExtraGqlField helloWorld baz a { @@ -188,3 +198,34 @@ def test_validate_query_params_no_query() -> None: p: QueryParameters = {"limit": 1, "where": ["1=1"], "order_by": ["a"], "read_cache": False} with pytest.raises(ValueError): validate_query_parameters(p) + + +def test_Metric_custom_granularity() -> None: + m = Metric( + name="metric", + description="my metric", + label="lala", + type=MetricType.SIMPLE, + dimensions=[], + entities=[], + measures=[], + queryable_granularities=[TimeGranularity.DAY, TimeGranularity.WEEK], + queryable_time_granilarities=["custom_grain"], + requires_metric_time=True, + ) + assert m.queryable_granularities == [TimeGranularity.DAY, TimeGranularity.WEEK, "custom_grain"] + + +def test_Dimension_custom_granularity() -> None: + d = Dimension( + name="dimension", + qualified_name="full_name__dimension", + description="my dimension", + label="lala", + type=DimensionType.TIME, + is_partition=True, + expr="a - b", + queryable_granularities=[TimeGranularity.DAY, TimeGranularity.WEEK], + queryable_time_granilarities=["custom_grain"], + ) + assert d.queryable_granularities == [TimeGranularity.DAY, TimeGranularity.WEEK, "custom_grain"]