From 67b1f48466ebefb5f8056967edd56711f912ca4e Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Fri, 29 Mar 2024 09:19:18 -0400 Subject: [PATCH] add more standard train/test split function Signed-off-by: Wesley M. Gifford --- .../toolkit/test_time_series_preprocessor.py | 23 ++++ tests/toolkit/test_util.py | 27 ++++- .../toolkit/time_series_preprocessor.py | 9 +- tsfm_public/toolkit/util.py | 100 +++++++++++++++--- 4 files changed, 138 insertions(+), 21 deletions(-) diff --git a/tests/toolkit/test_time_series_preprocessor.py b/tests/toolkit/test_time_series_preprocessor.py index f5f2c12d..9f599287 100644 --- a/tests/toolkit/test_time_series_preprocessor.py +++ b/tests/toolkit/test_time_series_preprocessor.py @@ -227,3 +227,26 @@ def test_get_datasets(ts_data): assert len(train) == fewshot_train_size assert len(valid) == len(test) + + # fraction splits + # no id columns, so treat as one big time series + tsp = TimeSeriesPreprocessor( + id_columns=[], + target_columns=["value1", "value2"], + prediction_length=5, + context_length=10, + ) + + train, valid, test = tsp.get_datasets( + ts_data, + split_config={ + "train": 0.7, + "test": 0.2, + }, + ) + + assert len(train) == int(150 * 0.7) - (tsp.context_length + tsp.prediction_length) + 1 + + assert len(test) == int(150 * 0.2) - (tsp.context_length + tsp.prediction_length) + 1 + + assert len(valid) == 150 - int(150 * 0.2) - int(150 * 0.7) - (tsp.context_length + tsp.prediction_length) + 1 diff --git a/tests/toolkit/test_util.py b/tests/toolkit/test_util.py index 9c01321b..c2ff1169 100644 --- a/tests/toolkit/test_util.py +++ b/tests/toolkit/test_util.py @@ -1,8 +1,9 @@ """Tests for util functions""" +import pandas as pd import pytest -from tsfm_public.toolkit.util import get_split_params +from tsfm_public.toolkit.util import get_split_params, train_test_split split_cases = [ @@ -24,3 +25,27 @@ def test_get_split_params(left_arg, right_arg, expected): split_config, split_function = get_split_params({"train": [left_arg, right_arg], "valid": [0, 1], "test": [0, 1]}) assert split_function["train"].__name__ == expected + + +def test_train_test_split(): + n = 100 + df = pd.DataFrame({"date": range(n), "value": range(n)}) + + train, valid, test = train_test_split(df, train=0.7, test=0.2) + + assert len(train) == int(n * 0.7) + assert len(test) == int(n * 0.2) + valid_len_100 = n - int(n * 0.7) - int(n * 0.2) + assert len(valid) == valid_len_100 + + n = 101 + df = pd.DataFrame({"date": range(n), "value": range(n)}) + + train, valid, test = train_test_split(df, train=0.7, test=0.2) + + assert len(train) == int(n * 0.7) + assert len(test) == int(n * 0.2) + valid_len_101 = n - int(n * 0.7) - int(n * 0.2) + assert len(valid) == valid_len_101 + + assert valid_len_100 + 1 == valid_len_101 diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index 172a0dc0..61aba7db 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -642,9 +642,12 @@ def get_datasets( } # split data - train_data = split_function["train"](data, id_columns=self.id_columns, **split_params["train"]) - valid_data = split_function["valid"](data, id_columns=self.id_columns, **split_params["valid"]) - test_data = split_function["test"](data, id_columns=self.id_columns, **split_params["test"]) + if isinstance(split_function, dict): + train_data = split_function["train"](data, id_columns=self.id_columns, **split_params["train"]) + valid_data = split_function["valid"](data, id_columns=self.id_columns, **split_params["valid"]) + test_data = split_function["test"](data, id_columns=self.id_columns, **split_params["test"]) + else: + train_data, valid_data, test_data = split_function(data, id_columns=self.id_columns, **split_params) # data preprocessing self.train(train_data) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index c8a434f9..7d2a2612 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -201,6 +201,63 @@ def select_by_fixed_fraction( return pd.concat(result) +def train_test_split( + df: pd.DataFrame, + id_columns: Optional[List[str]] = None, + train: Union[int, float] = 0.7, + test: Union[int, float] = 0.2, + valid_test_offset: int = 0, +): + # to do: add validation + + if not id_columns: + return tuple([tmp.copy() for tmp in _split_group_train_test(df, train=train, test=test)]) + + groups = df.groupby(_get_groupby_columns(id_columns)) + result = [] + for name, group in groups: + result.append( + _split_group_train_test( + group, + name=name, + train=train, + test=test, + valid_test_offset=valid_test_offset, + ) + ) + + result_train, result_valid, result_test = zip(**result) + return pd.concat(result_train), pd.concat(result_valid), pd.concat(result_test) + + +def _split_group_train_test( + group_df: pd.DataFrame, + name: Optional[str] = None, + train: Union[int, float] = 0.7, + test: Union[int, float] = 0.2, + valid_test_offset: int = 0, +): + l = len(group_df) + + train_size = int(l * train) + test_size = int(l * test) + + valid_size = l - train_size - test_size + + train_df = _split_group_by_index(group_df, name, start_index=0, end_index=train_size) + + valid_df = _split_group_by_index( + group_df, + name, + start_index=train_size - valid_test_offset, + end_index=train_size + valid_size, + ) + + test_df = _split_group_by_index(group_df, name, start_index=train_size + valid_size - valid_test_offset) + + return train_df, valid_df, test_df + + def _get_groupby_columns(id_columns: List[str]) -> Union[List[str], str]: if not isinstance(id_columns, (List)): raise ValueError("id_columns must be a list") @@ -440,23 +497,32 @@ def get_split_params( split_params = {} split_function = {} - for group in ["train", "test", "valid"]: - if ((split_config[group][0] < 1) and (split_config[group][0] != 0)) or (split_config[group][1] < 1): - split_params[group] = { - "start_fraction": split_config[group][0], - "end_fraction": split_config[group][1], - "start_offset": (context_length if (context_length and group != "train") else 0), - } - split_function[group] = select_by_relative_fraction - else: - split_params[group] = { - "start_index": ( - split_config[group][0] - (context_length if (context_length and group != "train") else 0) - ), - "end_index": split_config[group][1], - } - split_function[group] = select_by_index - + if "valid" in split_config: + for group in ["train", "test", "valid"]: + if ((split_config[group][0] < 1) and (split_config[group][0] != 0)) or (split_config[group][1] < 1): + split_params[group] = { + "start_fraction": split_config[group][0], + "end_fraction": split_config[group][1], + "start_offset": (context_length if (context_length and group != "train") else 0), + } + split_function[group] = select_by_relative_fraction + else: + split_params[group] = { + "start_index": ( + split_config[group][0] - (context_length if (context_length and group != "train") else 0) + ), + "end_index": split_config[group][1], + } + split_function[group] = select_by_index + return split_params, split_function + + # no valid, assume train/test split + split_function = train_test_split + split_params = { + "train": split_config["train"], + "test": split_config["test"], + "valid_test_offset": context_length if context_length else 0, + } return split_params, split_function