Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trino support to MetricFlow #810

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231008-195608.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add Trino support to the MetricFlow.
time: 2023-10-08T19:56:08.427006-06:00
custom:
Author: sarbmeetka
Issue: "207"
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ test-snowflake:
populate-persistent-source-schema-snowflake:
hatch -v run snowflake-env:pytest -vv $(ADDITIONAL_PYTEST_OPTIONS) $(USE_PERSISTENT_SOURCE_SCHEMA) $(POPULATE_PERSISTENT_SOURCE_SCHEMA)

.PHONY: test-trino
test-trino:
hatch -v run trino-env:pytest -vv -n $(PARALLELISM) $(ADDITIONAL_PYTEST_OPTIONS) metricflow/test/

.PHONY: lint
lint:
Expand Down
5 changes: 4 additions & 1 deletion dbt-metricflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ redshift = [
snowflake = [
"dbt-snowflake~=1.7.0"
]
trino = [
"dbt-trino~=1.7.0"
]

[tool.hatch.build.targets.sdist]
exclude = [
Expand All @@ -60,4 +63,4 @@ exclude = [
".pre-commit-config.yaml",
"CONTRIBUTING.md",
"MAKEFILE",
]
]
19 changes: 18 additions & 1 deletion metricflow/cli/dbt_connectors/adapter_backed_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from metricflow.sql.render.redshift import RedshiftSqlQueryPlanRenderer
from metricflow.sql.render.snowflake import SnowflakeSqlQueryPlanRenderer
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
from metricflow.sql.render.trino import TrinoSqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_request.sql_request_attributes import SqlJsonTag, SqlRequestId, SqlRequestTagSet
from metricflow.sql_request.sql_statement_metadata import CombinedSqlTags, SqlStatementCommentMetadata
Expand All @@ -42,6 +43,7 @@ class SupportedAdapterTypes(enum.Enum):
REDSHIFT = "redshift"
BIGQUERY = "bigquery"
DUCKDB = "duckdb"
TRINO = "trino"

@property
def sql_engine_type(self) -> SqlEngine:
Expand All @@ -58,6 +60,8 @@ def sql_engine_type(self) -> SqlEngine:
return SqlEngine.SNOWFLAKE
elif self is SupportedAdapterTypes.DUCKDB:
return SqlEngine.DUCKDB
elif self is SupportedAdapterTypes.TRINO:
return SqlEngine.TRINO
else:
assert_values_exhausted(self)

Expand All @@ -76,6 +80,8 @@ def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer:
return SnowflakeSqlQueryPlanRenderer()
elif self is SupportedAdapterTypes.DUCKDB:
return DuckDbSqlQueryPlanRenderer()
elif self is SupportedAdapterTypes.TRINO:
return TrinoSqlQueryPlanRenderer()
else:
assert_values_exhausted(self)

Expand Down Expand Up @@ -213,7 +219,18 @@ def dry_run(
request_id = SqlRequestId(f"mf_rid__{random_id()}")
connection_name = f"MetricFlow_dry_run_request_{request_id}"
# TODO - consolidate to self._adapter.validate_sql() when all implementations will work from within MetricFlow
if self.sql_engine_type is SqlEngine.BIGQUERY:

# Trino has a bug where explain command actually creates table. Wrapping with validate to avoid this.
sarbmeetka marked this conversation as resolved.
Show resolved Hide resolved
# See https://github.com/trinodb/trino/issues/130
if self.sql_engine_type is SqlEngine.TRINO:
with self._adapter.connection_named(connection_name):
# Either the response will be bool value or a string with error message from Trino.
result = self._adapter.execute(f"EXPLAIN (type validate) {stmt}", auto_begin=True, fetch=True)
has_error = False if str(result[0]) == "SUCCESS" else True
if has_error:
raise DbtDatabaseError("Encountered error in Trino dry run.")

elif self.sql_engine_type is SqlEngine.BIGQUERY:
with self._adapter.connection_named(connection_name):
self._adapter.validate_sql(stmt)
else:
Expand Down
1 change: 1 addition & 0 deletions metricflow/protocols/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class SqlEngine(Enum):
POSTGRES = "Postgres"
SNOWFLAKE = "Snowflake"
DATABRICKS = "Databricks"
TRINO = "Trino"


class SqlClient(Protocol):
Expand Down
126 changes: 126 additions & 0 deletions metricflow/sql/render/trino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from __future__ import annotations

from typing import Collection

from dateutil.parser import parse
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow.sql.render.expr_renderer import (
DefaultSqlExpressionRenderer,
SqlExpressionRenderer,
SqlExpressionRenderResult,
)
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql.sql_exprs import (
SqlBetweenExpression,
SqlGenerateUuidExpression,
SqlPercentileExpression,
SqlPercentileFunctionType,
SqlSubtractTimeIntervalExpression,
)


class TrinoSqlExpressionRenderer(DefaultSqlExpressionRenderer):
"""Expression renderer for the Trino engine."""

@property
@override
def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctionType]:
return {
SqlPercentileFunctionType.APPROXIMATE_CONTINUOUS,
}

@override
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult:
return SqlExpressionRenderResult(
sql="uuid()",
bind_parameters=SqlBindParameters(),
)

@override
def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
"""Render time delta for Trino, require granularity in quotes and function name change."""
arg_rendered = node.arg.accept(self)

count = node.count
granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count *= 3
return SqlExpressionRenderResult(
sql=f"DATE_ADD('{granularity.value}', -{count}, {arg_rendered.sql})",
bind_parameters=arg_rendered.bind_parameters,
)

@override
def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult:
"""Render a percentile expression for Trino."""
arg_rendered = self.render_sql_expr(node.order_by_arg)
params = arg_rendered.bind_parameters
percentile = node.percentile_args.percentile

if node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_CONTINUOUS:
return SqlExpressionRenderResult(
sql=f"approx_percentile({arg_rendered.sql}, {percentile})",
bind_parameters=params,
)
elif (
node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_DISCRETE
or node.percentile_args.function_type is SqlPercentileFunctionType.DISCRETE
or node.percentile_args.function_type is SqlPercentileFunctionType.CONTINUOUS
):
raise RuntimeError(
"Discrete, Continuous and Approximate discrete percentile aggregates are not supported for Trino. Set "
+ "use_approximate_percentile and disable use_discrete_percentile in all percentile measures."
sarbmeetka marked this conversation as resolved.
Show resolved Hide resolved
)
else:
assert_values_exhausted(node.percentile_args.function_type)

@override
def visit_between_expr(self, node: SqlBetweenExpression) -> SqlExpressionRenderResult:
"""Render a between expression for Trino. If the expression is a timestamp literal then wrap literals with timestamp."""
rendered_column_arg = self.render_sql_expr(node.column_arg)
rendered_start_expr = self.render_sql_expr(node.start_expr)
rendered_end_expr = self.render_sql_expr(node.end_expr)

bind_parameters = SqlBindParameters()
bind_parameters = bind_parameters.combine(rendered_column_arg.bind_parameters)
bind_parameters = bind_parameters.combine(rendered_start_expr.bind_parameters)
bind_parameters = bind_parameters.combine(rendered_end_expr.bind_parameters)

# Handle timestamp literals differently.
if parse(rendered_start_expr.sql):
sql = f"{rendered_column_arg.sql} BETWEEN timestamp {rendered_start_expr.sql} AND timestamp {rendered_end_expr.sql}"
else:
sql = f"{rendered_column_arg.sql} BETWEEN {rendered_start_expr.sql} AND {rendered_end_expr.sql}"

return SqlExpressionRenderResult(
sql=sql,
bind_parameters=bind_parameters,
)

@override
def render_date_part(self, date_part: DatePart) -> str:
"""Render DATE PART for an EXTRACT expression.

Override DAY_OF_WEEK in Trino to ISO date part to ensure all engines return consistent results.
"""
if date_part is DatePart.DOW:
return "DAY_OF_WEEK"

return date_part.value


class TrinoSqlQueryPlanRenderer(DefaultSqlQueryPlanRenderer):
"""Plan renderer for the Trino engine."""

EXPR_RENDERER = TrinoSqlExpressionRenderer()

@property
@override
def expr_renderer(self) -> SqlExpressionRenderer:
return self.EXPR_RENDERER
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These github missing newline markers annoy me. If you can, please add one here (and wherever else, there are a couple other places).

Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,14 @@ duckdb:
dev:
type: duckdb
schema: "{{ env_var('DBT_ENV_SECRET_SCHEMA') }}"
trino:
target: dev
outputs:
dev:
type: trino
host: "{{ env_var('DBT_ENV_SECRET_HOST') }}"
port: "{{ env_var('DBT_PROFILE_PORT') | int }}"
user: "{{ env_var('DBT_ENV_SECRET_USER') }}"
password: "{{ env_var('DBT_ENV_SECRET_PASSWORD') }}"
catalog: "{{ env_var('DBT_ENV_SECRET_CATALOG') }}"
schema: "{{ env_var('DBT_ENV_SECRET_SCHEMA') }}"
7 changes: 7 additions & 0 deletions metricflow/test/fixtures/sql_client_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
DBT_ENV_SECRET_PROJECT_ID = "DBT_ENV_SECRET_PROJECT_ID"
DBT_ENV_SECRET_TOKEN_URI = "DBT_ENV_SECRET_TOKEN_URI"

# Trino is special, so it gets its own set of env vars. Keeping them split out here for consistency.
DBT_ENV_SECRET_CATALOG = "DBT_ENV_SECRET_CATALOG"


def __configure_test_env_from_url(url: str, password: str, schema: str) -> sqlalchemy.engine.URL:
"""Populates default env var mapping from a sqlalchemy URL string.
Expand Down Expand Up @@ -163,6 +166,10 @@ def make_test_sql_client(url: str, password: str, schema: str) -> SqlClientWithD
__configure_databricks_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("databricks"))
elif dialect == SqlDialect.TRINO:
__configure_test_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("trino"))
else:
raise ValueError(f"Unknown dialect: `{dialect}` in URL {url}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def create_table_from_dataframe(
# This mirrors the SQLAlchemy schema detection logic in pandas.io.sql
df = df.convert_dtypes()
columns = df.columns

columns_to_insert = []
for i in range(len(df.columns)):
# Format as "column_name column_type"
Expand All @@ -63,7 +64,12 @@ def create_table_from_dataframe(
elif type(cell) in [str, pd.Timestamp]:
# Wrap cell in quotes & escape existing single quotes
escaped_cell = self._quote_escape_value(str(cell))
cells.append(f"'{escaped_cell}'")
# Trino requires timestamp literals to be wrapped in a timestamp() function.
# There is probably a better way to handle this.
sarbmeetka marked this conversation as resolved.
Show resolved Hide resolved
if self.sql_engine_type is SqlEngine.TRINO and type(cell) is pd.Timestamp:
cells.append(f"timestamp '{escaped_cell}'")
else:
cells.append(f"'{escaped_cell}'")
else:
cells.append(str(cell))

Expand Down Expand Up @@ -93,6 +99,8 @@ def _get_type_from_pandas_dtype(self, dtype: str) -> str:
if dtype == "string" or dtype == "object":
if self.sql_engine_type is SqlEngine.DATABRICKS or self.sql_engine_type is SqlEngine.BIGQUERY:
return "string"
if self.sql_engine_type is SqlEngine.TRINO:
return "varchar"
return "text"
elif dtype == "boolean" or dtype == "bool":
return "boolean"
Expand Down
1 change: 1 addition & 0 deletions metricflow/test/fixtures/sql_clients/common_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class SqlDialect(ExtendedEnum):
SNOWFLAKE = "snowflake"
BIGQUERY = "bigquery"
DATABRICKS = "databricks"
TRINO = "trino"


T = TypeVar("T")
Expand Down
10 changes: 10 additions & 0 deletions metricflow/test/generate_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
"engine_url": postgres://...",
"engine_password": "..."
},
"trino": {
"engine_url": trino://...",
"engine_password": "..."
},
}
EOF
)
Expand Down Expand Up @@ -69,6 +73,7 @@ class MetricFlowTestCredentialSetForAllEngines(FrozenBaseModel): # noqa: D
big_query: MetricFlowTestCredentialSet
databricks: MetricFlowTestCredentialSet
postgres: MetricFlowTestCredentialSet
trino: MetricFlowTestCredentialSet

@property
def as_configurations(self) -> Sequence[MetricFlowTestConfiguration]: # noqa: D
Expand Down Expand Up @@ -97,6 +102,10 @@ def as_configurations(self) -> Sequence[MetricFlowTestConfiguration]: # noqa: D
engine=SqlEngine.POSTGRES,
credential_set=self.postgres,
),
MetricFlowTestConfiguration(
engine=SqlEngine.TRINO,
credential_set=self.trino,
),
)


Expand Down Expand Up @@ -137,6 +146,7 @@ def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa:
or test_configuration.engine is SqlEngine.BIGQUERY
or test_configuration.engine is SqlEngine.DATABRICKS
or test_configuration.engine is SqlEngine.POSTGRES
or test_configuration.engine is SqlEngine.TRINO
):
engine_name = test_configuration.engine.value.lower()
os.environ["MF_TEST_ADAPTER_TYPE"] = engine_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,19 @@ def test_cumulative_metric_with_non_adjustable_filter(
it_helpers: IntegrationTestHelpers,
) -> None:
"""Tests a cumulative metric with a filter that cannot be adjusted to ensure all data is included."""

# Handle ds expression based on engine to support Trino.
first_ds_expr = f"CAST('2020-03-15' AS {sql_client.sql_query_plan_renderer.expr_renderer.timestamp_data_type})"
second_ds_expr = f"CAST('2020-04-30' AS {sql_client.sql_query_plan_renderer.expr_renderer.timestamp_data_type})"
where_constraint = f"{{{{ TimeDimension('metric_time', 'day') }}}} = {first_ds_expr} or"
where_constraint += f" {{{{ TimeDimension('metric_time', 'day') }}}} = {second_ds_expr}"
Comment on lines +138 to +141
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow. Ok, Trino makes this interface issue worse, but addressing this is way outside the scope of this PR. I'll open an internal issue for us to consider here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but this was blocking test to succeed successfully. Wasn't sure on getting around this.


query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["trailing_2_months_revenue"],
group_by_names=["metric_time"],
order_by_names=["metric_time"],
where_constraint=(
"{{ TimeDimension('metric_time', 'day') }} = '2020-03-15' or "
"{{ TimeDimension('metric_time', 'day') }} = '2020-04-30'"
),
where_constraint=where_constraint,
time_constraint_end=as_datetime("2020-12-31"),
)
)
Expand Down
8 changes: 4 additions & 4 deletions metricflow/test/integration/test_cases/itest_constraints.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ integration_test:
SELECT SUM(booking_value) AS booking_value
, ds AS metric_time__day
FROM {{ source_schema }}.fct_bookings b
WHERE ds = '2020-01-01'
WHERE {{ render_time_constraint("ds", "2020-01-01", "2020-01-01") }}
GROUP BY
ds
---
Expand All @@ -73,7 +73,7 @@ integration_test:
, ds AS metric_time__day
FROM {{ source_schema }}.fct_bookings b
WHERE is_instant
and ds = '2020-01-01'
and {{ render_time_constraint("ds", "2020-01-01", "2020-01-01") }}
GROUP BY ds
---
integration_test:
Expand Down Expand Up @@ -142,11 +142,11 @@ integration_test:
model: SIMPLE_MODEL
metrics: ["bookings"]
group_bys: ["metric_time"]
where_filter: "{{ render_time_dimension_template('metric_time') }} = '2020-01-01'"
where_filter: "{{ render_time_dimension_template('metric_time') }} = {{ cast_to_ts('2020-01-01') }}"
check_query: |
SELECT
SUM(1) AS bookings
, ds AS metric_time__day
FROM {{ source_schema }}.fct_bookings
WHERE ds = '2020-01-01'
WHERE ds = {{ cast_to_ts('2020-01-01') }}
GROUP BY ds
Loading
Loading