From 101efa6317430d92241774e6706b1cdac9be167c Mon Sep 17 00:00:00 2001 From: gwieloch Date: Wed, 25 Oct 2023 17:23:21 +0200 Subject: [PATCH] added `validate_df` task to `CustomerGaugeToADLS` class --- .../flows/test_customer_gauge_to_adls.py | 53 +++++++++++++++++++ viadot/flows/customer_gauge_to_adls.py | 13 +++++ 2 files changed, 66 insertions(+) diff --git a/tests/integration/flows/test_customer_gauge_to_adls.py b/tests/integration/flows/test_customer_gauge_to_adls.py index e6cdf1545..0e7afd3e2 100644 --- a/tests/integration/flows/test_customer_gauge_to_adls.py +++ b/tests/integration/flows/test_customer_gauge_to_adls.py @@ -5,6 +5,7 @@ import pytest from viadot.flows import CustomerGaugeToADLS +from viadot.exceptions import ValidationError DATA = { "user_name": ["Jane", "Bob"], @@ -15,6 +16,7 @@ "user_address_country_name": "United States", "user_address_country_code": "US", } + COLUMNS = ["user_name", "user_address_street"] ADLS_FILE_NAME = "test_customer_gauge.parquet" ADLS_DIR_PATH = "raw/tests/" @@ -40,3 +42,54 @@ def test_customer_gauge_to_adls_run_flow(mocked_class): assert result.is_successful() os.remove("test_customer_gauge_to_adls_flow_run.parquet") os.remove("test_customer_gauge_to_adls_flow_run.json") + + +@mock.patch( + "viadot.tasks.CustomerGaugeToDF.run", + return_value=pd.DataFrame(data=DATA), +) +@pytest.mark.run +def test_customer_gauge_to_adls_run_flow_validation_success(mocked_class): + flow = CustomerGaugeToADLS( + "test_customer_gauge_to_adls_run_flow_validation_success", + endpoint="responses", + total_load=False, + anonymize=True, + columns_to_anonymize=COLUMNS, + adls_dir_path=ADLS_DIR_PATH, + adls_file_name=ADLS_FILE_NAME, + overwrite_adls=True, + validate_df_dict={"column_size": {"user_address_state": 2}}, + ) + result = flow.run() + assert result.is_successful() + assert len(flow.tasks) == 11 + + os.remove("test_customer_gauge_to_adls_run_flow_validation_success.parquet") + os.remove("test_customer_gauge_to_adls_run_flow_validation_success.json") + + +@mock.patch( + "viadot.tasks.CustomerGaugeToDF.run", + return_value=pd.DataFrame(data=DATA), +) +@pytest.mark.run +def test_customer_gauge_to_adls_run_flow_validation_failure(mocked_class): + flow = CustomerGaugeToADLS( + "test_customer_gauge_to_adls_run_flow_validation_failure", + endpoint="responses", + total_load=False, + anonymize=True, + columns_to_anonymize=COLUMNS, + adls_dir_path=ADLS_DIR_PATH, + adls_file_name=ADLS_FILE_NAME, + overwrite_adls=True, + validate_df_dict={"column_size": {"user_name": 5}}, + ) + try: + flow.run() + except ValidationError: + pass + + os.remove("test_customer_gauge_to_adls_run_flow_validation_failure.parquet") + os.remove("test_customer_gauge_to_adls_run_flow_validation_failure.json") diff --git a/viadot/flows/customer_gauge_to_adls.py b/viadot/flows/customer_gauge_to_adls.py index 8053aeda3..6c54a0704 100644 --- a/viadot/flows/customer_gauge_to_adls.py +++ b/viadot/flows/customer_gauge_to_adls.py @@ -17,6 +17,7 @@ df_to_parquet, dtypes_to_json_task, update_dtypes_dict, + validate_df, ) from viadot.tasks import AzureDataLakeUpload, CustomerGaugeToDF @@ -52,6 +53,7 @@ def __init__( adls_sp_credentials_secret: str = None, overwrite_adls: bool = False, if_exists: str = "replace", + validate_df_dict: dict = None, timeout: int = 3600, *args: List[Any], **kwargs: Dict[str, Any] @@ -92,6 +94,8 @@ def __init__( Defaults to None. overwrite_adls (bool, optional): Whether to overwrite files in the lake. Defaults to False. if_exists (str, optional): What to do if the file exists. Defaults to "replace". + validate_df_dict (Dict[str], optional): A dictionary with optional list of tests to verify the output dataframe. + If defined, triggers the `validate_df` task from task_utils. Defaults to None. timeout (int, optional): The time (in seconds) to wait while running this task before a timeout occurs. Defaults to 3600. """ # CustomerGaugeToDF @@ -105,6 +109,9 @@ def __init__( self.end_date = end_date self.customer_gauge_credentials_secret = customer_gauge_credentials_secret + # validate_df + self.validate_df_dict = validate_df_dict + # anonymize_df self.anonymize = anonymize self.columns_to_anonymize = columns_to_anonymize @@ -169,6 +176,12 @@ def gen_flow(self) -> Flow: flow=self, ) + if self.validate_df_dict: + validation_task = validate_df.bind( + customerg_df, tests=self.validate_df_dict, flow=self + ) + validation_task.set_upstream(customerg_df, flow=self) + if self.anonymize == True: anonymized_df = anonymize_df.bind( customerg_df,