From 0a9aa9b7bb62a2169d4cd57ec1ded0cb847d8a79 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 18 Sep 2024 14:33:34 -0400 Subject: [PATCH] fix frequency inference --- services/inference/tsfminference/inference.py | 2 ++ tsfm_public/toolkit/time_series_forecasting_pipeline.py | 8 +------- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/services/inference/tsfminference/inference.py b/services/inference/tsfminference/inference.py index 09e45fcd..fe534dd9 100644 --- a/services/inference/tsfminference/inference.py +++ b/services/inference/tsfminference/inference.py @@ -115,6 +115,8 @@ def _forecast_common(self, input_payload: ForecastingInferenceInput) -> PredictO # train to estimate freq if not available preprocessor.train(data) + LOGGER.info(f"Data frequency determined: {preprocessor.freq}") + # warn if future data is not provided, but is needed by the model if preprocessor.exogenous_channel_indices and future_data is None: raise ValueError( diff --git a/tsfm_public/toolkit/time_series_forecasting_pipeline.py b/tsfm_public/toolkit/time_series_forecasting_pipeline.py index f22353fd..bacad2de 100644 --- a/tsfm_public/toolkit/time_series_forecasting_pipeline.py +++ b/tsfm_public/toolkit/time_series_forecasting_pipeline.py @@ -3,7 +3,7 @@ """Hugging Face Pipeline for Time Series Tasks""" import inspect -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union import numpy as np import pandas as pd @@ -116,13 +116,11 @@ def __init__( self, model: Union["PreTrainedModel"], *args, - freq: Optional[str] = None, explode_forecasts: bool = False, inverse_scale_outputs: bool = True, add_known_ground_truth: bool = True, **kwargs, ): - kwargs["freq"] = freq kwargs["explode_forecasts"] = explode_forecasts kwargs["inverse_scale_outputs"] = inverse_scale_outputs kwargs["add_known_ground_truth"] = add_known_ground_truth @@ -142,10 +140,6 @@ def __init__( if p not in kwargs: kwargs[p] = getattr(kwargs["feature_extractor"], p) - # get freq from kwargs or the preprocessor - if "freq" not in kwargs: - kwargs["freq"] = kwargs["feature_extractor"].freq - if "context_length" not in kwargs: kwargs["context_length"] = model.config.context_length