From e68faa91c5c7f78957b45d4832035e368ebe68a8 Mon Sep 17 00:00:00 2001 From: Arindam Jati Date: Tue, 9 Apr 2024 12:09:17 -0400 Subject: [PATCH 1/2] readme link update Signed-off-by: Arindam Jati --- .../hfdemo/tinytimemixer/ttm_pretrain_sample.py | 16 +++++++++++----- tsfm_public/models/tinytimemixer/README.md | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py b/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py index 673be8d9..7aac7aa4 100644 --- a/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py +++ b/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py @@ -10,6 +10,10 @@ # Have a look at the fine-tune scripts for example usecases of the pre-trained # TTM models. +# Basic usage: +# python ttm_pretrain_sample.py --data_root_path datasets/ +# See the get_ttm_args() function to know more about other TTM arguments + # Standard import math import os @@ -24,11 +28,8 @@ TinyTimeMixerConfig, TinyTimeMixerForPrediction, ) - -# First Party from tsfm_public.models.tinytimemixer.utils import get_data, get_ttm_args - # Arguments args = get_ttm_args() @@ -78,7 +79,9 @@ def pretrain(args, model, dset_train, dset_val): save_strategy="epoch", logging_strategy="epoch", save_total_limit=1, - logging_dir=os.path.join(args.save_dir, "logs"), # Make sure to specify a logging directory + logging_dir=os.path.join( + args.save_dir, "logs" + ), # Make sure to specify a logging directory load_best_model_at_end=True, # Load the best model when training ends metric_for_best_model="eval_loss", # Metric to monitor for early stopping greater_is_better=False, # For loss @@ -138,7 +141,10 @@ def pretrain(args, model, dset_train, dset_val): # Data prep dset_train, dset_val, dset_test = get_data( - args.dataset, args.context_length, args.forecast_length, data_root_path=args.data_root_path + args.dataset, + args.context_length, + args.forecast_length, + data_root_path=args.data_root_path, ) print("Length of the train dataset =", len(dset_train)) diff --git a/tsfm_public/models/tinytimemixer/README.md b/tsfm_public/models/tinytimemixer/README.md index a2b271f8..e1cabcc5 100644 --- a/tsfm_public/models/tinytimemixer/README.md +++ b/tsfm_public/models/tinytimemixer/README.md @@ -54,7 +54,7 @@ For Installation steps, refer [here](https://github.com/IBM/tsfm/tree/ttm) - Illustration notebook for 512-96 model on the considered target datasets: [here](../../../notebooks/hfdemo/tinytimemixer/ttm_benchmarking_512_96.ipynb) - Illustration notebook for 1024-96 model on the considered target datasets: [here](../../../notebooks/hfdemo/tinytimemixer/ttm_benchmarking_1024_96.ipynb) - M4-hourly transfer learning example: [here](../../../notebooks/hfdemo/tinytimemixer/ttm_m4_hourly.ipynb) -- Sample pretraining script: [here](../../../notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py.ipynb) +- Sample pretraining script: [here](../../../notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py) From 75b9965053ac1c0d26abd8d28a66ff9a3a9db5b6 Mon Sep 17 00:00:00 2001 From: Arindam Jati Date: Tue, 9 Apr 2024 12:10:59 -0400 Subject: [PATCH 2/2] remake style Signed-off-by: Arindam Jati --- notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py b/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py index 7aac7aa4..38724786 100644 --- a/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py +++ b/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py @@ -30,6 +30,7 @@ ) from tsfm_public.models.tinytimemixer.utils import get_data, get_ttm_args + # Arguments args = get_ttm_args() @@ -79,9 +80,7 @@ def pretrain(args, model, dset_train, dset_val): save_strategy="epoch", logging_strategy="epoch", save_total_limit=1, - logging_dir=os.path.join( - args.save_dir, "logs" - ), # Make sure to specify a logging directory + logging_dir=os.path.join(args.save_dir, "logs"), # Make sure to specify a logging directory load_best_model_at_end=True, # Load the best model when training ends metric_for_best_model="eval_loss", # Metric to monitor for early stopping greater_is_better=False, # For loss