diff --git a/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py b/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py index 38724786..8e77b147 100644 --- a/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py +++ b/notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py @@ -1,38 +1,36 @@ #!/usr/bin/env python # coding: utf-8 -# TTM pre-training example. -# This scrips provides a toy example to pretrain a Tiny Time Mixer (TTM) model on -# the `etth1` dataset. For pre-training TTM on a much large set of datasets, please -# have a look at our paper: https://arxiv.org/pdf/2401.03955.pdf -# If you want to directly utilize the pre-trained models. Please use them from the -# Hugging Face Hub: https://huggingface.co/ibm/TTM -# Have a look at the fine-tune scripts for example usecases of the pre-trained -# TTM models. - -# Basic usage: -# python ttm_pretrain_sample.py --data_root_path datasets/ -# See the get_ttm_args() function to know more about other TTM arguments - -# Standard +import logging import math import os -# Third Party from torch.optim import AdamW from torch.optim.lr_scheduler import OneCycleLR from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed -# Local from tsfm_public.models.tinytimemixer import ( TinyTimeMixerConfig, TinyTimeMixerForPrediction, ) -from tsfm_public.models.tinytimemixer.utils import get_data, get_ttm_args +from tsfm_public.models.tinytimemixer.utils import get_ttm_args +from tsfm_public.toolkit.data_handling import load_dataset -# Arguments -args = get_ttm_args() +logger = logging.getLogger(__file__) + +# TTM pre-training example. +# This scrips provides a toy example to pretrain a Tiny Time Mixer (TTM) model on +# the `etth1` dataset. For pre-training TTM on a much large set of datasets, please +# have a look at our paper: https://arxiv.org/pdf/2401.03955.pdf +# If you want to directly utilize the pre-trained models. Please use them from the +# Hugging Face Hub: https://huggingface.co/ibm/TTM +# Have a look at the fine-tune scripts for example usecases of the pre-trained +# TTM models. + +# Basic usage: +# python ttm_pretrain_sample.py --data_root_path datasets/ +# See the get_ttm_args() function to know more about other TTM arguments def get_model(args): @@ -71,7 +69,7 @@ def pretrain(args, model, dset_train, dset_val): overwrite_output_dir=True, learning_rate=args.learning_rate, num_train_epochs=args.num_epochs, - evaluation_strategy="epoch", + eval_strategy="epoch", per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, dataloader_num_workers=args.num_workers, @@ -129,21 +127,22 @@ def pretrain(args, model, dset_train, dset_val): if __name__ == "__main__": + # Arguments + args = get_ttm_args() + # Set seed set_seed(args.random_seed) - print( - "*" * 20, - f"Pre-training a TTM for context len = {args.context_length}, forecast len = {args.forecast_length}", - "*" * 20, + logger.info( + f"{'*' * 20} Pre-training a TTM for context len = {args.context_length}, forecast len = {args.forecast_length} {'*' * 20}" ) # Data prep - dset_train, dset_val, dset_test = get_data( + dset_train, dset_val, dset_test = load_dataset( args.dataset, args.context_length, args.forecast_length, - data_root_path=args.data_root_path, + dataset_root_path=args.data_root_path, ) print("Length of the train dataset =", len(dset_train)) diff --git a/tsfm_public/__init__.py b/tsfm_public/__init__.py index 6faab255..5725aa8d 100644 --- a/tsfm_public/__init__.py +++ b/tsfm_public/__init__.py @@ -1,16 +1,45 @@ # Copyright contributors to the TSFM project # +import logging +import os from pathlib import Path from typing import TYPE_CHECKING # Check the dependencies satisfy the minimal versions required. -from transformers.utils import _LazyModule, logging +from transformers.utils import _LazyModule from .version import __version__, __version_tuple__ -logger = logging.get_logger(__name__) # pylint: disable=invalid-name +TSFM_PYTHON_LOGGING_LEVEL = os.getenv("TSFM_PYTHON_LOGGING_LEVEL", "INFO") + +LevelNamesMapping = { + "INFO": logging.INFO, + "WARN": logging.WARN, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + "DEBUG": logging.DEBUG, + "FATAL": logging.FATAL, +} + +TSFM_PYTHON_LOGGING_LEVEL = ( + logging.getLevelNamesMapping()[TSFM_PYTHON_LOGGING_LEVEL] + if hasattr(logging, "getLevelNamesMapping") + else LevelNamesMapping[TSFM_PYTHON_LOGGING_LEVEL] +) +TSFM_PYTHON_LOGGING_FORMAT = os.getenv( + "TSFM_PYTHON_LOGGING_FORMAT", + "%(levelname)s:p-%(process)d:t-%(thread)d:%(filename)s:%(funcName)s:%(message)s", +) + +logging.basicConfig( + format=TSFM_PYTHON_LOGGING_FORMAT, + level=TSFM_PYTHON_LOGGING_LEVEL, +) + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name # Base objects, independent of any specific backend _import_structure = { diff --git a/tsfm_public/models/tinytimemixer/utils/ttm_args.py b/tsfm_public/models/tinytimemixer/utils/ttm_args.py index e889fcbb..c4ef3ed6 100644 --- a/tsfm_public/models/tinytimemixer/utils/ttm_args.py +++ b/tsfm_public/models/tinytimemixer/utils/ttm_args.py @@ -3,11 +3,15 @@ """Utilities for TTM notebooks""" import argparse +import logging import os import torch +logger = logging.getLogger(__name__) + + def get_ttm_args(): parser = argparse.ArgumentParser(description="TTM pretrain arguments.") # Adding a positional argument @@ -145,7 +149,7 @@ def get_ttm_args(): # Calculate number of gpus if args.num_gpus is None: args.num_gpus = torch.cuda.device_count() - print("Automatically calculated number of GPUs =", args.num_gpus) + logger.info(f"Automatically calculated number of GPUs ={args.num_gpus}") # Create save directory args.save_dir = os.path.join( diff --git a/tsfm_public/toolkit/data_handling.py b/tsfm_public/toolkit/data_handling.py index 7bd984d2..094d0a6a 100644 --- a/tsfm_public/toolkit/data_handling.py +++ b/tsfm_public/toolkit/data_handling.py @@ -1,6 +1,7 @@ """Utilities for handling datasets""" import glob +import logging import os from importlib import resources from pathlib import Path @@ -12,6 +13,9 @@ from .time_series_preprocessor import TimeSeriesPreprocessor, get_datasets +LOGGER = logging.getLogger(__file__) + + def load_dataset( dataset_name: str, context_length, @@ -21,7 +25,7 @@ def load_dataset( dataset_root_path: str = "datasets/", dataset_path: Optional[str] = None, ): - print(dataset_name, context_length, forecast_length) + LOGGER.info(f"Dataset name: {dataset_name}, context length: {context_length}, prediction length {forecast_length}") config_path = resources.files("tsfm_public.resources.data_config") configs = glob.glob(os.path.join(config_path, "*.yaml")) @@ -31,7 +35,7 @@ def load_dataset( if config_path is None: raise ValueError( - f"Currently `get_data()` function supports the following datasets: {names_to_config.keys()}\n \ + f"Currently the `load_dataset()` function supports the following datasets: {names_to_config.keys()}\n \ For other datasets, please provide the proper configs to the TimeSeriesPreprocessor (TSP) module." ) @@ -71,6 +75,6 @@ def load_dataset( fewshot_fraction=fewshot_fraction, fewshot_location=fewshot_location, ) - print(f"Data lengths: train = {len(train_dataset)}, val = {len(valid_dataset)}, test = {len(test_dataset)}") + LOGGER.info(f"Data lengths: train = {len(train_dataset)}, val = {len(valid_dataset)}, test = {len(test_dataset)}") return train_dataset, valid_dataset, test_dataset