Skip to content

Commit

Permalink
Alias unions & update docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Sep 25, 2023
1 parent df6a8f7 commit 55896cd
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 35 deletions.
23 changes: 10 additions & 13 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple

import pandas as pd
from dbt_semantic_interfaces.implementations.elements.dimension import PydanticDimensionTypeParams
Expand Down Expand Up @@ -47,12 +47,7 @@
DataflowToExecutionPlanConverter,
)
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.protocols.query_parameter import (
GroupByQueryParameter,
MetricQueryParameter,
OrderByQueryParameter,
TimeDimensionQueryParameter,
)
from metricflow.protocols.query_parameter import GroupByParameter, MetricQueryParameter, OrderByQueryParameter
from metricflow.protocols.sql_client import SqlClient
from metricflow.query.query_exceptions import InvalidQueryException
from metricflow.query.query_parser import MetricFlowQueryParser
Expand Down Expand Up @@ -90,13 +85,15 @@ class MetricFlowQueryRequest:
"""Encapsulates the parameters for a metric query.
metric_names: Names of the metrics to query.
metrics: Metric objects to query.
group_by_names: Names of the dimensions and entities to query.
group_by: Dimension or entity objects to query.
limit: Limit the result to this many rows.
time_constraint_start: Get data for the start of this time range.
time_constraint_end: Get data for the end of this time range.
where_constraint: A SQL string using group by names that can be used like a where clause on the output data.
order_by_names: metric and group by names to order by. A "-" can be used to specify reverse order e.g. "-ds"
order_by: metric and group by objects to order by
order_by_names: metric and group by names to order by. A "-" can be used to specify reverse order e.g. "-ds".
order_by: metric, dimension, or entity objects to order by.
output_table: If specified, output the result data to this table instead of a result dataframe.
sql_optimization_level: The level of optimization for the generated SQL.
query_type: Type of MetricFlow query.
Expand All @@ -106,7 +103,7 @@ class MetricFlowQueryRequest:
metric_names: Optional[Sequence[str]] = None
metrics: Optional[Sequence[MetricQueryParameter]] = None
group_by_names: Optional[Sequence[str]] = None
group_by: Optional[Tuple[Union[GroupByQueryParameter, TimeDimensionQueryParameter], ...]] = None
group_by: Optional[Tuple[GroupByParameter, ...]] = None
limit: Optional[int] = None
time_constraint_start: Optional[datetime.datetime] = None
time_constraint_end: Optional[datetime.datetime] = None
Expand All @@ -122,7 +119,7 @@ def create_with_random_request_id( # noqa: D
metric_names: Optional[Sequence[str]] = None,
metrics: Optional[Sequence[MetricQueryParameter]] = None,
group_by_names: Optional[Sequence[str]] = None,
group_by: Optional[Tuple[Union[GroupByQueryParameter, TimeDimensionQueryParameter], ...]] = None,
group_by: Optional[Tuple[GroupByParameter, ...]] = None,
limit: Optional[int] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
Expand Down Expand Up @@ -294,7 +291,7 @@ def explain_get_dimension_values( # noqa: D
metric_names: Optional[List[str]] = None,
metrics: Optional[Sequence[MetricQueryParameter]] = None,
get_group_by_values: Optional[str] = None,
group_by: Optional[Union[GroupByQueryParameter, TimeDimensionQueryParameter]] = None,
group_by: Optional[GroupByParameter] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
) -> MetricFlowExplainResult:
Expand Down Expand Up @@ -691,7 +688,7 @@ def explain_get_dimension_values( # noqa: D
metric_names: Optional[List[str]] = None,
metrics: Optional[Sequence[MetricQueryParameter]] = None,
get_group_by_values: Optional[str] = None,
group_by: Optional[Union[GroupByQueryParameter, TimeDimensionQueryParameter]] = None,
group_by: Optional[GroupByParameter] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
) -> MetricFlowExplainResult:
Expand Down
8 changes: 6 additions & 2 deletions metricflow/protocols/query_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def name(self) -> str:


@runtime_checkable
class GroupByQueryParameter(Protocol):
class DimensionOrEntityQueryParameter(Protocol):
"""Generic group by parameter for queries. Might be an entity or a dimension."""

@property
Expand All @@ -45,11 +45,15 @@ def date_part(self) -> Optional[DatePart]:
raise NotImplementedError


GroupByParameter = Union[DimensionOrEntityQueryParameter, TimeDimensionQueryParameter]
InputOrderByParameter = Union[MetricQueryParameter, GroupByParameter]


class OrderByQueryParameter(Protocol):
"""Parameter to order by, specifying ascending or descending."""

@property
def order_by(self) -> Union[MetricQueryParameter, GroupByQueryParameter, TimeDimensionQueryParameter]:
def order_by(self) -> InputOrderByParameter:
"""Parameter to order results by."""
raise NotImplementedError

Expand Down
10 changes: 5 additions & 5 deletions metricflow/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.call_parameter_sets import ParseWhereFilterException
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
Expand All @@ -29,7 +29,7 @@
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.protocols.query_parameter import (
GroupByQueryParameter,
GroupByParameter,
MetricQueryParameter,
OrderByQueryParameter,
TimeDimensionQueryParameter,
Expand Down Expand Up @@ -177,7 +177,7 @@ def parse_and_validate_query(
metric_names: Optional[Sequence[str]] = None,
metrics: Optional[Sequence[MetricQueryParameter]] = None,
group_by_names: Optional[Sequence[str]] = None,
group_by: Optional[Tuple[Union[GroupByQueryParameter, TimeDimensionQueryParameter], ...]] = None,
group_by: Optional[Tuple[GroupByParameter, ...]] = None,
limit: Optional[int] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
Expand Down Expand Up @@ -316,7 +316,7 @@ def _parse_and_validate_query(
metric_names: Optional[Sequence[str]] = None,
metrics: Optional[Sequence[MetricQueryParameter]] = None,
group_by_names: Optional[Sequence[str]] = None,
group_by: Optional[Tuple[Union[GroupByQueryParameter, TimeDimensionQueryParameter], ...]] = None,
group_by: Optional[Tuple[GroupByParameter, ...]] = None,
limit: Optional[int] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
Expand Down Expand Up @@ -665,7 +665,7 @@ def _parse_group_by(
self,
metric_references: Sequence[MetricReference],
group_by_names: Optional[Sequence[str]] = None,
group_by: Optional[Tuple[Union[GroupByQueryParameter, TimeDimensionQueryParameter], ...]] = None,
group_by: Optional[Tuple[GroupByParameter, ...]] = None,
) -> QueryTimeLinkableSpecSet:
"""Convert the linkable spec names into the respective specification objects."""
# TODO: refactor to only support group_by object inputs (removing group_by_names param)
Expand Down
12 changes: 4 additions & 8 deletions metricflow/specs/query_param_implementations.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Union
from typing import Optional

from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.protocols.query_parameter import (
GroupByQueryParameter,
MetricQueryParameter,
TimeDimensionQueryParameter,
)
from metricflow.protocols.query_parameter import InputOrderByParameter
from metricflow.time.date_part import DatePart


Expand All @@ -29,7 +25,7 @@ def __post_init__(self) -> None: # noqa: D


@dataclass(frozen=True)
class GroupByParameter:
class DimensionOrEntityParameter:
"""Group by parameter requested in a query.
Might represent an entity or a dimension.
Expand All @@ -49,5 +45,5 @@ class MetricParameter:
class OrderByParameter:
"""Order by requested in a query."""

order_by: Union[MetricQueryParameter, GroupByQueryParameter, TimeDimensionQueryParameter]
order_by: InputOrderByParameter
descending: bool = False
8 changes: 4 additions & 4 deletions metricflow/test/integration/test_configured_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from metricflow.plan_conversion.column_resolver import (
DunderColumnAssociationResolver,
)
from metricflow.protocols.query_parameter import GroupByQueryParameter
from metricflow.protocols.query_parameter import DimensionOrEntityQueryParameter
from metricflow.protocols.sql_client import SqlClient
from metricflow.specs.query_param_implementations import GroupByParameter, TimeDimensionParameter
from metricflow.specs.query_param_implementations import DimensionOrEntityParameter, TimeDimensionParameter
from metricflow.sql.sql_exprs import (
SqlCastToTimestampExpression,
SqlColumnReference,
Expand Down Expand Up @@ -256,7 +256,7 @@ def test_case(

check_query_helpers = CheckQueryHelpers(sql_client)

group_by: List[GroupByQueryParameter] = []
group_by: List[DimensionOrEntityQueryParameter] = []
for group_by_kwargs in case.group_by_objs:
kwargs = copy(group_by_kwargs)
date_part = kwargs.get("date_part")
Expand All @@ -268,7 +268,7 @@ def test_case(
kwargs["grain"] = TimeGranularity(grain)
group_by.append(TimeDimensionParameter(**kwargs))
else:
group_by.append(GroupByParameter(**kwargs))
group_by.append(DimensionOrEntityParameter(**kwargs))
query_result = engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=case.metrics,
Expand Down
6 changes: 3 additions & 3 deletions metricflow/test/query/test_query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from metricflow.query.query_exceptions import InvalidQueryException
from metricflow.query.query_parser import MetricFlowQueryParser
from metricflow.specs.query_param_implementations import (
GroupByParameter,
DimensionOrEntityParameter,
MetricParameter,
OrderByParameter,
TimeDimensionParameter,
Expand Down Expand Up @@ -196,8 +196,8 @@ def test_query_parser_with_object_params(bookings_query_parser: MetricFlowQueryP
Metric = namedtuple("Metric", ["name", "descending"])
metric = Metric("bookings", False)
group_by = (
GroupByParameter("booking__is_instant"),
GroupByParameter("listing"),
DimensionOrEntityParameter("booking__is_instant"),
DimensionOrEntityParameter("listing"),
TimeDimensionParameter(MTD),
)
order_by = (
Expand Down

0 comments on commit 55896cd

Please sign in to comment.