Skip to content

Commit

Permalink
Fix Python regressions in 1.9.0beta (#857)
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db authored Nov 26, 2024
1 parent b95a04a commit 6398033
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 34 deletions.
16 changes: 12 additions & 4 deletions dbt/adapters/databricks/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,10 @@ def start(self, cluster_id: str) -> None:

response = self.session.post("/start", json={"cluster_id": cluster_id})
if response.status_code != 200:
raise DbtRuntimeError(f"Error starting terminated cluster.\n {response.content!r}")
logger.debug(f"Cluster start response={response}")
if self.status(cluster_id) not in ["RUNNING", "PENDING"]:
raise DbtRuntimeError(f"Error starting terminated cluster.\n {response.content!r}")
else:
logger.debug("Presuming race condition, waiting for cluster to start")

self.wait_for_cluster(cluster_id)

Expand Down Expand Up @@ -289,7 +291,7 @@ def cancel(self, command: CommandExecution) -> None:
raise DbtRuntimeError(f"Cancel command {command} failed.\n {response.content!r}")

def poll_for_completion(self, command: CommandExecution) -> None:
self._poll_api(
response = self._poll_api(
url="/status",
params={
"clusterId": command.cluster_id,
Expand All @@ -300,7 +302,13 @@ def poll_for_completion(self, command: CommandExecution) -> None:
terminal_states={"Finished", "Error", "Cancelled"},
expected_end_state="Finished",
unexpected_end_state_func=self._get_exception,
)
).json()

if response["results"]["resultType"] == "error":
raise DbtRuntimeError(
f"Python model failed with traceback as:\n"
f"{utils.remove_ansi(response['results']['cause'])}"
)

def _get_exception(self, response: Response) -> None:
response_json = response.json()
Expand Down
2 changes: 0 additions & 2 deletions dbt/adapters/databricks/behaviors/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from dbt.adapters.databricks.utils import handle_missing_objects
from dbt.adapters.sql import SQLAdapter

GET_COLUMNS_COMMENTS_MACRO_NAME = "get_columns_comments"


class GetColumnsBehavior(ABC):
@classmethod
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
SHOW_TABLE_EXTENDED_MACRO_NAME = "show_table_extended"
SHOW_TABLES_MACRO_NAME = "show_tables"
SHOW_VIEWS_MACRO_NAME = "show_views"
GET_COLUMNS_COMMENTS_MACRO_NAME = "get_columns_comments"


USE_INFO_SCHEMA_FOR_COLUMNS = BehaviorFlag(
name="use_info_schema_for_columns",
Expand Down
7 changes: 7 additions & 0 deletions dbt/adapters/databricks/python_models/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dbt.adapters.base import PythonJobHelper
from dbt.adapters.databricks.api_client import CommandExecution, DatabricksApiClient, WorkflowJobApi
from dbt.adapters.databricks.credentials import DatabricksCredentials
from dbt.adapters.databricks.logging import logger
from dbt.adapters.databricks.python_models.python_config import ParsedPythonModel
from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker

Expand Down Expand Up @@ -70,6 +71,8 @@ def __init__(

@override
def submit(self, compiled_code: str) -> None:
logger.debug("Submitting Python model using the Command API.")

context_id = self.api_client.command_contexts.create(self.cluster_id)
command_exec: Optional[CommandExecution] = None
try:
Expand Down Expand Up @@ -263,6 +266,8 @@ def create(

@override
def submit(self, compiled_code: str) -> None:
logger.debug("Submitting Python model using the Job Run API.")

file_path = self.uploader.upload(compiled_code)
job_config = self.config_compiler.compile(file_path)

Expand Down Expand Up @@ -494,6 +499,8 @@ def create(

@override
def submit(self, compiled_code: str) -> None:
logger.debug("Submitting Python model using the Workflow API.")

file_path = self.uploader.upload(compiled_code)

workflow_config, existing_job_id = self.config_compiler.compile(file_path)
Expand Down
27 changes: 27 additions & 0 deletions dbt/include/databricks/macros/adapters/columns.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

{% macro get_columns_comments(relation) -%}
{% call statement('get_columns_comments', fetch_result=True) -%}
describe table {{ relation|lower }}
{% endcall %}

{% do return(load_result('get_columns_comments').table) %}
{% endmacro %}

{% macro get_columns_comments_via_information_schema(relation) -%}
{% call statement('repair_table', fetch_result=False) -%}
REPAIR TABLE {{ relation|lower }} SYNC METADATA
{% endcall %}
{% call statement('get_columns_comments_via_information_schema', fetch_result=True) -%}
select
column_name,
full_data_type,
comment
from `system`.`information_schema`.`columns`
where
table_catalog = '{{ relation.database|lower }}' and
table_schema = '{{ relation.schema|lower }}' and
table_name = '{{ relation.identifier|lower }}'
{% endcall %}

{% do return(load_result('get_columns_comments_via_information_schema').table) %}
{% endmacro %}
27 changes: 0 additions & 27 deletions dbt/include/databricks/macros/adapters/persist_docs.sql
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,6 @@
{% do run_query(comment_query) %}
{% endmacro %}

{% macro get_columns_comments(relation) -%}
{% call statement('get_columns_comments', fetch_result=True) -%}
describe table {{ relation|lower }}
{% endcall %}

{% do return(load_result('get_columns_comments').table) %}
{% endmacro %}

{% macro get_columns_comments_via_information_schema(relation) -%}
{% call statement('repair_table', fetch_result=False) -%}
REPAIR TABLE {{ relation|lower }} SYNC METADATA
{% endcall %}
{% call statement('get_columns_comments_via_information_schema', fetch_result=True) -%}
select
column_name,
full_data_type,
comment
from `system`.`information_schema`.`columns`
where
table_catalog = '{{ relation.database|lower }}' and
table_schema = '{{ relation.schema|lower }}' and
table_name = '{{ relation.identifier|lower }}'
{% endcall %}

{% do return(load_result('get_columns_comments_via_information_schema').table) %}
{% endmacro %}

{% macro databricks__persist_docs(relation, model, for_relation, for_columns) -%}
{%- if for_relation and config.persist_relation_docs() and model.description %}
{% do alter_table_comment(relation, model) %}
Expand Down
9 changes: 9 additions & 0 deletions tests/functional/adapter/python_model/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ def model(dbt, spark):
return spark.createDataFrame(data, schema=['test', 'test2'])
"""

python_error_model = """
import pandas as pd
def model(dbt, spark):
raise Exception("This is an error")
return pd.DataFrame()
"""

serverless_schema = """version: 2
models:
Expand Down
19 changes: 19 additions & 0 deletions tests/functional/adapter/python_model/test_python_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,25 @@ class TestPythonModel(BasePythonModelTests):
pass


@pytest.mark.python
@pytest.mark.skip_profile("databricks_uc_sql_endpoint")
class TestPythonFailureModel:
@pytest.fixture(scope="class")
def models(self):
return {"my_failure_model.py": override_fixtures.python_error_model}

def test_failure_model(self, project):
util.run_dbt(["run"], expect_pass=False)


@pytest.mark.python
@pytest.mark.skip_profile("databricks_uc_sql_endpoint")
class TestPythonFailureModelNotebook(TestPythonFailureModel):
@pytest.fixture(scope="class")
def project_config_update(self):
return {"models": {"+create_notebook": "true"}}


@pytest.mark.python
@pytest.mark.skip_profile("databricks_uc_sql_endpoint")
class TestPythonIncrementalModel(BasePythonIncrementalTests):
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/api_client/test_command_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_poll_for_completion__200(self, _, api, session, host, execution):
session.get.return_value.status_code = 200
session.get.return_value.json.return_value = {
"status": "Finished",
"results": {"resultType": "finished"},
}

api.poll_for_completion(execution)
Expand All @@ -99,3 +100,15 @@ def test_poll_for_completion__200(self, _, api, session, host, execution):
},
json=None,
)

@freezegun.freeze_time("2020-01-01")
@patch("dbt.adapters.databricks.api_client.time.sleep")
def test_poll_for_completion__200_with_error(self, _, api, session, host, execution):
session.get.return_value.status_code = 200
session.get.return_value.json.return_value = {
"status": "Finished",
"results": {"resultType": "error", "cause": "race condition"},
}

with pytest.raises(DbtRuntimeError, match="Python model failed"):
api.poll_for_completion(execution)

0 comments on commit 6398033

Please sign in to comment.