Skip to content

Commit

Permalink
use logging
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Aug 6, 2024
1 parent 846ba3f commit c622e6a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 32 deletions.
51 changes: 25 additions & 26 deletions notebooks/hfdemo/tinytimemixer/ttm_pretrain_sample.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,36 @@
#!/usr/bin/env python
# coding: utf-8

# TTM pre-training example.
# This scrips provides a toy example to pretrain a Tiny Time Mixer (TTM) model on
# the `etth1` dataset. For pre-training TTM on a much large set of datasets, please
# have a look at our paper: https://arxiv.org/pdf/2401.03955.pdf
# If you want to directly utilize the pre-trained models. Please use them from the
# Hugging Face Hub: https://huggingface.co/ibm/TTM
# Have a look at the fine-tune scripts for example usecases of the pre-trained
# TTM models.

# Basic usage:
# python ttm_pretrain_sample.py --data_root_path datasets/
# See the get_ttm_args() function to know more about other TTM arguments

# Standard
import logging
import math
import os

# Third Party
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed

# Local
from tsfm_public.models.tinytimemixer import (
TinyTimeMixerConfig,
TinyTimeMixerForPrediction,
)
from tsfm_public.models.tinytimemixer.utils import get_data, get_ttm_args
from tsfm_public.models.tinytimemixer.utils import get_ttm_args
from tsfm_public.toolkit.data_handling import load_dataset


# Arguments
args = get_ttm_args()
logger = logging.getLogger(__file__)

# TTM pre-training example.
# This scrips provides a toy example to pretrain a Tiny Time Mixer (TTM) model on
# the `etth1` dataset. For pre-training TTM on a much large set of datasets, please
# have a look at our paper: https://arxiv.org/pdf/2401.03955.pdf
# If you want to directly utilize the pre-trained models. Please use them from the
# Hugging Face Hub: https://huggingface.co/ibm/TTM
# Have a look at the fine-tune scripts for example usecases of the pre-trained
# TTM models.

# Basic usage:
# python ttm_pretrain_sample.py --data_root_path datasets/
# See the get_ttm_args() function to know more about other TTM arguments


def get_model(args):
Expand Down Expand Up @@ -71,7 +69,7 @@ def pretrain(args, model, dset_train, dset_val):
overwrite_output_dir=True,
learning_rate=args.learning_rate,
num_train_epochs=args.num_epochs,
evaluation_strategy="epoch",
eval_strategy="epoch",
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
dataloader_num_workers=args.num_workers,
Expand Down Expand Up @@ -129,21 +127,22 @@ def pretrain(args, model, dset_train, dset_val):


if __name__ == "__main__":
# Arguments
args = get_ttm_args()

# Set seed
set_seed(args.random_seed)

print(
"*" * 20,
f"Pre-training a TTM for context len = {args.context_length}, forecast len = {args.forecast_length}",
"*" * 20,
logger.info(
f"{'*' * 20} Pre-training a TTM for context len = {args.context_length}, forecast len = {args.forecast_length} {'*' * 20}"
)

# Data prep
dset_train, dset_val, dset_test = get_data(
dset_train, dset_val, dset_test = load_dataset(
args.dataset,
args.context_length,
args.forecast_length,
data_root_path=args.data_root_path,
dataset_root_path=args.data_root_path,
)
print("Length of the train dataset =", len(dset_train))

Expand Down
33 changes: 31 additions & 2 deletions tsfm_public/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,45 @@
# Copyright contributors to the TSFM project
#

import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING

# Check the dependencies satisfy the minimal versions required.
from transformers.utils import _LazyModule, logging
from transformers.utils import _LazyModule

from .version import __version__, __version_tuple__


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
TSFM_PYTHON_LOGGING_LEVEL = os.getenv("TSFM_PYTHON_LOGGING_LEVEL", "INFO")

LevelNamesMapping = {
"INFO": logging.INFO,
"WARN": logging.WARN,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
"DEBUG": logging.DEBUG,
"FATAL": logging.FATAL,
}

TSFM_PYTHON_LOGGING_LEVEL = (
logging.getLevelNamesMapping()[TSFM_PYTHON_LOGGING_LEVEL]
if hasattr(logging, "getLevelNamesMapping")
else LevelNamesMapping[TSFM_PYTHON_LOGGING_LEVEL]
)
TSFM_PYTHON_LOGGING_FORMAT = os.getenv(
"TSFM_PYTHON_LOGGING_FORMAT",
"%(levelname)s:p-%(process)d:t-%(thread)d:%(filename)s:%(funcName)s:%(message)s",
)

logging.basicConfig(
format=TSFM_PYTHON_LOGGING_FORMAT,
level=TSFM_PYTHON_LOGGING_LEVEL,
)

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

# Base objects, independent of any specific backend
_import_structure = {
Expand Down
6 changes: 5 additions & 1 deletion tsfm_public/models/tinytimemixer/utils/ttm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
"""Utilities for TTM notebooks"""

import argparse
import logging
import os

import torch


logger = logging.getLogger(__name__)


def get_ttm_args():
parser = argparse.ArgumentParser(description="TTM pretrain arguments.")
# Adding a positional argument
Expand Down Expand Up @@ -145,7 +149,7 @@ def get_ttm_args():
# Calculate number of gpus
if args.num_gpus is None:
args.num_gpus = torch.cuda.device_count()
print("Automatically calculated number of GPUs =", args.num_gpus)
logger.info(f"Automatically calculated number of GPUs ={args.num_gpus}")

# Create save directory
args.save_dir = os.path.join(
Expand Down
10 changes: 7 additions & 3 deletions tsfm_public/toolkit/data_handling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utilities for handling datasets"""

import glob
import logging
import os
from importlib import resources
from pathlib import Path
Expand All @@ -12,6 +13,9 @@
from .time_series_preprocessor import TimeSeriesPreprocessor, get_datasets


LOGGER = logging.getLogger(__file__)


def load_dataset(
dataset_name: str,
context_length,
Expand All @@ -21,7 +25,7 @@ def load_dataset(
dataset_root_path: str = "datasets/",
dataset_path: Optional[str] = None,
):
print(dataset_name, context_length, forecast_length)
LOGGER.info(f"Dataset name: {dataset_name}, context length: {context_length}, prediction length {forecast_length}")

config_path = resources.files("tsfm_public.resources.data_config")
configs = glob.glob(os.path.join(config_path, "*.yaml"))
Expand All @@ -31,7 +35,7 @@ def load_dataset(

if config_path is None:
raise ValueError(
f"Currently `get_data()` function supports the following datasets: {names_to_config.keys()}\n \
f"Currently the `load_dataset()` function supports the following datasets: {names_to_config.keys()}\n \
For other datasets, please provide the proper configs to the TimeSeriesPreprocessor (TSP) module."
)

Expand Down Expand Up @@ -71,6 +75,6 @@ def load_dataset(
fewshot_fraction=fewshot_fraction,
fewshot_location=fewshot_location,
)
print(f"Data lengths: train = {len(train_dataset)}, val = {len(valid_dataset)}, test = {len(test_dataset)}")
LOGGER.info(f"Data lengths: train = {len(train_dataset)}, val = {len(valid_dataset)}, test = {len(test_dataset)}")

return train_dataset, valid_dataset, test_dataset

0 comments on commit c622e6a

Please sign in to comment.