Skip to content

Commit

Permalink
Render EXTRACT columns in semantic model dataset by default
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Sep 8, 2023
1 parent a26a412 commit 4fd25c1
Show file tree
Hide file tree
Showing 486 changed files with 76,103 additions and 12,046 deletions.
1 change: 1 addition & 0 deletions metricflow/dag/id_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
SQL_EXPR_IS_NULL_PREFIX = "isn"
SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX = "ctt"
SQL_EXPR_DATE_TRUNC = "dt"
SQL_EXPR_EXTRACT = "ex"
SQL_EXPR_RATIO_COMPUTATION = "rc"
SQL_EXPR_BETWEEN_PREFIX = "betw"
SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc"
Expand Down
48 changes: 32 additions & 16 deletions metricflow/dataset/convert_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SqlColumnReferenceExpression,
SqlDateTruncExpression,
SqlExpressionNode,
SqlExtractExpression,
SqlStringExpression,
)
from metricflow.sql.sql_plan import (
Expand All @@ -46,6 +47,7 @@
SqlSelectStatementNode,
SqlTableFromClauseNode,
)
from metricflow.time.date_part import DatePart

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,12 +104,14 @@ def _create_time_dimension_instance(
time_dimension: Dimension,
entity_links: Tuple[EntityReference, ...],
time_granularity: TimeGranularity = DEFAULT_TIME_GRANULARITY,
date_part: Optional[DatePart] = None,
) -> TimeDimensionInstance:
"""Create a time dimension instance from the dimension object from a semantic model in the model."""
time_dimension_spec = TimeDimensionSpec(
element_name=time_dimension.reference.element_name,
entity_links=entity_links,
time_granularity=time_granularity,
date_part=date_part,
)

return TimeDimensionInstance(
Expand Down Expand Up @@ -219,6 +223,11 @@ def _convert_dimensions(
select_columns = []

for dimension in dimensions or []:
dimension_select_expr = SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
)
if dimension.type == DimensionType.CATEGORICAL:
dimension_instance = self._create_dimension_instance(
semantic_model_name=semantic_model_name,
Expand All @@ -228,11 +237,7 @@ def _convert_dimensions(
dimension_instances.append(dimension_instance)
select_columns.append(
SqlSelectColumn(
expr=SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
),
expr=dimension_select_expr,
column_alias=dimension_instance.associated_column.column_name,
)
)
Expand All @@ -251,11 +256,7 @@ def _convert_dimensions(
time_dimension_instances.append(time_dimension_instance)
select_columns.append(
SqlSelectColumn(
expr=SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
),
expr=dimension_select_expr,
column_alias=time_dimension_instance.associated_column.column_name,
)
)
Expand All @@ -274,16 +275,31 @@ def _convert_dimensions(
select_columns.append(
SqlSelectColumn(
expr=SqlDateTruncExpression(
time_granularity=time_granularity,
arg=SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=table_alias,
element_name=dimension.reference.element_name,
element_expr=dimension.expr,
),
time_granularity=time_granularity, arg=dimension_select_expr
),
column_alias=time_dimension_instance.associated_column.column_name,
)
)

# Add all date part options for easy query resolution
for date_part in DatePart:
if date_part.to_int() >= defined_time_granularity.to_int():
time_dimension_instance = self._create_time_dimension_instance(
semantic_model_name=semantic_model_name,
time_dimension=dimension,
entity_links=entity_links,
time_granularity=defined_time_granularity,
date_part=date_part,
)
time_dimension_instances.append(time_dimension_instance)

select_columns.append(
SqlSelectColumn(
expr=SqlExtractExpression(date_part=date_part, arg=dimension_select_expr),
column_alias=time_dimension_instance.associated_column.column_name,
)
)

else:
assert False, f"Unhandled dimension type: {dimension.type}"

Expand Down
47 changes: 42 additions & 5 deletions metricflow/naming/linkable_spec_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.time.date_part import DatePart

DUNDER = "__"

logger = logging.getLogger(__name__)
Expand All @@ -24,6 +26,7 @@ class StructuredLinkableSpecName:
entity_link_names: Tuple[str, ...]
element_name: str
time_granularity: Optional[TimeGranularity] = None
date_part: Optional[DatePart] = None

@staticmethod
def from_name(qualified_name: str) -> StructuredLinkableSpecName:
Expand All @@ -32,7 +35,26 @@ def from_name(qualified_name: str) -> StructuredLinkableSpecName:

# No dunder, e.g. "ds"
if len(name_parts) == 1:
return StructuredLinkableSpecName((), name_parts[0])
return StructuredLinkableSpecName(entity_link_names=(), element_name=name_parts[0])

associated_date_part: Optional[DatePart] = None
for date_part in DatePart:
if name_parts[-1] == StructuredLinkableSpecName.date_part_suffix(date_part):
associated_date_part = date_part

# Has a date_part
if associated_date_part:
# e.g. "ds__extract_month"
if len(name_parts) == 2:
return StructuredLinkableSpecName(
entity_link_names=(), element_name=name_parts[0], date_part=associated_date_part
)
# e.g. "messages__ds__extract_month"
return StructuredLinkableSpecName(
entity_link_names=tuple(name_parts[:-2]),
element_name=name_parts[-2],
date_part=associated_date_part,
)

associated_granularity = None
granularity: TimeGranularity
Expand All @@ -44,19 +66,29 @@ def from_name(qualified_name: str) -> StructuredLinkableSpecName:
if associated_granularity:
# e.g. "ds__month"
if len(name_parts) == 2:
return StructuredLinkableSpecName((), name_parts[0], associated_granularity)
return StructuredLinkableSpecName(
entity_link_names=(), element_name=name_parts[0], time_granularity=associated_granularity
)
# e.g. "messages__ds__month"
return StructuredLinkableSpecName(tuple(name_parts[:-2]), name_parts[-2], associated_granularity)
return StructuredLinkableSpecName(
entity_link_names=tuple(name_parts[:-2]),
element_name=name_parts[-2],
time_granularity=associated_granularity,
)

# e.g. "messages__ds"
else:
return StructuredLinkableSpecName(tuple(name_parts[:-1]), name_parts[-1])
return StructuredLinkableSpecName(entity_link_names=tuple(name_parts[:-1]), element_name=name_parts[-1])

@property
def qualified_name(self) -> str:
"""Return the full name form. e.g. ds or listing__ds__month."""
items = list(self.entity_link_names) + [self.element_name]
if self.time_granularity:
if self.date_part:
items.append(self.date_part_suffix(date_part=self.date_part))
elif self.time_granularity:
items.append(self.time_granularity.value)

return DUNDER.join(items)

@property
Expand All @@ -66,3 +98,8 @@ def entity_prefix(self) -> Optional[str]:
return DUNDER.join(self.entity_link_names)

return None

@staticmethod
def date_part_suffix(date_part: DatePart) -> str:
"""Suffix used for names with a date_part."""
return f"extract_{date_part.value}"
1 change: 1 addition & 0 deletions metricflow/plan_conversion/column_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def visit_time_dimension_spec(self, time_dimension_spec: TimeDimensionSpec) -> C
entity_link_names=tuple(x.element_name for x in time_dimension_spec.entity_links),
element_name=time_dimension_spec.element_name,
time_granularity=time_dimension_spec.time_granularity,
date_part=time_dimension_spec.date_part,
).qualified_name

return ColumnAssociation(
Expand Down
5 changes: 4 additions & 1 deletion metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,7 @@ def visit_metric_time_dimension_transform_node(
if (
len(time_dimension_instance.spec.entity_links) == 0
and time_dimension_instance.spec.reference == node.aggregation_time_dimension_reference
and time_dimension_instance.spec.date_part is None
):
matching_time_dimension_instances.append(time_dimension_instance)

Expand All @@ -1141,6 +1142,7 @@ def visit_metric_time_dimension_transform_node(
output_column_to_input_column[
metric_time_dimension_column_association.column_name
] = matching_time_dimension_instance.associated_column.column_name

output_instance_set = InstanceSet(
measure_instances=tuple(output_measure_instances),
dimension_instances=input_data_set.instance_set.dimension_instances,
Expand Down Expand Up @@ -1359,7 +1361,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode[SourceDataSetT
SqlColumnReference(table_alias=time_spine_alias, column_name=original_time_dim_instance.spec.qualified_name)
)

# Add requested granularities (skip for default granularity).
# Add requested granularities (skip for default granularity) and date_parts.
metric_time_select_columns = []
metric_time_dimension_instances = []
where: Optional[SqlExpressionNode] = None
Expand All @@ -1376,6 +1378,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode[SourceDataSetT
element_name=original_time_dim_instance.spec.element_name,
entity_links=original_time_dim_instance.spec.entity_links,
time_granularity=metric_time_dimension_spec.time_granularity,
date_part=metric_time_dimension_spec.date_part,
aggregation_state=original_time_dim_instance.spec.aggregation_state,
)
time_dim_instance = TimeDimensionInstance(
Expand Down
1 change: 1 addition & 0 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
+ time_dimension_instance.spec.entity_links
),
time_granularity=time_dimension_instance.spec.time_granularity,
date_part=time_dimension_instance.spec.date_part,
)
time_dimension_instances_with_additional_link.append(
TimeDimensionInstance(
Expand Down
4 changes: 4 additions & 0 deletions metricflow/specs/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql.sql_column_type import SqlColumnType
from metricflow.time.date_part import DatePart
from metricflow.visitor import VisitorOutputT


Expand Down Expand Up @@ -284,6 +285,7 @@ def accept(self, visitor: InstanceSpecVisitor[VisitorOutputT]) -> VisitorOutputT
@dataclass(frozen=True)
class TimeDimensionSpec(DimensionSpec): # noqa: D
time_granularity: TimeGranularity = DEFAULT_TIME_GRANULARITY
date_part: Optional[DatePart] = None

# Used for semi-additive joins. Some more thought is needed, but this may be useful in InstanceSpec.
aggregation_state: Optional[AggregationState] = None
Expand All @@ -295,6 +297,7 @@ def without_first_entity_link(self) -> TimeDimensionSpec: # noqa: D
element_name=self.element_name,
entity_links=self.entity_links[1:],
time_granularity=self.time_granularity,
date_part=self.date_part,
)

@property
Expand Down Expand Up @@ -338,6 +341,7 @@ def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDim
element_name=self.element_name,
entity_links=self.entity_links,
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=aggregation_state,
)

Expand Down
2 changes: 2 additions & 0 deletions metricflow/specs/where_filter_time_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,6 @@ def _convert_to_time_dimension_spec(
element_name=parameter_set.time_dimension_reference.element_name,
entity_links=parameter_set.entity_path,
time_granularity=parameter_set.time_granularity,
# TODO: add date_part to TimeDimensionCallParameterSet in DSI
# date_part=parameter_set.date_part,
)
8 changes: 8 additions & 0 deletions metricflow/sql/render/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from metricflow.sql.sql_exprs import SqlPercentileExpression, SqlPercentileFunctionType
from metricflow.time.date_part import DatePart


class DatabricksSqlExpressionRenderer(DefaultSqlExpressionRenderer):
Expand Down Expand Up @@ -56,6 +57,13 @@ def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionR
bind_parameters=params,
)

@override
def render_date_part(self, date_part: DatePart) -> str:
if date_part == DatePart.DAYOFYEAR:
return "DOY"

return super().render_date_part(date_part)


class DatabricksSqlQueryPlanRenderer(DefaultSqlQueryPlanRenderer):
"""Plan renderer for the Snowflake engine."""
Expand Down
14 changes: 14 additions & 0 deletions metricflow/sql/render/expr_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SqlDateTruncExpression,
SqlExpressionNode,
SqlExpressionNodeVisitor,
SqlExtractExpression,
SqlFunction,
SqlGenerateUuidExpression,
SqlIsNullExpression,
Expand All @@ -36,6 +37,7 @@
SqlWindowFunctionExpression,
)
from metricflow.sql.sql_plan import SqlSelectColumn
from metricflow.time.date_part import DatePart

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -267,6 +269,18 @@ def visit_date_trunc_expr(self, node: SqlDateTruncExpression) -> SqlExpressionRe
bind_parameters=arg_rendered.bind_parameters,
)

def visit_extract_expr(self, node: SqlExtractExpression) -> SqlExpressionRenderResult: # noqa: D
arg_rendered = self.render_sql_expr(node.arg)

return SqlExpressionRenderResult(
sql=f"EXTRACT({self.render_date_part(node.date_part)} FROM {arg_rendered.sql})",
bind_parameters=arg_rendered.bind_parameters,
)

def render_date_part(self, date_part: DatePart) -> str:
"""Render DATE PART for an EXTRACT expression."""
return date_part.name

def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: # noqa: D
arg_rendered = node.arg.accept(self)
if node.grain_to_date:
Expand Down
Loading

0 comments on commit 4fd25c1

Please sign in to comment.