Skip to content

Commit

Permalink
added validate_df task to CustomerGaugeToADLS class
Browse files Browse the repository at this point in the history
  • Loading branch information
gwieloch committed Oct 25, 2023
1 parent eab1c93 commit 101efa6
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
53 changes: 53 additions & 0 deletions tests/integration/flows/test_customer_gauge_to_adls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from viadot.flows import CustomerGaugeToADLS
from viadot.exceptions import ValidationError

DATA = {
"user_name": ["Jane", "Bob"],
Expand All @@ -15,6 +16,7 @@
"user_address_country_name": "United States",
"user_address_country_code": "US",
}

COLUMNS = ["user_name", "user_address_street"]
ADLS_FILE_NAME = "test_customer_gauge.parquet"
ADLS_DIR_PATH = "raw/tests/"
Expand All @@ -40,3 +42,54 @@ def test_customer_gauge_to_adls_run_flow(mocked_class):
assert result.is_successful()
os.remove("test_customer_gauge_to_adls_flow_run.parquet")
os.remove("test_customer_gauge_to_adls_flow_run.json")


@mock.patch(
"viadot.tasks.CustomerGaugeToDF.run",
return_value=pd.DataFrame(data=DATA),
)
@pytest.mark.run
def test_customer_gauge_to_adls_run_flow_validation_success(mocked_class):
flow = CustomerGaugeToADLS(
"test_customer_gauge_to_adls_run_flow_validation_success",
endpoint="responses",
total_load=False,
anonymize=True,
columns_to_anonymize=COLUMNS,
adls_dir_path=ADLS_DIR_PATH,
adls_file_name=ADLS_FILE_NAME,
overwrite_adls=True,
validate_df_dict={"column_size": {"user_address_state": 2}},
)
result = flow.run()
assert result.is_successful()
assert len(flow.tasks) == 11

os.remove("test_customer_gauge_to_adls_run_flow_validation_success.parquet")
os.remove("test_customer_gauge_to_adls_run_flow_validation_success.json")


@mock.patch(
"viadot.tasks.CustomerGaugeToDF.run",
return_value=pd.DataFrame(data=DATA),
)
@pytest.mark.run
def test_customer_gauge_to_adls_run_flow_validation_failure(mocked_class):
flow = CustomerGaugeToADLS(
"test_customer_gauge_to_adls_run_flow_validation_failure",
endpoint="responses",
total_load=False,
anonymize=True,
columns_to_anonymize=COLUMNS,
adls_dir_path=ADLS_DIR_PATH,
adls_file_name=ADLS_FILE_NAME,
overwrite_adls=True,
validate_df_dict={"column_size": {"user_name": 5}},
)
try:
flow.run()
except ValidationError:
pass

os.remove("test_customer_gauge_to_adls_run_flow_validation_failure.parquet")
os.remove("test_customer_gauge_to_adls_run_flow_validation_failure.json")
13 changes: 13 additions & 0 deletions viadot/flows/customer_gauge_to_adls.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
df_to_parquet,
dtypes_to_json_task,
update_dtypes_dict,
validate_df,
)
from viadot.tasks import AzureDataLakeUpload, CustomerGaugeToDF

Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
adls_sp_credentials_secret: str = None,
overwrite_adls: bool = False,
if_exists: str = "replace",
validate_df_dict: dict = None,
timeout: int = 3600,
*args: List[Any],
**kwargs: Dict[str, Any]
Expand Down Expand Up @@ -92,6 +94,8 @@ def __init__(
Defaults to None.
overwrite_adls (bool, optional): Whether to overwrite files in the lake. Defaults to False.
if_exists (str, optional): What to do if the file exists. Defaults to "replace".
validate_df_dict (Dict[str], optional): A dictionary with optional list of tests to verify the output dataframe.
If defined, triggers the `validate_df` task from task_utils. Defaults to None.
timeout (int, optional): The time (in seconds) to wait while running this task before a timeout occurs. Defaults to 3600.
"""
# CustomerGaugeToDF
Expand All @@ -105,6 +109,9 @@ def __init__(
self.end_date = end_date
self.customer_gauge_credentials_secret = customer_gauge_credentials_secret

# validate_df
self.validate_df_dict = validate_df_dict

# anonymize_df
self.anonymize = anonymize
self.columns_to_anonymize = columns_to_anonymize
Expand Down Expand Up @@ -169,6 +176,12 @@ def gen_flow(self) -> Flow:
flow=self,
)

if self.validate_df_dict:
validation_task = validate_df.bind(
customerg_df, tests=self.validate_df_dict, flow=self
)
validation_task.set_upstream(customerg_df, flow=self)

if self.anonymize == True:
anonymized_df = anonymize_df.bind(
customerg_df,
Expand Down

0 comments on commit 101efa6

Please sign in to comment.