diff --git a/tests/toolkit/test_time_series_preprocessor.py b/tests/toolkit/test_time_series_preprocessor.py index 2de78b49..73ac6bf6 100644 --- a/tests/toolkit/test_time_series_preprocessor.py +++ b/tests/toolkit/test_time_series_preprocessor.py @@ -16,6 +16,7 @@ create_timestamps, extend_time_series, ) +from tsfm_public.toolkit.util import FractionLocation def test_standard_scaler(sample_data): @@ -192,9 +193,35 @@ def test_get_datasets(ts_data): "test": [125, 150], }, fewshot_fraction=0.2, + fewshot_location=FractionLocation.LAST.value, ) # new train length should be 20% of 100, minus the usual for context length and prediction length - assert len(train) == (int(100 * 0.2) - (tsp.context_length + tsp.prediction_length) + 1) + fewshot_train_size = int(100 * 0.2) - (tsp.context_length + tsp.prediction_length) + 1 + assert len(train) == fewshot_train_size + + assert len(valid) == len(test) + + # no id columns, so treat as one big time series + tsp = TimeSeriesPreprocessor( + id_columns=[], + target_columns=["value1", "value2"], + prediction_length=5, + context_length=10, + ) + + train, valid, test = tsp.get_datasets( + ts_data, + split_config={ + "train": [0, 100], + "valid": [100, 125], + "test": [125, 150], + }, + fewshot_fraction=0.2, + fewshot_location=FractionLocation.FIRST.value, + ) + + # new train length should be 20% of 100, minus the usual for context length and prediction length + assert len(train) == fewshot_train_size assert len(valid) == len(test) diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index eb8137b0..89dbe279 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -23,9 +23,10 @@ from .dataset import ForecastDFDataset from .util import ( + FractionLocation, get_split_params, join_list_without_repeat, - select_by_relative_fraction, + select_by_fixed_fraction, ) @@ -596,6 +597,7 @@ def get_datasets( dataset: Union[Dataset, pd.DataFrame], split_config: Dict[str, Any], fewshot_fraction: Optional[float] = None, + fewshot_location: str = FractionLocation.LAST.value, ) -> Tuple[Any]: """Creates the preprocessed pytorch datasets needed for training and evaluation using the HuggingFace trainer @@ -613,6 +615,9 @@ def get_datasets( fewshot_fraction (float, optional): When non-null, return this percent of the original training dataset. This is done to support fewshot fine-tuning. The fraction of data chosen is at the end of the training dataset. + fewshot_location (str): Determines where the fewshot data is chosen. Valid options are "first" and "last" + as described in the enum FewshotLocation. Default is to choose the fewshot data at the end + of the training dataset (i.e., "last"). Returns: Tuple of pytorch datasets, including: train, validation, test. @@ -646,9 +651,15 @@ def get_datasets( # handle fewshot operation if fewshot_fraction is not None: - if not ((fewshot_fraction <= 1) and (fewshot_fraction > 0)): + if not ((fewshot_fraction <= 1.0) and (fewshot_fraction > 0.0)): raise ValueError(f"Fewshot fraction should be between 0 and 1, received {fewshot_fraction}") - train_data = select_by_relative_fraction(train_data, start_fraction=1 - fewshot_fraction, end_fraction=1) + + train_data = select_by_fixed_fraction( + train_data, + id_columns=self.id_columns, + fraction=fewshot_fraction, + location=fewshot_location, + ) params = column_specifiers params["context_length"] = self.context_length diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 044ceed3..71359b1a 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -3,6 +3,7 @@ """Basic functions and utilities""" import copy +import enum from datetime import datetime from distutils.util import strtobool from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -10,6 +11,13 @@ import pandas as pd +class FractionLocation(enum.Enum): + """`Enum` for the different locations where a fraction of data can be chosen.""" + + FIRST = "first" + LAST = "last" + + def select_by_timestamp( df: pd.DataFrame, timestamp_column: str = "timestamp", @@ -145,6 +153,53 @@ def select_by_relative_fraction( return pd.concat(result) +def select_by_fixed_fraction( + df: pd.DataFrame, + id_columns: Optional[List[str]] = None, + fraction: float = 1.0, + location: str = FractionLocation.FIRST.value, +) -> pd.DataFrame: + """Select a portion of a dataset based on a fraction of the data. + Fraction can either be located at the start (location = FractionLocation.FIRST) or at the end (location = FractionLocation.LAST) + + Args: + df (pd.DataFrame): Input dataframe. + id_columns (List[str], optional): Columns which specify the IDs in the dataset. Defaults to None. + fraction (float): The fraction to select. + location (str): Location of where to select the fraction Defaults to FractionLocation.FIRST.value. + + Raises: + ValueError: Raised when the + + Returns: + pd.DataFrame: Subset of the dataframe. + """ + + if fraction < 0 or fraction > 1: + raise ValueError("The value of fraction should be between 0 and 1.") + + if not id_columns: + return _split_group_by_fixed_fraction( + df, + fraction=fraction, + location=location, + ).copy() + + groups = df.groupby(_get_groupby_columns(id_columns)) + result = [] + for name, group in groups: + result.append( + _split_group_by_fixed_fraction( + group, + name=name, + fraction=fraction, + location=location, + ) + ) + + return pd.concat(result) + + def _get_groupby_columns(id_columns: List[str]) -> Union[List[str], str]: if not isinstance(id_columns, (List)): raise ValueError("id_columns must be a list") @@ -202,7 +257,30 @@ def _split_group_by_fraction( else: end_index = None - return _split_group_by_index(group_df=group_df, start_index=start_index, end_index=end_index) + return _split_group_by_index(group_df=group_df, name=name, start_index=start_index, end_index=end_index) + + +def _split_group_by_fixed_fraction( + group_df: pd.DataFrame, + name: Optional[str] = None, + fraction: float = 1.0, + location: Optional[str] = None, +): + l = len(group_df) + fraction_size = int(fraction * l) + + if location == FractionLocation.FIRST.value: + start_index = 0 + end_index = fraction_size + elif location == FractionLocation.LAST.value: + start_index = l - fraction_size + end_index = l + else: + raise ValueError( + f"`location` should be either `{FractionLocation.FIRST.value}` or `{FractionLocation.LAST.value}`" + ) + + return _split_group_by_index(group_df=group_df, name=name, start_index=start_index, end_index=end_index) def convert_tsf_to_dataframe(