From bf2f94bd8ca231c8916f3bfd434e85f3321fe2b7 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:05:29 -0400 Subject: [PATCH 1/3] add short dataset test --- tests/toolkit/test_dataset.py | 28 ++++++++++++++++++++++++++++ tsfm_public/toolkit/dataset.py | 4 +++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/toolkit/test_dataset.py b/tests/toolkit/test_dataset.py index 6055ffba..80537a60 100644 --- a/tests/toolkit/test_dataset.py +++ b/tests/toolkit/test_dataset.py @@ -176,6 +176,34 @@ def test_forecasting_df_dataset(ts_data_with_categorical): assert np.all(ds[0]["future_values"][:, 2].numpy() == 0) +def test_short_forecasting_df_dataset(ts_data_with_categorical): + prediction_length = 3 + context_length = 4 + target_columns = ["value1"] + + # df = ts_data_with_categorical.iloc[:2].copy() + + df = pd.DataFrame( + { + "timestamp": pd.to_datetime(range(10)), + "id": ["A"] * 10, + "value1": range(10), + } + ) + df = df.iloc[:1] + + ds = ForecastDFDataset( + df, + timestamp_column="timestamp", + id_columns=["id"], + target_columns=target_columns, + context_length=context_length, + prediction_length=prediction_length, + ) + + assert ds[0]["timestamp"] is pd.NaT + + def test_forecasting_df_dataset_stride(ts_data_with_categorical): prediction_length = 2 context_length = 3 diff --git a/tsfm_public/toolkit/dataset.py b/tsfm_public/toolkit/dataset.py index e199fbf6..55931a8a 100644 --- a/tsfm_public/toolkit/dataset.py +++ b/tsfm_public/toolkit/dataset.py @@ -902,7 +902,9 @@ def ts_padding( pad_df[c] = pad_df[c].astype(df.dtypes[c], copy=False) if timestamp_column: - if (df[timestamp_column].dtype.type == np.datetime64) or (df[timestamp_column].dtype == int): + if len(df) < 2: + pad_df[timestamp_column] = None + elif (df[timestamp_column].dtype.type == np.datetime64) or (df[timestamp_column].dtype == int): last_timestamp = df.iloc[0][timestamp_column] period = df.iloc[1][timestamp_column] - df.iloc[0][timestamp_column] prepended_timestamps = [last_timestamp + offset * period for offset in range(-fill_length, 0)] From 420d53f310c762a1b9d939159fab9ce8f8c1a597 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Thu, 24 Oct 2024 16:16:27 -0400 Subject: [PATCH 2/3] add test with very short input --- services/inference/tests/test_inference.py | 58 ++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/services/inference/tests/test_inference.py b/services/inference/tests/test_inference.py index 90c8eac7..1bbc5aff 100644 --- a/services/inference/tests/test_inference.py +++ b/services/inference/tests/test_inference.py @@ -125,6 +125,26 @@ def test_zero_shot_forecast_inference(ts_data): assert len(df_out) == 1 assert df_out[0].shape[0] == prediction_length + # test single, very short (length 2) + test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() + + msg = { + "model_id": model_id_path, + "parameters": { + "prediction_length": params["prediction_length"], + }, + "schema": { + "timestamp_column": params["timestamp_column"], + # "id_columns": params["id_columns"], + "target_columns": params["target_columns"], + }, + "data": encode_data(test_data_.iloc[:2], params["timestamp_column"]), + "future_data": {}, + } + + out = get_inference_response(msg) + assert "Received 2 time points for id a" in out.text + # test single, more data test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() @@ -361,3 +381,41 @@ def test_trained_model_inference(ts_data): df_out = get_inference_response(msg) assert len(df_out) == 1 assert df_out[0].shape[0] == prediction_length + + +# def test_simple(): +# import numpy as np +# import pandas as pd + +# series_length = 512 +# timestamps = pd.date_range("2021-01-01", periods=series_length).to_list() +# num_series = 5 + +# def encode_data(df: pd.DataFrame, timestamp_column: str) -> Dict[str, Any]: +# df[timestamp_column] = df[timestamp_column].apply(lambda x: x.isoformat()) +# data_payload = df.to_dict(orient="list") +# return data_payload + +# test_data = pd.DataFrame( +# { +# "date": timestamps * num_series, +# "id": np.array([f"id{i}" for i in range(num_series)]).repeat(series_length), +# "target": np.tile(np.arange(series_length).astype(float), num_series), +# } +# ) + +# msg = { +# "model_id": "ttm-r2", +# "parameters": { +# "prediction_length": 96, +# }, +# "schema": { +# "timestamp_column": "date", +# "id_columns": ["id"], +# "target_columns": ["target"], +# }, +# "data": encode_data(test_data, "date"), +# } + +# df_out = get_inference_response(msg) +# print(df_out) From 3e1f4a5dee0b0f48e1808a2d697ed40096d6abfc Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Thu, 24 Oct 2024 16:23:25 -0400 Subject: [PATCH 3/3] fix test --- services/inference/tests/test_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/services/inference/tests/test_inference.py b/services/inference/tests/test_inference.py index 1bbc5aff..ab1a782b 100644 --- a/services/inference/tests/test_inference.py +++ b/services/inference/tests/test_inference.py @@ -135,7 +135,7 @@ def test_zero_shot_forecast_inference(ts_data): }, "schema": { "timestamp_column": params["timestamp_column"], - # "id_columns": params["id_columns"], + "id_columns": params["id_columns"], "target_columns": params["target_columns"], }, "data": encode_data(test_data_.iloc[:2], params["timestamp_column"]),