diff --git a/tests/integration/flows/test_salesforce_to_adls.py b/tests/integration/flows/test_salesforce_to_adls.py index 60ec41aaa..ec68a1227 100644 --- a/tests/integration/flows/test_salesforce_to_adls.py +++ b/tests/integration/flows/test_salesforce_to_adls.py @@ -4,6 +4,7 @@ from viadot.flows import SalesforceToADLS from viadot.tasks import AzureDataLakeRemove +from viadot.exceptions import ValidationError ADLS_FILE_NAME = "test_salesforce.parquet" ADLS_DIR_PATH = "raw/tests/" @@ -32,3 +33,29 @@ def test_salesforce_to_adls(): vault_name="azuwevelcrkeyv001s", ) rm.run(sp_credentials_secret=credentials_secret) + + +def test_salesforce_to_adls_validate_success(): + credentials_secret = PrefectSecret( + "AZURE_DEFAULT_ADLS_SERVICE_PRINCIPAL_SECRET" + ).run() + + flow = SalesforceToADLS( + "test_salesforce_to_adls_run_flow", + query="SELECT IsDeleted, FiscalYear FROM Opportunity LIMIT 50", + adls_sp_credentials_secret=credentials_secret, + adls_dir_path=ADLS_DIR_PATH, + adls_file_name=ADLS_FILE_NAME, + validate_df_dict={"column_list_to_match": ["IsDeleted", "FiscalYear"]}, + ) + + result = flow.run() + assert result.is_successful() + + os.remove("test_salesforce_to_adls_run_flow.parquet") + os.remove("test_salesforce_to_adls_run_flow.json") + rm = AzureDataLakeRemove( + path=ADLS_DIR_PATH + ADLS_FILE_NAME, + vault_name="azuwevelcrkeyv001s", + ) + rm.run(sp_credentials_secret=credentials_secret) diff --git a/viadot/flows/aselite_to_adls.py b/viadot/flows/aselite_to_adls.py index 86e9b215b..bd77cf40f 100644 --- a/viadot/flows/aselite_to_adls.py +++ b/viadot/flows/aselite_to_adls.py @@ -2,7 +2,12 @@ from prefect import Flow -from viadot.task_utils import df_clean_column, df_converts_bytes_to_int, df_to_csv +from viadot.task_utils import ( + df_clean_column, + df_converts_bytes_to_int, + df_to_csv, + validate_df, +) from viadot.tasks import AzureDataLakeUpload from viadot.tasks.aselite import ASELiteToDF @@ -19,6 +24,7 @@ def __init__( to_path: str = None, if_exists: Literal["replace", "append", "delete"] = "replace", overwrite: bool = True, + validate_df_dict: Dict[str, Any] = None, convert_bytes: bool = False, sp_credentials_secret: str = None, remove_special_characters: bool = None, @@ -41,6 +47,8 @@ def __init__( to_path (str): The path to an ADLS file. Defaults to None. if_exists (Literal, optional): What to do if the table exists. Defaults to "replace". overwrite (str, optional): Whether to overwrite the destination file. Defaults to True. + validate_df_dict (Dict[str], optional): A dictionary with optional list of tests to verify the output + dataframe. If defined, triggers the `validate_df` task from task_utils. Defaults to None. sp_credentials_secret (str, optional): The name of the Azure Key Vault secret containing a dictionary with ACCOUNT_NAME and Service Principal credentials (TENANT_ID, CLIENT_ID, CLIENT_SECRET). Defaults to None. remove_special_characters (str, optional): Call a function that remove special characters like escape symbols. Defaults to None. @@ -53,6 +61,7 @@ def __init__( self.sqldb_credentials_secret = sqldb_credentials_secret self.vault_name = vault_name self.overwrite = overwrite + self.validate_df_dict = validate_df_dict self.file_path = file_path self.sep = sep @@ -83,6 +92,12 @@ def gen_flow(self) -> Flow: if self.remove_special_characters == True: df = df_clean_column(df, columns_to_clean=self.columns_to_clean, flow=self) + if self.validate_df_dict: + validation_task = validate_df.bind( + df, tests=self.validate_df_dict, flow=self + ) + validation_task.set_upstream(df, flow=self) + create_csv = df_to_csv.bind( df, path=self.file_path, diff --git a/viadot/flows/genesys_to_adls.py b/viadot/flows/genesys_to_adls.py index 4f8c54f3e..830c02c71 100644 --- a/viadot/flows/genesys_to_adls.py +++ b/viadot/flows/genesys_to_adls.py @@ -5,7 +5,7 @@ import pandas as pd from prefect import Flow, task -from viadot.task_utils import add_ingestion_metadata_task, adls_bulk_upload +from viadot.task_utils import add_ingestion_metadata_task, adls_bulk_upload, validate_df from viadot.tasks.genesys import GenesysToCSV @@ -95,6 +95,7 @@ def __init__( overwrite_adls: bool = True, adls_sp_credentials_secret: str = None, credentials_genesys: Dict[str, Any] = None, + validate_df_dict: Dict[str, Any] = None, timeout: int = 3600, *args: List[any], **kwargs: Dict[str, Any], @@ -143,6 +144,8 @@ def __init__( adls_sp_credentials_secret (str, optional): The name of the Azure Key Vault secret containing a dictionary with ACCOUNT_NAME and Service Principal credentials (TENANT_ID, CLIENT_ID, CLIENT_SECRET). Defaults to None. credentials(dict, optional): Credentials for the genesys api. Defaults to None. + validate_df_dict (Dict[str,Any], optional): A dictionary with optional list of tests to verify the output dataframe. If defined, triggers + the `validate_df` task from task_utils. Defaults to None. timeout(int, optional): The amount of time (in seconds) to wait while running this task before a timeout occurs. Defaults to 3600. """ @@ -165,6 +168,7 @@ def __init__( self.start_date = start_date self.end_date = end_date self.sep = sep + self.validate_df_dict = validate_df_dict self.timeout = timeout # AzureDataLake @@ -183,6 +187,7 @@ def gen_flow(self) -> Flow: timeout=self.timeout, local_file_path=self.local_file_path, sep=self.sep, + validate_df_dict=self.validate_df_dict, ) file_names = to_csv.bind( diff --git a/viadot/flows/salesforce_to_adls.py b/viadot/flows/salesforce_to_adls.py index fe84be381..1ace9aa5a 100644 --- a/viadot/flows/salesforce_to_adls.py +++ b/viadot/flows/salesforce_to_adls.py @@ -16,6 +16,7 @@ df_to_parquet, dtypes_to_json_task, update_dtypes_dict, + validate_df, ) from viadot.tasks import AzureDataLakeUpload, SalesforceToDF @@ -41,6 +42,7 @@ def __init__( adls_file_name: str = None, adls_sp_credentials_secret: str = None, if_exists: str = "replace", + validate_df_dict: Dict[str, Any] = None, timeout: int = 3600, *args: List[Any], **kwargs: Dict[str, Any], @@ -70,6 +72,8 @@ def __init__( ACCOUNT_NAME and Service Principal credentials (TENANT_ID, CLIENT_ID, CLIENT_SECRET) for the Azure Data Lake. Defaults to None. if_exists (str, optional): What to do if the file exists. Defaults to "replace". + validate_df_dict (Dict[str,Any], optional): A dictionary with optional list of tests to verify the output + dataframe. If defined, triggers the `validate_df` task from task_utils. Defaults to None. timeout(int, optional): The amount of time (in seconds) to wait while running this task before a timeout occurs. Defaults to 3600. """ @@ -82,6 +86,7 @@ def __init__( self.env = env self.vault_name = vault_name self.credentials_secret = credentials_secret + self.validate_df_dict = validate_df_dict # AzureDataLakeUpload self.adls_sp_credentials_secret = adls_sp_credentials_secret @@ -135,6 +140,13 @@ def gen_flow(self) -> Flow: df_clean = df_clean_column.bind(df=df, flow=self) df_with_metadata = add_ingestion_metadata_task.bind(df_clean, flow=self) dtypes_dict = df_get_data_types_task.bind(df_with_metadata, flow=self) + + if self.validate_df_dict: + validation_task = validate_df.bind( + df, tests=self.validate_df_dict, flow=self + ) + validation_task.set_upstream(df, flow=self) + df_to_be_loaded = df_map_mixed_dtypes_for_parquet( df_with_metadata, dtypes_dict, flow=self ) diff --git a/viadot/tasks/genesys.py b/viadot/tasks/genesys.py index f39bff6d2..de47ddebf 100644 --- a/viadot/tasks/genesys.py +++ b/viadot/tasks/genesys.py @@ -10,6 +10,7 @@ from prefect.engine import signals from prefect.utilities import logging from prefect.utilities.tasks import defaults_from_attrs +from viadot.task_utils import * from viadot.exceptions import APIError from viadot.sources import Genesys @@ -33,6 +34,7 @@ def __init__( conversationId_list: List[str] = None, key_list: List[str] = None, credentials_genesys: Dict[str, Any] = None, + validate_df_dict: Dict[str, Any] = None, timeout: int = 3600, *args: List[Any], **kwargs: Dict[str, Any], @@ -54,6 +56,8 @@ def __init__( sep (str, optional): Separator in csv file. Defaults to "\t". conversationId_list (List[str], optional): List of conversationId passed as attribute of GET method. Defaults to None. key_list (List[str], optional): List of keys needed to specify the columns in the GET request method. Defaults to None. + validate_df_dict (Dict[str,Any], optional): A dictionary with optional list of tests to verify the output dataframe. If defined, triggers + the `validate_df` task from task_utils. Defaults to None. timeout(int, optional): The amount of time (in seconds) to wait while running this task before a timeout occurs. Defaults to 3600. """ @@ -72,6 +76,7 @@ def __init__( self.sep = sep self.conversationId_list = conversationId_list self.key_list = key_list + self.validate_df_dict = validate_df_dict super().__init__( name=self.report_name, @@ -293,6 +298,7 @@ def merge_conversations_dfs(self, data_to_merge: list) -> DataFrame: "credentials_genesys", "conversationId_list", "key_list", + "validate_df_dict", ) def run( self, @@ -309,6 +315,7 @@ def run( conversationId_list: List[str] = None, key_list: List[str] = None, credentials_genesys: Dict[str, Any] = None, + validate_df_dict: Dict[str, Any] = None, ) -> List[str]: """ Task for downloading data from the Genesys API to DF. @@ -327,6 +334,8 @@ def run( report_columns (List[str], optional): List of exisiting column in report. Defaults to None. conversationId_list (List[str], optional): List of conversationId passed as attribute of GET method. Defaults to None. key_list (List[str], optional): List of keys needed to specify the columns in the GET request method. Defaults to None. + validate_df_dict (Dict[str,Any], optional): A dictionary with optional list of tests to verify the output dataframe. If defined, triggers + the `validate_df` task from task_utils. Defaults to None. Returns: List[str]: List of file names. @@ -450,7 +459,8 @@ def run( date = start_date.replace("-", "") file_name = f"conversations_detail_{date}".upper() + ".csv" - + if validate_df_dict: + validate_df.run(df=final_df, tests=validate_df_dict) final_df.to_csv( os.path.join(self.local_file_path, file_name), index=False, @@ -488,6 +498,8 @@ def run( end = end_date.replace("-", "") file_name = f"WEBMESSAGE_{start}-{end}.csv" + if validate_df_dict: + validate_df.run(df=df, tests=validate_df_dict) df.to_csv( os.path.join(file_name), index=False,