Skip to content

Commit

Permalink
add configuration options for enable_list_inference and intermediate_…
Browse files Browse the repository at this point in the history
…format for python models
  • Loading branch information
mikealfare committed Apr 26, 2024
1 parent 3b8c6e8 commit 7245a59
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20240426-105319.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Add configuration options `enable_list_inference` and `intermediate_format` for python
models
time: 2024-04-26T10:53:19.874239-04:00
custom:
Author: mikealfare
Issue: 1047 1114
7 changes: 7 additions & 0 deletions .changes/unreleased/Fixes-20240426-105224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Fixes
body: Default `enableListInference` to `True` for python models to support nested
lists
time: 2024-04-26T10:52:24.827314-04:00
custom:
Author: mikealfare
Issue: 1047 1114
2 changes: 2 additions & 0 deletions dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class BigqueryConfig(AdapterConfig):
enable_refresh: Optional[bool] = None
refresh_interval_minutes: Optional[int] = None
max_staleness: Optional[str] = None
enable_list_inference: Optional[bool] = None
intermediate_format: Optional[str] = None


class BigQueryAdapter(BaseAdapter):
Expand Down
7 changes: 7 additions & 0 deletions dbt/include/bigquery/macros/materializations/table.sql
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,19 @@
from pyspark.sql import SparkSession
{%- set raw_partition_by = config.get('partition_by', none) -%}
{%- set raw_cluster_by = config.get('cluster_by', none) -%}
{%- set enable_list_inference = config.get('enable_list_inference', true) -%}
{%- set intermediate_format = config.get('intermediate_format', none) -%}

{%- set partition_config = adapter.parse_partition_by(raw_partition_by) %}

spark = SparkSession.builder.appName('smallTest').getOrCreate()

spark.conf.set("viewsEnabled","true")
spark.conf.set("temporaryGcsBucket","{{target.gcs_bucket}}")
spark.conf.set("enableListInference", "{{ enable_list_inference }}")
{% if intermediate_format %}
spark.conf.set("intermediateFormat", "{{ intermediate_format }}")
{% endif %}

{{ compiled_code }}
dbt = dbtObj(spark.read.format("bigquery").load)
Expand Down
Empty file.
148 changes: 148 additions & 0 deletions tests/functional/python_model_tests/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
SINGLE_RECORD = """
import pandas as pd
def model(dbt, session):
dbt.config(
submission_method="serverless",
materialized="table"
)
df = pd.DataFrame(
[
{"column_name": {"name": "hello", "my_list": ["h", "e", "l", "l", "o"]}},
]
)
return df
"""


MULTI_RECORD_DEFAULT = """
import pandas as pd
def model(dbt, session):
dbt.config(
submission_method="serverless",
materialized="table",
)
df = pd.DataFrame(
[
{"column_name": [{"name": "hello", "my_list": ["h", "e", "l", "l", "o"]}]},
]
)
return df
"""


ORC_FORMAT = """
import pandas as pd
def model(dbt, session):
dbt.config(
submission_method="serverless",
materialized="table",
intermediate_format="orc",
)
df = pd.DataFrame(
[
{"column_name": [{"name": "hello", "my_list": ["h", "e", "l", "l", "o"]}]},
]
)
return df
"""


ENABLE_LIST_INFERENCE = """
import pandas as pd
def model(dbt, session):
dbt.config(
submission_method="serverless",
materialized="table",
enable_list_inference="true",
)
df = pd.DataFrame(
[
{"column_name": [{"name": "hello", "my_list": ["h", "e", "l", "l", "o"]}]},
]
)
return df
"""


# this should fail
DISABLE_LIST_INFERENCE = """
import pandas as pd
def model(dbt, session):
dbt.config(
submission_method="serverless",
materialized="table",
enable_list_inference="false",
)
df = pd.DataFrame(
[
{"column_name": [{"name": "hello", "my_list": ["h", "e", "l", "l", "o"]}]},
]
)
return df
"""


ENABLE_LIST_INFERENCE_PARQUET_FORMAT = """
import pandas as pd
def model(dbt, session):
dbt.config(
submission_method="serverless",
materialized="table",
enable_list_inference="true",
intermediate_format="parquet",
)
df = pd.DataFrame(
[
{"column_name": [{"name": "hello", "my_list": ["h", "e", "l", "l", "o"]}]},
]
)
return df
"""


DISABLE_LIST_INFERENCE_ORC_FORMAT = """
import pandas as pd
def model(dbt, session):
dbt.config(
submission_method="serverless",
materialized="table",
enable_list_inference="false",
intermediate_format="orc",
)
df = pd.DataFrame(
[
{"column_name": [{"name": "hello", "my_list": ["h", "e", "l", "l", "o"]}]},
]
)
return df
"""
70 changes: 70 additions & 0 deletions tests/functional/python_model_tests/test_list_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
This test case addresses this regression: https://github.com/dbt-labs/dbt-bigquery/issues/1047
As the comments point out, the issue appears to be that the default settings are:
- list inference: off
- intermediate format: parquet
Adjusting either of these alleviates the issue.
When the regression was first reported, `models.MULTI_RECORD` failed while the other three models passed.
"""
from dbt.tests.util import run_dbt_and_capture
import pytest

from tests.functional.python_model_tests import models


class ListInference:
expect_pass = True

def test_model(self, project):
result, output = run_dbt_and_capture(["run"], expect_pass=self.expect_pass)
assert len(result) == 1


class TestSingleRecord(ListInference):
@pytest.fixture(scope="class")
def models(self):
return {"model.py": models.SINGLE_RECORD}


class TestMultiRecordDefault(ListInference):
@pytest.fixture(scope="class")
def models(self):
# this is the model that initially failed for this issue
return {"model.py": models.MULTI_RECORD_DEFAULT}


class TestDisableListInference(ListInference):
expect_pass = False

@pytest.fixture(scope="class")
def models(self):
# this model mimics what was happening before defaulting enable_list_inference=True
return {"model.py": models.DISABLE_LIST_INFERENCE}


class TestEnableListInference(ListInference):
@pytest.fixture(scope="class")
def models(self):
return {"model.py": models.ENABLE_LIST_INFERENCE}


class TestOrcFormat(ListInference):
@pytest.fixture(scope="class")
def models(self):
return {"model.py": models.ORC_FORMAT}


class TestDisableListInferenceOrcFormat(ListInference):
@pytest.fixture(scope="class")
def models(self):
return {"model.py": models.DISABLE_LIST_INFERENCE_ORC_FORMAT}


class TestEnableListInferenceParquetFormat(ListInference):
@pytest.fixture(scope="class")
def models(self):
# this is the model that initially failed for this issue
return {"model.py": models.ENABLE_LIST_INFERENCE_PARQUET_FORMAT}

0 comments on commit 7245a59

Please sign in to comment.