Skip to content

Commit

Permalink
redirect to test models
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Aug 16, 2024
1 parent 0263bd5 commit bbab0f5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions services/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_zero_shot_forecast_inference(ts_data):
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy()

msg = {
"model_id": "ibm-granite/granite-timeseries-ttm-v1",
"model_id": "ibm/test-ttm-v1",
"parameters": {
# "prediction_length": params["prediction_length"],
},
Expand All @@ -89,7 +89,7 @@ def test_zero_shot_forecast_inference(ts_data):
test_data_ = test_data.copy()

msg = {
"model_id": "ibm-granite/granite-timeseries-ttm-v1",
"model_id": "ibm/test-ttm-v1",
"parameters": {
# "prediction_length": params["prediction_length"],
},
Expand Down
8 changes: 5 additions & 3 deletions tests/toolkit/test_time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from transformers import PatchTSTForPrediction

from tsfm_public import TinyTimeMixerForPrediction
from tsfm_public import TinyTimeMixerConfig, TinyTimeMixerForPrediction
from tsfm_public.toolkit.time_series_forecasting_pipeline import (
TimeSeriesForecastingPipeline,
)
Expand All @@ -25,8 +25,10 @@ def patchtst_model():

@pytest.fixture(scope="module")
def ttm_model():
model_path = "ibm-granite/granite-timeseries-ttm-v1"
model = TinyTimeMixerForPrediction.from_pretrained(model_path)
# model_path = "ibm-granite/granite-timeseries-ttm-v1"

conf = TinyTimeMixerConfig()
model = TinyTimeMixerForPrediction(conf)

return model

Expand Down

0 comments on commit bbab0f5

Please sign in to comment.