Skip to content

Commit

Permalink
Render EXTRACT columns in semantic model dataset by default
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Sep 8, 2023
1 parent a26a412 commit 66c988f
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 16 deletions.
1 change: 1 addition & 0 deletions metricflow/dag/id_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
SQL_EXPR_IS_NULL_PREFIX = "isn"
SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX = "ctt"
SQL_EXPR_DATE_TRUNC = "dt"
SQL_EXPR_EXTRACT = "ex"
SQL_EXPR_RATIO_COMPUTATION = "rc"
SQL_EXPR_BETWEEN_PREFIX = "betw"
SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc"
Expand Down
47 changes: 31 additions & 16 deletions metricflow/dataset/convert_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SqlColumnReferenceExpression,
SqlDateTruncExpression,
SqlExpressionNode,
SqlExtractExpression,
SqlStringExpression,
)
from metricflow.sql.sql_plan import (
Expand All @@ -46,6 +47,7 @@
SqlSelectStatementNode,
SqlTableFromClauseNode,
)
from metricflow.time.date_part import DatePart

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,12 +104,14 @@ def _create_time_dimension_instance(
time_dimension: Dimension,
entity_links: Tuple[EntityReference, ...],
time_granularity: TimeGranularity = DEFAULT_TIME_GRANULARITY,
date_part: Optional[DatePart] = None,
) -> TimeDimensionInstance:
"""Create a time dimension instance from the dimension object from a semantic model in the model."""
time_dimension_spec = TimeDimensionSpec(
element_name=time_dimension.reference.element_name,
entity_links=entity_links,
time_granularity=time_granularity,
date_part=date_part,
)

return TimeDimensionInstance(
Expand Down Expand Up @@ -219,6 +223,11 @@ def _convert_dimensions(
select_columns = []

for dimension in dimensions or []:
dimension_select_expr = SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
)
if dimension.type == DimensionType.CATEGORICAL:
dimension_instance = self._create_dimension_instance(
semantic_model_name=semantic_model_name,
Expand All @@ -228,11 +237,7 @@ def _convert_dimensions(
dimension_instances.append(dimension_instance)
select_columns.append(
SqlSelectColumn(
expr=SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
),
expr=dimension_select_expr,
column_alias=dimension_instance.associated_column.column_name,
)
)
Expand All @@ -251,11 +256,7 @@ def _convert_dimensions(
time_dimension_instances.append(time_dimension_instance)
select_columns.append(
SqlSelectColumn(
expr=SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
),
expr=dimension_select_expr,
column_alias=time_dimension_instance.associated_column.column_name,
)
)
Expand All @@ -274,16 +275,30 @@ def _convert_dimensions(
select_columns.append(
SqlSelectColumn(
expr=SqlDateTruncExpression(
time_granularity=time_granularity,
arg=SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
),
time_granularity=time_granularity, arg=dimension_select_expr
),
column_alias=time_dimension_instance.associated_column.column_name,
)
)

for date_part in DatePart:
if date_part.to_int() > defined_time_granularity.to_int():
time_dimension_instance = self._create_time_dimension_instance(
semantic_model_name=semantic_model_name,
time_dimension=dimension,
entity_links=entity_links,
time_granularity=time_granularity,
date_part=date_part,
)
time_dimension_instances.append(time_dimension_instance)

select_columns.append(
SqlSelectColumn(
expr=SqlExtractExpression(date_part=date_part, arg=dimension_select_expr),
column_alias=time_dimension_instance.associated_column.column_name,
)
)

else:
assert False, f"Unhandled dimension type: {dimension.type}"

Expand Down
4 changes: 4 additions & 0 deletions metricflow/specs/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql.sql_column_type import SqlColumnType
from metricflow.time.date_part import DatePart
from metricflow.visitor import VisitorOutputT


Expand Down Expand Up @@ -284,6 +285,7 @@ def accept(self, visitor: InstanceSpecVisitor[VisitorOutputT]) -> VisitorOutputT
@dataclass(frozen=True)
class TimeDimensionSpec(DimensionSpec): # noqa: D
time_granularity: TimeGranularity = DEFAULT_TIME_GRANULARITY
date_part: Optional[DatePart] = None

# Used for semi-additive joins. Some more thought is needed, but this may be useful in InstanceSpec.
aggregation_state: Optional[AggregationState] = None
Expand All @@ -295,6 +297,7 @@ def without_first_entity_link(self) -> TimeDimensionSpec: # noqa: D
element_name=self.element_name,
entity_links=self.entity_links[1:],
time_granularity=self.time_granularity,
date_part=self.date_part,
)

@property
Expand Down Expand Up @@ -338,6 +341,7 @@ def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDim
element_name=self.element_name,
entity_links=self.entity_links,
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=aggregation_state,
)

Expand Down
10 changes: 10 additions & 0 deletions metricflow/sql/render/expr_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SqlDateTruncExpression,
SqlExpressionNode,
SqlExpressionNodeVisitor,
SqlExtractExpression,
SqlFunction,
SqlGenerateUuidExpression,
SqlIsNullExpression,
Expand Down Expand Up @@ -267,6 +268,15 @@ def visit_date_trunc_expr(self, node: SqlDateTruncExpression) -> SqlExpressionRe
bind_parameters=arg_rendered.bind_parameters,
)

# TODO: test this syntax on all engines
def visit_extract_expr(self, node: SqlExtractExpression) -> SqlExpressionRenderResult: # noqa: D
arg_rendered = self.render_sql_expr(node.arg)

return SqlExpressionRenderResult(
sql=f"EXTRACT('{node.date_part.value}', {arg_rendered.sql})",
bind_parameters=arg_rendered.bind_parameters,
)

def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: # noqa: D
arg_rendered = node.arg.accept(self)
if node.grain_to_date:
Expand Down
64 changes: 64 additions & 0 deletions metricflow/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
SQL_EXPR_COLUMN_REFERENCE_ID_PREFIX,
SQL_EXPR_COMPARISON_ID_PREFIX,
SQL_EXPR_DATE_TRUNC,
SQL_EXPR_EXTRACT,
SQL_EXPR_FUNCTION_ID_PREFIX,
SQL_EXPR_GENERATE_UUID_PREFIX,
SQL_EXPR_IS_NULL_PREFIX,
Expand All @@ -32,6 +33,7 @@
)
from metricflow.dag.mf_dag import DagNode, DisplayedProperty, NodeId
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.time.date_part import DatePart
from metricflow.visitor import Visitable, VisitorOutputT


Expand Down Expand Up @@ -218,6 +220,10 @@ def visit_cast_to_timestamp_expr(self, node: SqlCastToTimestampExpression) -> Vi
def visit_date_trunc_expr(self, node: SqlDateTruncExpression) -> VisitorOutputT: # noqa: D
pass

@abstractmethod
def visit_extract_expr(self, node: SqlExtractExpression) -> VisitorOutputT: # noqa: D
pass

@abstractmethod
def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> VisitorOutputT: # noqa: D
pass
Expand Down Expand Up @@ -1416,6 +1422,64 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D
return self.time_granularity == other.time_granularity and self._parents_match(other)


class SqlExtractExpression(SqlExpressionNode):
"""Extract a date part from a time expression."""

def __init__(self, date_part: DatePart, arg: SqlExpressionNode) -> None:
"""Constructor.
Args:
date_part: the date part to extract.
arg: the expression to extract from.
"""
self._date_part = date_part
super().__init__(node_id=self.create_unique_id(), parent_nodes=[arg])

@classmethod
def id_prefix(cls) -> str: # noqa: D
return SQL_EXPR_EXTRACT

@property
def requires_parenthesis(self) -> bool: # noqa: D
return False

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D
return visitor.visit_extract_expr(self)

@property
def description(self) -> str: # noqa: D
return f"Extract {self.date_part.name}"

@property
def date_part(self) -> DatePart: # noqa: D
return self._date_part

@property
def arg(self) -> SqlExpressionNode: # noqa: D
assert len(self.parent_nodes) == 1
return self.parent_nodes[0]

def rewrite( # noqa: D
self,
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return SqlExtractExpression(
date_part=self.date_part, arg=self.arg.rewrite(column_replacements, should_render_table_alias)
)

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D
return SqlExpressionTreeLineage.combine(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D
if not isinstance(other, SqlExtractExpression):
return False
return self.date_part == other.date_part and self._parents_match(other)


class SqlRatioComputationExpression(SqlExpressionNode):
"""Node for expressing Ratio metrics to allow for appropriate casting to float/double in each engine.
Expand Down
40 changes: 40 additions & 0 deletions metricflow/time/date_part.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from enum import Enum

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity


class DatePart(Enum):
"""Date parts able to be extracted from a time dimension.
TODO: add support for hour, minute, second once those granularities are available
"""

YEAR = "year"
QUARTER = "quarter"
MONTH = "month"
WEEK = "week"
WEEKDAY = "weekday"
DAYOFYEAR = "dayofyear"
DAY = "day"

def to_int(self) -> int:
"""Convert to an int so that the size of the granularity can be easily compared."""
if self is DatePart.DAY:
return TimeGranularity.DAY.to_int()
elif self is DatePart.WEEKDAY:
return TimeGranularity.DAY.to_int()
elif self is DatePart.DAYOFYEAR:
return TimeGranularity.DAY.to_int()
elif self is DatePart.WEEK:
return TimeGranularity.WEEK.to_int()
elif self is DatePart.MONTH:
return TimeGranularity.MONTH.to_int()
elif self is DatePart.QUARTER:
return TimeGranularity.QUARTER.to_int()
elif self is DatePart.YEAR:
return TimeGranularity.YEAR.to_int()
else:
assert_values_exhausted(self)

0 comments on commit 66c988f

Please sign in to comment.