diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 00000000..ee92f043 --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,39 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Python package + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # python -m pip install flake8 pytest + # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install ".[dev]" + - name: Ruff quality checks + run: | + # use hf-style check + make quality + - name: Test with pytest + run: | + pytest diff --git a/Makefile b/Makefile index b542c94f..571ad16d 100644 --- a/Makefile +++ b/Makefile @@ -1,18 +1,18 @@ # Adapted from HF Transformers: https://github.com/huggingface/transformers/tree/main .PHONY: quality style -check_dirs := tests src tsfm notebooks +check_dirs := tests tsfm_public tsfmhfdemos notebooks # this target runs checks on all files quality: - ruff check $(check_dirs) setup.py - ruff format --check $(check_dirs) setup.py + ruff check $(check_dirs) + ruff format --check $(check_dirs) # this target runs checks on all files and potentially modifies some of them style: - ruff check $(check_dirs) setup.py --fix - ruff format $(check_dirs) setup.py + ruff check $(check_dirs) --fix + ruff format $(check_dirs) diff --git a/tests/toolkit/test_dataset.py b/tests/toolkit/test_dataset.py index 9d48b98b..ee8e2e1f 100644 --- a/tests/toolkit/test_dataset.py +++ b/tests/toolkit/test_dataset.py @@ -52,9 +52,7 @@ def test_ts_padding(ts_data): # test date handled # integer - assert df_padded.iloc[0]["time_int"] == df.iloc[0]["time_int"] - ( - context_length - df.shape[0] - ) + assert df_padded.iloc[0]["time_int"] == df.iloc[0]["time_int"] - (context_length - df.shape[0]) # date df_padded = ts_padding( @@ -64,9 +62,9 @@ def test_ts_padding(ts_data): context_length=context_length, ) - assert df_padded.iloc[0]["time_date"] == df.iloc[0]["time_date"] - ( - context_length - df.shape[0] - ) * timedelta(days=1) + assert df_padded.iloc[0]["time_date"] == df.iloc[0]["time_date"] - (context_length - df.shape[0]) * timedelta( + days=1 + ) def test_pretrain_df_dataset(ts_data): diff --git a/tsfm_public/toolkit/dataset.py b/tsfm_public/toolkit/dataset.py index 72715704..aa418a62 100644 --- a/tsfm_public/toolkit/dataset.py +++ b/tsfm_public/toolkit/dataset.py @@ -47,22 +47,14 @@ def __init__( y_cols = [y_cols] if len(x_cols) > 0: - assert is_cols_in_df( - data_df, x_cols - ), f"one or more {x_cols} is not in the list of data_df columns" + assert is_cols_in_df(data_df, x_cols), f"one or more {x_cols} is not in the list of data_df columns" if len(y_cols) > 0: - assert is_cols_in_df( - data_df, y_cols - ), f"one or more {y_cols} is not in the list of data_df columns" + assert is_cols_in_df(data_df, y_cols), f"one or more {y_cols} is not in the list of data_df columns" if datetime_col: - assert datetime_col in list( - data_df.columns - ), f"{datetime_col} is not in the list of data_df columns" - assert ( - datetime_col not in x_cols - ), f"{datetime_col} should not be in the list of x_cols" + assert datetime_col in list(data_df.columns), f"{datetime_col} is not in the list of data_df columns" + assert datetime_col not in x_cols, f"{datetime_col} should not be in the list of x_cols" self.data_df = data_df self.datetime_col = datetime_col @@ -160,9 +152,7 @@ def __init__( cls=BaseDFDataset, ): if len(id_columns) > 0: - assert is_cols_in_df( - data_df, id_columns - ), f"{id_columns} is not in the data_df columns" + assert is_cols_in_df(data_df, id_columns), f"{id_columns} is not in the data_df columns" self.datetime_col = datetime_col self.id_columns = id_columns @@ -398,9 +388,7 @@ def __getitem__(self, time_id): # seq_x: batch_size x seq_len x num_x_cols seq_x = self.X[time_id : time_id + self.seq_len].values # seq_y: batch_size x pred_len x num_x_cols - seq_y = self.y[ - time_id + self.seq_len : time_id + self.seq_len + self.pred_len - ].values + seq_y = self.y[time_id + self.seq_len : time_id + self.seq_len + self.pred_len].values ret = { "past_values": np_to_torch(seq_x), @@ -490,9 +478,7 @@ def __init__( def __getitem__(self, time_id): # seq_x: batch_size x seq_len x num_x_cols seq_x = self.X[time_id : time_id + self.seq_len].values - seq_y = self.y[ - time_id + self.seq_len - 1 : time_id + self.seq_len - ].values.ravel() + seq_y = self.y[time_id + self.seq_len - 1 : time_id + self.seq_len].values.ravel() # return _torch(seq_x, seq_y) ret = { @@ -582,16 +568,12 @@ def ts_padding( if df[timestamp_column].dtype in [" bool: d6 = PretrainDFDataset(data_df=df, x_cols=["A", "B"], group_ids=["g1"], seq_len=2) print(f"d6: {d6}") - d7 = ForecastDFDataset( - data_df=df, x_cols=["A", "B"], group_ids=["g1"], seq_len=2, pred_len=2 - ) + d7 = ForecastDFDataset(data_df=df, x_cols=["A", "B"], group_ids=["g1"], seq_len=2, pred_len=2) diff --git a/tsfm_public/toolkit/time_series_forecasting_pipeline.py b/tsfm_public/toolkit/time_series_forecasting_pipeline.py index a7a58265..c012e0a9 100644 --- a/tsfm_public/toolkit/time_series_forecasting_pipeline.py +++ b/tsfm_public/toolkit/time_series_forecasting_pipeline.py @@ -63,7 +63,7 @@ def _sanitize_parameters(self, **kwargs): return preprocess_kwargs, {}, postprocess_kwargs - def __call__(self, time_series: Union["pandas.DataFrame", str], **kwargs): + def __call__(self, time_series: Union["pd.DataFrame", str], **kwargs): """Main method of the forecasting pipeline. Takes the input time series data (in tabular format) and produces predictions. @@ -146,11 +146,7 @@ def _forward(self, model_inputs, **kwargs): # copy the other inputs copy_inputs = True - for k in [ - akey - for akey in model_inputs.keys() - if (akey not in model_input_keys) or copy_inputs - ]: + for k in [akey for akey in model_inputs.keys() if (akey not in model_input_keys) or copy_inputs]: model_outputs[k] = model_inputs[k] return model_outputs @@ -162,11 +158,7 @@ def postprocess(self, input, **kwargs): """ out = {} - model_output_key = ( - "prediction_outputs" - if "prediction_outputs" in input.keys() - else "prediction_logits" - ) + model_output_key = "prediction_outputs" if "prediction_outputs" in input.keys() else "prediction_logits" for i, c in enumerate(kwargs["output_columns"]): out[f"{c}_prediction"] = input[model_output_key][:, :, i].numpy().tolist() diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index cada1e4f..46a64e71 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -11,7 +11,10 @@ import pandas as pd from datasets import Dataset from sklearn.preprocessing import StandardScaler -from transformers.feature_extraction_utils import FeatureExtractionMixin +from transformers.feature_extraction_utils import ( + FeatureExtractionMixin, + PreTrainedFeatureExtractor, +) # Local @@ -37,9 +40,7 @@ def to_dict(self) -> Dict[str, Any]: return output @classmethod - def from_dict( - cls, feature_extractor_dict: Dict[str, Any], **kwargs - ) -> "TimeSeriesScaler": + def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> "TimeSeriesScaler": """ Instantiates a TimeSeriesScaler from a Python dictionary of parameters. @@ -59,18 +60,12 @@ def from_dict( init_param_names = ["copy", "with_mean", "with_std"] init_params = {} - for k, v in [ - (k, v) for k, v in feature_extractor_dict.items() if k in init_param_names - ]: + for k, v in [(k, v) for k, v in feature_extractor_dict.items() if k in init_param_names]: init_params[k] = v t = TimeSeriesScaler(**init_params) - for k, v in [ - (k, v) - for k, v in feature_extractor_dict.items() - if k not in init_param_names - ]: + for k, v in [(k, v) for k, v in feature_extractor_dict.items() if k not in init_param_names]: setattr(t, k, v) return t @@ -104,9 +99,7 @@ def __init__( # note base class __init__ methods sets all arguments as attributes if not isinstance(id_columns, list): - raise ValueError( - f"Invalid argument provided for `id_columns`: {id_columns}" - ) + raise ValueError(f"Invalid argument provided for `id_columns`: {id_columns}") self.timestamp_column = timestamp_column self.input_columns = input_columns @@ -117,7 +110,7 @@ def __init__( self.scaling = scaling self.time_series_task = time_series_task self.scale_outputs = scale_outputs - self.scaler_dict = dict() + self.scaler_dict = {} kwargs["processor_class"] = self.__class__.__name__ @@ -138,9 +131,7 @@ def to_dict(self) -> Dict[str, Any]: return output @classmethod - def from_dict( - cls, feature_extractor_dict: Dict[str, Any], **kwargs - ) -> "PreTrainedFeatureExtractor": + def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor: """ Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of parameters. @@ -178,11 +169,7 @@ def _prepare_single_time_series(self, name, d): seq_x = d[self.input_columns].iloc[s_begin:s_end].values if self.time_series_task == TimeSeriesTask.FORECASTING: - seq_y = ( - d[self.output_columns] - .iloc[s_end : s_end + self.prediction_length] - .values - ) + seq_y = d[self.output_columns].iloc[s_end : s_end + self.prediction_length].values else: seq_y = None # to do: add handling of other types @@ -223,9 +210,7 @@ def _get_groups( dataset: pd.DataFrame, ): if self.id_columns: - group_by_columns = ( - self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] - ) + group_by_columns = self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] else: group_by_columns = INTERNAL_ID_COLUMN @@ -245,9 +230,7 @@ def _get_columns_to_scale( """ cols_to_scale = copy.copy(self.input_columns) if self.scale_outputs: - cols_to_scale.extend( - [c for c in self.output_columns if c not in self.input_columns] - ) + cols_to_scale.extend([c for c in self.output_columns if c not in self.input_columns]) return cols_to_scale def train( @@ -312,9 +295,7 @@ def scale_func(grp, id_columns): df = self._standardize_dataframe(dataset) if self.id_columns: - id_columns = ( - self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] - ) + id_columns = self.id_columns if len(self.id_columns) > 1 else self.id_columns[0] else: id_columns = INTERNAL_ID_COLUMN diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index e2f6effd..964c2e6c 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -18,18 +18,13 @@ def select_by_timestamp( end_timestamp: Optional[Union[str, datetime]] = None, ): if not start_timestamp and not end_timestamp: - raise ValueError( - "At least one of start_timestamp or end_timestamp must be specified." - ) + raise ValueError("At least one of start_timestamp or end_timestamp must be specified.") elif not start_timestamp: return df[df[timestamp_column] < end_timestamp] elif not end_timestamp: return df[df[timestamp_column] >= start_timestamp] else: - return df[ - (df[timestamp_column] >= start_timestamp) - & (df[timestamp_column] < end_timestamp) - ] + return df[(df[timestamp_column] >= start_timestamp) & (df[timestamp_column] < end_timestamp)] def select_by_index( @@ -54,16 +49,12 @@ def _split_group_by_index( return group_df.iloc[start_index:end_index, :] if not id_columns: - return _split_group_by_index( - df, start_index=start_index, end_index=end_index - ).copy() + return _split_group_by_index(df, start_index=start_index, end_index=end_index).copy() groups = df.groupby(id_columns) result = [] for name, group in groups: - result.append( - _split_group_by_index(group, start_index=start_index, end_index=end_index) - ) + result.append(_split_group_by_index(group, start_index=start_index, end_index=end_index)) return pd.concat(result) @@ -95,17 +86,13 @@ def convert_tsf_to_dataframe( if not line.startswith("@data"): line_content = line.split(" ") if line.startswith("@attribute"): - if ( - len(line_content) != 3 - ): # Attributes have both name and type + if len(line_content) != 3: # Attributes have both name and type raise Exception("Invalid meta-data specification.") col_names.append(line_content[1]) col_types.append(line_content[2]) else: - if ( - len(line_content) != 2 - ): # Other meta-data have only values + if len(line_content) != 2: # Other meta-data have only values raise Exception("Invalid meta-data specification.") if line.startswith("@frequency"): @@ -113,24 +100,18 @@ def convert_tsf_to_dataframe( elif line.startswith("@horizon"): forecast_horizon = int(line_content[1]) elif line.startswith("@missing"): - contain_missing_values = bool( - strtobool(line_content[1]) - ) + contain_missing_values = bool(strtobool(line_content[1])) elif line.startswith("@equallength"): contain_equal_length = bool(strtobool(line_content[1])) else: if len(col_names) == 0: - raise Exception( - "Missing attribute section. Attribute section must come before data." - ) + raise Exception("Missing attribute section. Attribute section must come before data.") found_data_tag = True elif not line.startswith("#"): if len(col_names) == 0: - raise Exception( - "Missing attribute section. Attribute section must come before data." - ) + raise Exception("Missing attribute section. Attribute section must come before data.") elif not found_data_tag: raise Exception("Missing @data tag.") else: @@ -163,9 +144,7 @@ def convert_tsf_to_dataframe( else: numeric_series.append(float(val)) - if numeric_series.count(replace_missing_vals_with) == len( - numeric_series - ): + if numeric_series.count(replace_missing_vals_with) == len(numeric_series): raise Exception( "All series values are missing. A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series." ) @@ -179,9 +158,7 @@ def convert_tsf_to_dataframe( elif col_types[i] == "string": att_val = str(full_info[i]) elif col_types[i] == "date": - att_val = datetime.strptime( - full_info[i], "%Y-%m-%d %H-%M-%S" - ) + att_val = datetime.strptime(full_info[i], "%Y-%m-%d %H-%M-%S") else: raise Exception( "Invalid attribute type." diff --git a/tsfm_public/toolkit/visualization.py b/tsfm_public/toolkit/visualization.py index 115e9e71..18af9cc8 100644 --- a/tsfm_public/toolkit/visualization.py +++ b/tsfm_public/toolkit/visualization.py @@ -83,9 +83,7 @@ def plot_ts_forecasting( # plot true data if not HAVE_SEABORN and plot_type == "seaborn": - raise ValueError( - "Please install the seaborn package if seaborn plots are needed." - ) + raise ValueError("Please install the seaborn package if seaborn plots are needed.") # if plot_start > len(test_data_updated): # logging.warning( @@ -138,9 +136,7 @@ def plot_ts_forecasting( # index into the predictions so that the end of the prediction coincides with the end of the ground truth # - predictions_end = ( - plot_range[-1] - prediction_length - context_length + 1 - ) # - context_length - prediction_length + predictions_end = plot_range[-1] - prediction_length - context_length + 1 # - context_length - prediction_length predictions_start = plot_range[0] - context_length @@ -154,9 +150,7 @@ def plot_ts_forecasting( if plot_type == "plotly": for i in plot_index: start = forecast_data.iloc[i][timestamp_column] - timestamps = pd.date_range( - start, freq=periodicity, periods=prediction_length + 1 - ) + timestamps = pd.date_range(start, freq=periodicity, periods=prediction_length + 1) timestamp = timestamps[1:] forecast_val = forecast_data.iloc[i][forecast_name] plot_line( @@ -189,7 +183,7 @@ def plot_ts_forecasting( height=fig_size[1], width=fig_size[0], title=title, - xaxis=dict(tickangle=-45), + xaxis={"tickangle": -45}, ) if return_image: return Image(fig.to_image(format="png")) diff --git a/tsfmhfdemos/neurips/app.py b/tsfmhfdemos/neurips/app.py index 8d6539b3..9b89d031 100644 --- a/tsfmhfdemos/neurips/app.py +++ b/tsfmhfdemos/neurips/app.py @@ -42,9 +42,7 @@ def tsforecasting_with_fmdls(): ) st.title(GLOBAL_CONFIG["title"]) - st.write( - "", unsafe_allow_html=True - ) + st.write("", unsafe_allow_html=True) st.write(GLOBAL_CONFIG["intro"]) @@ -103,26 +101,20 @@ def tsforecasting_with_fmdls(): for idx, channel in enumerate(dataset_meta["channel_plots"]): # col = columns[idx % num_cols] st.plotly_chart( - model_util.create_figure( - **dataset_meta, **model_meta, **approach_meta, channel=channel - ), + model_util.create_figure(**dataset_meta, **model_meta, **approach_meta, channel=channel), use_container_width=True, fig_size=(1600, 200), ) with col2: st.subheader("Performance") - df_perf = model_util.get_performance( - metrics=METRICS, **dataset_meta, **model_meta, **approach_meta - ) + df_perf = model_util.get_performance(metrics=METRICS, **dataset_meta, **model_meta, **approach_meta) df_perf_styled = df_perf.style.set_table_styles( [ {"selector": "th", "props": "background-color: whitesmoke;"}, ] - ).format( - precision=3 - ) # .style.hide(axis="index") + ).format(precision=3) # .style.hide(axis="index") st.write(df_perf_styled.to_html(), unsafe_allow_html=True) st.write("") @@ -142,60 +134,58 @@ def tsforecasting_with_fmdls(): # ) table_source = r""" - \begin{tabular}{cc|c|cc|cc|cc|cc|cc|ccc} - \cline{2-15} - &\multicolumn{2}{c|}{Models} & \multicolumn{2}{c}{\textbf{\citsm-Best}} & \multicolumn{2}{c|}{DLinear} & \multicolumn{2}{c|}{PatchTST}& \multicolumn{2}{c|}{FEDformer}& \multicolumn{2}{c|}{Autoformer}& \multicolumn{2}{c}{Informer} \\ - \cline{2-15} - &\multicolumn{2}{c|}{Metric}&MSE&MAE&MSE&MAE&MSE&MAE&MSE&MAE&MSE&MAE&MSE&MAE\\ - \cline{2-15} - &\multirow{4}*{\rotatebox{90}{ETTH1}} & 96 & \textbf{0.368$\pm$0.001} & \textbf{0.398$\pm$0.001} & 0.375 & \uline{0.399} & \uline{0.370} & 0.400 & 0.376 & 0.419 & 0.449 & 0.459 & 0.865 & 0.713 \\ - &\multicolumn{1}{c|}{} & 192 & \textbf{0.399$\pm$0.002} & \uline{0.418$\pm$0.001} & \uline{0.405} & \textbf{0.416} & 0.413 & 0.429 & 0.420 & 0.448 & 0.500 & 0.482 & 1.008 & 0.792 \\ - &\multicolumn{1}{c|}{} & 336 & \textbf{0.421$\pm$0.004} & \textbf{0.436$\pm$0.003} & 0.439 & 0.443 & \uline{0.422} & \uline{0.440} & 0.459 & 0.465 & 0.521 & 0.496 & 1.107 & 0.809 \\ - &\multicolumn{1}{c|}{} & 720 & \textbf{0.444$\pm$0.003} & \textbf{0.467$\pm$0.002} & 0.472 & 0.490 & \uline{0.447} & \uline{0.468} & 0.506 & 0.507 & 0.514 & 0.512 & 1.181 & 0.865 \\ - \cline{2-15} - &\multirow{4}*{\rotatebox{90}{ETTH2}} & 96 & \uline{0.276$\pm$0.006} & \textbf{0.337$\pm$0.003} & 0.289 & \uline{0.353} & \textbf{0.274} & \textbf{0.337} & 0.346 & 0.388 & 0.358 & 0.397 & 3.755 & 1.525 \\ - &\multicolumn{1}{c|}{} & 192 & \textbf{0.330$\pm$0.003} & \textbf{0.374$\pm$0.001} & 0.383 & 0.418 & \uline{0.341} & \uline{0.382} & 0.429 & 0.439 & 0.456 & 0.452 & 5.602 & 1.931 \\ - &\multicolumn{1}{c|}{} & 336 & \uline{0.357$\pm$0.001} & \uline{0.401$\pm$0.002} & 0.448 & 0.465 & \textbf{0.329} & \textbf{0.384} & 0.496 & 0.487 & 0.482 & 0.486 & 4.721 & 1.835 \\ - &\multicolumn{1}{c|}{} & 720 & \uline{0.395$\pm$0.003} & \uline{0.436$\pm$0.003} & 0.605 & 0.551 & \textbf{0.379} & \textbf{0.422} & 0.463 & 0.474 & 0.515 & 0.511 & 3.647 & 1.625 \\ - \cline{2-15} - &\multirow{4}*{\rotatebox{90}{ETTM1}} & 96 & \textbf{0.291$\pm$0.002} & \uline{0.346$\pm$0.002} & 0.299 & \textbf{0.343} & \uline{0.293} & \uline{0.346} & 0.379 & 0.419 & 0.505 & 0.475 & 0.672 & 0.571 \\ - &\multicolumn{1}{c|}{} & 192 & \textbf{0.333$\pm$0.002} & \uline{0.369$\pm$0.002} & \uline{0.335} & \textbf{0.365} & \textbf{0.333} & 0.370 & 0.426 & 0.441 & 0.553 & 0.496 & 0.795 & 0.669 \\ - &\multicolumn{1}{c|}{} & 336 & \textbf{0.365$\pm$0.005} & \textbf{0.385$\pm$0.004} & \uline{0.369} & \uline{0.386} & \uline{0.369} & 0.392 & 0.445 & 0.459 & 0.621 & 0.537 & 1.212 & 0.871 \\ - &\multicolumn{1}{c|}{} & 720 & \textbf{0.416$\pm$0.002} & \textbf{0.413$\pm$0.001} & \uline{0.425} & 0.421 & \textbf{0.416} & \uline{0.420} & 0.543 & 0.490 & 0.671 & 0.561 & 1.166 & 0.823 \\ - \cline{2-15} - &\multirow{4}*{\rotatebox{90}{ETTM2}} & 96 & \textbf{0.164$\pm$0.002} & \textbf{0.255$\pm$0.002} & 0.167 & 0.260 & \uline{0.166} & \uline{0.256} & 0.203 & 0.287 & 0.255 & 0.339 & 0.365 & 0.453 \\ - &\multicolumn{1}{c|}{} & 192 & \textbf{0.219$\pm$0.002} & \textbf{0.293$\pm$0.002} & 0.224 & 0.303 & \uline{0.223} & \uline{0.296} & 0.269 & 0.328 & 0.281 & 0.340 & 0.533 & 0.563 \\ - &\multicolumn{1}{c|}{} & 336 & \textbf{0.273$\pm$0.003} & \textbf{0.329$\pm$0.003} & 0.281 & \uline{0.342} & \uline{0.274} & \textbf{0.329} & 0.325 & 0.366 & 0.339 & 0.372 & 1.363 & 0.887 \\ - &\multicolumn{1}{c|}{} & 720 & \textbf{0.358$\pm$0.002} & \textbf{0.380$\pm$0.001} & 0.397 & 0.421 & \uline{0.362} & \uline{0.385} & 0.421 & 0.415 & 0.433 & 0.432 & 3.379 & 1.338 \\ - \cline{2-15} - &\multirow{4}*{\rotatebox{90}{Electricity}} & 96 & \textbf{0.129$\pm$1e-4} & \uline{0.224$\pm$0.001} & \uline{0.140} & 0.237 & \textbf{0.129} & \textbf{0.222} & 0.193 & 0.308 & 0.201 & 0.317 & 0.274 & 0.368 \\ - &\multicolumn{1}{c|}{} & 192 & \textbf{0.146$\pm$0.001} & \uline{0.242$\pm$1e-4} & 0.153 & 0.249 & \uline{0.147} & \textbf{0.240} & 0.201 & 0.315 & 0.222 & 0.334 & 0.296 & 0.386 \\ - &\multicolumn{1}{c|}{} & 336 & \textbf{0.158$\pm$0.001} & \textbf{0.256$\pm$0.001} & 0.169 & 0.267 & \uline{0.163} & \uline{0.259} & 0.214 & 0.329 & 0.231 & 0.338 & 0.300 & 0.394 \\ - &\multicolumn{1}{c|}{} & 720 & \textbf{0.186$\pm$0.001} & \textbf{0.282$\pm$0.001} & 0.203 & 0.301 & \uline{0.197} & \uline{0.290} & 0.246 & 0.355 & 0.254 & 0.361 & 0.373 & 0.439 \\ - \cline{2-15} - &\multirow{4}*{\rotatebox{90}{Traffic}} & 96 & \textbf{0.356$\pm$0.002} & \textbf{0.248$\pm$0.002} & 0.410 & 0.282 & \uline{0.360} & \uline{0.249} & 0.587 & 0.366 & 0.613 & 0.388 & 0.719 & 0.391 \\ - &\multicolumn{1}{c|}{} & 192 & \textbf{0.377$\pm$0.003} & \uline{0.257$\pm$0.002} & 0.423 & 0.287 & \uline{0.379} & \textbf{0.256} & 0.604 & 0.373 & 0.616 & 0.382 & 0.696 & 0.379 \\ - &\multicolumn{1}{c|}{} & 336 & \textbf{0.385$\pm$0.002} & \textbf{0.262$\pm$0.001} & 0.436 & 0.296 & \uline{0.392} & \uline{0.264} & 0.621 & 0.383 & 0.622 & 0.337 & 0.777 & 0.420 \\ - &\multicolumn{1}{c|}{} & 720 & \textbf{0.424$\pm$0.001} & \textbf{0.283$\pm$0.001} & 0.466 & 0.315 & \uline{0.432} & \uline{0.286} & 0.626 & 0.382 & 0.660 & 0.408 & 0.864 & 0.472 \\ - \cline{2-15} - &\multirow{4}*{\rotatebox{90}{Weather}} & 96 & \textbf{0.146$\pm$0.001} & \textbf{0.197$\pm$0.002} & 0.176 & 0.237 & \uline{0.149} & \uline{0.198} & 0.217 & 0.296 & 0.266 & 0.336 & 0.300 & 0.384 \\ - &\multicolumn{1}{c|}{} & 192 & \textbf{0.191$\pm$0.001} & \textbf{0.240$\pm$0.001} & 0.220 & 0.282 & \uline{0.194} & \uline{0.241} & 0.276 & 0.336 & 0.307 & 0.367 & 0.598 & 0.544 \\ - &\multicolumn{1}{c|}{} & 336 & \textbf{0.243$\pm$0.001} & \textbf{0.279$\pm$0.002} & 0.265 & 0.319 & \uline{0.245} & \uline{0.282} & 0.339 & 0.380 & 0.359 & 0.395 & 0.578 & 0.523 \\ - &\multicolumn{1}{c|}{} & 720 & \uline{0.316$\pm$0.001} & \textbf{0.333$\pm$0.002} & 0.323 & 0.362 & \textbf{0.314} & \uline{0.334} & 0.403 & 0.428 & 0.419 & 0.428 & 1.059 & 0.741 \\ - \cline{2-15} - % &\multicolumn{4}{c|}{\makecell{\textbf{\citsm-Best} \textbf{\% improvement}}}& \textbf{8\%} & \textbf{6.8\%}& \textbf{0.7\%} & \textbf{0.4\%} & \textbf{22.9\%} & \textbf{18.2\%} & \textbf{30.1\%} & \textbf{22.7\%} & \textbf{64\%} & \textbf{50.3\%} \\ - % &\multicolumn{4}{c|}{\makecell{\textbf{\citsm-Best} \textbf{\% improvement (MSE)}}}& \multicolumn{2}{c}{\textbf{8\%}} & \multicolumn{2}{c}{ \textbf{0.7\%}} & \multicolumn{2}{c}{\textbf{22.9\%}} & \multicolumn{2}{c}{\textbf{30.1\%}} & \multicolumn{2}{c}{\textbf{64\%}} \\ - &\multicolumn{4}{c|}{\makecell{\textbf{\citsm-Best} \textbf{\% improvement (MSE)}}}& \multicolumn{2}{c}{\textbf{8\%}} & \multicolumn{2}{c}{ \textbf{1\%}} & \multicolumn{2}{c}{\textbf{23\%}} & \multicolumn{2}{c}{\textbf{30\%}} & \multicolumn{2}{c}{\textbf{64\%}} \\ - \cline{2-15} - \end{tabular} +\begin{tabular}{cc|c|cc|cc|cc|cc|cc|ccc} +\cline{2-15} +&\multicolumn{2}{c|}{Models} & \multicolumn{2}{c}{\textbf{\citsm-Best}} & \multicolumn{2}{c|}{DLinear} & \multicolumn{2}{c|}{PatchTST}& \multicolumn{2}{c|}{FEDformer}& \multicolumn{2}{c|}{Autoformer}& \multicolumn{2}{c}{Informer} \\ +\cline{2-15} +&\multicolumn{2}{c|}{Metric}&MSE&MAE&MSE&MAE&MSE&MAE&MSE&MAE&MSE&MAE&MSE&MAE\\ +\cline{2-15} +&\multirow{4}*{\rotatebox{90}{ETTH1}} & 96 & \textbf{0.368$\pm$0.001} & \textbf{0.398$\pm$0.001} & 0.375 & \uline{0.399} & \uline{0.370} & 0.400 & 0.376 & 0.419 & 0.449 & 0.459 & 0.865 & 0.713 \\ +&\multicolumn{1}{c|}{} & 192 & \textbf{0.399$\pm$0.002} & \uline{0.418$\pm$0.001} & \uline{0.405} & \textbf{0.416} & 0.413 & 0.429 & 0.420 & 0.448 & 0.500 & 0.482 & 1.008 & 0.792 \\ +&\multicolumn{1}{c|}{} & 336 & \textbf{0.421$\pm$0.004} & \textbf{0.436$\pm$0.003} & 0.439 & 0.443 & \uline{0.422} & \uline{0.440} & 0.459 & 0.465 & 0.521 & 0.496 & 1.107 & 0.809 \\ +&\multicolumn{1}{c|}{} & 720 & \textbf{0.444$\pm$0.003} & \textbf{0.467$\pm$0.002} & 0.472 & 0.490 & \uline{0.447} & \uline{0.468} & 0.506 & 0.507 & 0.514 & 0.512 & 1.181 & 0.865 \\ +\cline{2-15} +&\multirow{4}*{\rotatebox{90}{ETTH2}} & 96 & \uline{0.276$\pm$0.006} & \textbf{0.337$\pm$0.003} & 0.289 & \uline{0.353} & \textbf{0.274} & \textbf{0.337} & 0.346 & 0.388 & 0.358 & 0.397 & 3.755 & 1.525 \\ +&\multicolumn{1}{c|}{} & 192 & \textbf{0.330$\pm$0.003} & \textbf{0.374$\pm$0.001} & 0.383 & 0.418 & \uline{0.341} & \uline{0.382} & 0.429 & 0.439 & 0.456 & 0.452 & 5.602 & 1.931 \\ +&\multicolumn{1}{c|}{} & 336 & \uline{0.357$\pm$0.001} & \uline{0.401$\pm$0.002} & 0.448 & 0.465 & \textbf{0.329} & \textbf{0.384} & 0.496 & 0.487 & 0.482 & 0.486 & 4.721 & 1.835 \\ +&\multicolumn{1}{c|}{} & 720 & \uline{0.395$\pm$0.003} & \uline{0.436$\pm$0.003} & 0.605 & 0.551 & \textbf{0.379} & \textbf{0.422} & 0.463 & 0.474 & 0.515 & 0.511 & 3.647 & 1.625 \\ +\cline{2-15} +&\multirow{4}*{\rotatebox{90}{ETTM1}} & 96 & \textbf{0.291$\pm$0.002} & \uline{0.346$\pm$0.002} & 0.299 & \textbf{0.343} & \uline{0.293} & \uline{0.346} & 0.379 & 0.419 & 0.505 & 0.475 & 0.672 & 0.571 \\ +&\multicolumn{1}{c|}{} & 192 & \textbf{0.333$\pm$0.002} & \uline{0.369$\pm$0.002} & \uline{0.335} & \textbf{0.365} & \textbf{0.333} & 0.370 & 0.426 & 0.441 & 0.553 & 0.496 & 0.795 & 0.669 \\ +&\multicolumn{1}{c|}{} & 336 & \textbf{0.365$\pm$0.005} & \textbf{0.385$\pm$0.004} & \uline{0.369} & \uline{0.386} & \uline{0.369} & 0.392 & 0.445 & 0.459 & 0.621 & 0.537 & 1.212 & 0.871 \\ +&\multicolumn{1}{c|}{} & 720 & \textbf{0.416$\pm$0.002} & \textbf{0.413$\pm$0.001} & \uline{0.425} & 0.421 & \textbf{0.416} & \uline{0.420} & 0.543 & 0.490 & 0.671 & 0.561 & 1.166 & 0.823 \\ +\cline{2-15} +&\multirow{4}*{\rotatebox{90}{ETTM2}} & 96 & \textbf{0.164$\pm$0.002} & \textbf{0.255$\pm$0.002} & 0.167 & 0.260 & \uline{0.166} & \uline{0.256} & 0.203 & 0.287 & 0.255 & 0.339 & 0.365 & 0.453 \\ +&\multicolumn{1}{c|}{} & 192 & \textbf{0.219$\pm$0.002} & \textbf{0.293$\pm$0.002} & 0.224 & 0.303 & \uline{0.223} & \uline{0.296} & 0.269 & 0.328 & 0.281 & 0.340 & 0.533 & 0.563 \\ +&\multicolumn{1}{c|}{} & 336 & \textbf{0.273$\pm$0.003} & \textbf{0.329$\pm$0.003} & 0.281 & \uline{0.342} & \uline{0.274} & \textbf{0.329} & 0.325 & 0.366 & 0.339 & 0.372 & 1.363 & 0.887 \\ +&\multicolumn{1}{c|}{} & 720 & \textbf{0.358$\pm$0.002} & \textbf{0.380$\pm$0.001} & 0.397 & 0.421 & \uline{0.362} & \uline{0.385} & 0.421 & 0.415 & 0.433 & 0.432 & 3.379 & 1.338 \\ +\cline{2-15} +&\multirow{4}*{\rotatebox{90}{Electricity}} & 96 & \textbf{0.129$\pm$1e-4} & \uline{0.224$\pm$0.001} & \uline{0.140} & 0.237 & \textbf{0.129} & \textbf{0.222} & 0.193 & 0.308 & 0.201 & 0.317 & 0.274 & 0.368 \\ +&\multicolumn{1}{c|}{} & 192 & \textbf{0.146$\pm$0.001} & \uline{0.242$\pm$1e-4} & 0.153 & 0.249 & \uline{0.147} & \textbf{0.240} & 0.201 & 0.315 & 0.222 & 0.334 & 0.296 & 0.386 \\ +&\multicolumn{1}{c|}{} & 336 & \textbf{0.158$\pm$0.001} & \textbf{0.256$\pm$0.001} & 0.169 & 0.267 & \uline{0.163} & \uline{0.259} & 0.214 & 0.329 & 0.231 & 0.338 & 0.300 & 0.394 \\ +&\multicolumn{1}{c|}{} & 720 & \textbf{0.186$\pm$0.001} & \textbf{0.282$\pm$0.001} & 0.203 & 0.301 & \uline{0.197} & \uline{0.290} & 0.246 & 0.355 & 0.254 & 0.361 & 0.373 & 0.439 \\ +\cline{2-15} +&\multirow{4}*{\rotatebox{90}{Traffic}} & 96 & \textbf{0.356$\pm$0.002} & \textbf{0.248$\pm$0.002} & 0.410 & 0.282 & \uline{0.360} & \uline{0.249} & 0.587 & 0.366 & 0.613 & 0.388 & 0.719 & 0.391 \\ +&\multicolumn{1}{c|}{} & 192 & \textbf{0.377$\pm$0.003} & \uline{0.257$\pm$0.002} & 0.423 & 0.287 & \uline{0.379} & \textbf{0.256} & 0.604 & 0.373 & 0.616 & 0.382 & 0.696 & 0.379 \\ +&\multicolumn{1}{c|}{} & 336 & \textbf{0.385$\pm$0.002} & \textbf{0.262$\pm$0.001} & 0.436 & 0.296 & \uline{0.392} & \uline{0.264} & 0.621 & 0.383 & 0.622 & 0.337 & 0.777 & 0.420 \\ +&\multicolumn{1}{c|}{} & 720 & \textbf{0.424$\pm$0.001} & \textbf{0.283$\pm$0.001} & 0.466 & 0.315 & \uline{0.432} & \uline{0.286} & 0.626 & 0.382 & 0.660 & 0.408 & 0.864 & 0.472 \\ +\cline{2-15} +&\multirow{4}*{\rotatebox{90}{Weather}} & 96 & \textbf{0.146$\pm$0.001} & \textbf{0.197$\pm$0.002} & 0.176 & 0.237 & \uline{0.149} & \uline{0.198} & 0.217 & 0.296 & 0.266 & 0.336 & 0.300 & 0.384 \\ +&\multicolumn{1}{c|}{} & 192 & \textbf{0.191$\pm$0.001} & \textbf{0.240$\pm$0.001} & 0.220 & 0.282 & \uline{0.194} & \uline{0.241} & 0.276 & 0.336 & 0.307 & 0.367 & 0.598 & 0.544 \\ +&\multicolumn{1}{c|}{} & 336 & \textbf{0.243$\pm$0.001} & \textbf{0.279$\pm$0.002} & 0.265 & 0.319 & \uline{0.245} & \uline{0.282} & 0.339 & 0.380 & 0.359 & 0.395 & 0.578 & 0.523 \\ +&\multicolumn{1}{c|}{} & 720 & \uline{0.316$\pm$0.001} & \textbf{0.333$\pm$0.002} & 0.323 & 0.362 & \textbf{0.314} & \uline{0.334} & 0.403 & 0.428 & 0.419 & 0.428 & 1.059 & 0.741 \\ +\cline{2-15} +% &\multicolumn{4}{c|}{\makecell{\textbf{\citsm-Best} \textbf{\% improvement}}}& \textbf{8\%} & \textbf{6.8\%}& \textbf{0.7\%} & \textbf{0.4\%} & \textbf{22.9\%} & \textbf{18.2\%} & \textbf{30.1\%} & \textbf{22.7\%} & \textbf{64\%} & \textbf{50.3\%} \\ +% &\multicolumn{4}{c|}{\makecell{\textbf{\citsm-Best} \textbf{\% improvement (MSE)}}}& \multicolumn{2}{c}{\textbf{8\%}} & \multicolumn{2}{c}{ \textbf{0.7\%}} & \multicolumn{2}{c}{\textbf{22.9\%}} & \multicolumn{2}{c}{\textbf{30.1\%}} & \multicolumn{2}{c}{\textbf{64\%}} \\ +&\multicolumn{4}{c|}{\makecell{\textbf{\citsm-Best} \textbf{\% improvement (MSE)}}}& \multicolumn{2}{c}{\textbf{8\%}} & \multicolumn{2}{c}{ \textbf{1\%}} & \multicolumn{2}{c}{\textbf{23\%}} & \multicolumn{2}{c}{\textbf{30\%}} & \multicolumn{2}{c}{\textbf{64\%}} \\ +\cline{2-15} +\end{tabular} """ out = re.sub(r"\\textbf{([^&]*)}", r"\1", table_source) out = re.sub(r"\\uline{([^&]*)}", r"\1", out) out = re.sub(r"\s*|\$\\pm\$[^&]*|\\cline{.*}", "", out) - vals = np.array([r.split("&")[3:] for r in out.split(r"\\")[2:30]]).astype( - float - ) + vals = np.array([r.split("&")[3:] for r in out.split(r"\\")[2:30]]).astype(float) leaderboard = pd.DataFrame( index=pd.MultiIndex.from_product( diff --git a/tsfmhfdemos/neurips/backends/v1/model_util.py b/tsfmhfdemos/neurips/backends/v1/model_util.py index 57cd8166..edfe6f19 100644 --- a/tsfmhfdemos/neurips/backends/v1/model_util.py +++ b/tsfmhfdemos/neurips/backends/v1/model_util.py @@ -149,9 +149,7 @@ def forecast(**kwargs) -> pd.DataFrame: prep_path = get_preprocessor_path(**kwargs) model_class = get_model_class(model_path) - model = model_class.from_pretrained( - model_path, num_input_channels=len(forecast_columns) - ) + model = model_class.from_pretrained(model_path, num_input_channels=len(forecast_columns)) forecast_pipeline = TimeSeriesForecastingPipeline( model=model, @@ -194,9 +192,7 @@ def create_figure(**kwargs) -> graph_objs.Figure: model_class = get_model_class(model_path) - model = model_class.from_pretrained( - model_path, num_input_channels=len(forecast_columns) - ) + model = model_class.from_pretrained(model_path, num_input_channels=len(forecast_columns)) context_length = model.config.context_length periodicity = kwargs["periodicity"] channel = kwargs["channel"]