Skip to content

Commit

Permalink
MismatchedSeriesLength Datacheck (#4296)
Browse files Browse the repository at this point in the history
* Added mismatch series length data check
  • Loading branch information
MichaelFu512 authored Sep 5, 2023
1 parent 93e7b97 commit 1329988
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 1 deletion.
3 changes: 2 additions & 1 deletion docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ Release Notes
**Future Releases**
* Enhancements
* Extended STLDecomposer to Support Multiseries :pr:`4253`
* Extended TimeSeriesImputer to handle multiseries :pr:`4291`
* Added datacheck to check for mismatched series length in multiseries :pr:`4296`
* Fixes
* Changes
* Documentation Changes
Expand All @@ -17,7 +19,6 @@ Release Notes
* Enhancements
* Added support for prediction intervals for VARMAX regressor :pr:`4267`
* Integrated multiseries time series into AutoMLSearch :pr:`4270`
* Extended TimeSeriesImputer to handle multiple series :pr:`4291`
* Fixes
* Fixed error when stacking data with no exogenous variables :pr:`4275`
* Changes
Expand Down
3 changes: 3 additions & 0 deletions evalml/data_checks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@
from evalml.data_checks.datetime_format_data_check import DateTimeFormatDataCheck
from evalml.data_checks.ts_parameters_data_check import TimeSeriesParametersDataCheck
from evalml.data_checks.ts_splitting_data_check import TimeSeriesSplittingDataCheck
from evalml.data_checks.mismatched_series_length_data_check import (
MismatchedSeriesLengthDataCheck,
)
6 changes: 6 additions & 0 deletions evalml/data_checks/data_check_message_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,9 @@ class DataCheckMessageCode(Enum):
"timeseries_target_not_compatible_with_split"
)
"""Message code when any training and validation split of the time series target doesn't contain all classes."""

MISMATCHED_SERIES_LENGTH = "mismatched_series_length"
"""Message code for when one or more unique series in a multiseries dataset is of a different length than the others"""

INVALID_SERIES_ID_COL = "invalid_series_id_col"
"""Message code for when given series_id is invalid"""
184 changes: 184 additions & 0 deletions evalml/data_checks/mismatched_series_length_data_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""Data check that checks if one or more unique series in a multiseres data is a different length than the others."""

from evalml.data_checks import (
DataCheck,
DataCheckError,
DataCheckMessageCode,
DataCheckWarning,
)


class MismatchedSeriesLengthDataCheck(DataCheck):
"""Check if one or more unique series in a multiseries dataset is of a different length than the others.
Currently works specifically on stacked data
Args:
series_id (str): The name of the series_id column for the dataset.
"""

def __init__(self, series_id):
if series_id is None:
raise ValueError(
"series_id must be set to the series_id column in the dataset and not None",
)
self.series_id = series_id

def validate(self, X, y=None):
"""Check if one or more unique series in a multiseries dataset is of a different length than the other.
Currently works specifically on stacked data
Args:
X (pd.DataFrame, np.ndarray): The input features to check. Must have a series_id column.
y (pd.Series): The target. Defaults to None. Ignored.
Returns:
dict (DataCheckWarning, DataCheckError): List with DataCheckWarning if there are mismatch series length in the datasets
or list with DataCheckError if the given series_id is not in the dataset
Examples:
>>> import pandas as pd
For multiseries time series datasets, each seriesID should ideally have the same number of datetime entries as
each other. If they don't, then a warning will be raised denoting which seriesID have mismatched lengths.
>>> X = pd.DataFrame(
... {
... "date": pd.date_range(start="1/1/2018", periods=20).repeat(5),
... "series_id": pd.Series(list(range(5)) * 20, dtype="str"),
... "feature_a": range(100),
... "feature_b": reversed(range(100)),
... },
... )
>>> X = X.drop(labels=0, axis=0)
>>> mismatched_series_length_check = MismatchedSeriesLengthDataCheck("series_id")
>>> assert mismatched_series_length_check.validate(X) == [
... {
... "message": "Series ID ['0'] do not match the majority length of the other series, which is 20",
... "data_check_name": "MismatchedSeriesLengthDataCheck",
... "level": "warning",
... "details": {
... "columns": None,
... "rows": None,
... "series_id": ['0'],
... "majority_length": 20
... },
... "code": "MISMATCHED_SERIES_LENGTH",
... "action_options": [],
... }
... ]
If MismatchedSeriesLengthDataCheck is passed in an invalid series_id column name, then an error will be raised.
>>> X = pd.DataFrame(
... {
... "date": pd.date_range(start="1/1/2018", periods=20).repeat(5),
... "series_id": pd.Series(list(range(5)) * 20, dtype="str"),
... "feature_a": range(100),
... "feature_b": reversed(range(100)),
... },
... )
>>> X = X.drop(labels=0, axis=0)
>>> mismatched_series_length_check = MismatchedSeriesLengthDataCheck("not_series_id")
>>> assert mismatched_series_length_check.validate(X) == [
... {
... "message": "series_id 'not_series_id' is not in the dataset.",
... "data_check_name": "MismatchedSeriesLengthDataCheck",
... "level": "error",
... "details": {
... "columns": None,
... "rows": None,
... "series_id": "not_series_id",
... },
... "code": "INVALID_SERIES_ID_COL",
... "action_options": [],
... }
... ]
If there are multiple lengths that have the same number of series (e.g. two series have length 20 and two series have length 19),
this datacheck will consider the higher length to be the majority length (e.g. from the previous example length 20 would be the majority length)
>>> X = pd.DataFrame(
... {
... "date": pd.date_range(start="1/1/2018", periods=20).repeat(4),
... "series_id": pd.Series(list(range(4)) * 20, dtype="str"),
... "feature_a": range(80),
... "feature_b": reversed(range(80)),
... },
... )
>>> X = X.drop(labels=[0, 1], axis=0)
>>> mismatched_series_length_check = MismatchedSeriesLengthDataCheck("series_id")
>>> assert mismatched_series_length_check.validate(X) == [
... {
... "message": "Series ID ['0', '1'] do not match the majority length of the other series, which is 20",
... "data_check_name": "MismatchedSeriesLengthDataCheck",
... "level": "warning",
... "details": {
... "columns": None,
... "rows": None,
... "series_id": ['0', '1'],
... "majority_length": 20
... },
... "code": "MISMATCHED_SERIES_LENGTH",
... "action_options": [],
... }
... ]
"""
messages = []
if self.series_id not in X:
messages.append(
DataCheckError(
message=f"""series_id '{self.series_id}' is not in the dataset.""",
data_check_name=self.name,
message_code=DataCheckMessageCode.INVALID_SERIES_ID_COL,
details={"series_id": self.series_id},
).to_dict(),
)
return messages

# gets all the number of entries per series_id
series_id_len = {
id: len(X[X[self.series_id] == id]) for id in X[self.series_id].unique()
}

# dictionary where {length: number of series with that length}
tracker = {}
for series_length in series_id_len.values():
if series_length not in tracker:
tracker[series_length] = 0
else:
tracker[series_length] += 1

if len(tracker) == 1:
return messages

majority_len = max(tracker, key=tracker.get)

# get the series_id's that aren't the majority length
not_majority = [id for id in series_id_len if series_id_len[id] != majority_len]

if len(not_majority) > 0 and len(not_majority) < len(series_id_len) - 1:
warning_msg = f"Series ID {not_majority} do not match the majority length of the other series, which is {majority_len}"
messages.append(
DataCheckWarning(
message=warning_msg,
data_check_name=self.name,
message_code=DataCheckMessageCode.MISMATCHED_SERIES_LENGTH,
details={
"series_id": not_majority,
"majority_length": majority_len,
},
action_options=[],
).to_dict(),
)
elif len(not_majority) == len(series_id_len) - 1:
warning_msg = "All series ID have different lengths than each other"
messages.append(
DataCheckWarning(
message=warning_msg,
data_check_name=self.name,
message_code=DataCheckMessageCode.MISMATCHED_SERIES_LENGTH,
action_options=[],
).to_dict(),
)
return messages
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pytest

from evalml.data_checks import (
DataCheckError,
DataCheckMessageCode,
DataCheckWarning,
MismatchedSeriesLengthDataCheck,
)

mismatch_series_length_dc_name = MismatchedSeriesLengthDataCheck.name


def test_mismatched_series_length_data_check_raises_value_error(
multiseries_ts_data_stacked,
):
with pytest.raises(
ValueError,
match="series_id must be set to the series_id column in the dataset and not None",
):
MismatchedSeriesLengthDataCheck(None)


def test_mistmatched_series_length_data_check_data_check_error(
multiseries_ts_data_stacked,
):
X, _ = multiseries_ts_data_stacked
dc = MismatchedSeriesLengthDataCheck("not_series_id")
messages = dc.validate(X)
assert len(messages) == 1
assert messages == [
DataCheckError(
message="""series_id 'not_series_id' is not in the dataset.""",
data_check_name=mismatch_series_length_dc_name,
message_code=DataCheckMessageCode.INVALID_SERIES_ID_COL,
details={"series_id": "not_series_id"},
action_options=[],
).to_dict(),
]


@pytest.mark.parametrize(
"num_drop, not_majority, majority_length",
[(1, ["0"], 20), (2, ["0", "1"], 20), (3, ["3", "4"], 19)],
)
def test_mismatched_series_length_data_check(
multiseries_ts_data_stacked,
num_drop,
not_majority,
majority_length,
):
X, _ = multiseries_ts_data_stacked
for i in range(num_drop):
X = X.drop(labels=0, axis=0).reset_index(drop=True)
mismatch_series_length_dc = MismatchedSeriesLengthDataCheck("series_id")
messages = mismatch_series_length_dc.validate(X)
assert len(messages) == 1
assert messages == [
DataCheckWarning(
message=f"Series ID {not_majority} do not match the majority length of the other series, which is {majority_length}",
data_check_name=mismatch_series_length_dc_name,
message_code=DataCheckMessageCode.MISMATCHED_SERIES_LENGTH,
details={"series_id": not_majority, "majority_length": majority_length},
action_options=[],
).to_dict(),
]


def test_mismatched_series_length_data_check_all(multiseries_ts_data_stacked):
rows_index = [0, 1, 2, 3, 5, 6, 7, 10, 11, 15]
X, _ = multiseries_ts_data_stacked
X = X.drop(rows_index).reset_index(drop=True)
mismatch_series_length_dc = MismatchedSeriesLengthDataCheck("series_id")
messages = mismatch_series_length_dc.validate(X)
assert len(messages) == 1
assert messages == [
DataCheckWarning(
message="All series ID have different lengths than each other",
data_check_name=mismatch_series_length_dc_name,
message_code=DataCheckMessageCode.MISMATCHED_SERIES_LENGTH,
action_options=[],
).to_dict(),
]


def test_mismatched_series_length_data_check_no_mismatch(multiseries_ts_data_stacked):
X, _ = multiseries_ts_data_stacked
mismatch_series_length_dc = MismatchedSeriesLengthDataCheck("series_id")
messages = mismatch_series_length_dc.validate(X)
assert len(messages) == 0
assert messages == []

0 comments on commit 1329988

Please sign in to comment.