Skip to content

Commit

Permalink
✨ Add hardcoded dtypes length check
Browse files Browse the repository at this point in the history
  • Loading branch information
malgorzatagwinner committed Oct 5, 2023
1 parent 7d5cfd4 commit 252d799
Showing 1 changed file with 83 additions and 1 deletion.
84 changes: 83 additions & 1 deletion viadot/flows/adls_to_azure_sql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import os
import re
from typing import Any, Dict, List, Literal

from visions.typesets.complete_set import CompleteSet
from visions.functional import infer_type
import pandas as pd
from prefect import Flow, task
from prefect.backend import get_key_value
Expand All @@ -26,6 +28,80 @@
def union_dfs_task(dfs: List[pd.DataFrame]):
return pd.concat(dfs, ignore_index=True)

def get_real_sql_dtypes_from_df(df: pd.DataFrame) -> Dict[str, Any]:
"""Obtain SQL data types from a pandas DataFrame"""
typeset = CompleteSet()
dtypes = infer_type(df.head(10000), typeset)
dtypes_dict = {k: str(v) for k, v in dtypes.items()}
max_length_list = (df.applymap(lambda x: len(str(x))).max() ).to_dict()
dict_mapping = {
"Float": "REAL",
"Image": None,
"Time": "TIME",
"Boolean": "VARCHAR(5)", # Bool is True/False, Microsoft expects 0/1
"DateTime": "DATETIMEOFFSET", # DATETIMEOFFSET is the only timezone-aware dtype in TSQL
"File": None,
"Geometry": "GEOMETRY",
"Ordinal": "INT",
"Integer": "INT",
"Complex": None,
"Date": "DATE",
"Count": "INT",
}
dict_dtypes_mapped = {}
for k in dtypes_dict:
#TimeDelta - datetime.timedelta, eg. '1 days 11:00:00'
if dtypes_dict[k] in ('Categorical', 'Ordinal', 'Object', 'EmailAddress','Generic', 'UUID', 'String', 'IPAddress', 'Path', 'TimeDelta', 'URL'):
dict_dtypes_mapped[k] = f'VARCHAR({max_length_list[k]})'
else:
dict_dtypes_mapped[k] = dict_mapping[dtypes_dict[k]]

return dict_dtypes_mapped

def len_from_dtypes(dtypes) -> Dict[str, Any]:
"""Function that turns a dictionary of column names and their dtypes into a dictionary
of column names and either the lengths of the varchars or the dtypes.
Args:
dtypes (Dict[str, Any], optional): Dictionary of columns and data type to apply
to the Data Frame downloaded.
Returns:
Dict[str, Any]: Dictionary of the columns and their dtypes as strings
or the lengths of their varchars, as ints
"""
dtypes_lens = {}
for k, v in dtypes.items():
if 'varchar' in v.lower():
num = re.findall(r'\d+', v)
dtypes_lens[k] = int(num.pop())
else:
dtypes_lens[k] = str(v)
return dtypes_lens

def check_hardcoded_dtypes_len(real_data_df, given_dtypes):
"""Function to check if the length of columns provided by the hard-coded dtypes are not too small
compared to the real columns of the df.
Args:
real_data_df (pd.DataFrame): Data Frame from original ADLS file.
given_dtypes (Dict[str, Any]): Dictionary of columns and data type to apply
to the Data Frame downloaded.
Raises:
ValueError: Raised whenever the length of the hardcoded dtypes is too small to contain the full data.
"""
real_data_len = len_from_dtypes( get_real_sql_dtypes_from_df(real_data_df) )
given_data_len = len_from_dtypes(given_dtypes)

for (column_given, len_given), (column_real, len_real) in zip(given_data_len.items(), real_data_len.items()):
#check if both of them are varchars
if isinstance(given_data_len[column_real], int) and isinstance(real_data_len[column_real], int):
if len_real > len_given:
logger.error(f"The length of the column {column_real} is too big, some data could be lost. Please change the length of the provided dtypes to {len_real}")
raise ValueError("Dtype length is incorrect!")


@task(timeout=3600)
def map_data_types_task(json_shema_path: str):
Expand Down Expand Up @@ -113,15 +189,21 @@ def check_dtypes_sort(
new_dtypes = dict()
for key in df.columns:
new_dtypes.update([(key, dtypes[key])])

check_hardcoded_dtypes_len(df, new_dtypes)

else:
new_dtypes = dtypes.copy()
check_hardcoded_dtypes_len(df, new_dtypes)

else:
logger.error("There is a discrepancy with any of the columns.")
raise signals.FAIL(
"dtype dictionary contains key(s) that not matching with the ADLS file columns name, or they have different length."
)
else:
new_dtypes = dtypes.copy()
check_hardcoded_dtypes_len(df, new_dtypes)

return new_dtypes

Expand Down

0 comments on commit 252d799

Please sign in to comment.