Skip to content

Commit

Permalink
Add downcasting to datetimes when comparing two columns.
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Nov 13, 2024
1 parent 3115f6e commit b99571a
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 4 deletions.
9 changes: 9 additions & 0 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
get_datetime_diff,
get_mappable_combination,
logit,
match_datetime_precision,
matches_datetime_format,
revert_nans_columns,
sigmoid,
Expand Down Expand Up @@ -484,6 +485,14 @@ def is_valid(self, table_data):
low = cast_to_datetime64(low, self._low_datetime_format)
high = cast_to_datetime64(high, self._high_datetime_format)

if self._low_datetime_format != self._high_datetime_format:
low, high = match_datetime_precision(
low=low,
high=high,
low_datetime_format=self._low_datetime_format,
high_datetime_format=self._high_datetime_format,
)

valid = pd.isna(low) | pd.isna(high) | self._operator(high, low)
return valid

Expand Down
131 changes: 131 additions & 0 deletions sdv/constraints/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,39 @@
"""Constraint utility functions."""

import re
from datetime import datetime
from decimal import Decimal

import numpy as np
import pandas as pd

PRECISION_LEVELS = {
'%Y': 1, # Year
'%y': 1, # Year without century (same precision as %Y)
'%B': 2, # Full month name
'%b': 2, # Abbreviated month name
'%m': 2, # Month as a number
'%d': 3, # Day of the month
'%j': 3, # Day of the year
'%U': 3, # Week number (Sunday-starting)
'%W': 3, # Week number (Monday-starting)
'%A': 3, # Full weekday name
'%a': 3, # Abbreviated weekday name
'%w': 3, # Weekday as a decimal
'%H': 4, # Hour (24-hour clock)
'%I': 4, # Hour (12-hour clock)
'%M': 5, # Minute
'%S': 6, # Second
'%f': 7, # Microsecond
# Formats that don't add precision
'%p': 0, # AM/PM
'%z': 0, # UTC offset
'%Z': 0, # Time zone name
'%c': 0, # Locale-based date/time
'%x': 0, # Locale-based date
'%X': 0, # Locale-based time
}


def cast_to_datetime64(value, datetime_format=None):
"""Cast a given value to a ``numpy.datetime64`` format.
Expand Down Expand Up @@ -199,6 +227,14 @@ def get_datetime_diff(high, low, high_datetime_format=None, low_datetime_format=
low = cast_to_datetime64(low, low_datetime_format)
high = cast_to_datetime64(high, high_datetime_format)

if low_datetime_format != high_datetime_format:
low, high = match_datetime_precision(
low=low,
high=high,
low_datetime_format=low_datetime_format,
high_datetime_format=high_datetime_format,
)

diff_column = high - low
nan_mask = pd.isna(diff_column)
diff_column = diff_column.astype(np.float64)
Expand All @@ -221,3 +257,98 @@ def get_mappable_combination(combination):
A mappable combination of values.
"""
return tuple(None if pd.isna(x) else x for x in combination)


def match_datetime_precision(low, high, low_datetime_format, high_datetime_format):
"""Match `low` or `high` datetime array to match the lower precision format.
Args:
low (np.ndarray):
Array of datetime values for the low column.
high (np.ndarray):
Array of datetime values for the high column.
low_datetime_format (str):
The datetime format of the `low` column.
high_datetime_format (str):
The datetime format of the `high` column.
Returns:
Tuple[np.ndarray, np.ndarray]:
Adjusted `low` and `high` arrays where the higher precision format is
downcasted to the lower precision format.
"""
lower_precision_format = get_lower_precision_format(low_datetime_format, high_datetime_format)
if lower_precision_format == high_datetime_format:
low = downcast_datetime_to_lower_format(low, lower_precision_format)
else:
high = downcast_datetime_to_lower_format(high, lower_precision_format)

Check warning on line 284 in sdv/constraints/utils.py

View check run for this annotation

Codecov / codecov/patch

sdv/constraints/utils.py#L284

Added line #L284 was not covered by tests

return low, high


def get_datetime_format_precision(format_str):
"""Return the precision level of a datetime format string."""
# Find all format codes in the format string
found_formats = re.findall(r'%[A-Za-z]', format_str)
found_levels = (
PRECISION_LEVELS.get(found_format)
for found_format in found_formats
if found_format in PRECISION_LEVELS
)

return max(found_levels, default=0)


def get_lower_precision_format(primary_format, secondary_format):
"""Compare two datetime format strings and return the one with lower precision.
Args:
primary_format (str):
The first datetime format string to compare.
low_precision_format (str):
The second datetime format string to compare.
Returns:
str:
The datetime format string with the lower precision level.
"""
primary_level = get_datetime_format_precision(primary_format)
secondary_level = get_datetime_format_precision(secondary_format)
if primary_level >= secondary_level:
return secondary_format

return primary_format

Check warning on line 320 in sdv/constraints/utils.py

View check run for this annotation

Codecov / codecov/patch

sdv/constraints/utils.py#L320

Added line #L320 was not covered by tests


def downcast_datetime_to_lower_format(data, target_format):
"""Convert a datetime string from a higher-precision format to a lower-precision format.
Args:
data (np.array):
The data to cast to the `target_format`.
target_format (str):
The datetime string to downcast.
Returns:
str: The datetime string in the lower precision format.
"""
downcasted_data = format_datetime_array(data, target_format)
return cast_to_datetime64(downcasted_data, target_format)


def format_datetime_array(datetime_array, target_format):
"""Format each element in a numpy datetime64 array to a specified string format.
Args:
datetime_array (np.ndarray):
Array of datetime64[ns] elements.
target_format (str):
The datetime format to cast each element to.
Returns:
np.ndarray: Array of formatted datetime strings.
"""
return np.array([
pd.to_datetime(date).strftime(target_format) if not pd.isna(date) else pd.NaT
for date in datetime_array
])
129 changes: 125 additions & 4 deletions tests/integration/constraints/test_tabular.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import pandas as pd
import pytest

from sdv.metadata import SingleTableMetadata
from sdv.errors import ConstraintsNotMetError
from sdv.metadata import Metadata
from sdv.single_table import GaussianCopulaSynthesizer


Expand All @@ -11,7 +13,7 @@ def test_fixed_combinations_integers():
'A': [1, 2, 3, 1, 2, 1],
'B': [10, 20, 30, 10, 20, 10],
})
metadata = SingleTableMetadata().load_from_dict({
metadata = Metadata.load_from_dict({
'columns': {
'A': {'sdtype': 'categorical'},
'B': {'sdtype': 'categorical'},
Expand Down Expand Up @@ -45,7 +47,7 @@ def test_fixed_combinations_with_nans():
'A': [1, 2, np.nan, 1, 2, 1],
'B': [10, 20, 30, 10, 20, 10],
})
metadata = SingleTableMetadata().load_from_dict({
metadata = Metadata.load_from_dict({
'columns': {
'A': {'sdtype': 'categorical'},
'B': {'sdtype': 'categorical'},
Expand Down Expand Up @@ -81,7 +83,7 @@ def test_fixedincrements_with_nullable_pandas_dtypes():
'UInt32': pd.Series([1, pd.NA, 5], dtype='UInt32') * 10,
'UInt64': pd.Series([1, pd.NA, 6], dtype='UInt64') * 10,
})
metadata = SingleTableMetadata().load_from_dict({
metadata = Metadata.load_from_dict({
'columns': {
'UInt8': {'sdtype': 'numerical', 'computer_representation': 'UInt8'},
'UInt16': {'sdtype': 'numerical', 'computer_representation': 'UInt16'},
Expand All @@ -107,3 +109,122 @@ def test_fixedincrements_with_nullable_pandas_dtypes():
synthetic_data.dtypes.to_dict() == data.dtypes.to_dict()
for column in data.columns:
assert np.all(synthetic_data[column] % 10 == 0)


def test_inequality_constraint_with_timestamp_and_date():
"""Test that the inequality constraint passes without strict boundaries.
This test checks if the `Inequality` constraint can handle two columns
with different datetime formats when `strict_boundaries` is set to `False`.
The constraint allows the `SUBMISSION_TIMESTAMP` column to be less than
or equal to the `DUE_DATE` column, even when they differ in precision but end
within the same day.
"""
# Setup
data = pd.DataFrame(
data={
'SUBMISSION_TIMESTAMP': [
'2016-07-10 17:04:00',
'2016-07-11 13:23:00',
'2016-07-12 08:45:30',
'2016-07-11 12:00:00',
'2016-07-12 10:30:00',
],
'DUE_DATE': ['2016-07-10', '2016-07-11', '2016-07-12', '2016-07-13', '2016-07-14'],
}
)

metadata = Metadata.load_from_dict({
'tables': {
'table': {
'columns': {
'SUBMISSION_TIMESTAMP': {
'sdtype': 'datetime',
'datetime_format': '%Y-%m-%d %H:%M:%S',
},
'DUE_DATE': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
}
}
}
})
synthesizer = GaussianCopulaSynthesizer(metadata)

constraint = {
'constraint_class': 'Inequality',
'constraint_parameters': {
'low_column_name': 'SUBMISSION_TIMESTAMP',
'high_column_name': 'DUE_DATE',
'strict_boundaries': False,
},
}

synthesizer.add_constraints([constraint])

# Run
synthesizer.fit(data)
synthetic_data = synthesizer.sample(num_rows=10)

# Assert
synthetic_data['SUBMISSION_TIMESTAMP'] = pd.to_datetime(
synthetic_data['SUBMISSION_TIMESTAMP'], errors='coerce'
)
synthetic_data['DUE_DATE'] = pd.to_datetime(synthetic_data['DUE_DATE'], errors='coerce')
invalid_rows = synthetic_data[
synthetic_data['SUBMISSION_TIMESTAMP'].dt.date > synthetic_data['DUE_DATE'].dt.date
]
assert invalid_rows.empty


def test_inequality_constraint_with_timestamp_and_date_strict_boundaries():
"""Test that the inequality constraint fails with strict boundaries.
This test evaluates the `Inequality` constraint when `strict_boundaries`
is set to `True`. The `SUBMISSION_TIMESTAMP` column values must be strictly
less than the `DUE_DATE` values to satisfy the constraint. If any
`SUBMISSION_TIMESTAMP` matches or exceeds the `DUE_DATE`, an error should
be raised.
"""
# Setup
data = pd.DataFrame(
data={
'SUBMISSION_TIMESTAMP': [
'2016-07-10 17:04:00',
'2016-07-11 13:23:00',
'2016-07-12 08:45:30',
'2016-07-11 12:00:00',
'2016-07-12 10:30:00',
],
'DUE_DATE': ['2016-07-10', '2016-07-11', '2016-07-12', '2016-07-13', '2016-07-14'],
}
)

metadata = Metadata.load_from_dict({
'tables': {
'table': {
'columns': {
'SUBMISSION_TIMESTAMP': {
'sdtype': 'datetime',
'datetime_format': '%Y-%m-%d %H:%M:%S',
},
'DUE_DATE': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
}
}
}
})

synthesizer = GaussianCopulaSynthesizer(metadata)
constraint = {
'constraint_class': 'Inequality',
'constraint_parameters': {
'low_column_name': 'SUBMISSION_TIMESTAMP',
'high_column_name': 'DUE_DATE',
'strict_boundaries': True,
},
}
synthesizer.add_constraints([constraint])

# Run and Assert
error_msg = "Data is not valid for the 'Inequality' constraint: "
with pytest.raises(ConstraintsNotMetError) as error:
synthesizer.fit(data)
assert error_msg in error

0 comments on commit b99571a

Please sign in to comment.