Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for custom time granularities #54

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changes/unreleased/Breaking Changes-20241017-144053.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Breaking Changes
body: '`Dimension` and `Metric`''s `queryable_granularities` can now contain strings that correspond to custom grains. Elements can still be `TimeGranularity`, though'
time: 2024-10-17T14:40:53.023434+02:00
3 changes: 3 additions & 0 deletions .changes/unreleased/Features-20241017-143959.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Features
body: Add support for custom time granularities
time: 2024-10-17T14:39:59.367812+02:00
3 changes: 2 additions & 1 deletion dbtsl/api/adbc/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
QueryParameters,
validate_query_parameters,
)
from dbtsl.models.time import TimeGranularity


class ADBCProtocol:
Expand All @@ -32,7 +33,7 @@ def _serialize_val(cls, val: Any) -> str:
if isinstance(val, OrderByGroupBy):
d = f'Dimension("{val.name}")'
if val.grain:
grain_str = val.grain.name.lower()
grain_str = val.grain.name.lower() if isinstance(val.grain, TimeGranularity) else val.grain.lower()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@courtneyholcomb This supposes custom grains are not case-sensitive. Is this true?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They will be lowercase! This is enforced with a validation error in DSI

d += f'.grain("{grain_str}")'
if val.descending:
d += ".descending(True)"
Expand Down
4 changes: 2 additions & 2 deletions dbtsl/api/shared/query_params.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import List, Optional, TypedDict, Union

from dbtsl.models.time import TimeGranularity
from dbtsl.models.time import Grain


@dataclass(frozen=True)
Expand All @@ -20,7 +20,7 @@ class OrderByGroupBy:
"""

name: str
grain: Optional[TimeGranularity]
grain: Optional[Grain]
descending: bool = False


Expand Down
3 changes: 2 additions & 1 deletion dbtsl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
SavedQueryQueryParams,
SavedQueryWhereParam,
)
from .time import DatePart, TimeGranularity
from .time import DatePart, Grain, TimeGranularity

# Only importing this so it registers aliases
_ = QueryResult
Expand All @@ -37,6 +37,7 @@
"Export",
"ExportConfig",
"ExportDestinationType",
"Grain",
"Measure",
"Metric",
"MetricType",
Expand Down
15 changes: 14 additions & 1 deletion dbtsl/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def gql_model_name(cls) -> str:
"""The model's name in the GraphQL schema. Defaults to same as class name."""
return cls.__name__

@classmethod
def extra_gql_fields(cls) -> List[str]:
"""Any extra GraphQL fields that the mixin requires.

This can be paired with `dataclasses.InitVar` to create fields that are queried via GraphQL
and passed into `__post_init__`, but are not a part of the final dataclass object.
"""
return []

# NOTE: this will overflow the stack if we add any circular dependencies in our GraphQL schema, like
# Metric -> Dimension -> Metric -> Dimension ...
#
Expand Down Expand Up @@ -85,9 +94,13 @@ def gql_fragments(cls) -> List[GraphQLFragment]:
gql_model_name = cls.gql_model_name()
fragment_name = f"fragment{cls.__name__}"

# NOTE: for some reason MyPy and pyright freak out when calling `extra_gql_fields` from
# `cls` after this assertion, even though it's guaranteed the method exists. If we proxy
# it through this `cls_ref` variable then they stop complaining
ref = cls
assert is_dataclass(cls), "Subclass of GraphQLFragmentMixin must be dataclass"

query_elements: List[str] = []
query_elements: List[str] = ref.extra_gql_fields()
dependencies: Set[GraphQLFragment] = set()
for field in fields(cls):
frag_or_field = GraphQLFragmentMixin._get_fragments_for_field(field.type, field.name)
Expand Down
29 changes: 26 additions & 3 deletions dbtsl/models/dimension.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass
from dataclasses import InitVar, dataclass
from enum import Enum
from typing import List, Optional

from typing_extensions import override

from dbtsl.models.base import BaseModel, GraphQLFragmentMixin
from dbtsl.models.time import TimeGranularity
from dbtsl.models.time import Grain


class DimensionType(str, Enum):
Expand All @@ -24,4 +26,25 @@ class Dimension(BaseModel, GraphQLFragmentMixin):
label: Optional[str]
is_partition: bool
expr: Optional[str]
queryable_granularities: List[TimeGranularity]
queryable_granularities: List[Grain]

queryable_time_granilarities: InitVar[List[str]]

@override
@classmethod
def extra_gql_fields(cls) -> List[str]:
return ["queryableTimeGranularities"]

def __post_init__(self, queryable_time_granilarities: List[str]) -> None:
"""Initialize queryable_granularities from queryable_time_granilarities.

In GraphQL, the standard time granularities are in `queryableGranularities`
but the custom time granularities are in `queryableTimeGranularities`.

Here' we're setting `queryable_time_granilarities` as an `InitVar`, and
making `queryable_granularities` contain both standard and non standard
`Grain`s.

This method is what merges both of them.
"""
self.queryable_granularities.extend(queryable_time_granilarities)
29 changes: 26 additions & 3 deletions dbtsl/models/metric.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from dataclasses import dataclass
from dataclasses import InitVar, dataclass
from enum import Enum
from typing import List, Optional

from typing_extensions import override

from dbtsl.models.base import BaseModel, GraphQLFragmentMixin
from dbtsl.models.dimension import Dimension
from dbtsl.models.entity import Entity
from dbtsl.models.measure import Measure
from dbtsl.models.time import TimeGranularity
from dbtsl.models.time import Grain


class MetricType(str, Enum):
Expand All @@ -29,6 +31,27 @@ class Metric(BaseModel, GraphQLFragmentMixin):
dimensions: List[Dimension]
measures: List[Measure]
entities: List[Entity]
queryable_granularities: List[TimeGranularity]
queryable_granularities: List[Grain]
label: str
requires_metric_time: bool

queryable_time_granilarities: InitVar[List[str]]

@override
@classmethod
def extra_gql_fields(cls) -> List[str]:
return ["queryableTimeGranularities"]

def __post_init__(self, queryable_time_granilarities: List[str]) -> None:
"""Initialize queryable_granularities from queryable_time_granilarities.

In GraphQL, the standard time granularities are in `queryableGranularities`
but the custom time granularities are in `queryableTimeGranularities`.

Here' we're setting `queryable_time_granilarities` as an `InitVar`, and
making `queryable_granularities` contain both standard and non standard
`Grain`s.

This method is what merges both of them.
"""
self.queryable_granularities.extend(queryable_time_granilarities)
4 changes: 2 additions & 2 deletions dbtsl/models/saved_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Optional

from dbtsl.models.base import BaseModel, GraphQLFragmentMixin
from dbtsl.models.time import DatePart, TimeGranularity
from dbtsl.models.time import DatePart, Grain


class ExportDestinationType(str, Enum):
Expand Down Expand Up @@ -42,7 +42,7 @@ class SavedQueryGroupByParam(BaseModel, GraphQLFragmentMixin):
"""The groupBy param of a saved query."""

name: str
grain: Optional[TimeGranularity]
grain: Optional[Grain]
date_part: Optional[DatePart]


Expand Down
5 changes: 5 additions & 0 deletions dbtsl/models/time.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum
from typing import Union


class TimeGranularity(str, Enum):
Expand All @@ -17,6 +18,10 @@ class TimeGranularity(str, Enum):
YEAR = "YEAR"


Grain = Union[TimeGranularity, str]
"""Either a standard TimeGranularity or a custom grain."""


class DatePart(str, Enum):
"""Date part."""

Expand Down
4 changes: 4 additions & 0 deletions tests/api/adbc/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def test_serialize_val_OrderByGroupBy() -> None:
ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=TimeGranularity.WEEK, descending=True))
== 'Dimension("m").grain("week").descending(True)'
)
assert (
ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain="custom_grain"))
== 'Dimension("m").grain("custom_grain")'
)


def test_serialize_query_params_metrics() -> None:
Expand Down
41 changes: 41 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from mashumaro.codecs.basic import decode
from typing_extensions import override

from dbtsl.api.graphql.util import normalize_query
from dbtsl.api.shared.query_params import (
Expand All @@ -16,6 +17,9 @@
)
from dbtsl.models.base import BaseModel, GraphQLFragmentMixin
from dbtsl.models.base import snake_case_to_camel_case as stc
from dbtsl.models.dimension import Dimension, DimensionType
from dbtsl.models.metric import Metric, MetricType
from dbtsl.models.time import TimeGranularity


def test_snake_case_to_camel_case() -> None:
Expand Down Expand Up @@ -56,6 +60,11 @@ class B(BaseModel, GraphQLFragmentMixin):
a: A
many_a: List[A]

@override
@classmethod
def extra_gql_fields(cls) -> List[str]:
return ["myExtraGqlField"]

a_fragments = A.gql_fragments()
assert len(a_fragments) == 1
a_fragment = a_fragments[0]
Expand All @@ -74,6 +83,7 @@ class B(BaseModel, GraphQLFragmentMixin):

b_expect = normalize_query("""
fragment fragmentB on B {
myExtraGqlField
helloWorld
baz
a {
Expand Down Expand Up @@ -188,3 +198,34 @@ def test_validate_query_params_no_query() -> None:
p: QueryParameters = {"limit": 1, "where": ["1=1"], "order_by": ["a"], "read_cache": False}
with pytest.raises(ValueError):
validate_query_parameters(p)


def test_Metric_custom_granularity() -> None:
m = Metric(
name="metric",
description="my metric",
label="lala",
type=MetricType.SIMPLE,
dimensions=[],
entities=[],
measures=[],
queryable_granularities=[TimeGranularity.DAY, TimeGranularity.WEEK],
queryable_time_granilarities=["custom_grain"],
requires_metric_time=True,
)
assert m.queryable_granularities == [TimeGranularity.DAY, TimeGranularity.WEEK, "custom_grain"]


def test_Dimension_custom_granularity() -> None:
d = Dimension(
name="dimension",
qualified_name="full_name__dimension",
description="my dimension",
label="lala",
type=DimensionType.TIME,
is_partition=True,
expr="a - b",
queryable_granularities=[TimeGranularity.DAY, TimeGranularity.WEEK],
queryable_time_granilarities=["custom_grain"],
)
assert d.queryable_granularities == [TimeGranularity.DAY, TimeGranularity.WEEK, "custom_grain"]
Loading