diff --git a/tests/integration/flows/test_adls_to_azure_sql.py b/tests/integration/flows/test_adls_to_azure_sql.py index e13dc31b2..de20910d4 100644 --- a/tests/integration/flows/test_adls_to_azure_sql.py +++ b/tests/integration/flows/test_adls_to_azure_sql.py @@ -6,8 +6,30 @@ from prefect.engine import signals from viadot.flows import ADLSToAzureSQL -from viadot.flows.adls_to_azure_sql import check_dtypes_sort, df_to_csv_task - +from viadot.flows.adls_to_azure_sql import check_dtypes_sort, df_to_csv_task, len_from_dtypes, check_hardcoded_dtypes_len, get_real_sql_dtypes_from_df + +TEST_DF = pd.DataFrame( + { + "Date": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04", "2023-01-05"], + "User ID": ["1a34", "1d34$56", "1a3456&8", "1d3456789!", "1s3"], # max length = 10 + "Web ID": ["4321", "1234$56", "123", "0", "12"], # max length = 7 + "User name": ["Ada", "aaaaadAA", "Adulkaaa", "A", " "], # max length = 8 + "User country": ["Poland", "USA", "Norway", "USA", "USA"], # max length = 6 + "All Users": [1234, 123456, 12345678, 123456789, 123], + "Age": [0, 12, 123, 89, 23], + "Last varchar": ["Last", " ", "varchar", "of this ", "df"], # max length =8 + } +) +REAL_SQL_DTYPES = { + "Date": "DATE", + "User ID": "VARCHAR(10)", + "Web ID": "VARCHAR(7)", + "User name": "VARCHAR(8)", + "User country": "VARCHAR(6)", + "All Users": "INT", + "Age": "INT", + "Last varchar": "VARCHAR(8)", +} def test_get_promoted_adls_path_csv_file(): adls_path_file = "raw/supermetrics/adls_ga_load_times_fr_test/2021-07-14T13%3A09%3A02.997357%2B00%3A00.csv" @@ -101,3 +123,73 @@ def test_check_dtypes_sort(): assert False except signals.FAIL: assert True + + +def test_get_real_sql_dtypes_from_df(): + assert get_real_sql_dtypes_from_df(TEST_DF) == REAL_SQL_DTYPES + + +def test_len_from_dtypes(): + real_df_lengths = { + "Date": "DATE", + "User ID": 10, + "Web ID": 7, + "User name": 8, + "User country": 6, + "All Users": "INT", + "Age": "INT", + "Last varchar": 8, + } + assert len_from_dtypes(REAL_SQL_DTYPES) == real_df_lengths + + +def test_check_hardcoded_dtypes_len_userid(caplog): + smaller_dtype_userid = { + "Date": "DateTime", + "User ID": "varchar(1)", + "Web ID": "varchar(10)", + "User name": "varchar(10)", + "User country": "varchar(10)", + "All Users": "int", + "Age": "int", + "Last varchar": "varchar(10)", + } + with pytest.raises(ValueError): + check_hardcoded_dtypes_len(TEST_DF, smaller_dtype_userid) + assert ( + "The length of the column User ID is too big, some data could be lost. Please change the length of the provided dtypes to 10" + in caplog.text + ) + + +def test_check_hardcoded_dtypes_len_usercountry(caplog): + smaller_dtype_usercountry = { + "Date": "DateTime", + "User ID": "varchar(10)", + "Web ID": "varchar(10)", + "User name": "varchar(10)", + "User country": "varchar(5)", + "All Users": "int", + "Age": "int", + "Last varchar": "varchar(10)", + } + with pytest.raises(ValueError): + check_hardcoded_dtypes_len(TEST_DF, smaller_dtype_usercountry) + assert ( + "The length of the column User country is too big, some data could be lost. Please change the length of the provided dtypes to 6" + in caplog.text + ) + + +def test_check_hardcoded_dtypes_len(): + good_dtypes = { + "Date": "DateTime", + "User ID": "varchar(10)", + "Web ID": "varchar(10)", + "User name": "varchar(10)", + "User country": "varchar(10)", + "All Users": "int", + "Age": "int", + "Last varchar": "varchar(10)", + } + assert check_hardcoded_dtypes_len(TEST_DF, good_dtypes) == None \ No newline at end of file diff --git a/viadot/flows/adls_to_azure_sql.py b/viadot/flows/adls_to_azure_sql.py index a9e49c6b6..c696f410f 100644 --- a/viadot/flows/adls_to_azure_sql.py +++ b/viadot/flows/adls_to_azure_sql.py @@ -1,7 +1,9 @@ import json import os +import re from typing import Any, Dict, List, Literal - +from visions.typesets.complete_set import CompleteSet +from visions.functional import infer_type import pandas as pd from prefect import Flow, task from prefect.backend import get_key_value @@ -26,6 +28,90 @@ def union_dfs_task(dfs: List[pd.DataFrame]): return pd.concat(dfs, ignore_index=True) +def get_real_sql_dtypes_from_df(df: pd.DataFrame) -> Dict[str, Any]: + """Obtain SQL data types from a pandas DataFrame + and the lengths of the columns based on the real maximum lengths of the data in them. + + Args: + df (pd.DataFrame): Data Frame from original ADLS file. + + Returns: + Dict[str, Any]: Dictionary with data types of columns and their real maximum length. + """ + + typeset = CompleteSet() + dtypes = infer_type(df.head(10000), typeset) + dtypes_dict = {k: str(v) for k, v in dtypes.items()} + max_length_list = (df.applymap(lambda x: len(str(x))).max() ).to_dict() + dict_mapping = { + "Float": "REAL", + "Image": None, + "Time": "TIME", + "Boolean": "VARCHAR(5)", # Bool is True/False, Microsoft expects 0/1 + "DateTime": "DATETIMEOFFSET", # DATETIMEOFFSET is the only timezone-aware dtype in TSQL + "File": None, + "Geometry": "GEOMETRY", + "Ordinal": "INT", + "Integer": "INT", + "Complex": None, + "Date": "DATE", + "Count": "INT", + } + dict_dtypes_mapped = {} + for k in dtypes_dict: + #TimeDelta - datetime.timedelta, eg. '1 days 11:00:00' + if dtypes_dict[k] in ('Categorical', 'Ordinal', 'Object', 'EmailAddress','Generic', 'UUID', 'String', 'IPAddress', 'Path', 'TimeDelta', 'URL'): + dict_dtypes_mapped[k] = f'VARCHAR({max_length_list[k]})' + else: + dict_dtypes_mapped[k] = dict_mapping[dtypes_dict[k]] + + return dict_dtypes_mapped + +def len_from_dtypes(dtypes: Dict[str, Any]) -> Dict[str, Any]: + """Function that turns a dictionary of column names and their dtypes into a dictionary + of column names and either the lengths of the varchars (of 'int' type) or the dtypes (of 'string' type). + + Args: + dtypes (Dict[str, Any], optional): Dictionary of columns and data type to apply + to the Data Frame downloaded. + + Returns: + Dict[str, Any]: Dictionary of the columns and their dtypes as strings + or the lengths of their varchars, as ints. + """ + dtypes_lens = {} + for k, v in dtypes.items(): + if 'varchar' in v.lower(): + num = re.findall(r'\d+', v) + dtypes_lens[k] = int(num.pop()) + else: + dtypes_lens[k] = str(v) + return dtypes_lens + +def check_hardcoded_dtypes_len(real_data_df: pd.DataFrame, given_dtypes: Dict[str, Any]) -> None: + """Function to check if the length of columns provided by the hard-coded dtypes are not too small + compared to the real columns of the df. + + Args: + real_data_df (pd.DataFrame): Data Frame from original ADLS file. + given_dtypes (Dict[str, Any]): Dictionary of columns and data type to apply + to the Data Frame downloaded. + + Raises: + ValueError: Raised whenever the length of the hardcoded dtypes is too small to contain the full data. + + Returns: + None + """ + real_column_lengths = len_from_dtypes( get_real_sql_dtypes_from_df(real_data_df) ) + given_column_lengths = len_from_dtypes(given_dtypes) + + for (column_given, len_given), (column_real, len_real) in zip(given_column_lengths.items(), real_column_lengths.items()): + #checking only the columns with lengths of varchars + if isinstance(given_column_lengths[column_real], int) and isinstance(real_column_lengths[column_real], int): + if len_real > len_given: + logger.error(f"The length of the column {column_real} is too big, some data could be lost. Please change the length of the provided dtypes to {len_real}") + raise ValueError("Datatype length is incorrect!") @task(timeout=3600) def map_data_types_task(json_shema_path: str): @@ -122,7 +208,8 @@ def check_dtypes_sort( ) else: new_dtypes = dtypes.copy() - + + check_hardcoded_dtypes_len(df, new_dtypes) return new_dtypes