diff --git a/tests/integration/flows/test_mysql_to_adls.py b/tests/integration/flows/test_mysql_to_adls.py index a66fca5db..942bab99d 100644 --- a/tests/integration/flows/test_mysql_to_adls.py +++ b/tests/integration/flows/test_mysql_to_adls.py @@ -1,5 +1,4 @@ from unittest import mock - from viadot.flows.mysql_to_adls import MySqlToADLS query = """SELECT * FROM `example-views`.`sales`""" @@ -23,3 +22,18 @@ def test_adls_gen1_to_azure_sql_new_mock(TEST_PARQUET_FILE_PATH): ) flow.run() mock_method.assert_called_with() + + +def test_validate_df(TEST_PARQUET_FILE_PATH): + with mock.patch.object(MySqlToADLS, "run", return_value=True) as mock_method: + flow = MySqlToADLS( + "test validate_df", + country_short="DE", + query=query, + file_path=TEST_PARQUET_FILE_PATH, + sp_credentials_secret="App-Azure-CR-DatalakeGen2-AIA", + to_path=f"raw/examples/{TEST_PARQUET_FILE_PATH}", + validate_df_dict={"column_size": {"sales_org": 3}}, + ) + flow.run() + mock_method.assert_called_with() diff --git a/viadot/flows/mysql_to_adls.py b/viadot/flows/mysql_to_adls.py index 4c18148fe..afe594e47 100644 --- a/viadot/flows/mysql_to_adls.py +++ b/viadot/flows/mysql_to_adls.py @@ -2,7 +2,7 @@ from prefect import Flow -from viadot.task_utils import df_to_csv +from viadot.task_utils import df_to_csv, validate_df from viadot.tasks import AzureDataLakeUpload from viadot.tasks.mysql_to_df import MySqlToDf @@ -20,6 +20,7 @@ def __init__( to_path: str = None, if_exists: Literal["replace", "append", "delete"] = "replace", overwrite_adls: bool = True, + validate_df_dict: dict = None, sp_credentials_secret: str = None, credentials_secret: str = None, timeout: int = 3600, @@ -41,6 +42,8 @@ def __init__( to_path (str): The path to an ADLS file. Defaults to None. if_exists (Literal, optional): What to do if the table exists. Defaults to "replace". overwrite_adls (str, optional): Whether to overwrite_adls the destination file. Defaults to True. + validate_df_dict (Dict[str], optional): A dictionary with optional list of tests to verify the output dataframe. + If defined, triggers the `validate_df` task from task_utils. Defaults to None. sp_credentials_secret (str, optional): The name of the Azure Key Vault secret containing a dictionary with ACCOUNT_NAME and Service Principal credentials (TENANT_ID, CLIENT_ID, CLIENT_SECRET). Defaults to None. credentials_secret (str, optional): Key Vault name. Defaults to None. @@ -57,6 +60,9 @@ def __init__( self.vault_name = vault_name self.overwrite_adls = overwrite_adls + # validate df + self.validate_df_dict = validate_df_dict + # Upload to ADLS self.file_path = file_path self.sep = sep @@ -76,6 +82,12 @@ def gen_flow(self) -> Flow: credentials_secret=self.credentials_secret, query=self.query, flow=self ) + if self.validate_df_dict: + validation_task = validate_df.bind( + df, tests=self.validate_df_dict, flow=self + ) + validation_task.set_upstream(df, flow=self) + create_csv = df_to_csv.bind( df, path=self.file_path,