diff --git a/CHANGELOG.md b/CHANGELOG.md index 80c1882e1..7f81230a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added - Added logic for if_empty param: `check_if_df_empty` task to `ADLSToAzureSQL` flow +- Added new parameter `validate_df_dict` to `ADLSToAzureSQL` class ### Fixed diff --git a/tests/integration/flows/test_adls_to_azure_sql.py b/tests/integration/flows/test_adls_to_azure_sql.py index e13dc31b2..e3ae45623 100644 --- a/tests/integration/flows/test_adls_to_azure_sql.py +++ b/tests/integration/flows/test_adls_to_azure_sql.py @@ -101,3 +101,48 @@ def test_check_dtypes_sort(): assert False except signals.FAIL: assert True + + +def test_adls_to_azure_sql_mocked(TEST_CSV_FILE_PATH): + with mock.patch.object(ADLSToAzureSQL, "run", return_value=True) as mock_method: + instance = ADLSToAzureSQL( + name="test_adls_to_azure_sql_flow", + adls_path=TEST_CSV_FILE_PATH, + schema="sandbox", + table="test_bcp", + dtypes={"test_str": "VARCHAR(25)", "test_int": "INT"}, + if_exists="replace", + ) + instance.run() + mock_method.assert_called_with() + + +def test_adls_to_azure_sql_mocked_validate_df_param(TEST_CSV_FILE_PATH): + with mock.patch.object(ADLSToAzureSQL, "run", return_value=True) as mock_method: + instance = ADLSToAzureSQL( + name="test_adls_to_azure_sql_flow", + adls_path=TEST_CSV_FILE_PATH, + schema="sandbox", + table="test_bcp", + dtypes={"test_str": "VARCHAR(25)", "test_int": "INT"}, + if_exists="replace", + validate_df_dict={"column_list_to_match": ["test_str", "test_int"]}, + ) + instance.run() + mock_method.assert_called_with() + + +def test_adls_to_azure_sql_mocked_wrong_param(TEST_CSV_FILE_PATH): + with pytest.raises(TypeError) as excinfo: + instance = ADLSToAzureSQL( + name="test_adls_to_azure_sql_flow", + adls_path=TEST_CSV_FILE_PATH, + schema="sandbox", + table="test_bcp", + dtypes={"test_str": "VARCHAR(25)", "test_int": "INT"}, + if_exists="replace", + validate_df_dit={"column_list_to_match": ["test_str", "test_int"]}, + ) + instance.run() + + assert "validate_df_dit" in str(excinfo) diff --git a/viadot/flows/adls_to_azure_sql.py b/viadot/flows/adls_to_azure_sql.py index 95b1340ca..c12cc7e1d 100644 --- a/viadot/flows/adls_to_azure_sql.py +++ b/viadot/flows/adls_to_azure_sql.py @@ -17,6 +17,7 @@ CheckColumnOrder, DownloadGitHubFile, ) +from viadot.task_utils import validate_df from viadot.tasks.azure_data_lake import AzureDataLakeDownload from viadot.task_utils import check_if_df_empty @@ -151,6 +152,7 @@ def __init__( tags: List[str] = ["promotion"], vault_name: str = None, timeout: int = 3600, + validate_df_dict: Dict[str, Any] = None, *args: List[any], **kwargs: Dict[str, Any], ): @@ -187,6 +189,8 @@ def __init__( vault_name (str, optional): The name of the vault from which to obtain the secrets. Defaults to None. timeout(int, optional): The amount of time (in seconds) to wait while running this task before a timeout occurs. Defaults to 3600. + validate_df_dict (Dict[str,Any], optional): A dictionary with optional list of tests to verify the output dataframe. + If defined, triggers the `validate_df` task from task_utils. Defaults to None. """ adls_path = adls_path.strip("/") @@ -237,6 +241,7 @@ def __init__( self.tags = tags self.vault_name = vault_name self.timeout = timeout + self.validate_df_dict = validate_df_dict super().__init__(*args, name=name, **kwargs) @@ -360,6 +365,11 @@ def gen_flow(self) -> Flow: flow=self, ) + # data validation function (optional) + if self.validate_df_dict: + validate_df.bind(df=df, tests=self.validate_df_dict, flow=self) + validate_df.set_upstream(lake_to_df_task, flow=self) + df_reorder.set_upstream(lake_to_df_task, flow=self) df_to_csv.set_upstream(df_reorder, flow=self) promote_to_conformed_task.set_upstream(df_to_csv, flow=self)