diff --git a/tsfm_public/__init__.py b/tsfm_public/__init__.py index 51ab4da0..1eb66186 100644 --- a/tsfm_public/__init__.py +++ b/tsfm_public/__init__.py @@ -1,4 +1,66 @@ # Copyright contributors to the TSFM project # +from pathlib import Path +from typing import TYPE_CHECKING + +# Check the dependencies satisfy the minimal versions required. +from transformers.utils import _LazyModule, logging + from .version import __version__, __version_tuple__ + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Base objects, independent of any specific backend +_import_structure = { + "models": [], + "models.tinytimemixer": ["TINYTIMEMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TinyTimeMixerConfig"], + "toolkit": [ + "TimeSeriesPreprocessor", + "TimeSeriesForecastingPipeline", + "ForecastDFDataset", + "PretrainDFDataset", + "RegressionDFDataset", + ], +} + + +# PyTorch-backed objects +_import_structure["models.tinytimemixer"].extend( + [ + "TINYTIMEMIXER_PRETRAINED_MODEL_ARCHIVE_LIST", + "TinyTimeMixerPreTrainedModel", + "TinyTimeMixerModel", + "TinyTimeMixerForPrediction", + ] +) + +# Direct imports for type-checking +if TYPE_CHECKING: + from .models.tinytimemixer import ( + TINYTIMEMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP, + TINYTIMEMIXER_PRETRAINED_MODEL_ARCHIVE_LIST, + TinyTimeMixerConfig, + TinyTimeMixerForPrediction, + TinyTimeMixerModel, + TinyTimeMixerPreTrainedModel, + ) + from .toolkit import ( + ForecastDFDataset, + PretrainDFDataset, + RegressionDFDataset, + TimeSeriesForecastingPipeline, + TimeSeriesPreprocessor, + ) +else: + # Standard + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + extra_objects={"__version__": __version__}, + ) diff --git a/tsfm_public/toolkit/__init__.py b/tsfm_public/toolkit/__init__.py index 4f85bd0b..6c019cb7 100644 --- a/tsfm_public/toolkit/__init__.py +++ b/tsfm_public/toolkit/__init__.py @@ -1,2 +1,6 @@ # Copyright contributors to the TSFM project # + +from .dataset import ForecastDFDataset, PretrainDFDataset, RegressionDFDataset +from .time_series_forecasting_pipeline import TimeSeriesForecastingPipeline +from .time_series_preprocessor import TimeSeriesPreprocessor