diff --git a/viadot/flows/adls_to_azure_sql.py b/viadot/flows/adls_to_azure_sql.py index a9e49c6b6..466c1156c 100644 --- a/viadot/flows/adls_to_azure_sql.py +++ b/viadot/flows/adls_to_azure_sql.py @@ -17,6 +17,7 @@ CheckColumnOrder, DownloadGitHubFile, ) +from viadot.task_utils import validate_df from viadot.tasks.azure_data_lake import AzureDataLakeDownload logger = logging.get_logger(__name__) @@ -150,6 +151,7 @@ def __init__( tags: List[str] = ["promotion"], vault_name: str = None, timeout: int = 3600, + validate_df_dict: Dict[str, Any] = None, *args: List[any], **kwargs: Dict[str, Any], ): @@ -186,6 +188,8 @@ def __init__( vault_name (str, optional): The name of the vault from which to obtain the secrets. Defaults to None. timeout(int, optional): The amount of time (in seconds) to wait while running this task before a timeout occurs. Defaults to 3600. + validate_df_dict (Dict[str,Any], optional): A dictionary with optional list of tests to verify the output dataframe. + If defined, triggers the `validate_df` task from task_utils. Defaults to None. """ adls_path = adls_path.strip("/") @@ -236,6 +240,7 @@ def __init__( self.tags = tags self.vault_name = vault_name self.timeout = timeout + self.validate_df_dict = validate_df_dict super().__init__(*args, name=name, **kwargs) @@ -356,6 +361,11 @@ def gen_flow(self) -> Flow: flow=self, ) + # data validation function (optional) + if self.validate_df_dict: + validate_df.bind(df=df, tests=self.validate_df_dict, flow=self) + validate_df.set_upstream(lake_to_df_task, flow=self) + df_reorder.set_upstream(lake_to_df_task, flow=self) df_to_csv.set_upstream(df_reorder, flow=self) promote_to_conformed_task.set_upstream(df_to_csv, flow=self)