From 252d799b6abf9b8d2bdd93a93785ef66dd7b7e30 Mon Sep 17 00:00:00 2001 From: mgwinner Date: Thu, 5 Oct 2023 13:49:38 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20hardcoded=20dtypes=20length?= =?UTF-8?q?=20check?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- viadot/flows/adls_to_azure_sql.py | 84 ++++++++++++++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/viadot/flows/adls_to_azure_sql.py b/viadot/flows/adls_to_azure_sql.py index a9e49c6b6..f95866924 100644 --- a/viadot/flows/adls_to_azure_sql.py +++ b/viadot/flows/adls_to_azure_sql.py @@ -1,7 +1,9 @@ import json import os +import re from typing import Any, Dict, List, Literal - +from visions.typesets.complete_set import CompleteSet +from visions.functional import infer_type import pandas as pd from prefect import Flow, task from prefect.backend import get_key_value @@ -26,6 +28,80 @@ def union_dfs_task(dfs: List[pd.DataFrame]): return pd.concat(dfs, ignore_index=True) +def get_real_sql_dtypes_from_df(df: pd.DataFrame) -> Dict[str, Any]: + """Obtain SQL data types from a pandas DataFrame""" + typeset = CompleteSet() + dtypes = infer_type(df.head(10000), typeset) + dtypes_dict = {k: str(v) for k, v in dtypes.items()} + max_length_list = (df.applymap(lambda x: len(str(x))).max() ).to_dict() + dict_mapping = { + "Float": "REAL", + "Image": None, + "Time": "TIME", + "Boolean": "VARCHAR(5)", # Bool is True/False, Microsoft expects 0/1 + "DateTime": "DATETIMEOFFSET", # DATETIMEOFFSET is the only timezone-aware dtype in TSQL + "File": None, + "Geometry": "GEOMETRY", + "Ordinal": "INT", + "Integer": "INT", + "Complex": None, + "Date": "DATE", + "Count": "INT", + } + dict_dtypes_mapped = {} + for k in dtypes_dict: + #TimeDelta - datetime.timedelta, eg. '1 days 11:00:00' + if dtypes_dict[k] in ('Categorical', 'Ordinal', 'Object', 'EmailAddress','Generic', 'UUID', 'String', 'IPAddress', 'Path', 'TimeDelta', 'URL'): + dict_dtypes_mapped[k] = f'VARCHAR({max_length_list[k]})' + else: + dict_dtypes_mapped[k] = dict_mapping[dtypes_dict[k]] + + return dict_dtypes_mapped + +def len_from_dtypes(dtypes) -> Dict[str, Any]: + """Function that turns a dictionary of column names and their dtypes into a dictionary + of column names and either the lengths of the varchars or the dtypes. + + + Args: + dtypes (Dict[str, Any], optional): Dictionary of columns and data type to apply + to the Data Frame downloaded. + + Returns: + Dict[str, Any]: Dictionary of the columns and their dtypes as strings + or the lengths of their varchars, as ints + """ + dtypes_lens = {} + for k, v in dtypes.items(): + if 'varchar' in v.lower(): + num = re.findall(r'\d+', v) + dtypes_lens[k] = int(num.pop()) + else: + dtypes_lens[k] = str(v) + return dtypes_lens + +def check_hardcoded_dtypes_len(real_data_df, given_dtypes): + """Function to check if the length of columns provided by the hard-coded dtypes are not too small + compared to the real columns of the df. + + Args: + real_data_df (pd.DataFrame): Data Frame from original ADLS file. + given_dtypes (Dict[str, Any]): Dictionary of columns and data type to apply + to the Data Frame downloaded. + + Raises: + ValueError: Raised whenever the length of the hardcoded dtypes is too small to contain the full data. + """ + real_data_len = len_from_dtypes( get_real_sql_dtypes_from_df(real_data_df) ) + given_data_len = len_from_dtypes(given_dtypes) + + for (column_given, len_given), (column_real, len_real) in zip(given_data_len.items(), real_data_len.items()): + #check if both of them are varchars + if isinstance(given_data_len[column_real], int) and isinstance(real_data_len[column_real], int): + if len_real > len_given: + logger.error(f"The length of the column {column_real} is too big, some data could be lost. Please change the length of the provided dtypes to {len_real}") + raise ValueError("Dtype length is incorrect!") + @task(timeout=3600) def map_data_types_task(json_shema_path: str): @@ -113,8 +189,13 @@ def check_dtypes_sort( new_dtypes = dict() for key in df.columns: new_dtypes.update([(key, dtypes[key])]) + + check_hardcoded_dtypes_len(df, new_dtypes) + else: new_dtypes = dtypes.copy() + check_hardcoded_dtypes_len(df, new_dtypes) + else: logger.error("There is a discrepancy with any of the columns.") raise signals.FAIL( @@ -122,6 +203,7 @@ def check_dtypes_sort( ) else: new_dtypes = dtypes.copy() + check_hardcoded_dtypes_len(df, new_dtypes) return new_dtypes