diff --git a/services/inference/tests/test_inference_service.py b/services/inference/tests/test_inference_service.py index a791b47..3c1dd18 100644 --- a/services/inference/tests/test_inference_service.py +++ b/services/inference/tests/test_inference_service.py @@ -21,7 +21,8 @@ "ttm-1536-96-r2": {"context_length": 1536, "prediction_length": 96}, "ibm/test-patchtst": {"context_length": 512, "prediction_length": 96}, "ibm/test-patchtsmixer": {"context_length": 512, "prediction_length": 96}, - "chronos-t5-tiny": {"context_length": 512, "prediction_length": 96}, + "chronos-t5-tiny": {"context_length": 512, "prediction_length": 16}, + "chronos-bolt-tiny": {"context_length": 512, "prediction_length": 16}, } @@ -369,24 +370,28 @@ def test_zero_shot_forecast_inference(ts_data): assert counts["output_data_points"] == (prediction_length // 4) * len(params["target_columns"][1:]) -@pytest.mark.parametrize("ts_data", ["chronos-t5-tiny"], indirect=True) +@pytest.mark.parametrize("ts_data", ["chronos-t5-tiny", "chronos-bolt-tiny"], indirect=True) def test_zero_shot_forecast_inference_chronos(ts_data): test_data, params = ts_data prediction_length = params["prediction_length"] model_id = params["model_id"] model_id_path: str = model_id - id_columns = params["id_columns"] + num_samples = 10 # test single test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() + parameters = { + "prediction_length": params["prediction_length"], + } + if model_id == "chronos-t5-tiny": + parameters["num_samples"] = num_samples + msg = { "model_id": model_id_path, - "parameters": { - "prediction_length": params["prediction_length"], - }, + "parameters": parameters, "schema": { "timestamp_column": params["timestamp_column"], "id_columns": params["id_columns"], @@ -400,6 +405,7 @@ def test_zero_shot_forecast_inference_chronos(ts_data): assert len(df_out) == 1 assert df_out[0].shape[0] == prediction_length + # test with future data. should throw error. test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() future_data = extend_time_series( select_by_index(test_data_, id_columns=params["id_columns"], start_index=-1), @@ -414,9 +420,7 @@ def test_zero_shot_forecast_inference_chronos(ts_data): msg = { "model_id": model_id, - "parameters": { - # "prediction_length": params["prediction_length"], - }, + "parameters": parameters, "schema": { "timestamp_column": params["timestamp_column"], "id_columns": params["id_columns"], @@ -430,6 +434,74 @@ def test_zero_shot_forecast_inference_chronos(ts_data): out, _ = get_inference_response(msg) assert "Chronos does not support or require future exogenous." in out.text + # test multi-time series + num_ids = test_data[id_columns[0]].nunique() + test_data_ = test_data.copy() + + msg = { + "model_id": model_id_path, + "parameters": parameters, + "schema": { + "timestamp_column": params["timestamp_column"], + "id_columns": params["id_columns"], + "target_columns": params["target_columns"], + }, + "data": encode_data(test_data_, params["timestamp_column"]), + "future_data": {}, + } + + df_out, _ = get_inference_response(msg) + + assert len(df_out) == 1 + assert df_out[0].shape[0] == prediction_length * num_ids + + # test multi-time series multi-id + multi_df = [] + for grp in ["A", "B"]: + td = test_data.copy() + td["id2"] = grp + multi_df.append(td) + test_data_ = pd.concat(multi_df, ignore_index=True) + new_id_columns = id_columns + ["id2"] + + num_ids = test_data_[new_id_columns[0]].nunique() * test_data_[new_id_columns[1]].nunique() + + msg = { + "model_id": model_id_path, + "parameters": parameters, + "schema": { + "timestamp_column": params["timestamp_column"], + "id_columns": new_id_columns, + "target_columns": params["target_columns"], + }, + "data": encode_data(test_data_, params["timestamp_column"]), + "future_data": {}, + } + + df_out, _ = get_inference_response(msg) + assert len(df_out) == 1 + assert df_out[0].shape[0] == prediction_length * num_ids + + # single series, less columns, no id + test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() + + msg = { + "model_id": model_id_path, + "parameters": parameters, + "schema": { + "timestamp_column": params["timestamp_column"], + "id_columns": [], + "target_columns": ["HULL"], + }, + "data": encode_data(test_data_, params["timestamp_column"]), + "future_data": {}, + } + + df_out, counts = get_inference_response(msg) + assert len(df_out) == 1 + assert df_out[0].shape[0] == prediction_length + assert df_out[0].shape[1] == 2 + @pytest.mark.parametrize("ts_data", ["ttm-r2-etth-finetuned-control"], indirect=True) def test_future_data_forecast_inference(ts_data): diff --git a/services/inference/tsfminference/chronos_service_handler.py b/services/inference/tsfminference/chronos_service_handler.py index bf21ddd..ca55998 100644 --- a/services/inference/tsfminference/chronos_service_handler.py +++ b/services/inference/tsfminference/chronos_service_handler.py @@ -1,17 +1,18 @@ """Service handler for Chronos""" import copy +import importlib import logging from pathlib import Path from typing import Dict, Optional, Union -import numpy as np import pandas as pd import torch -from chronos import ChronosPipeline from tsfm_public import TimeSeriesPreprocessor -from tsfm_public.toolkit.time_series_preprocessor import extend_time_series +from tsfm_public.toolkit.time_series_preprocessor import ( + create_timestamps, +) from .inference_payloads import ( ForecastingMetadataInput, @@ -66,11 +67,21 @@ def _prepare( ChronosForecastingHandler: The updated service handler object. """ - model = ChronosPipeline.from_pretrained( - self.model_path, - ) + # load model class + try: + mod = importlib.import_module(self.handler_config.module_path) + except ModuleNotFoundError as exc: + raise AttributeError("Could not load module '{module_path}'.") from exc + + model_class = getattr(mod, self.handler_config.model_class_name) + model = model_class.from_pretrained(self.model_path) + self.model = model - self.config = model.model.model.config + if hasattr(self.model.model, "model"): # chronos t5 family + self.config = model.model.model.config + else: # chronos bolt family + self.config = model.model.config + self.chronos_config = self.config.chronos_config preprocessor_params = copy.deepcopy(schema.model_dump()) @@ -78,6 +89,12 @@ def _prepare( parameters.prediction_length or self.chronos_config["prediction_length"] ) + # model specific parameters + preprocessor_params["num_samples"] = getattr(parameters, "num_samples", None) + preprocessor_params["temperature"] = getattr(parameters, "temperature", None) + preprocessor_params["top_k"] = getattr(parameters, "top_k", None) + preprocessor_params["top_p"] = getattr(parameters, "top_p", None) + LOGGER.info("initializing TSFM TimeSeriesPreprocessor") preprocessor = TimeSeriesPreprocessor( **preprocessor_params, @@ -128,47 +145,73 @@ def _run( pd.DataFrame: The forecasts produced by the model. """ + predictions = None if self.preprocessor.exogenous_channel_indices or future_data is not None: raise ValueError("Chronos does not support or require future exogenous.") target_columns = self.preprocessor.target_columns prediction_length = self.preprocessor.prediction_length + timestamp_column = self.preprocessor.timestamp_column + id_columns = self.preprocessor.id_columns - num_samples = self.chronos_config["num_samples"] - temperature = self.chronos_config["temperature"] - top_k = self.chronos_config["top_k"] - top_p = self.chronos_config["top_p"] + additional_params = {} + if "num_samples" in self.chronos_config: # chronos t5 family + additional_params["num_samples"] = self.preprocessor.num_samples or self.chronos_config["num_samples"] + additional_params["temperature"] = self.preprocessor.temperature or self.chronos_config["temperature"] + additional_params["top_k"] = self.preprocessor.top_k or self.chronos_config["top_k"] + additional_params["top_p"] = self.preprocessor.top_p or self.chronos_config["top_p"] + + LOGGER.info("model specific params: {}".format(additional_params)) + + scoped_cols = [timestamp_column] + id_columns + target_columns - context = torch.tensor(data[target_columns].values).transpose(1, 0) LOGGER.info("computing chronos forecasts.") - forecasts = self.model.predict( - context, - prediction_length=prediction_length, - num_samples=num_samples, - temperature=temperature, - top_k=top_k, - top_p=top_p, - limit_prediction_length=False, - ) - median_forecast_arr = [] - for i in range(len(target_columns)): - median_forecast_arr.append(np.quantile(forecasts[i].numpy(), [0.5], axis=0).flatten()) - - result = pd.DataFrame(np.array(median_forecast_arr).transpose(), columns=target_columns) - LOGGER.info("extend the time series.") - time_series = extend_time_series( - time_series=data, - freq=self.preprocessor.freq, - timestamp_column=schema.timestamp_column, - grouping_columns=schema.id_columns, - periods=prediction_length, - ) - # append time stamp column to the result - result[schema.timestamp_column] = ( - time_series[schema.timestamp_column].tail(result.shape[0]).reset_index(drop=True) - ) - return result + if not id_columns: + LOGGER.info("id columns are not provided, proceeding without groups.") + context = torch.tensor(data[target_columns].values).transpose(1, 0) + forecasts = self.model.predict( + context, + prediction_length=prediction_length, + limit_prediction_length=False, + **additional_params, + ) + median_forecasts = torch.quantile(forecasts, 0.5, dim=1).transpose(1, 0) + result = pd.DataFrame(median_forecasts, columns=target_columns) + if timestamp_column: + result[timestamp_column] = create_timestamps( + data[timestamp_column].iloc[-1], + freq=self.preprocessor.freq, + periods=result.shape[0], + ) + predictions = result + else: # create groups + LOGGER.info("using id columns {} to create groups.".format(id_columns)) + accumulator = [] + for grp, batch in data[scoped_cols].groupby(id_columns): + context = torch.tensor(batch[target_columns].values).transpose(1, 0) + forecasts = self.model.predict( + context, + prediction_length=prediction_length, + limit_prediction_length=False, + **additional_params, + ) + median_forecasts = torch.quantile(forecasts, 0.5, dim=1).transpose(1, 0) + result = pd.DataFrame(median_forecasts, columns=target_columns) + if timestamp_column: + result[timestamp_column] = create_timestamps( + batch[timestamp_column].iloc[-1], + freq=self.preprocessor.freq, + periods=result.shape[0], + ) + if (id_columns is not None) and id_columns: + for k, id_col in enumerate(id_columns): + result[id_col] = grp[k] + accumulator.append(result) + + predictions = pd.concat(accumulator, ignore_index=True) + + return predictions[scoped_cols] def _train( self,