From 87d6deaaf81c1fa1fd582e2bfb9fcc758a23fbee Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Wed, 20 Mar 2024 15:34:59 -0400 Subject: [PATCH 01/10] add get datasets function Signed-off-by: Wesley M. Gifford --- .../toolkit/time_series_preprocessor.py | 54 ++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index e583e588..4868ae13 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -21,7 +21,8 @@ PreTrainedFeatureExtractor, ) -from .util import join_list_without_repeat +from .dataset import ForecastDFDataset +from .util import join_list_without_repeat, select_by_index INTERNAL_ID_COLUMN = "__id" @@ -586,6 +587,57 @@ def scale_func(grp, id_columns): return df + def get_datasets( + self, dataset: Union[Dataset, pd.DataFrame], config: Dict[str, Any] + ): # load data, assume data file is in csv format + data = self._standardize_dataframe(dataset) + + # to do: get split_params + # split_params = get_split_params(config, self.context_length, len(data)) + split_params = {} + + # specify columns + column_specifiers = { + "id_columns": config["data"]["id_columns"], + "timestamp_column": config["data"]["timestamp_column"], + "target_columns": config["data"]["target_columns"], + "observable_columns": config["data"]["observable_columns"], + "control_columns": config["data"]["control_columns"], + "conditional_columns": config["data"]["conditional_columns"], + "static_categorical_columns": config["data"]["static_categorical_columns"], + } + + # split data + train_data = select_by_index(data, id_columns=column_specifiers["id_columns"], **split_params["train"]) + valid_data = select_by_index(data, id_columns=column_specifiers["id_columns"], **split_params["valid"]) + test_data = select_by_index(data, id_columns=column_specifiers["id_columns"], **split_params["test"]) + + # # data preprocessing + # tsp = TimeSeriesPreprocessor( + # **column_specifiers, + # scaling=config["scale"]["scaling"], + # encode_categorical=config["encode_categorical"], + # scaler_type=config["scale"]["scaler_type"], + # freq=config["data"]["freq"], + # ) + self.train(train_data) + + params = column_specifiers + params["context_length"] = self.context_length + params["prediction_length"] = self.prediction_length + + # get torch datasets + test_dataset = ForecastDFDataset( + self.preprocess(test_data), + **params, + ) + train_dataset = ForecastDFDataset(self.preprocess(train_data), **params) + valid_dataset = ForecastDFDataset( + self.preprocess(valid_data), + **params, + ) + return train_dataset, valid_dataset, test_dataset + def create_timestamps( last_timestamp: Union[datetime.datetime, pd.Timestamp], From f1091f9f0ebf5b7fc491ae9e397f484c33e1d836 Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Wed, 20 Mar 2024 19:42:55 -0400 Subject: [PATCH 02/10] Add code to produce datasests directly Co-authored-by: Nam Nguyen Signed-off-by: Wesley M. Gifford --- .../toolkit/time_series_preprocessor.py | 51 +++++++++++++------ tsfm_public/toolkit/util.py | 38 +++++++++++++- 2 files changed, 72 insertions(+), 17 deletions(-) diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index 4868ae13..2beb2275 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -22,7 +22,7 @@ ) from .dataset import ForecastDFDataset -from .util import join_list_without_repeat, select_by_index +from .util import get_split_params, join_list_without_repeat INTERNAL_ID_COLUMN = "__id" @@ -587,30 +587,49 @@ def scale_func(grp, id_columns): return df - def get_datasets( - self, dataset: Union[Dataset, pd.DataFrame], config: Dict[str, Any] - ): # load data, assume data file is in csv format + def get_datasets(self, dataset: Union[Dataset, pd.DataFrame], split_config: Dict[str, Any]) -> Tuple[Any]: + """Creates the preprocessed pytorch datasets needed for training and evaluation + using the HuggingFace trainer + + Args: + dataset (Union[Dataset, pd.DataFrame]): Loaded pandas dataframe + split_config (Dict[str, Any]): Dictionary of dictionaries containing + split parameters. For example: + { + train: {start: 0, end: 50}, + valid: {start: 50, end: 70}, + test: {start: 70, end: 100} + } + end value is not inclusive + + Returns: + Tuple of pytorch datasets, including: train, validation, test. + + + """ + data = self._standardize_dataframe(dataset) - # to do: get split_params + # get split_params # split_params = get_split_params(config, self.context_length, len(data)) - split_params = {} + + split_params, split_function = get_split_params(split_config) # specify columns column_specifiers = { - "id_columns": config["data"]["id_columns"], - "timestamp_column": config["data"]["timestamp_column"], - "target_columns": config["data"]["target_columns"], - "observable_columns": config["data"]["observable_columns"], - "control_columns": config["data"]["control_columns"], - "conditional_columns": config["data"]["conditional_columns"], - "static_categorical_columns": config["data"]["static_categorical_columns"], + "id_columns": self.id_columns, + "timestamp_column": self.timestamp_column, + "target_columns": self.target_columns, + "observable_columns": self.observable_columns, + "control_columns": self.control_columns, + "conditional_columns": self.conditional_columns, + "static_categorical_columns": self.static_categorical_columns, } # split data - train_data = select_by_index(data, id_columns=column_specifiers["id_columns"], **split_params["train"]) - valid_data = select_by_index(data, id_columns=column_specifiers["id_columns"], **split_params["valid"]) - test_data = select_by_index(data, id_columns=column_specifiers["id_columns"], **split_params["test"]) + train_data = split_function["train"](data, id_columns=self.id_columns, **split_params["train"]) + valid_data = split_function["valid"](data, id_columns=self.id_columns, **split_params["valid"]) + test_data = split_function["test"](data, id_columns=self.id_columns, **split_params["test"]) # # data preprocessing # tsp = TimeSeriesPreprocessor( diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 270b100a..06126261 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -5,7 +5,7 @@ import copy from datetime import datetime from distutils.util import strtobool -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import pandas as pd @@ -336,6 +336,42 @@ def convert_tsf_to_dataframe( ) +def get_split_params( + split_config: Dict[str, List[int]], context_length=None +) -> Dict[str, Dict[str, Union[int, float]]]: + """_summary_ + + Args: + split_config (Dict[str, List[int]]): _description_ + context_length (_type_, optional): _description_. Defaults to None. + + Returns: + Dict[str, Dict[str, Union[int, float]]]: _description_ + """ + + split_params = {} + split_function = {} + + for group in ["train", "test", "valid"]: + if split_config[group][1] < 1: + split_params[group] = { + "start_fraction": split_config[group][0], + "end_fraction": split_config[group][1], + "start_offset": (context_length if (context_length and group != "train") else 0), + } + split_function[group] = select_by_relative_fraction + else: + split_params[group] = { + "start_index": ( + split_config[group][0] - (context_length if (context_length and group != "train") else 0) + ), + "end_index": split_config[group][1], + } + split_function[group] = select_by_index + + return split_params, split_function + + def convert_tsf(filename: str) -> pd.DataFrame: """Converts a tsf format file into a pandas dataframe. Returns the result in canonical multi-time series format, with an ID column, and timestamp. From fec158e6382db7648ad691d03945850783dfc784 Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Wed, 20 Mar 2024 19:45:55 -0400 Subject: [PATCH 03/10] minimal example Signed-off-by: Wesley M. Gifford --- hacking/datasets_from_tsp.py | 38 ++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 hacking/datasets_from_tsp.py diff --git a/hacking/datasets_from_tsp.py b/hacking/datasets_from_tsp.py new file mode 100644 index 00000000..8cafc100 --- /dev/null +++ b/hacking/datasets_from_tsp.py @@ -0,0 +1,38 @@ +# %% +import pandas as pd + +from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor + + +split_config = {"train": [0, 8640], "valid": [8640, 11520], "test": [11520, 14400]} + + +dataset_path = ( + "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv" +) + +timestamp_column = "date" + +df = pd.read_csv( + dataset_path, + parse_dates=[timestamp_column], +) + +tsp = TimeSeriesPreprocessor( + id_columns=[], + timestamp_column=timestamp_column, + target_columns=["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"], + observable_columns=[], + control_columns=[], + conditional_columns=[], + static_categorical_columns=[], + scaling=True, + scaler_type="standard", + encode_categorical=False, + prediction_length=10, + context_length=96, +) + +train, valid, test = tsp.get_datasets(df, split_config) + +# %% From 9dfe4dd1a76487a4a5665518e9124c7419c4ecdd Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Thu, 21 Mar 2024 17:21:36 -0400 Subject: [PATCH 04/10] update logic Signed-off-by: Wesley M. Gifford --- hacking/datasets_from_tsp.py | 5 +++ tsfm_public/toolkit/util.py | 78 ++++++++++++++++++++++++++---------- 2 files changed, 61 insertions(+), 22 deletions(-) diff --git a/hacking/datasets_from_tsp.py b/hacking/datasets_from_tsp.py index 8cafc100..a4df0701 100644 --- a/hacking/datasets_from_tsp.py +++ b/hacking/datasets_from_tsp.py @@ -36,3 +36,8 @@ train, valid, test = tsp.get_datasets(df, split_config) # %% +split_config = {"train": [0, 0.7], "valid": [0.7, 0.9], "test": [0.9, 1]} + +train, valid, test = tsp.get_datasets(df, split_config) + +# %% diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 06126261..90e596dd 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -35,7 +35,9 @@ def select_by_timestamp( """ if not start_timestamp and not end_timestamp: - raise ValueError("At least one of start_timestamp or end_timestamp must be specified.") + raise ValueError( + "At least one of start_timestamp or end_timestamp must be specified." + ) if not start_timestamp: return df[df[timestamp_column] < end_timestamp] @@ -43,7 +45,10 @@ def select_by_timestamp( if not end_timestamp: return df[df[timestamp_column] >= start_timestamp] - return df[(df[timestamp_column] >= start_timestamp) & (df[timestamp_column] < end_timestamp)] + return df[ + (df[timestamp_column] >= start_timestamp) + & (df[timestamp_column] < end_timestamp) + ] def select_by_index( @@ -74,12 +79,18 @@ def select_by_index( raise ValueError("At least one of start_index or end_index must be specified.") if not id_columns: - return _split_group_by_index(df, start_index=start_index, end_index=end_index).copy() + return _split_group_by_index( + df, start_index=start_index, end_index=end_index + ).copy() groups = df.groupby(_get_groupby_columns(id_columns)) result = [] for name, group in groups: - result.append(_split_group_by_index(group, name=name, start_index=start_index, end_index=end_index)) + result.append( + _split_group_by_index( + group, name=name, start_index=start_index, end_index=end_index + ) + ) return pd.concat(result) @@ -116,7 +127,9 @@ def select_by_relative_fraction( pd.DataFrame: Subset of the dataframe. """ if not start_fraction and not end_fraction: - raise ValueError("At least one of start_fraction or end_fraction must be specified.") + raise ValueError( + "At least one of start_fraction or end_fraction must be specified." + ) if start_offset < 0: raise ValueError("The value of start_offset should ne non-negative.") @@ -202,7 +215,9 @@ def _split_group_by_fraction( else: end_index = None - return _split_group_by_index(group_df=group_df, start_index=start_index, end_index=end_index) + return _split_group_by_index( + group_df=group_df, start_index=start_index, end_index=end_index + ) def convert_tsf_to_dataframe( @@ -232,13 +247,17 @@ def convert_tsf_to_dataframe( if not line.startswith("@data"): line_content = line.split(" ") if line.startswith("@attribute"): - if len(line_content) != 3: # Attributes have both name and type + if ( + len(line_content) != 3 + ): # Attributes have both name and type raise Exception("Invalid meta-data specification.") col_names.append(line_content[1]) col_types.append(line_content[2]) else: - if len(line_content) != 2: # Other meta-data have only values + if ( + len(line_content) != 2 + ): # Other meta-data have only values raise Exception("Invalid meta-data specification.") if line.startswith("@frequency"): @@ -246,18 +265,24 @@ def convert_tsf_to_dataframe( elif line.startswith("@horizon"): forecast_horizon = int(line_content[1]) elif line.startswith("@missing"): - contain_missing_values = bool(strtobool(line_content[1])) + contain_missing_values = bool( + strtobool(line_content[1]) + ) elif line.startswith("@equallength"): contain_equal_length = bool(strtobool(line_content[1])) else: if len(col_names) == 0: - raise Exception("Missing attribute section. Attribute section must come before data.") + raise Exception( + "Missing attribute section. Attribute section must come before data." + ) found_data_tag = True elif not line.startswith("#"): if len(col_names) == 0: - raise Exception("Missing attribute section. Attribute section must come before data.") + raise Exception( + "Missing attribute section. Attribute section must come before data." + ) elif not found_data_tag: raise Exception("Missing @data tag.") else: @@ -290,7 +315,9 @@ def convert_tsf_to_dataframe( else: numeric_series.append(float(val)) - if numeric_series.count(replace_missing_vals_with) == len(numeric_series): + if numeric_series.count(replace_missing_vals_with) == len( + numeric_series + ): raise Exception( "All series values are missing. A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series." ) @@ -304,7 +331,9 @@ def convert_tsf_to_dataframe( elif col_types[i] == "string": att_val = str(full_info[i]) elif col_types[i] == "date": - att_val = datetime.strptime(full_info[i], "%Y-%m-%d %H-%M-%S") + att_val = datetime.strptime( + full_info[i], "%Y-%m-%d %H-%M-%S" + ) else: raise Exception( "Invalid attribute type." @@ -353,21 +382,26 @@ def get_split_params( split_function = {} for group in ["train", "test", "valid"]: - if split_config[group][1] < 1: - split_params[group] = { - "start_fraction": split_config[group][0], - "end_fraction": split_config[group][1], - "start_offset": (context_length if (context_length and group != "train") else 0), - } - split_function[group] = select_by_relative_fraction - else: + if isinstance(split_config[group][0], int) and isinstance( + split_config[group][1], int + ): split_params[group] = { "start_index": ( - split_config[group][0] - (context_length if (context_length and group != "train") else 0) + split_config[group][0] + - (context_length if (context_length and group != "train") else 0) ), "end_index": split_config[group][1], } split_function[group] = select_by_index + else: + split_params[group] = { + "start_fraction": split_config[group][0], + "end_fraction": split_config[group][1], + "start_offset": ( + context_length if (context_length and group != "train") else 0 + ), + } + split_function[group] = select_by_relative_fraction return split_params, split_function From 257deae02d0f11b6222978b445235f116ca97d3f Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Thu, 21 Mar 2024 20:25:34 -0400 Subject: [PATCH 05/10] update splitting, add tests Signed-off-by: Wesley M. Gifford --- .../toolkit/time_series_preprocessor.py | 107 ++++++++++++------ tsfm_public/toolkit/util.py | 38 ++++--- 2 files changed, 93 insertions(+), 52 deletions(-) diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index 2beb2275..cd0f0ceb 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -51,7 +51,9 @@ def to_json(self) -> str: return json.dumps(self.to_dict()) @classmethod - def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> "SKLearnFeatureExtractionBase": + def from_dict( + cls, feature_extractor_dict: Dict[str, Any], **kwargs + ) -> "SKLearnFeatureExtractionBase": """ """ t = cls() @@ -120,7 +122,9 @@ def __init__( # note base class __init__ methods sets all arguments as attributes if not isinstance(id_columns, list): - raise ValueError(f"Invalid argument provided for `id_columns`: {id_columns}") + raise ValueError( + f"Invalid argument provided for `id_columns`: {id_columns}" + ) self.id_columns = id_columns self.timestamp_column = timestamp_column @@ -213,7 +217,10 @@ def recursive_check_ndarray(dictionary): elif isinstance(value, np.int64): dictionary[key] = int(value) elif isinstance(value, list): - dictionary[key] = [vv.tolist() if isinstance(vv, np.ndarray) else vv for vv in value] + dictionary[key] = [ + vv.tolist() if isinstance(vv, np.ndarray) else vv + for vv in value + ] elif isinstance(value, dict): dictionary[key] = recursive_check_ndarray(value) return dictionary @@ -229,7 +236,9 @@ def recursive_check_ndarray(dictionary): return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" @classmethod - def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> "PreTrainedFeatureExtractor": + def from_dict( + cls, feature_extractor_dict: Dict[str, Any], **kwargs + ) -> "PreTrainedFeatureExtractor": """ Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of parameters. @@ -347,7 +356,9 @@ def _get_groups( Generator[Any, pd.DataFrame]: Group name and resulting pandas dataframe for the group. """ if self.id_columns: - group_by_columns = self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] + group_by_columns = ( + self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] + ) else: group_by_columns = INTERNAL_ID_COLUMN @@ -411,7 +422,9 @@ def get_frequency_token(self, token_name: str): token = self.frequency_mapping.get(token_name, None) if token is None: - warn(f"Frequency token {token_name} was not found in the frequncy token mapping.") + warn( + f"Frequency token {token_name} was not found in the frequncy token mapping." + ) token = self.frequency_mapping["oov"] return token @@ -444,7 +457,11 @@ def exogenous_channel_indices(self) -> List[int]: @property def prediction_channel_indices(self) -> List[int]: - return [i for i, c in enumerate(self._get_real_valued_dynamic_channels()) if c in self.target_columns] + return [ + i + for i, c in enumerate(self._get_real_valued_dynamic_channels()) + if c in self.target_columns + ] def _check_dataset(self, dataset: Union[Dataset, pd.DataFrame]): """Basic checks for input dataset. @@ -468,7 +485,10 @@ def _estimate_frequency(self, df: pd.DataFrame): df_subset = df # to do: make more robust - self.freq = df_subset[self.timestamp_column].iloc[-1] - df_subset[self.timestamp_column].iloc[-2] + self.freq = ( + df_subset[self.timestamp_column].iloc[-1] + - df_subset[self.timestamp_column].iloc[-2] + ) else: # no timestamp, assume sequential count? self.freq = 1 @@ -519,11 +539,15 @@ def inverse_scale_func(grp, id_columns): name = tuple(grp.iloc[0][id_columns].tolist()) else: name = grp.iloc[0][id_columns] - grp[cols_to_scale] = self.target_scaler_dict[name].inverse_transform(grp[cols_to_scale]) + grp[cols_to_scale] = self.target_scaler_dict[name].inverse_transform( + grp[cols_to_scale] + ) return grp if self.id_columns: - id_columns = self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] + id_columns = ( + self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] + ) else: id_columns = INTERNAL_ID_COLUMN @@ -562,14 +586,20 @@ def scale_func(grp, id_columns): name = tuple(grp.iloc[0][id_columns].tolist()) else: name = grp.iloc[0][id_columns] - grp[self.target_columns] = self.target_scaler_dict[name].transform(grp[self.target_columns]) + grp[self.target_columns] = self.target_scaler_dict[name].transform( + grp[self.target_columns] + ) if other_cols_to_scale: - grp[other_cols_to_scale] = self.scaler_dict[name].transform(grp[other_cols_to_scale]) + grp[other_cols_to_scale] = self.scaler_dict[name].transform( + grp[other_cols_to_scale] + ) return grp if self.id_columns: - id_columns = self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] + id_columns = ( + self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] + ) else: id_columns = INTERNAL_ID_COLUMN @@ -582,12 +612,16 @@ def scale_func(grp, id_columns): cols_to_encode = self._get_columns_to_encode() if self.encode_categorical and cols_to_encode: if not self.categorical_encoder: - raise RuntimeError("Attempt to encode categorical columns, but the encoder has not been trained yet.") + raise RuntimeError( + "Attempt to encode categorical columns, but the encoder has not been trained yet." + ) df[cols_to_encode] = self.categorical_encoder.transform(df[cols_to_encode]) return df - def get_datasets(self, dataset: Union[Dataset, pd.DataFrame], split_config: Dict[str, Any]) -> Tuple[Any]: + def get_datasets( + self, dataset: Union[Dataset, pd.DataFrame], split_config: Dict[str, Any] + ) -> Tuple[Any]: """Creates the preprocessed pytorch datasets needed for training and evaluation using the HuggingFace trainer @@ -596,16 +630,14 @@ def get_datasets(self, dataset: Union[Dataset, pd.DataFrame], split_config: Dict split_config (Dict[str, Any]): Dictionary of dictionaries containing split parameters. For example: { - train: {start: 0, end: 50}, - valid: {start: 50, end: 70}, - test: {start: 70, end: 100} + train: [0, 50], + valid: [50, 70], + test: [70, 100] } end value is not inclusive Returns: Tuple of pytorch datasets, including: train, validation, test. - - """ data = self._standardize_dataframe(dataset) @@ -627,18 +659,17 @@ def get_datasets(self, dataset: Union[Dataset, pd.DataFrame], split_config: Dict } # split data - train_data = split_function["train"](data, id_columns=self.id_columns, **split_params["train"]) - valid_data = split_function["valid"](data, id_columns=self.id_columns, **split_params["valid"]) - test_data = split_function["test"](data, id_columns=self.id_columns, **split_params["test"]) - - # # data preprocessing - # tsp = TimeSeriesPreprocessor( - # **column_specifiers, - # scaling=config["scale"]["scaling"], - # encode_categorical=config["encode_categorical"], - # scaler_type=config["scale"]["scaler_type"], - # freq=config["data"]["freq"], - # ) + train_data = split_function["train"]( + data, id_columns=self.id_columns, **split_params["train"] + ) + valid_data = split_function["valid"]( + data, id_columns=self.id_columns, **split_params["valid"] + ) + test_data = split_function["test"]( + data, id_columns=self.id_columns, **split_params["test"] + ) + + # data preprocessing self.train(train_data) params = column_specifiers @@ -661,13 +692,17 @@ def get_datasets(self, dataset: Union[Dataset, pd.DataFrame], split_config: Dict def create_timestamps( last_timestamp: Union[datetime.datetime, pd.Timestamp], freq: Optional[Union[int, float, datetime.timedelta, pd.Timedelta, str]] = None, - time_sequence: Optional[Union[List[int], List[float], List[datetime.datetime], List[pd.Timestamp]]] = None, + time_sequence: Optional[ + Union[List[int], List[float], List[datetime.datetime], List[pd.Timestamp]] + ] = None, periods: int = 1, ): """Simple utility to create a list of timestamps based on start, delta and number of periods""" if freq is None and time_sequence is None: - raise ValueError("Neither `freq` nor `time_sequence` provided, cannot determine frequency.") + raise ValueError( + "Neither `freq` nor `time_sequence` provided, cannot determine frequency." + ) if freq is None: # to do: make more robust @@ -730,7 +765,9 @@ def augment_one_series(group: Union[pd.Series, pd.DataFrame]): if grouping_columns == []: new_time_series = augment_one_series(time_series) else: - new_time_series = time_series.groupby(grouping_columns).apply(augment_one_series, include_groups=False) + new_time_series = time_series.groupby(grouping_columns).apply( + augment_one_series, include_groups=False + ) idx_names = list(new_time_series.index.names) idx_names[-1] = "__delete" new_time_series = new_time_series.reset_index(names=idx_names) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 90e596dd..80fd8016 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -5,7 +5,7 @@ import copy from datetime import datetime from distutils.util import strtobool -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import pandas as pd @@ -367,33 +367,28 @@ def convert_tsf_to_dataframe( def get_split_params( split_config: Dict[str, List[int]], context_length=None -) -> Dict[str, Dict[str, Union[int, float]]]: - """_summary_ +) -> Tuple[Dict[str, Dict[str, Union[int, float]]], Dict[str, Callable]]: + """Get split parameters Args: - split_config (Dict[str, List[int]]): _description_ - context_length (_type_, optional): _description_. Defaults to None. + split_config (Dict[str, List[int]]): Dictionary containing keys for + train, valid, test. Each value consists of a list of length two, indicating + the boundaries of a split. + context_length (_type_, optional): Context length, used only when offseting + the split so predictions can be made for all elements of split. Defaults to None. Returns: - Dict[str, Dict[str, Union[int, float]]]: _description_ + Tuple[Dict[str, Dict[str, Union[int, float]]], Dict[str, Callable]]: Tuple of split parameters + and split functions to use to split the data. """ split_params = {} split_function = {} for group in ["train", "test", "valid"]: - if isinstance(split_config[group][0], int) and isinstance( - split_config[group][1], int + if ((split_config[group][0] < 1) and (split_config[group][0] != 0)) or ( + split_config[group][1] < 1 ): - split_params[group] = { - "start_index": ( - split_config[group][0] - - (context_length if (context_length and group != "train") else 0) - ), - "end_index": split_config[group][1], - } - split_function[group] = select_by_index - else: split_params[group] = { "start_fraction": split_config[group][0], "end_fraction": split_config[group][1], @@ -402,6 +397,15 @@ def get_split_params( ), } split_function[group] = select_by_relative_fraction + else: + split_params[group] = { + "start_index": ( + split_config[group][0] + - (context_length if (context_length and group != "train") else 0) + ), + "end_index": split_config[group][1], + } + split_function[group] = select_by_index return split_params, split_function From 1f4e1227baba5cf9f0638ff4ba355673f9aeccdd Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Thu, 21 Mar 2024 20:26:26 -0400 Subject: [PATCH 06/10] minor updates Signed-off-by: Wesley M. Gifford --- hacking/datasets_from_tsp.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/hacking/datasets_from_tsp.py b/hacking/datasets_from_tsp.py index a4df0701..bb4312e7 100644 --- a/hacking/datasets_from_tsp.py +++ b/hacking/datasets_from_tsp.py @@ -39,5 +39,3 @@ split_config = {"train": [0, 0.7], "valid": [0.7, 0.9], "test": [0.9, 1]} train, valid, test = tsp.get_datasets(df, split_config) - -# %% From ccba0ea674b33db77f0ff15fe494de0058830888 Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Fri, 22 Mar 2024 10:18:30 -0400 Subject: [PATCH 07/10] update tests, support finetuning fraction Signed-off-by: Wesley M. Gifford --- .../toolkit/test_time_series_preprocessor.py | 44 ++++++- .../toolkit/time_series_preprocessor.py | 120 +++++++----------- tsfm_public/toolkit/util.py | 73 +++-------- 3 files changed, 109 insertions(+), 128 deletions(-) diff --git a/tests/toolkit/test_time_series_preprocessor.py b/tests/toolkit/test_time_series_preprocessor.py index 42069611..2de78b49 100644 --- a/tests/toolkit/test_time_series_preprocessor.py +++ b/tests/toolkit/test_time_series_preprocessor.py @@ -72,7 +72,7 @@ def test_time_series_preprocessor_encodes(sample_data): static_categorical_columns = ["cat", "cat2"] tsp = TimeSeriesPreprocessor( - input_columns=["val", "val2"], + target_columns=["val", "val2"], static_categorical_columns=static_categorical_columns, ) tsp.train(sample_data) @@ -156,3 +156,45 @@ def test_create_timestamps(): # it is an error to provide neither freq or sequence with pytest.raises(ValueError): ts = create_timestamps(start, periods=periods) + + +def test_get_datasets(ts_data): + tsp = TimeSeriesPreprocessor( + id_columns=["id"], + target_columns=["value1", "value2"], + prediction_length=5, + context_length=10, + ) + + train, valid, test = tsp.get_datasets( + ts_data, + split_config={"train": [0, 1 / 3], "valid": [1 / 3, 2 / 3], "test": [2 / 3, 1]}, + ) + + # 3 time series of length 50 + assert len(train) == 3 * (int((1 / 3) * 50) - (tsp.context_length + tsp.prediction_length) + 1) + + assert len(valid) == len(test) + + # no id columns, so treat as one big time series + tsp = TimeSeriesPreprocessor( + id_columns=[], + target_columns=["value1", "value2"], + prediction_length=5, + context_length=10, + ) + + train, valid, test = tsp.get_datasets( + ts_data, + split_config={ + "train": [0, 100], + "valid": [100, 125], + "test": [125, 150], + }, + fewshot_fraction=0.2, + ) + + # new train length should be 20% of 100, minus the usual for context length and prediction length + assert len(train) == (int(100 * 0.2) - (tsp.context_length + tsp.prediction_length) + 1) + + assert len(valid) == len(test) diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index cd0f0ceb..eb598047 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -22,7 +22,11 @@ ) from .dataset import ForecastDFDataset -from .util import get_split_params, join_list_without_repeat +from .util import ( + get_split_params, + join_list_without_repeat, + select_by_relative_fraction, +) INTERNAL_ID_COLUMN = "__id" @@ -51,9 +55,7 @@ def to_json(self) -> str: return json.dumps(self.to_dict()) @classmethod - def from_dict( - cls, feature_extractor_dict: Dict[str, Any], **kwargs - ) -> "SKLearnFeatureExtractionBase": + def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> "SKLearnFeatureExtractionBase": """ """ t = cls() @@ -122,9 +124,7 @@ def __init__( # note base class __init__ methods sets all arguments as attributes if not isinstance(id_columns, list): - raise ValueError( - f"Invalid argument provided for `id_columns`: {id_columns}" - ) + raise ValueError(f"Invalid argument provided for `id_columns`: {id_columns}") self.id_columns = id_columns self.timestamp_column = timestamp_column @@ -217,10 +217,7 @@ def recursive_check_ndarray(dictionary): elif isinstance(value, np.int64): dictionary[key] = int(value) elif isinstance(value, list): - dictionary[key] = [ - vv.tolist() if isinstance(vv, np.ndarray) else vv - for vv in value - ] + dictionary[key] = [vv.tolist() if isinstance(vv, np.ndarray) else vv for vv in value] elif isinstance(value, dict): dictionary[key] = recursive_check_ndarray(value) return dictionary @@ -236,9 +233,7 @@ def recursive_check_ndarray(dictionary): return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" @classmethod - def from_dict( - cls, feature_extractor_dict: Dict[str, Any], **kwargs - ) -> "PreTrainedFeatureExtractor": + def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> "PreTrainedFeatureExtractor": """ Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of parameters. @@ -356,9 +351,7 @@ def _get_groups( Generator[Any, pd.DataFrame]: Group name and resulting pandas dataframe for the group. """ if self.id_columns: - group_by_columns = ( - self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] - ) + group_by_columns = self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] else: group_by_columns = INTERNAL_ID_COLUMN @@ -422,9 +415,7 @@ def get_frequency_token(self, token_name: str): token = self.frequency_mapping.get(token_name, None) if token is None: - warn( - f"Frequency token {token_name} was not found in the frequncy token mapping." - ) + warn(f"Frequency token {token_name} was not found in the frequncy token mapping.") token = self.frequency_mapping["oov"] return token @@ -457,11 +448,7 @@ def exogenous_channel_indices(self) -> List[int]: @property def prediction_channel_indices(self) -> List[int]: - return [ - i - for i, c in enumerate(self._get_real_valued_dynamic_channels()) - if c in self.target_columns - ] + return [i for i, c in enumerate(self._get_real_valued_dynamic_channels()) if c in self.target_columns] def _check_dataset(self, dataset: Union[Dataset, pd.DataFrame]): """Basic checks for input dataset. @@ -485,10 +472,7 @@ def _estimate_frequency(self, df: pd.DataFrame): df_subset = df # to do: make more robust - self.freq = ( - df_subset[self.timestamp_column].iloc[-1] - - df_subset[self.timestamp_column].iloc[-2] - ) + self.freq = df_subset[self.timestamp_column].iloc[-1] - df_subset[self.timestamp_column].iloc[-2] else: # no timestamp, assume sequential count? self.freq = 1 @@ -539,15 +523,11 @@ def inverse_scale_func(grp, id_columns): name = tuple(grp.iloc[0][id_columns].tolist()) else: name = grp.iloc[0][id_columns] - grp[cols_to_scale] = self.target_scaler_dict[name].inverse_transform( - grp[cols_to_scale] - ) + grp[cols_to_scale] = self.target_scaler_dict[name].inverse_transform(grp[cols_to_scale]) return grp if self.id_columns: - id_columns = ( - self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] - ) + id_columns = self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] else: id_columns = INTERNAL_ID_COLUMN @@ -586,20 +566,14 @@ def scale_func(grp, id_columns): name = tuple(grp.iloc[0][id_columns].tolist()) else: name = grp.iloc[0][id_columns] - grp[self.target_columns] = self.target_scaler_dict[name].transform( - grp[self.target_columns] - ) + grp[self.target_columns] = self.target_scaler_dict[name].transform(grp[self.target_columns]) if other_cols_to_scale: - grp[other_cols_to_scale] = self.scaler_dict[name].transform( - grp[other_cols_to_scale] - ) + grp[other_cols_to_scale] = self.scaler_dict[name].transform(grp[other_cols_to_scale]) return grp if self.id_columns: - id_columns = ( - self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] - ) + id_columns = self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] else: id_columns = INTERNAL_ID_COLUMN @@ -612,29 +586,33 @@ def scale_func(grp, id_columns): cols_to_encode = self._get_columns_to_encode() if self.encode_categorical and cols_to_encode: if not self.categorical_encoder: - raise RuntimeError( - "Attempt to encode categorical columns, but the encoder has not been trained yet." - ) + raise RuntimeError("Attempt to encode categorical columns, but the encoder has not been trained yet.") df[cols_to_encode] = self.categorical_encoder.transform(df[cols_to_encode]) return df def get_datasets( - self, dataset: Union[Dataset, pd.DataFrame], split_config: Dict[str, Any] + self, + dataset: Union[Dataset, pd.DataFrame], + split_config: Dict[str, Any], + fewshot_fraction: Optional[float] = None, ) -> Tuple[Any]: """Creates the preprocessed pytorch datasets needed for training and evaluation using the HuggingFace trainer Args: dataset (Union[Dataset, pd.DataFrame]): Loaded pandas dataframe - split_config (Dict[str, Any]): Dictionary of dictionaries containing - split parameters. For example: - { - train: [0, 50], - valid: [50, 70], - test: [70, 100] - } - end value is not inclusive + split_config (Dict[str, Any]): Dictionary of dictionaries containing + split parameters. For example: + { + train: [0, 50], + valid: [50, 70], + test: [70, 100] + } + end value is not inclusive + fewshot_fraction (float, optional): When non-null, return this percent of the original training + dataset. This is done to support fewshot fine-tuning. The fraction of data chosen is at the + end of the training dataset. Returns: Tuple of pytorch datasets, including: train, validation, test. @@ -659,19 +637,19 @@ def get_datasets( } # split data - train_data = split_function["train"]( - data, id_columns=self.id_columns, **split_params["train"] - ) - valid_data = split_function["valid"]( - data, id_columns=self.id_columns, **split_params["valid"] - ) - test_data = split_function["test"]( - data, id_columns=self.id_columns, **split_params["test"] - ) + train_data = split_function["train"](data, id_columns=self.id_columns, **split_params["train"]) + valid_data = split_function["valid"](data, id_columns=self.id_columns, **split_params["valid"]) + test_data = split_function["test"](data, id_columns=self.id_columns, **split_params["test"]) # data preprocessing self.train(train_data) + # handle fewshot operation + if fewshot_fraction is not None: + if not ((fewshot_fraction <= 1) and (fewshot_fraction > 0)): + raise ValueError(f"Fewshot fraction should be between 0 and 1, received {fewshot_fraction}") + train_data = select_by_relative_fraction(train_data, start_fraction=1 - fewshot_fraction, end_fraction=1) + params = column_specifiers params["context_length"] = self.context_length params["prediction_length"] = self.prediction_length @@ -692,17 +670,13 @@ def get_datasets( def create_timestamps( last_timestamp: Union[datetime.datetime, pd.Timestamp], freq: Optional[Union[int, float, datetime.timedelta, pd.Timedelta, str]] = None, - time_sequence: Optional[ - Union[List[int], List[float], List[datetime.datetime], List[pd.Timestamp]] - ] = None, + time_sequence: Optional[Union[List[int], List[float], List[datetime.datetime], List[pd.Timestamp]]] = None, periods: int = 1, ): """Simple utility to create a list of timestamps based on start, delta and number of periods""" if freq is None and time_sequence is None: - raise ValueError( - "Neither `freq` nor `time_sequence` provided, cannot determine frequency." - ) + raise ValueError("Neither `freq` nor `time_sequence` provided, cannot determine frequency.") if freq is None: # to do: make more robust @@ -765,9 +739,7 @@ def augment_one_series(group: Union[pd.Series, pd.DataFrame]): if grouping_columns == []: new_time_series = augment_one_series(time_series) else: - new_time_series = time_series.groupby(grouping_columns).apply( - augment_one_series, include_groups=False - ) + new_time_series = time_series.groupby(grouping_columns).apply(augment_one_series, include_groups=False) idx_names = list(new_time_series.index.names) idx_names[-1] = "__delete" new_time_series = new_time_series.reset_index(names=idx_names) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 80fd8016..044ceed3 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -35,9 +35,7 @@ def select_by_timestamp( """ if not start_timestamp and not end_timestamp: - raise ValueError( - "At least one of start_timestamp or end_timestamp must be specified." - ) + raise ValueError("At least one of start_timestamp or end_timestamp must be specified.") if not start_timestamp: return df[df[timestamp_column] < end_timestamp] @@ -45,10 +43,7 @@ def select_by_timestamp( if not end_timestamp: return df[df[timestamp_column] >= start_timestamp] - return df[ - (df[timestamp_column] >= start_timestamp) - & (df[timestamp_column] < end_timestamp) - ] + return df[(df[timestamp_column] >= start_timestamp) & (df[timestamp_column] < end_timestamp)] def select_by_index( @@ -79,18 +74,12 @@ def select_by_index( raise ValueError("At least one of start_index or end_index must be specified.") if not id_columns: - return _split_group_by_index( - df, start_index=start_index, end_index=end_index - ).copy() + return _split_group_by_index(df, start_index=start_index, end_index=end_index).copy() groups = df.groupby(_get_groupby_columns(id_columns)) result = [] for name, group in groups: - result.append( - _split_group_by_index( - group, name=name, start_index=start_index, end_index=end_index - ) - ) + result.append(_split_group_by_index(group, name=name, start_index=start_index, end_index=end_index)) return pd.concat(result) @@ -127,9 +116,7 @@ def select_by_relative_fraction( pd.DataFrame: Subset of the dataframe. """ if not start_fraction and not end_fraction: - raise ValueError( - "At least one of start_fraction or end_fraction must be specified." - ) + raise ValueError("At least one of start_fraction or end_fraction must be specified.") if start_offset < 0: raise ValueError("The value of start_offset should ne non-negative.") @@ -215,9 +202,7 @@ def _split_group_by_fraction( else: end_index = None - return _split_group_by_index( - group_df=group_df, start_index=start_index, end_index=end_index - ) + return _split_group_by_index(group_df=group_df, start_index=start_index, end_index=end_index) def convert_tsf_to_dataframe( @@ -247,17 +232,13 @@ def convert_tsf_to_dataframe( if not line.startswith("@data"): line_content = line.split(" ") if line.startswith("@attribute"): - if ( - len(line_content) != 3 - ): # Attributes have both name and type + if len(line_content) != 3: # Attributes have both name and type raise Exception("Invalid meta-data specification.") col_names.append(line_content[1]) col_types.append(line_content[2]) else: - if ( - len(line_content) != 2 - ): # Other meta-data have only values + if len(line_content) != 2: # Other meta-data have only values raise Exception("Invalid meta-data specification.") if line.startswith("@frequency"): @@ -265,24 +246,18 @@ def convert_tsf_to_dataframe( elif line.startswith("@horizon"): forecast_horizon = int(line_content[1]) elif line.startswith("@missing"): - contain_missing_values = bool( - strtobool(line_content[1]) - ) + contain_missing_values = bool(strtobool(line_content[1])) elif line.startswith("@equallength"): contain_equal_length = bool(strtobool(line_content[1])) else: if len(col_names) == 0: - raise Exception( - "Missing attribute section. Attribute section must come before data." - ) + raise Exception("Missing attribute section. Attribute section must come before data.") found_data_tag = True elif not line.startswith("#"): if len(col_names) == 0: - raise Exception( - "Missing attribute section. Attribute section must come before data." - ) + raise Exception("Missing attribute section. Attribute section must come before data.") elif not found_data_tag: raise Exception("Missing @data tag.") else: @@ -315,9 +290,7 @@ def convert_tsf_to_dataframe( else: numeric_series.append(float(val)) - if numeric_series.count(replace_missing_vals_with) == len( - numeric_series - ): + if numeric_series.count(replace_missing_vals_with) == len(numeric_series): raise Exception( "All series values are missing. A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series." ) @@ -331,9 +304,7 @@ def convert_tsf_to_dataframe( elif col_types[i] == "string": att_val = str(full_info[i]) elif col_types[i] == "date": - att_val = datetime.strptime( - full_info[i], "%Y-%m-%d %H-%M-%S" - ) + att_val = datetime.strptime(full_info[i], "%Y-%m-%d %H-%M-%S") else: raise Exception( "Invalid attribute type." @@ -366,15 +337,16 @@ def convert_tsf_to_dataframe( def get_split_params( - split_config: Dict[str, List[int]], context_length=None + split_config: Dict[str, List[Union[int, float]]], + context_length: Optional[int] = None, ) -> Tuple[Dict[str, Dict[str, Union[int, float]]], Dict[str, Callable]]: """Get split parameters Args: - split_config (Dict[str, List[int]]): Dictionary containing keys for + split_config (Dict[str, List[int, float]]): Dictionary containing keys for train, valid, test. Each value consists of a list of length two, indicating the boundaries of a split. - context_length (_type_, optional): Context length, used only when offseting + context_length (int, optional): Context length, used only when offseting the split so predictions can be made for all elements of split. Defaults to None. Returns: @@ -386,22 +358,17 @@ def get_split_params( split_function = {} for group in ["train", "test", "valid"]: - if ((split_config[group][0] < 1) and (split_config[group][0] != 0)) or ( - split_config[group][1] < 1 - ): + if ((split_config[group][0] < 1) and (split_config[group][0] != 0)) or (split_config[group][1] < 1): split_params[group] = { "start_fraction": split_config[group][0], "end_fraction": split_config[group][1], - "start_offset": ( - context_length if (context_length and group != "train") else 0 - ), + "start_offset": (context_length if (context_length and group != "train") else 0), } split_function[group] = select_by_relative_fraction else: split_params[group] = { "start_index": ( - split_config[group][0] - - (context_length if (context_length and group != "train") else 0) + split_config[group][0] - (context_length if (context_length and group != "train") else 0) ), "end_index": split_config[group][1], } From 83a6f99b0217f0086f218e524f86dab0a6840286 Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Fri, 22 Mar 2024 13:46:33 -0400 Subject: [PATCH 08/10] pass context length Signed-off-by: Wesley M. Gifford --- tsfm_public/toolkit/time_series_preprocessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index eb598047..eb8137b0 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -623,7 +623,7 @@ def get_datasets( # get split_params # split_params = get_split_params(config, self.context_length, len(data)) - split_params, split_function = get_split_params(split_config) + split_params, split_function = get_split_params(split_config, context_length=self.context_length) # specify columns column_specifiers = { From 2f54d423b3b773eafea3aa7ea440af58bb86d76b Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Mon, 25 Mar 2024 11:12:09 -0400 Subject: [PATCH 09/10] update fewshot selection, add location, tests Signed-off-by: Wesley M. Gifford --- .../toolkit/test_time_series_preprocessor.py | 29 ++++++- .../toolkit/time_series_preprocessor.py | 17 +++- tsfm_public/toolkit/util.py | 80 ++++++++++++++++++- 3 files changed, 121 insertions(+), 5 deletions(-) diff --git a/tests/toolkit/test_time_series_preprocessor.py b/tests/toolkit/test_time_series_preprocessor.py index 2de78b49..73ac6bf6 100644 --- a/tests/toolkit/test_time_series_preprocessor.py +++ b/tests/toolkit/test_time_series_preprocessor.py @@ -16,6 +16,7 @@ create_timestamps, extend_time_series, ) +from tsfm_public.toolkit.util import FractionLocation def test_standard_scaler(sample_data): @@ -192,9 +193,35 @@ def test_get_datasets(ts_data): "test": [125, 150], }, fewshot_fraction=0.2, + fewshot_location=FractionLocation.LAST.value, ) # new train length should be 20% of 100, minus the usual for context length and prediction length - assert len(train) == (int(100 * 0.2) - (tsp.context_length + tsp.prediction_length) + 1) + fewshot_train_size = int(100 * 0.2) - (tsp.context_length + tsp.prediction_length) + 1 + assert len(train) == fewshot_train_size + + assert len(valid) == len(test) + + # no id columns, so treat as one big time series + tsp = TimeSeriesPreprocessor( + id_columns=[], + target_columns=["value1", "value2"], + prediction_length=5, + context_length=10, + ) + + train, valid, test = tsp.get_datasets( + ts_data, + split_config={ + "train": [0, 100], + "valid": [100, 125], + "test": [125, 150], + }, + fewshot_fraction=0.2, + fewshot_location=FractionLocation.FIRST.value, + ) + + # new train length should be 20% of 100, minus the usual for context length and prediction length + assert len(train) == fewshot_train_size assert len(valid) == len(test) diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index eb8137b0..89dbe279 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -23,9 +23,10 @@ from .dataset import ForecastDFDataset from .util import ( + FractionLocation, get_split_params, join_list_without_repeat, - select_by_relative_fraction, + select_by_fixed_fraction, ) @@ -596,6 +597,7 @@ def get_datasets( dataset: Union[Dataset, pd.DataFrame], split_config: Dict[str, Any], fewshot_fraction: Optional[float] = None, + fewshot_location: str = FractionLocation.LAST.value, ) -> Tuple[Any]: """Creates the preprocessed pytorch datasets needed for training and evaluation using the HuggingFace trainer @@ -613,6 +615,9 @@ def get_datasets( fewshot_fraction (float, optional): When non-null, return this percent of the original training dataset. This is done to support fewshot fine-tuning. The fraction of data chosen is at the end of the training dataset. + fewshot_location (str): Determines where the fewshot data is chosen. Valid options are "first" and "last" + as described in the enum FewshotLocation. Default is to choose the fewshot data at the end + of the training dataset (i.e., "last"). Returns: Tuple of pytorch datasets, including: train, validation, test. @@ -646,9 +651,15 @@ def get_datasets( # handle fewshot operation if fewshot_fraction is not None: - if not ((fewshot_fraction <= 1) and (fewshot_fraction > 0)): + if not ((fewshot_fraction <= 1.0) and (fewshot_fraction > 0.0)): raise ValueError(f"Fewshot fraction should be between 0 and 1, received {fewshot_fraction}") - train_data = select_by_relative_fraction(train_data, start_fraction=1 - fewshot_fraction, end_fraction=1) + + train_data = select_by_fixed_fraction( + train_data, + id_columns=self.id_columns, + fraction=fewshot_fraction, + location=fewshot_location, + ) params = column_specifiers params["context_length"] = self.context_length diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 044ceed3..71359b1a 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -3,6 +3,7 @@ """Basic functions and utilities""" import copy +import enum from datetime import datetime from distutils.util import strtobool from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -10,6 +11,13 @@ import pandas as pd +class FractionLocation(enum.Enum): + """`Enum` for the different locations where a fraction of data can be chosen.""" + + FIRST = "first" + LAST = "last" + + def select_by_timestamp( df: pd.DataFrame, timestamp_column: str = "timestamp", @@ -145,6 +153,53 @@ def select_by_relative_fraction( return pd.concat(result) +def select_by_fixed_fraction( + df: pd.DataFrame, + id_columns: Optional[List[str]] = None, + fraction: float = 1.0, + location: str = FractionLocation.FIRST.value, +) -> pd.DataFrame: + """Select a portion of a dataset based on a fraction of the data. + Fraction can either be located at the start (location = FractionLocation.FIRST) or at the end (location = FractionLocation.LAST) + + Args: + df (pd.DataFrame): Input dataframe. + id_columns (List[str], optional): Columns which specify the IDs in the dataset. Defaults to None. + fraction (float): The fraction to select. + location (str): Location of where to select the fraction Defaults to FractionLocation.FIRST.value. + + Raises: + ValueError: Raised when the + + Returns: + pd.DataFrame: Subset of the dataframe. + """ + + if fraction < 0 or fraction > 1: + raise ValueError("The value of fraction should be between 0 and 1.") + + if not id_columns: + return _split_group_by_fixed_fraction( + df, + fraction=fraction, + location=location, + ).copy() + + groups = df.groupby(_get_groupby_columns(id_columns)) + result = [] + for name, group in groups: + result.append( + _split_group_by_fixed_fraction( + group, + name=name, + fraction=fraction, + location=location, + ) + ) + + return pd.concat(result) + + def _get_groupby_columns(id_columns: List[str]) -> Union[List[str], str]: if not isinstance(id_columns, (List)): raise ValueError("id_columns must be a list") @@ -202,7 +257,30 @@ def _split_group_by_fraction( else: end_index = None - return _split_group_by_index(group_df=group_df, start_index=start_index, end_index=end_index) + return _split_group_by_index(group_df=group_df, name=name, start_index=start_index, end_index=end_index) + + +def _split_group_by_fixed_fraction( + group_df: pd.DataFrame, + name: Optional[str] = None, + fraction: float = 1.0, + location: Optional[str] = None, +): + l = len(group_df) + fraction_size = int(fraction * l) + + if location == FractionLocation.FIRST.value: + start_index = 0 + end_index = fraction_size + elif location == FractionLocation.LAST.value: + start_index = l - fraction_size + end_index = l + else: + raise ValueError( + f"`location` should be either `{FractionLocation.FIRST.value}` or `{FractionLocation.LAST.value}`" + ) + + return _split_group_by_index(group_df=group_df, name=name, start_index=start_index, end_index=end_index) def convert_tsf_to_dataframe( From 19fe429c866230995572b98f38e15d1f8943b685 Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Mon, 25 Mar 2024 11:36:36 -0400 Subject: [PATCH 10/10] add tests Signed-off-by: Wesley M. Gifford --- tests/toolkit/test_util.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/toolkit/test_util.py diff --git a/tests/toolkit/test_util.py b/tests/toolkit/test_util.py new file mode 100644 index 00000000..9c01321b --- /dev/null +++ b/tests/toolkit/test_util.py @@ -0,0 +1,26 @@ +"""Tests for util functions""" + +import pytest + +from tsfm_public.toolkit.util import get_split_params + + +split_cases = [ + (0, 1, "select_by_index"), + (0, 0.1, "select_by_relative_fraction"), + (0.0, 0.1, "select_by_relative_fraction"), + (0.0, 200.0, "select_by_index"), + (0.0, 200, "select_by_index"), + (0.5, 1, "select_by_relative_fraction"), + (0.5, 1.0, "select_by_relative_fraction"), + (10, 100.0, "select_by_index"), +] + + +@pytest.mark.parametrize("left_arg,right_arg,expected", split_cases) +def test_get_split_params(left_arg, right_arg, expected): + """Test that get_split_params gives the right split function""" + + split_config, split_function = get_split_params({"train": [left_arg, right_arg], "valid": [0, 1], "test": [0, 1]}) + + assert split_function["train"].__name__ == expected