Skip to content

Commit

Permalink
⚡️ Resolve issues and add test of new functions
Browse files Browse the repository at this point in the history
  • Loading branch information
malgorzatagwinner committed Oct 9, 2023
1 parent 252d799 commit 3496a94
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 22 deletions.
96 changes: 94 additions & 2 deletions tests/integration/flows/test_adls_to_azure_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,30 @@
from prefect.engine import signals

from viadot.flows import ADLSToAzureSQL
from viadot.flows.adls_to_azure_sql import check_dtypes_sort, df_to_csv_task

from viadot.flows.adls_to_azure_sql import check_dtypes_sort, df_to_csv_task, len_from_dtypes, check_hardcoded_dtypes_len, get_real_sql_dtypes_from_df

test_df = pd.DataFrame(
{
"Date": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04", "2023-01-05"],
"User ID": ["1a34", "1d34$56", "1a3456&8", "1d3456789!", "1s3"], # max length = 10
"Web ID": ["4321", "1234$56", "123", "0", "12"], # max length = 7
"User name": ["Ada", "aaaaadAA", "Adulkaaa", "A", " "], # max length = 8
"User country": ["Poland", "USA", "Norway", "USA", "USA"], # max length = 6
"All Users": [1234, 123456, 12345678, 123456789, 123],
"Age": [0, 12, 123, 89, 23],
"Last varchar": ["Last", " ", "varchar", "of this ", "df"], # max length =8
}
)
Real_Sql_Dtypes = {
"Date": "DATE",
"User ID": "VARCHAR(10)",
"Web ID": "VARCHAR(7)",
"User name": "VARCHAR(8)",
"User country": "VARCHAR(6)",
"All Users": "INT",
"Age": "INT",
"Last varchar": "VARCHAR(8)",
}

def test_get_promoted_adls_path_csv_file():
adls_path_file = "raw/supermetrics/adls_ga_load_times_fr_test/2021-07-14T13%3A09%3A02.997357%2B00%3A00.csv"
Expand Down Expand Up @@ -101,3 +123,73 @@ def test_check_dtypes_sort():
assert False
except signals.FAIL:
assert True


def test_get_real_sql_dtypes_from_df():
assert get_real_sql_dtypes_from_df(test_df) == Real_Sql_Dtypes


def test_len_from_dtypes():
real_df_lengths = {
"Date": "DATE",
"User ID": 10,
"Web ID": 7,
"User name": 8,
"User country": 6,
"All Users": "INT",
"Age": "INT",
"Last varchar": 8,
}
assert len_from_dtypes(Real_Sql_Dtypes) == real_df_lengths


def test_check_hardcoded_dtypes_len_userid(caplog):
smaller_dtype_userid = {
"Date": "DateTime",
"User ID": "varchar(1)",
"Web ID": "varchar(10)",
"User name": "varchar(10)",
"User country": "varchar(10)",
"All Users": "int",
"Age": "int",
"Last varchar": "varchar(10)",
}
with pytest.raises(ValueError):
check_hardcoded_dtypes_len(test_df, smaller_dtype_userid)
assert (
"The length of the column User ID is too big, some data could be lost. Please change the length of the provided dtypes to 10"
in caplog.text
)


def test_check_hardcoded_dtypes_len_usercountry(caplog):
smaller_dtype_usercountry = {
"Date": "DateTime",
"User ID": "varchar(10)",
"Web ID": "varchar(10)",
"User name": "varchar(10)",
"User country": "varchar(5)",
"All Users": "int",
"Age": "int",
"Last varchar": "varchar(10)",
}
with pytest.raises(ValueError):
check_hardcoded_dtypes_len(test_df, smaller_dtype_usercountry)
assert (
"The length of the column User country is too big, some data could be lost. Please change the length of the provided dtypes to 6"
in caplog.text
)


def test_check_hardcoded_dtypes_len():
good_dtypes = {
"Date": "DateTime",
"User ID": "varchar(10)",
"Web ID": "varchar(10)",
"User name": "varchar(10)",
"User country": "varchar(10)",
"All Users": "int",
"Age": "int",
"Last varchar": "varchar(10)",
}
assert check_hardcoded_dtypes_len(test_df, good_dtypes) == None
43 changes: 23 additions & 20 deletions viadot/flows/adls_to_azure_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ def union_dfs_task(dfs: List[pd.DataFrame]):
return pd.concat(dfs, ignore_index=True)

def get_real_sql_dtypes_from_df(df: pd.DataFrame) -> Dict[str, Any]:
"""Obtain SQL data types from a pandas DataFrame"""
"""Obtain SQL data types from a pandas DataFrame
and the lengths of the columns based on the real maximum lengths of the data in them.
Args:
df (pd.DataFrame): Data Frame from original ADLS file.
Returns:
Dict[str, Any]: Dictionary with data types of columns and their real maximum length.
"""

typeset = CompleteSet()
dtypes = infer_type(df.head(10000), typeset)
dtypes_dict = {k: str(v) for k, v in dtypes.items()}
Expand Down Expand Up @@ -58,18 +65,17 @@ def get_real_sql_dtypes_from_df(df: pd.DataFrame) -> Dict[str, Any]:

return dict_dtypes_mapped

def len_from_dtypes(dtypes) -> Dict[str, Any]:
def len_from_dtypes(dtypes: Dict[str, Any]) -> Dict[str, Any]:
"""Function that turns a dictionary of column names and their dtypes into a dictionary
of column names and either the lengths of the varchars or the dtypes.
of column names and either the lengths of the varchars (of 'int' type) or the dtypes (of 'string' type).
Args:
dtypes (Dict[str, Any], optional): Dictionary of columns and data type to apply
to the Data Frame downloaded.
Returns:
Dict[str, Any]: Dictionary of the columns and their dtypes as strings
or the lengths of their varchars, as ints
or the lengths of their varchars, as ints.
"""
dtypes_lens = {}
for k, v in dtypes.items():
Expand All @@ -80,7 +86,7 @@ def len_from_dtypes(dtypes) -> Dict[str, Any]:
dtypes_lens[k] = str(v)
return dtypes_lens

def check_hardcoded_dtypes_len(real_data_df, given_dtypes):
def check_hardcoded_dtypes_len(real_data_df: pd.DataFrame, given_dtypes: Dict[str, Any]) -> None:
"""Function to check if the length of columns provided by the hard-coded dtypes are not too small
compared to the real columns of the df.
Expand All @@ -91,17 +97,19 @@ def check_hardcoded_dtypes_len(real_data_df, given_dtypes):
Raises:
ValueError: Raised whenever the length of the hardcoded dtypes is too small to contain the full data.
Returns:
None
"""
real_data_len = len_from_dtypes( get_real_sql_dtypes_from_df(real_data_df) )
given_data_len = len_from_dtypes(given_dtypes)
real_column_lengths = len_from_dtypes( get_real_sql_dtypes_from_df(real_data_df) )
given_column_lengths = len_from_dtypes(given_dtypes)

for (column_given, len_given), (column_real, len_real) in zip(given_data_len.items(), real_data_len.items()):
#check if both of them are varchars
if isinstance(given_data_len[column_real], int) and isinstance(real_data_len[column_real], int):
for (column_given, len_given), (column_real, len_real) in zip(given_column_lengths.items(), real_column_lengths.items()):
#checking only the columns with lengths of varchars
if isinstance(given_column_lengths[column_real], int) and isinstance(real_column_lengths[column_real], int):
if len_real > len_given:
logger.error(f"The length of the column {column_real} is too big, some data could be lost. Please change the length of the provided dtypes to {len_real}")
raise ValueError("Dtype length is incorrect!")

raise ValueError("Datatype length is incorrect!")

@task(timeout=3600)
def map_data_types_task(json_shema_path: str):
Expand Down Expand Up @@ -189,22 +197,17 @@ def check_dtypes_sort(
new_dtypes = dict()
for key in df.columns:
new_dtypes.update([(key, dtypes[key])])

check_hardcoded_dtypes_len(df, new_dtypes)

else:
new_dtypes = dtypes.copy()
check_hardcoded_dtypes_len(df, new_dtypes)

else:
logger.error("There is a discrepancy with any of the columns.")
raise signals.FAIL(
"dtype dictionary contains key(s) that not matching with the ADLS file columns name, or they have different length."
)
else:
new_dtypes = dtypes.copy()
check_hardcoded_dtypes_len(df, new_dtypes)


check_hardcoded_dtypes_len(df, new_dtypes)
return new_dtypes


Expand Down

0 comments on commit 3496a94

Please sign in to comment.