From b99571a684e8407b12f53eed434cd847140aca72 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 13 Nov 2024 17:34:35 +0100 Subject: [PATCH] Add downcasting to datetimes when comparing two columns. --- sdv/constraints/tabular.py | 9 ++ sdv/constraints/utils.py | 131 ++++++++++++++++++ tests/integration/constraints/test_tabular.py | 129 ++++++++++++++++- 3 files changed, 265 insertions(+), 4 deletions(-) diff --git a/sdv/constraints/tabular.py b/sdv/constraints/tabular.py index af7116899..0386de004 100644 --- a/sdv/constraints/tabular.py +++ b/sdv/constraints/tabular.py @@ -50,6 +50,7 @@ get_datetime_diff, get_mappable_combination, logit, + match_datetime_precision, matches_datetime_format, revert_nans_columns, sigmoid, @@ -484,6 +485,14 @@ def is_valid(self, table_data): low = cast_to_datetime64(low, self._low_datetime_format) high = cast_to_datetime64(high, self._high_datetime_format) + if self._low_datetime_format != self._high_datetime_format: + low, high = match_datetime_precision( + low=low, + high=high, + low_datetime_format=self._low_datetime_format, + high_datetime_format=self._high_datetime_format, + ) + valid = pd.isna(low) | pd.isna(high) | self._operator(high, low) return valid diff --git a/sdv/constraints/utils.py b/sdv/constraints/utils.py index c395d3a29..d90ce59fc 100644 --- a/sdv/constraints/utils.py +++ b/sdv/constraints/utils.py @@ -1,11 +1,39 @@ """Constraint utility functions.""" +import re from datetime import datetime from decimal import Decimal import numpy as np import pandas as pd +PRECISION_LEVELS = { + '%Y': 1, # Year + '%y': 1, # Year without century (same precision as %Y) + '%B': 2, # Full month name + '%b': 2, # Abbreviated month name + '%m': 2, # Month as a number + '%d': 3, # Day of the month + '%j': 3, # Day of the year + '%U': 3, # Week number (Sunday-starting) + '%W': 3, # Week number (Monday-starting) + '%A': 3, # Full weekday name + '%a': 3, # Abbreviated weekday name + '%w': 3, # Weekday as a decimal + '%H': 4, # Hour (24-hour clock) + '%I': 4, # Hour (12-hour clock) + '%M': 5, # Minute + '%S': 6, # Second + '%f': 7, # Microsecond + # Formats that don't add precision + '%p': 0, # AM/PM + '%z': 0, # UTC offset + '%Z': 0, # Time zone name + '%c': 0, # Locale-based date/time + '%x': 0, # Locale-based date + '%X': 0, # Locale-based time +} + def cast_to_datetime64(value, datetime_format=None): """Cast a given value to a ``numpy.datetime64`` format. @@ -199,6 +227,14 @@ def get_datetime_diff(high, low, high_datetime_format=None, low_datetime_format= low = cast_to_datetime64(low, low_datetime_format) high = cast_to_datetime64(high, high_datetime_format) + if low_datetime_format != high_datetime_format: + low, high = match_datetime_precision( + low=low, + high=high, + low_datetime_format=low_datetime_format, + high_datetime_format=high_datetime_format, + ) + diff_column = high - low nan_mask = pd.isna(diff_column) diff_column = diff_column.astype(np.float64) @@ -221,3 +257,98 @@ def get_mappable_combination(combination): A mappable combination of values. """ return tuple(None if pd.isna(x) else x for x in combination) + + +def match_datetime_precision(low, high, low_datetime_format, high_datetime_format): + """Match `low` or `high` datetime array to match the lower precision format. + + Args: + low (np.ndarray): + Array of datetime values for the low column. + high (np.ndarray): + Array of datetime values for the high column. + low_datetime_format (str): + The datetime format of the `low` column. + high_datetime_format (str): + The datetime format of the `high` column. + + Returns: + Tuple[np.ndarray, np.ndarray]: + Adjusted `low` and `high` arrays where the higher precision format is + downcasted to the lower precision format. + """ + lower_precision_format = get_lower_precision_format(low_datetime_format, high_datetime_format) + if lower_precision_format == high_datetime_format: + low = downcast_datetime_to_lower_format(low, lower_precision_format) + else: + high = downcast_datetime_to_lower_format(high, lower_precision_format) + + return low, high + + +def get_datetime_format_precision(format_str): + """Return the precision level of a datetime format string.""" + # Find all format codes in the format string + found_formats = re.findall(r'%[A-Za-z]', format_str) + found_levels = ( + PRECISION_LEVELS.get(found_format) + for found_format in found_formats + if found_format in PRECISION_LEVELS + ) + + return max(found_levels, default=0) + + +def get_lower_precision_format(primary_format, secondary_format): + """Compare two datetime format strings and return the one with lower precision. + + Args: + primary_format (str): + The first datetime format string to compare. + low_precision_format (str): + The second datetime format string to compare. + + Returns: + str: + The datetime format string with the lower precision level. + """ + primary_level = get_datetime_format_precision(primary_format) + secondary_level = get_datetime_format_precision(secondary_format) + if primary_level >= secondary_level: + return secondary_format + + return primary_format + + +def downcast_datetime_to_lower_format(data, target_format): + """Convert a datetime string from a higher-precision format to a lower-precision format. + + Args: + data (np.array): + The data to cast to the `target_format`. + target_format (str): + The datetime string to downcast. + + Returns: + str: The datetime string in the lower precision format. + """ + downcasted_data = format_datetime_array(data, target_format) + return cast_to_datetime64(downcasted_data, target_format) + + +def format_datetime_array(datetime_array, target_format): + """Format each element in a numpy datetime64 array to a specified string format. + + Args: + datetime_array (np.ndarray): + Array of datetime64[ns] elements. + target_format (str): + The datetime format to cast each element to. + + Returns: + np.ndarray: Array of formatted datetime strings. + """ + return np.array([ + pd.to_datetime(date).strftime(target_format) if not pd.isna(date) else pd.NaT + for date in datetime_array + ]) diff --git a/tests/integration/constraints/test_tabular.py b/tests/integration/constraints/test_tabular.py index 32bb6e7d2..05f6dad27 100644 --- a/tests/integration/constraints/test_tabular.py +++ b/tests/integration/constraints/test_tabular.py @@ -1,7 +1,9 @@ import numpy as np import pandas as pd +import pytest -from sdv.metadata import SingleTableMetadata +from sdv.errors import ConstraintsNotMetError +from sdv.metadata import Metadata from sdv.single_table import GaussianCopulaSynthesizer @@ -11,7 +13,7 @@ def test_fixed_combinations_integers(): 'A': [1, 2, 3, 1, 2, 1], 'B': [10, 20, 30, 10, 20, 10], }) - metadata = SingleTableMetadata().load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'categorical'}, 'B': {'sdtype': 'categorical'}, @@ -45,7 +47,7 @@ def test_fixed_combinations_with_nans(): 'A': [1, 2, np.nan, 1, 2, 1], 'B': [10, 20, 30, 10, 20, 10], }) - metadata = SingleTableMetadata().load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'categorical'}, 'B': {'sdtype': 'categorical'}, @@ -81,7 +83,7 @@ def test_fixedincrements_with_nullable_pandas_dtypes(): 'UInt32': pd.Series([1, pd.NA, 5], dtype='UInt32') * 10, 'UInt64': pd.Series([1, pd.NA, 6], dtype='UInt64') * 10, }) - metadata = SingleTableMetadata().load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'UInt8': {'sdtype': 'numerical', 'computer_representation': 'UInt8'}, 'UInt16': {'sdtype': 'numerical', 'computer_representation': 'UInt16'}, @@ -107,3 +109,122 @@ def test_fixedincrements_with_nullable_pandas_dtypes(): synthetic_data.dtypes.to_dict() == data.dtypes.to_dict() for column in data.columns: assert np.all(synthetic_data[column] % 10 == 0) + + +def test_inequality_constraint_with_timestamp_and_date(): + """Test that the inequality constraint passes without strict boundaries. + + This test checks if the `Inequality` constraint can handle two columns + with different datetime formats when `strict_boundaries` is set to `False`. + The constraint allows the `SUBMISSION_TIMESTAMP` column to be less than + or equal to the `DUE_DATE` column, even when they differ in precision but end + within the same day. + """ + # Setup + data = pd.DataFrame( + data={ + 'SUBMISSION_TIMESTAMP': [ + '2016-07-10 17:04:00', + '2016-07-11 13:23:00', + '2016-07-12 08:45:30', + '2016-07-11 12:00:00', + '2016-07-12 10:30:00', + ], + 'DUE_DATE': ['2016-07-10', '2016-07-11', '2016-07-12', '2016-07-13', '2016-07-14'], + } + ) + + metadata = Metadata.load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'SUBMISSION_TIMESTAMP': { + 'sdtype': 'datetime', + 'datetime_format': '%Y-%m-%d %H:%M:%S', + }, + 'DUE_DATE': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + } + } + } + }) + synthesizer = GaussianCopulaSynthesizer(metadata) + + constraint = { + 'constraint_class': 'Inequality', + 'constraint_parameters': { + 'low_column_name': 'SUBMISSION_TIMESTAMP', + 'high_column_name': 'DUE_DATE', + 'strict_boundaries': False, + }, + } + + synthesizer.add_constraints([constraint]) + + # Run + synthesizer.fit(data) + synthetic_data = synthesizer.sample(num_rows=10) + + # Assert + synthetic_data['SUBMISSION_TIMESTAMP'] = pd.to_datetime( + synthetic_data['SUBMISSION_TIMESTAMP'], errors='coerce' + ) + synthetic_data['DUE_DATE'] = pd.to_datetime(synthetic_data['DUE_DATE'], errors='coerce') + invalid_rows = synthetic_data[ + synthetic_data['SUBMISSION_TIMESTAMP'].dt.date > synthetic_data['DUE_DATE'].dt.date + ] + assert invalid_rows.empty + + +def test_inequality_constraint_with_timestamp_and_date_strict_boundaries(): + """Test that the inequality constraint fails with strict boundaries. + + This test evaluates the `Inequality` constraint when `strict_boundaries` + is set to `True`. The `SUBMISSION_TIMESTAMP` column values must be strictly + less than the `DUE_DATE` values to satisfy the constraint. If any + `SUBMISSION_TIMESTAMP` matches or exceeds the `DUE_DATE`, an error should + be raised. + """ + # Setup + data = pd.DataFrame( + data={ + 'SUBMISSION_TIMESTAMP': [ + '2016-07-10 17:04:00', + '2016-07-11 13:23:00', + '2016-07-12 08:45:30', + '2016-07-11 12:00:00', + '2016-07-12 10:30:00', + ], + 'DUE_DATE': ['2016-07-10', '2016-07-11', '2016-07-12', '2016-07-13', '2016-07-14'], + } + ) + + metadata = Metadata.load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'SUBMISSION_TIMESTAMP': { + 'sdtype': 'datetime', + 'datetime_format': '%Y-%m-%d %H:%M:%S', + }, + 'DUE_DATE': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + } + } + } + }) + + synthesizer = GaussianCopulaSynthesizer(metadata) + constraint = { + 'constraint_class': 'Inequality', + 'constraint_parameters': { + 'low_column_name': 'SUBMISSION_TIMESTAMP', + 'high_column_name': 'DUE_DATE', + 'strict_boundaries': True, + }, + } + synthesizer.add_constraints([constraint]) + + # Run and Assert + error_msg = "Data is not valid for the 'Inequality' constraint: " + with pytest.raises(ConstraintsNotMetError) as error: + synthesizer.fit(data) + assert error_msg in error