Skip to content

Commit

Permalink
Merge pull request #202 from ibm-granite/vj_2
Browse files Browse the repository at this point in the history
Enhance getting_started notebook
  • Loading branch information
wgifford authored Nov 20, 2024
2 parents 9be4661 + ac5705f commit 15599d3
Show file tree
Hide file tree
Showing 4 changed files with 633 additions and 233 deletions.
141 changes: 124 additions & 17 deletions notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@
import logging
import math
import os
import tempfile

import pandas as pd
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed

from tsfm_public import TimeSeriesPreprocessor, get_datasets
from tsfm_public.models.tinytimemixer import (
TinyTimeMixerConfig,
TinyTimeMixerForPrediction,
)
from tsfm_public.models.tinytimemixer.utils import get_ttm_args
from tsfm_public.toolkit.data_handling import load_dataset
from tsfm_public.toolkit.get_model import get_model
from tsfm_public.toolkit.lr_finder import optimal_lr_finder
from tsfm_public.toolkit.visualization import plot_predictions


logger = logging.getLogger(__file__)
Expand All @@ -33,7 +38,7 @@
# See the get_ttm_args() function to know more about other TTM arguments


def get_model(args):
def get_base_model(args):
# Pre-train a `TTM` forecasting model
config = TinyTimeMixerConfig(
context_length=args.context_length,
Expand All @@ -42,16 +47,16 @@ def get_model(args):
num_input_channels=1,
patch_stride=args.patch_length,
d_model=args.d_model,
num_layers=2,
num_layers=args.num_layers, # increase the number of layers if we want more complex models
mode="common_channel",
expansion_factor=2,
dropout=0.2,
head_dropout=0.2,
dropout=args.dropout,
head_dropout=args.head_dropout,
scaling="std",
gated_attn=True,
adaptive_patching_levels=args.adaptive_patching_levels,
# decoder params
decoder_num_layers=2,
decoder_num_layers=args.decoder_num_layers, # increase the number of layers if we want more complex models
decoder_adaptive_patching_levels=0,
decoder_mode="common_channel",
decoder_raw_residual=False,
Expand All @@ -64,11 +69,24 @@ def get_model(args):


def pretrain(args, model, dset_train, dset_val):
# Find optimal learning rate
# Use with caution: Set it manually if the suggested learning rate is not suitable

learning_rate, model = optimal_lr_finder(
model,
dset_train,
batch_size=args.batch_size,
)
print("OPTIMAL SUGGESTED LEARNING RATE =", learning_rate)

# learning_rate = args.learning_rate

trainer_args = TrainingArguments(
output_dir=os.path.join(args.save_dir, "checkpoint"),
overwrite_output_dir=True,
learning_rate=args.learning_rate,
num_train_epochs=args.num_epochs,
seed=args.random_seed,
eval_strategy="epoch",
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
Expand All @@ -85,10 +103,10 @@ def pretrain(args, model, dset_train, dset_val):
)

# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
optimizer = AdamW(model.parameters(), lr=learning_rate)
scheduler = OneCycleLR(
optimizer,
args.learning_rate,
learning_rate,
epochs=args.num_epochs,
steps_per_epoch=math.ceil(len(dset_train) / args.batch_size),
# steps_per_epoch=math.ceil(len(dset_train) / (args.batch_size * args.num_gpus)),
Expand Down Expand Up @@ -123,7 +141,54 @@ def pretrain(args, model, dset_train, dset_val):
trainer.train()

# Save the pretrained model
trainer.save_model(os.path.join(args.save_dir, "ttm_pretrained"))

model_save_path = os.path.join(args.save_dir, "ttm_pretrained")
trainer.save_model(model_save_path)
return model_save_path


def inference(args, model_path, dset_test):
model = get_model(model_path=model_path)

temp_dir = tempfile.mkdtemp()
trainer = Trainer(
model=model,
args=TrainingArguments(
output_dir=temp_dir,
per_device_eval_batch_size=args.batch_size,
seed=args.random_seed,
report_to="none",
),
)
# evaluate = zero-shot performance
print("+" * 20, "Test MSE output:", "+" * 20)
output = trainer.evaluate(dset_test)
print(output)

# get predictions

predictions_dict = trainer.predict(dset_test)

predictions_np = predictions_dict.predictions[0]

print(predictions_np.shape)

# get backbone embeddings (if needed for further analysis)

backbone_embedding = predictions_dict.predictions[1]

print(backbone_embedding.shape)

plot_path = os.path.join(args.save_dir, "plots")
# plot
plot_predictions(
model=trainer.model,
dset=dset_test,
plot_dir=plot_path,
plot_prefix="test_inference",
channel=0,
)
print("Plots saved in location:", plot_path)


if __name__ == "__main__":
Expand All @@ -138,17 +203,59 @@ def pretrain(args, model, dset_train, dset_val):
)

# Data prep
dset_train, dset_val, dset_test = load_dataset(
args.dataset,
args.context_length,
args.forecast_length,
dataset_root_path=args.data_root_path,
# Dataset
TARGET_DATASET = "etth1"
dataset_path = (
"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv" # mention the dataset path
)
timestamp_column = "date"
id_columns = [] # mention the ids that uniquely identify a time-series.

target_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]

# mention the train, valid and split config.
split_config = {
"train": [0, 8640],
"valid": [8640, 11520],
"test": [
11520,
14400,
],
}

data = pd.read_csv(
dataset_path,
parse_dates=[timestamp_column],
)
print("Length of the train dataset =", len(dset_train))

column_specifiers = {
"timestamp_column": timestamp_column,
"id_columns": id_columns,
"target_columns": target_columns,
"control_columns": [],
}

tsp = TimeSeriesPreprocessor(
**column_specifiers,
context_length=args.context_length,
prediction_length=args.forecast_length,
scaling=True,
encode_categorical=False,
scaler_type="standard",
)

dset_train, dset_valid, dset_test = get_datasets(tsp, data, split_config)

# Get model
model = get_model(args)
model = get_base_model(args)

# Pretrain
pretrain(args, model, dset_train, dset_val)
model_save_path = pretrain(args, model, dset_train, dset_valid)
print("=" * 20, "Pretraining Completed!", "=" * 20)
print("Model saved in location:", model_save_path)

# inference

inference(args=args, model_path=model_save_path, dset_test=dset_test)

print("inference completed..")
11 changes: 11 additions & 0 deletions notebooks/hfdemo/tinytimemixer/ttm_pretrain_script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
python ttm_pretrain_sample.py --context_length 90 \
--forecast_length 30 \
--patch_length 10 \
--batch_size 64 \
--num_layers 3 \
--decoder_num_layers 3 \
--dropout 0.2 \
--head_dropout 0.2 \
--early_stopping 1 \
--adaptive_patching_levels 0 \
--num_epochs 10
Loading

0 comments on commit 15599d3

Please sign in to comment.