Skip to content

Commit

Permalink
Merge pull request #116 from ibm-granite/pipeline_updates
Browse files Browse the repository at this point in the history
Pipeline updates
  • Loading branch information
wgifford authored Aug 23, 2024
2 parents 6cb3f1a + c312d23 commit c2f07f2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/toolkit/test_time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def etth_data():
def test_forecasting_pipeline_defaults():
model = PatchTSTForPrediction(PatchTSTConfig(prediction_length=3, context_length=33))

tspipe = TimeSeriesForecastingPipeline(model=model)
tspipe = TimeSeriesForecastingPipeline(model)

assert tspipe._preprocess_params["prediction_length"] == 3
assert tspipe._preprocess_params["context_length"] == 33
Expand Down
8 changes: 3 additions & 5 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers import PreTrainedModel
from transformers.data.data_collator import default_data_collator
from transformers.pipelines.base import (
GenericTensor,
Expand Down Expand Up @@ -112,6 +113,7 @@ class TimeSeriesForecastingPipeline(TimeSeriesPipeline):

def __init__(
self,
model: Union["PreTrainedModel"],
*args,
freq: Optional[str] = None,
explode_forecasts: bool = False,
Expand Down Expand Up @@ -143,10 +145,6 @@ def __init__(
if "freq" not in kwargs:
kwargs["freq"] = kwargs["feature_extractor"].freq

model = kwargs.get("model", None)
if not model:
raise ValueError("A model must be supplied during instantiation of a TimeSeriesForecastingPipeline")

if "context_length" not in kwargs:
kwargs["context_length"] = model.config.context_length

Expand All @@ -161,7 +159,7 @@ def __init__(
else:
kwargs["frequency_token"] = None

super().__init__(*args, **kwargs)
super().__init__(model, *args, **kwargs)

if self.framework == "tf":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
Expand Down

0 comments on commit c2f07f2

Please sign in to comment.