Skip to content

Commit

Permalink
update fewshot selection, add location, tests
Browse files Browse the repository at this point in the history
Signed-off-by: Wesley M. Gifford <[email protected]>
  • Loading branch information
wgifford committed Mar 25, 2024
1 parent 83a6f99 commit 2f54d42
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 5 deletions.
29 changes: 28 additions & 1 deletion tests/toolkit/test_time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
create_timestamps,
extend_time_series,
)
from tsfm_public.toolkit.util import FractionLocation


def test_standard_scaler(sample_data):
Expand Down Expand Up @@ -192,9 +193,35 @@ def test_get_datasets(ts_data):
"test": [125, 150],
},
fewshot_fraction=0.2,
fewshot_location=FractionLocation.LAST.value,
)

# new train length should be 20% of 100, minus the usual for context length and prediction length
assert len(train) == (int(100 * 0.2) - (tsp.context_length + tsp.prediction_length) + 1)
fewshot_train_size = int(100 * 0.2) - (tsp.context_length + tsp.prediction_length) + 1
assert len(train) == fewshot_train_size

assert len(valid) == len(test)

# no id columns, so treat as one big time series
tsp = TimeSeriesPreprocessor(
id_columns=[],
target_columns=["value1", "value2"],
prediction_length=5,
context_length=10,
)

train, valid, test = tsp.get_datasets(
ts_data,
split_config={
"train": [0, 100],
"valid": [100, 125],
"test": [125, 150],
},
fewshot_fraction=0.2,
fewshot_location=FractionLocation.FIRST.value,
)

# new train length should be 20% of 100, minus the usual for context length and prediction length
assert len(train) == fewshot_train_size

assert len(valid) == len(test)
17 changes: 14 additions & 3 deletions tsfm_public/toolkit/time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

from .dataset import ForecastDFDataset
from .util import (
FractionLocation,
get_split_params,
join_list_without_repeat,
select_by_relative_fraction,
select_by_fixed_fraction,
)


Expand Down Expand Up @@ -596,6 +597,7 @@ def get_datasets(
dataset: Union[Dataset, pd.DataFrame],
split_config: Dict[str, Any],
fewshot_fraction: Optional[float] = None,
fewshot_location: str = FractionLocation.LAST.value,
) -> Tuple[Any]:
"""Creates the preprocessed pytorch datasets needed for training and evaluation
using the HuggingFace trainer
Expand All @@ -613,6 +615,9 @@ def get_datasets(
fewshot_fraction (float, optional): When non-null, return this percent of the original training
dataset. This is done to support fewshot fine-tuning. The fraction of data chosen is at the
end of the training dataset.
fewshot_location (str): Determines where the fewshot data is chosen. Valid options are "first" and "last"
as described in the enum FewshotLocation. Default is to choose the fewshot data at the end
of the training dataset (i.e., "last").
Returns:
Tuple of pytorch datasets, including: train, validation, test.
Expand Down Expand Up @@ -646,9 +651,15 @@ def get_datasets(

# handle fewshot operation
if fewshot_fraction is not None:
if not ((fewshot_fraction <= 1) and (fewshot_fraction > 0)):
if not ((fewshot_fraction <= 1.0) and (fewshot_fraction > 0.0)):
raise ValueError(f"Fewshot fraction should be between 0 and 1, received {fewshot_fraction}")
train_data = select_by_relative_fraction(train_data, start_fraction=1 - fewshot_fraction, end_fraction=1)

train_data = select_by_fixed_fraction(
train_data,
id_columns=self.id_columns,
fraction=fewshot_fraction,
location=fewshot_location,
)

params = column_specifiers
params["context_length"] = self.context_length
Expand Down
80 changes: 79 additions & 1 deletion tsfm_public/toolkit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
"""Basic functions and utilities"""

import copy
import enum
from datetime import datetime
from distutils.util import strtobool
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import pandas as pd


class FractionLocation(enum.Enum):
"""`Enum` for the different locations where a fraction of data can be chosen."""

FIRST = "first"
LAST = "last"


def select_by_timestamp(
df: pd.DataFrame,
timestamp_column: str = "timestamp",
Expand Down Expand Up @@ -145,6 +153,53 @@ def select_by_relative_fraction(
return pd.concat(result)


def select_by_fixed_fraction(
df: pd.DataFrame,
id_columns: Optional[List[str]] = None,
fraction: float = 1.0,
location: str = FractionLocation.FIRST.value,
) -> pd.DataFrame:
"""Select a portion of a dataset based on a fraction of the data.
Fraction can either be located at the start (location = FractionLocation.FIRST) or at the end (location = FractionLocation.LAST)
Args:
df (pd.DataFrame): Input dataframe.
id_columns (List[str], optional): Columns which specify the IDs in the dataset. Defaults to None.
fraction (float): The fraction to select.
location (str): Location of where to select the fraction Defaults to FractionLocation.FIRST.value.
Raises:
ValueError: Raised when the
Returns:
pd.DataFrame: Subset of the dataframe.
"""

if fraction < 0 or fraction > 1:
raise ValueError("The value of fraction should be between 0 and 1.")

if not id_columns:
return _split_group_by_fixed_fraction(
df,
fraction=fraction,
location=location,
).copy()

groups = df.groupby(_get_groupby_columns(id_columns))
result = []
for name, group in groups:
result.append(
_split_group_by_fixed_fraction(
group,
name=name,
fraction=fraction,
location=location,
)
)

return pd.concat(result)


def _get_groupby_columns(id_columns: List[str]) -> Union[List[str], str]:
if not isinstance(id_columns, (List)):
raise ValueError("id_columns must be a list")
Expand Down Expand Up @@ -202,7 +257,30 @@ def _split_group_by_fraction(
else:
end_index = None

return _split_group_by_index(group_df=group_df, start_index=start_index, end_index=end_index)
return _split_group_by_index(group_df=group_df, name=name, start_index=start_index, end_index=end_index)


def _split_group_by_fixed_fraction(
group_df: pd.DataFrame,
name: Optional[str] = None,
fraction: float = 1.0,
location: Optional[str] = None,
):
l = len(group_df)
fraction_size = int(fraction * l)

if location == FractionLocation.FIRST.value:
start_index = 0
end_index = fraction_size
elif location == FractionLocation.LAST.value:
start_index = l - fraction_size
end_index = l
else:
raise ValueError(
f"`location` should be either `{FractionLocation.FIRST.value}` or `{FractionLocation.LAST.value}`"
)

return _split_group_by_index(group_df=group_df, name=name, start_index=start_index, end_index=end_index)


def convert_tsf_to_dataframe(
Expand Down

0 comments on commit 2f54d42

Please sign in to comment.