Skip to content

Commit

Permalink
add more standard train/test split function
Browse files Browse the repository at this point in the history
Signed-off-by: Wesley M. Gifford <[email protected]>
  • Loading branch information
wgifford committed Mar 29, 2024
1 parent e072c64 commit 67b1f48
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 21 deletions.
23 changes: 23 additions & 0 deletions tests/toolkit/test_time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,26 @@ def test_get_datasets(ts_data):
assert len(train) == fewshot_train_size

assert len(valid) == len(test)

# fraction splits
# no id columns, so treat as one big time series
tsp = TimeSeriesPreprocessor(
id_columns=[],
target_columns=["value1", "value2"],
prediction_length=5,
context_length=10,
)

train, valid, test = tsp.get_datasets(
ts_data,
split_config={
"train": 0.7,
"test": 0.2,
},
)

assert len(train) == int(150 * 0.7) - (tsp.context_length + tsp.prediction_length) + 1

assert len(test) == int(150 * 0.2) - (tsp.context_length + tsp.prediction_length) + 1

assert len(valid) == 150 - int(150 * 0.2) - int(150 * 0.7) - (tsp.context_length + tsp.prediction_length) + 1
27 changes: 26 additions & 1 deletion tests/toolkit/test_util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Tests for util functions"""

import pandas as pd
import pytest

from tsfm_public.toolkit.util import get_split_params
from tsfm_public.toolkit.util import get_split_params, train_test_split


split_cases = [
Expand All @@ -24,3 +25,27 @@ def test_get_split_params(left_arg, right_arg, expected):
split_config, split_function = get_split_params({"train": [left_arg, right_arg], "valid": [0, 1], "test": [0, 1]})

assert split_function["train"].__name__ == expected


def test_train_test_split():
n = 100
df = pd.DataFrame({"date": range(n), "value": range(n)})

train, valid, test = train_test_split(df, train=0.7, test=0.2)

assert len(train) == int(n * 0.7)
assert len(test) == int(n * 0.2)
valid_len_100 = n - int(n * 0.7) - int(n * 0.2)
assert len(valid) == valid_len_100

n = 101
df = pd.DataFrame({"date": range(n), "value": range(n)})

train, valid, test = train_test_split(df, train=0.7, test=0.2)

assert len(train) == int(n * 0.7)
assert len(test) == int(n * 0.2)
valid_len_101 = n - int(n * 0.7) - int(n * 0.2)
assert len(valid) == valid_len_101

assert valid_len_100 + 1 == valid_len_101
9 changes: 6 additions & 3 deletions tsfm_public/toolkit/time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,9 +642,12 @@ def get_datasets(
}

# split data
train_data = split_function["train"](data, id_columns=self.id_columns, **split_params["train"])
valid_data = split_function["valid"](data, id_columns=self.id_columns, **split_params["valid"])
test_data = split_function["test"](data, id_columns=self.id_columns, **split_params["test"])
if isinstance(split_function, dict):
train_data = split_function["train"](data, id_columns=self.id_columns, **split_params["train"])
valid_data = split_function["valid"](data, id_columns=self.id_columns, **split_params["valid"])
test_data = split_function["test"](data, id_columns=self.id_columns, **split_params["test"])
else:
train_data, valid_data, test_data = split_function(data, id_columns=self.id_columns, **split_params)

# data preprocessing
self.train(train_data)
Expand Down
100 changes: 83 additions & 17 deletions tsfm_public/toolkit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,63 @@ def select_by_fixed_fraction(
return pd.concat(result)


def train_test_split(
df: pd.DataFrame,
id_columns: Optional[List[str]] = None,
train: Union[int, float] = 0.7,
test: Union[int, float] = 0.2,
valid_test_offset: int = 0,
):
# to do: add validation

if not id_columns:
return tuple([tmp.copy() for tmp in _split_group_train_test(df, train=train, test=test)])

groups = df.groupby(_get_groupby_columns(id_columns))
result = []
for name, group in groups:
result.append(
_split_group_train_test(
group,
name=name,
train=train,
test=test,
valid_test_offset=valid_test_offset,
)
)

result_train, result_valid, result_test = zip(**result)
return pd.concat(result_train), pd.concat(result_valid), pd.concat(result_test)


def _split_group_train_test(
group_df: pd.DataFrame,
name: Optional[str] = None,
train: Union[int, float] = 0.7,
test: Union[int, float] = 0.2,
valid_test_offset: int = 0,
):
l = len(group_df)

train_size = int(l * train)
test_size = int(l * test)

valid_size = l - train_size - test_size

train_df = _split_group_by_index(group_df, name, start_index=0, end_index=train_size)

valid_df = _split_group_by_index(
group_df,
name,
start_index=train_size - valid_test_offset,
end_index=train_size + valid_size,
)

test_df = _split_group_by_index(group_df, name, start_index=train_size + valid_size - valid_test_offset)

return train_df, valid_df, test_df


def _get_groupby_columns(id_columns: List[str]) -> Union[List[str], str]:
if not isinstance(id_columns, (List)):
raise ValueError("id_columns must be a list")
Expand Down Expand Up @@ -440,23 +497,32 @@ def get_split_params(
split_params = {}
split_function = {}

for group in ["train", "test", "valid"]:
if ((split_config[group][0] < 1) and (split_config[group][0] != 0)) or (split_config[group][1] < 1):
split_params[group] = {
"start_fraction": split_config[group][0],
"end_fraction": split_config[group][1],
"start_offset": (context_length if (context_length and group != "train") else 0),
}
split_function[group] = select_by_relative_fraction
else:
split_params[group] = {
"start_index": (
split_config[group][0] - (context_length if (context_length and group != "train") else 0)
),
"end_index": split_config[group][1],
}
split_function[group] = select_by_index

if "valid" in split_config:
for group in ["train", "test", "valid"]:
if ((split_config[group][0] < 1) and (split_config[group][0] != 0)) or (split_config[group][1] < 1):
split_params[group] = {
"start_fraction": split_config[group][0],
"end_fraction": split_config[group][1],
"start_offset": (context_length if (context_length and group != "train") else 0),
}
split_function[group] = select_by_relative_fraction
else:
split_params[group] = {
"start_index": (
split_config[group][0] - (context_length if (context_length and group != "train") else 0)
),
"end_index": split_config[group][1],
}
split_function[group] = select_by_index
return split_params, split_function

# no valid, assume train/test split
split_function = train_test_split
split_params = {
"train": split_config["train"],
"test": split_config["test"],
"valid_test_offset": context_length if context_length else 0,
}
return split_params, split_function


Expand Down

0 comments on commit 67b1f48

Please sign in to comment.