Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add hardcoded dtypes length check #759

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 94 additions & 2 deletions tests/integration/flows/test_adls_to_azure_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,30 @@
from prefect.engine import signals

from viadot.flows import ADLSToAzureSQL
from viadot.flows.adls_to_azure_sql import check_dtypes_sort, df_to_csv_task

from viadot.flows.adls_to_azure_sql import check_dtypes_sort, df_to_csv_task, len_from_dtypes, check_hardcoded_dtypes_len, get_real_sql_dtypes_from_df

test_df = pd.DataFrame(
Rafalz13 marked this conversation as resolved.
Show resolved Hide resolved
{
"Date": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04", "2023-01-05"],
"User ID": ["1a34", "1d34$56", "1a3456&8", "1d3456789!", "1s3"], # max length = 10
"Web ID": ["4321", "1234$56", "123", "0", "12"], # max length = 7
"User name": ["Ada", "aaaaadAA", "Adulkaaa", "A", " "], # max length = 8
"User country": ["Poland", "USA", "Norway", "USA", "USA"], # max length = 6
"All Users": [1234, 123456, 12345678, 123456789, 123],
"Age": [0, 12, 123, 89, 23],
"Last varchar": ["Last", " ", "varchar", "of this ", "df"], # max length =8
}
)
Real_Sql_Dtypes = {
Rafalz13 marked this conversation as resolved.
Show resolved Hide resolved
"Date": "DATE",
"User ID": "VARCHAR(10)",
"Web ID": "VARCHAR(7)",
"User name": "VARCHAR(8)",
"User country": "VARCHAR(6)",
"All Users": "INT",
"Age": "INT",
"Last varchar": "VARCHAR(8)",
}

def test_get_promoted_adls_path_csv_file():
adls_path_file = "raw/supermetrics/adls_ga_load_times_fr_test/2021-07-14T13%3A09%3A02.997357%2B00%3A00.csv"
Expand Down Expand Up @@ -101,3 +123,73 @@ def test_check_dtypes_sort():
assert False
except signals.FAIL:
assert True


def test_get_real_sql_dtypes_from_df():
assert get_real_sql_dtypes_from_df(test_df) == Real_Sql_Dtypes


def test_len_from_dtypes():
real_df_lengths = {
"Date": "DATE",
"User ID": 10,
"Web ID": 7,
"User name": 8,
"User country": 6,
"All Users": "INT",
"Age": "INT",
"Last varchar": 8,
}
assert len_from_dtypes(Real_Sql_Dtypes) == real_df_lengths


def test_check_hardcoded_dtypes_len_userid(caplog):
smaller_dtype_userid = {
"Date": "DateTime",
"User ID": "varchar(1)",
"Web ID": "varchar(10)",
"User name": "varchar(10)",
"User country": "varchar(10)",
"All Users": "int",
"Age": "int",
"Last varchar": "varchar(10)",
}
with pytest.raises(ValueError):
check_hardcoded_dtypes_len(test_df, smaller_dtype_userid)
assert (
"The length of the column User ID is too big, some data could be lost. Please change the length of the provided dtypes to 10"
in caplog.text
)


def test_check_hardcoded_dtypes_len_usercountry(caplog):
smaller_dtype_usercountry = {
"Date": "DateTime",
"User ID": "varchar(10)",
"Web ID": "varchar(10)",
"User name": "varchar(10)",
"User country": "varchar(5)",
"All Users": "int",
"Age": "int",
"Last varchar": "varchar(10)",
}
with pytest.raises(ValueError):
check_hardcoded_dtypes_len(test_df, smaller_dtype_usercountry)
assert (
"The length of the column User country is too big, some data could be lost. Please change the length of the provided dtypes to 6"
in caplog.text
)


def test_check_hardcoded_dtypes_len():
good_dtypes = {
"Date": "DateTime",
"User ID": "varchar(10)",
"Web ID": "varchar(10)",
"User name": "varchar(10)",
"User country": "varchar(10)",
"All Users": "int",
"Age": "int",
"Last varchar": "varchar(10)",
}
assert check_hardcoded_dtypes_len(test_df, good_dtypes) == None
89 changes: 87 additions & 2 deletions viadot/flows/adls_to_azure_sql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import os
import re
from typing import Any, Dict, List, Literal

from visions.typesets.complete_set import CompleteSet
from visions.functional import infer_type
import pandas as pd
from prefect import Flow, task
from prefect.backend import get_key_value
Expand All @@ -26,6 +28,88 @@
def union_dfs_task(dfs: List[pd.DataFrame]):
return pd.concat(dfs, ignore_index=True)

def get_real_sql_dtypes_from_df(df: pd.DataFrame) -> Dict[str, Any]:
"""Obtain SQL data types from a pandas DataFrame
and the lengths of the columns based on the real maximum lengths of the data in them.
Args:
Rafalz13 marked this conversation as resolved.
Show resolved Hide resolved
df (pd.DataFrame): Data Frame from original ADLS file.
Returns:
Rafalz13 marked this conversation as resolved.
Show resolved Hide resolved
Dict[str, Any]: Dictionary with data types of columns and their real maximum length.
"""

typeset = CompleteSet()
dtypes = infer_type(df.head(10000), typeset)
dtypes_dict = {k: str(v) for k, v in dtypes.items()}
max_length_list = (df.applymap(lambda x: len(str(x))).max() ).to_dict()
dict_mapping = {
"Float": "REAL",
"Image": None,
"Time": "TIME",
"Boolean": "VARCHAR(5)", # Bool is True/False, Microsoft expects 0/1
"DateTime": "DATETIMEOFFSET", # DATETIMEOFFSET is the only timezone-aware dtype in TSQL
"File": None,
"Geometry": "GEOMETRY",
"Ordinal": "INT",
"Integer": "INT",
"Complex": None,
"Date": "DATE",
"Count": "INT",
}
dict_dtypes_mapped = {}
for k in dtypes_dict:
#TimeDelta - datetime.timedelta, eg. '1 days 11:00:00'
if dtypes_dict[k] in ('Categorical', 'Ordinal', 'Object', 'EmailAddress','Generic', 'UUID', 'String', 'IPAddress', 'Path', 'TimeDelta', 'URL'):
dict_dtypes_mapped[k] = f'VARCHAR({max_length_list[k]})'
else:
dict_dtypes_mapped[k] = dict_mapping[dtypes_dict[k]]

return dict_dtypes_mapped

def len_from_dtypes(dtypes: Dict[str, Any]) -> Dict[str, Any]:
"""Function that turns a dictionary of column names and their dtypes into a dictionary
of column names and either the lengths of the varchars (of 'int' type) or the dtypes (of 'string' type).

Args:
dtypes (Dict[str, Any], optional): Dictionary of columns and data type to apply
to the Data Frame downloaded.

Returns:
Dict[str, Any]: Dictionary of the columns and their dtypes as strings
or the lengths of their varchars, as ints.
"""
dtypes_lens = {}
for k, v in dtypes.items():
if 'varchar' in v.lower():
num = re.findall(r'\d+', v)
dtypes_lens[k] = int(num.pop())
else:
dtypes_lens[k] = str(v)
return dtypes_lens

def check_hardcoded_dtypes_len(real_data_df: pd.DataFrame, given_dtypes: Dict[str, Any]) -> None:
"""Function to check if the length of columns provided by the hard-coded dtypes are not too small
compared to the real columns of the df.

Args:
real_data_df (pd.DataFrame): Data Frame from original ADLS file.
given_dtypes (Dict[str, Any]): Dictionary of columns and data type to apply
to the Data Frame downloaded.

Raises:
ValueError: Raised whenever the length of the hardcoded dtypes is too small to contain the full data.

Returns:
None
"""
real_column_lengths = len_from_dtypes( get_real_sql_dtypes_from_df(real_data_df) )
given_column_lengths = len_from_dtypes(given_dtypes)

for (column_given, len_given), (column_real, len_real) in zip(given_column_lengths.items(), real_column_lengths.items()):
#checking only the columns with lengths of varchars
if isinstance(given_column_lengths[column_real], int) and isinstance(real_column_lengths[column_real], int):
if len_real > len_given:
logger.error(f"The length of the column {column_real} is too big, some data could be lost. Please change the length of the provided dtypes to {len_real}")
raise ValueError("Datatype length is incorrect!")

@task(timeout=3600)
def map_data_types_task(json_shema_path: str):
Expand Down Expand Up @@ -122,7 +206,8 @@ def check_dtypes_sort(
)
else:
new_dtypes = dtypes.copy()


check_hardcoded_dtypes_len(df, new_dtypes)
return new_dtypes


Expand Down
Loading