diff --git a/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py b/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py index 673be8d9..38724786 100644 --- a/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py +++ b/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py @@ -10,6 +10,10 @@ # Have a look at the fine-tune scripts for example usecases of the pre-trained # TTM models. +# Basic usage: +# python ttm_pretrain_sample.py --data_root_path datasets/ +# See the get_ttm_args() function to know more about other TTM arguments + # Standard import math import os @@ -24,8 +28,6 @@ TinyTimeMixerConfig, TinyTimeMixerForPrediction, ) - -# First Party from tsfm_public.models.tinytimemixer.utils import get_data, get_ttm_args @@ -138,7 +140,10 @@ def pretrain(args, model, dset_train, dset_val): # Data prep dset_train, dset_val, dset_test = get_data( - args.dataset, args.context_length, args.forecast_length, data_root_path=args.data_root_path + args.dataset, + args.context_length, + args.forecast_length, + data_root_path=args.data_root_path, ) print("Length of the train dataset =", len(dset_train)) diff --git a/tsfm_public/models/tinytimemixer/README.md b/tsfm_public/models/tinytimemixer/README.md index a2b271f8..e1cabcc5 100644 --- a/tsfm_public/models/tinytimemixer/README.md +++ b/tsfm_public/models/tinytimemixer/README.md @@ -54,7 +54,7 @@ For Installation steps, refer [here](https://github.com/IBM/tsfm/tree/ttm) - Illustration notebook for 512-96 model on the considered target datasets: [here](../../../notebooks/hfdemo/tinytimemixer/ttm_benchmarking_512_96.ipynb) - Illustration notebook for 1024-96 model on the considered target datasets: [here](../../../notebooks/hfdemo/tinytimemixer/ttm_benchmarking_1024_96.ipynb) - M4-hourly transfer learning example: [here](../../../notebooks/hfdemo/tinytimemixer/ttm_m4_hourly.ipynb) -- Sample pretraining script: [here](../../../notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py.ipynb) +- Sample pretraining script: [here](../../../notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py)