From 63980336fb46423b9f7c0148451d3acfa324f96a Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:37:10 -0800 Subject: [PATCH] Fix Python regressions in 1.9.0beta (#857) --- dbt/adapters/databricks/api_client.py | 16 ++++++++--- dbt/adapters/databricks/behaviors/columns.py | 2 -- dbt/adapters/databricks/impl.py | 2 +- .../python_models/python_submissions.py | 7 +++++ .../databricks/macros/adapters/columns.sql | 27 +++++++++++++++++++ .../macros/adapters/persist_docs.sql | 27 ------------------- .../adapter/python_model/fixtures.py | 9 +++++++ .../adapter/python_model/test_python_model.py | 19 +++++++++++++ tests/unit/api_client/test_command_api.py | 13 +++++++++ 9 files changed, 88 insertions(+), 34 deletions(-) create mode 100644 dbt/include/databricks/macros/adapters/columns.sql diff --git a/dbt/adapters/databricks/api_client.py b/dbt/adapters/databricks/api_client.py index c2dbefb1..c14af9f5 100644 --- a/dbt/adapters/databricks/api_client.py +++ b/dbt/adapters/databricks/api_client.py @@ -88,8 +88,10 @@ def start(self, cluster_id: str) -> None: response = self.session.post("/start", json={"cluster_id": cluster_id}) if response.status_code != 200: - raise DbtRuntimeError(f"Error starting terminated cluster.\n {response.content!r}") - logger.debug(f"Cluster start response={response}") + if self.status(cluster_id) not in ["RUNNING", "PENDING"]: + raise DbtRuntimeError(f"Error starting terminated cluster.\n {response.content!r}") + else: + logger.debug("Presuming race condition, waiting for cluster to start") self.wait_for_cluster(cluster_id) @@ -289,7 +291,7 @@ def cancel(self, command: CommandExecution) -> None: raise DbtRuntimeError(f"Cancel command {command} failed.\n {response.content!r}") def poll_for_completion(self, command: CommandExecution) -> None: - self._poll_api( + response = self._poll_api( url="/status", params={ "clusterId": command.cluster_id, @@ -300,7 +302,13 @@ def poll_for_completion(self, command: CommandExecution) -> None: terminal_states={"Finished", "Error", "Cancelled"}, expected_end_state="Finished", unexpected_end_state_func=self._get_exception, - ) + ).json() + + if response["results"]["resultType"] == "error": + raise DbtRuntimeError( + f"Python model failed with traceback as:\n" + f"{utils.remove_ansi(response['results']['cause'])}" + ) def _get_exception(self, response: Response) -> None: response_json = response.json() diff --git a/dbt/adapters/databricks/behaviors/columns.py b/dbt/adapters/databricks/behaviors/columns.py index 63d80d2d..8f3d1eca 100644 --- a/dbt/adapters/databricks/behaviors/columns.py +++ b/dbt/adapters/databricks/behaviors/columns.py @@ -7,8 +7,6 @@ from dbt.adapters.databricks.utils import handle_missing_objects from dbt.adapters.sql import SQLAdapter -GET_COLUMNS_COMMENTS_MACRO_NAME = "get_columns_comments" - class GetColumnsBehavior(ABC): @classmethod diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index ad1b28da..3e0288b0 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -85,7 +85,7 @@ SHOW_TABLE_EXTENDED_MACRO_NAME = "show_table_extended" SHOW_TABLES_MACRO_NAME = "show_tables" SHOW_VIEWS_MACRO_NAME = "show_views" -GET_COLUMNS_COMMENTS_MACRO_NAME = "get_columns_comments" + USE_INFO_SCHEMA_FOR_COLUMNS = BehaviorFlag( name="use_info_schema_for_columns", diff --git a/dbt/adapters/databricks/python_models/python_submissions.py b/dbt/adapters/databricks/python_models/python_submissions.py index cbb27b44..afcb383c 100644 --- a/dbt/adapters/databricks/python_models/python_submissions.py +++ b/dbt/adapters/databricks/python_models/python_submissions.py @@ -8,6 +8,7 @@ from dbt.adapters.base import PythonJobHelper from dbt.adapters.databricks.api_client import CommandExecution, DatabricksApiClient, WorkflowJobApi from dbt.adapters.databricks.credentials import DatabricksCredentials +from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.python_config import ParsedPythonModel from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker @@ -70,6 +71,8 @@ def __init__( @override def submit(self, compiled_code: str) -> None: + logger.debug("Submitting Python model using the Command API.") + context_id = self.api_client.command_contexts.create(self.cluster_id) command_exec: Optional[CommandExecution] = None try: @@ -263,6 +266,8 @@ def create( @override def submit(self, compiled_code: str) -> None: + logger.debug("Submitting Python model using the Job Run API.") + file_path = self.uploader.upload(compiled_code) job_config = self.config_compiler.compile(file_path) @@ -494,6 +499,8 @@ def create( @override def submit(self, compiled_code: str) -> None: + logger.debug("Submitting Python model using the Workflow API.") + file_path = self.uploader.upload(compiled_code) workflow_config, existing_job_id = self.config_compiler.compile(file_path) diff --git a/dbt/include/databricks/macros/adapters/columns.sql b/dbt/include/databricks/macros/adapters/columns.sql new file mode 100644 index 00000000..e1fc1d11 --- /dev/null +++ b/dbt/include/databricks/macros/adapters/columns.sql @@ -0,0 +1,27 @@ + +{% macro get_columns_comments(relation) -%} + {% call statement('get_columns_comments', fetch_result=True) -%} + describe table {{ relation|lower }} + {% endcall %} + + {% do return(load_result('get_columns_comments').table) %} +{% endmacro %} + +{% macro get_columns_comments_via_information_schema(relation) -%} + {% call statement('repair_table', fetch_result=False) -%} + REPAIR TABLE {{ relation|lower }} SYNC METADATA + {% endcall %} + {% call statement('get_columns_comments_via_information_schema', fetch_result=True) -%} + select + column_name, + full_data_type, + comment + from `system`.`information_schema`.`columns` + where + table_catalog = '{{ relation.database|lower }}' and + table_schema = '{{ relation.schema|lower }}' and + table_name = '{{ relation.identifier|lower }}' + {% endcall %} + + {% do return(load_result('get_columns_comments_via_information_schema').table) %} +{% endmacro %} diff --git a/dbt/include/databricks/macros/adapters/persist_docs.sql b/dbt/include/databricks/macros/adapters/persist_docs.sql index f623a8a6..8e959a9f 100644 --- a/dbt/include/databricks/macros/adapters/persist_docs.sql +++ b/dbt/include/databricks/macros/adapters/persist_docs.sql @@ -20,33 +20,6 @@ {% do run_query(comment_query) %} {% endmacro %} -{% macro get_columns_comments(relation) -%} - {% call statement('get_columns_comments', fetch_result=True) -%} - describe table {{ relation|lower }} - {% endcall %} - - {% do return(load_result('get_columns_comments').table) %} -{% endmacro %} - -{% macro get_columns_comments_via_information_schema(relation) -%} - {% call statement('repair_table', fetch_result=False) -%} - REPAIR TABLE {{ relation|lower }} SYNC METADATA - {% endcall %} - {% call statement('get_columns_comments_via_information_schema', fetch_result=True) -%} - select - column_name, - full_data_type, - comment - from `system`.`information_schema`.`columns` - where - table_catalog = '{{ relation.database|lower }}' and - table_schema = '{{ relation.schema|lower }}' and - table_name = '{{ relation.identifier|lower }}' - {% endcall %} - - {% do return(load_result('get_columns_comments_via_information_schema').table) %} -{% endmacro %} - {% macro databricks__persist_docs(relation, model, for_relation, for_columns) -%} {%- if for_relation and config.persist_relation_docs() and model.description %} {% do alter_table_comment(relation, model) %} diff --git a/tests/functional/adapter/python_model/fixtures.py b/tests/functional/adapter/python_model/fixtures.py index 5ce51702..127bcf74 100644 --- a/tests/functional/adapter/python_model/fixtures.py +++ b/tests/functional/adapter/python_model/fixtures.py @@ -9,6 +9,15 @@ def model(dbt, spark): return spark.createDataFrame(data, schema=['test', 'test2']) """ +python_error_model = """ +import pandas as pd + +def model(dbt, spark): + raise Exception("This is an error") + + return pd.DataFrame() +""" + serverless_schema = """version: 2 models: diff --git a/tests/functional/adapter/python_model/test_python_model.py b/tests/functional/adapter/python_model/test_python_model.py index 482674fe..858214b7 100644 --- a/tests/functional/adapter/python_model/test_python_model.py +++ b/tests/functional/adapter/python_model/test_python_model.py @@ -17,6 +17,25 @@ class TestPythonModel(BasePythonModelTests): pass +@pytest.mark.python +@pytest.mark.skip_profile("databricks_uc_sql_endpoint") +class TestPythonFailureModel: + @pytest.fixture(scope="class") + def models(self): + return {"my_failure_model.py": override_fixtures.python_error_model} + + def test_failure_model(self, project): + util.run_dbt(["run"], expect_pass=False) + + +@pytest.mark.python +@pytest.mark.skip_profile("databricks_uc_sql_endpoint") +class TestPythonFailureModelNotebook(TestPythonFailureModel): + @pytest.fixture(scope="class") + def project_config_update(self): + return {"models": {"+create_notebook": "true"}} + + @pytest.mark.python @pytest.mark.skip_profile("databricks_uc_sql_endpoint") class TestPythonIncrementalModel(BasePythonIncrementalTests): diff --git a/tests/unit/api_client/test_command_api.py b/tests/unit/api_client/test_command_api.py index 2021bd74..90efaf47 100644 --- a/tests/unit/api_client/test_command_api.py +++ b/tests/unit/api_client/test_command_api.py @@ -86,6 +86,7 @@ def test_poll_for_completion__200(self, _, api, session, host, execution): session.get.return_value.status_code = 200 session.get.return_value.json.return_value = { "status": "Finished", + "results": {"resultType": "finished"}, } api.poll_for_completion(execution) @@ -99,3 +100,15 @@ def test_poll_for_completion__200(self, _, api, session, host, execution): }, json=None, ) + + @freezegun.freeze_time("2020-01-01") + @patch("dbt.adapters.databricks.api_client.time.sleep") + def test_poll_for_completion__200_with_error(self, _, api, session, host, execution): + session.get.return_value.status_code = 200 + session.get.return_value.json.return_value = { + "status": "Finished", + "results": {"resultType": "error", "cause": "race condition"}, + } + + with pytest.raises(DbtRuntimeError, match="Python model failed"): + api.poll_for_completion(execution)