diff --git a/tests/integration/flows/test_bigquery_to_adls.py b/tests/integration/flows/test_bigquery_to_adls.py index 1a867b58c..de793344a 100644 --- a/tests/integration/flows/test_bigquery_to_adls.py +++ b/tests/integration/flows/test_bigquery_to_adls.py @@ -1,11 +1,16 @@ import os import pendulum -from prefect.tasks.secrets import PrefectSecret +import pytest +from unittest import mock +import pandas as pd +from prefect.tasks.secrets import PrefectSecret from viadot.flows import BigQueryToADLS from viadot.tasks import AzureDataLakeRemove +from viadot.exceptions import ValidationError + ADLS_DIR_PATH = "raw/tests/" ADLS_FILE_NAME = str(pendulum.now("utc")) + ".parquet" BIGQ_CREDENTIAL_KEY = "BIGQUERY-TESTS" @@ -72,6 +77,68 @@ def test_bigquery_to_adls_false(): assert result.is_failed() os.remove("test_bigquery_to_adls_overwrite_false.parquet") os.remove("test_bigquery_to_adls_overwrite_false.json") + + +DATA = { + "type": ["banner", "banner"], + "country": ["PL", "DE"], +} + + +@mock.patch( + "viadot.tasks.BigQueryToDF.run", + return_value=pd.DataFrame(data=DATA), +) +@pytest.mark.run +def test_bigquery_to_adls_validate_df_fail(mocked_data): + flow_bigquery = BigQueryToADLS( + name="Test BigQuery to ADLS validate df fail", + dataset_name="official_empty", + table_name="space", + credentials_key=BIGQ_CREDENTIAL_KEY, + adls_file_name=ADLS_FILE_NAME, + overwrite_adls=True, + adls_dir_path=ADLS_DIR_PATH, + adls_sp_credentials_secret=ADLS_CREDENTIAL_SECRET, + validate_df_dict={"column_list_to_match": ["type", "country", "test"]}, + ) + try: + result = flow_bigquery.run() + except ValidationError: + pass + + os.remove("test_bigquery_to_adls_validate_df_fail.parquet") + os.remove("test_bigquery_to_adls_validate_df_fail.json") + + +@mock.patch( + "viadot.tasks.BigQueryToDF.run", + return_value=pd.DataFrame(data=DATA), +) +@pytest.mark.run +def test_bigquery_to_adls_validate_df_success(mocked_data): + flow_bigquery = BigQueryToADLS( + name="Test BigQuery to ADLS validate df success", + dataset_name="official_empty", + table_name="space", + credentials_key=BIGQ_CREDENTIAL_KEY, + adls_file_name=ADLS_FILE_NAME, + overwrite_adls=True, + adls_dir_path=ADLS_DIR_PATH, + adls_sp_credentials_secret=ADLS_CREDENTIAL_SECRET, + validate_df_dict={"column_list_to_match": ["type", "country"]}, + ) + result = flow_bigquery.run() + + result = flow_bigquery.run() + assert result.is_successful() + + task_results = result.result.values() + assert all([task_result.is_successful() for task_result in task_results]) + + os.remove("test_bigquery_to_adls_validate_df_success.parquet") + os.remove("test_bigquery_to_adls_validate_df_success.json") + rm = AzureDataLakeRemove( path=ADLS_DIR_PATH + ADLS_FILE_NAME, vault_name="azuwevelcrkeyv001s" ) diff --git a/viadot/flows/bigquery_to_adls.py b/viadot/flows/bigquery_to_adls.py index 8e8095b5a..935b3f7a1 100644 --- a/viadot/flows/bigquery_to_adls.py +++ b/viadot/flows/bigquery_to_adls.py @@ -15,6 +15,7 @@ df_to_parquet, dtypes_to_json_task, update_dtypes_dict, + validate_df, ) from viadot.tasks import AzureDataLakeUpload, BigQueryToDF @@ -40,6 +41,7 @@ def __init__( adls_sp_credentials_secret: str = None, overwrite_adls: bool = False, if_exists: str = "replace", + validate_df_dict: dict = None, timeout: int = 3600, *args: List[Any], **kwargs: Dict[str, Any], @@ -78,6 +80,8 @@ def __init__( Defaults to None. overwrite_adls (bool, optional): Whether to overwrite files in the lake. Defaults to False. if_exists (str, optional): What to do if the file exists. Defaults to "replace". + validate_df_dict (dict, optional): An optional dictionary to verify the received dataframe. + When passed, `validate_df` task validation tests are triggered. Defaults to None. timeout(int, optional): The amount of time (in seconds) to wait while running this task before a timeout occurs. Defaults to 3600. """ @@ -91,6 +95,9 @@ def __init__( self.vault_name = vault_name self.credentials_key = credentials_key + # Validate DataFrame + self.validate_df_dict = validate_df_dict + # AzureDataLakeUpload self.overwrite = overwrite_adls self.adls_sp_credentials_secret = adls_sp_credentials_secret @@ -140,6 +147,12 @@ def gen_flow(self) -> Flow: flow=self, ) + if self.validate_df_dict: + validation_task = validate_df.bind( + df, tests=self.validate_df_dict, flow=self + ) + validation_task.set_upstream(df, flow=self) + df_with_metadata = add_ingestion_metadata_task.bind(df, flow=self) dtypes_dict = df_get_data_types_task.bind(df_with_metadata, flow=self) df_to_be_loaded = df_map_mixed_dtypes_for_parquet(