From eefd995735f9bb087f428e35139de9c981e062ed Mon Sep 17 00:00:00 2001 From: S I Harini Date: Wed, 24 Aug 2022 00:56:15 +0530 Subject: [PATCH 1/3] logging --- jeta/logger.py | 153 +++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 2 files changed, 154 insertions(+) create mode 100644 jeta/logger.py diff --git a/jeta/logger.py b/jeta/logger.py new file mode 100644 index 0000000..17057cb --- /dev/null +++ b/jeta/logger.py @@ -0,0 +1,153 @@ +import logging +import os +import time +from datetime import timedelta + +import pandas as pd +import tensorflow as tf +import wandb + + +class Locallog: + def __init__(self): + self.start_time = time.time() + + def format(self, record): + elapsed_seconds = round(record.created - self.start_time) + + prefix = "%s - %s - %s" % ( + record.levelname, + time.strftime("%x %X"), + timedelta(seconds=elapsed_seconds), + ) + message = record.getMessage() + message = message.replace("\n", "\n" + " " * (len(prefix) + 3)) + return "%s - %s" % (prefix, message) if message else "" + + +def create_logger(filepath, rank): + + log_formatter = Locallog() + # create file handler and set level to debug + if filepath is not None: + if rank > 0: + filepath = "%s-%i" % (filepath, rank) + file_handler = logging.FileHandler(filepath, "a") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(log_formatter) + + # create console handler and set level to info + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(log_formatter) + + # create logger and set level to debug + logger = logging.getLogger() + logger.handlers = [] + logger.setLevel(logging.DEBUG) + logger.propagate = False + if filepath is not None: + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + # reset logger elapsed time + def reset_time(): + log_formatter.start_time = time.time() + + logger.reset_time = reset_time + + return logger + + +class PD_Stats(object): + """ + Log stuff with pandas library + """ + + def __init__(self, path, columns): + self.path = path + + # reload path stats + if os.path.isfile(self.path): + self.stats = pd.read_pickle(self.path) + + # check that columns are the same + assert list(self.stats.columns) == list(columns) + + else: + self.stats = pd.DataFrame(columns=columns) + + def update(self, row, save=True): + self.stats.loc[len(self.stats.index)] = row + + # save the statistics + if save: + self.stats.to_pickle(self.path) + + +class TensorboardLogger(object): + """ + Logging with tensorboard + """ + + def __init__(self, log_dir): + self.log_dir = log_dir + self.writer = None + + def __enter__(self): + self.writer = tf.summary.create_file_writer(self.log_dir) + self.writer.set_as_default() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.writer.close() + + def scalar(self, name, value, step): + with tf.summary.record_if(True): + tf.summary.scalar(name, value, step=step) + + def histogram(self, name, values, step): + with tf.summary.record_if(True): + tf.summary.histogram(name, values, step=step) + + def custom_scalars(self, name, scalars, step): + with tf.summary.record_if(True): + for scalar in scalars: + tf.summary.scalar(name, scalar, step=step) + + def custom_histograms(self, name, histograms, step): + with tf.summary.record_if(True): + for histogram in histograms: + tf.summary.histogram(name, histogram, step=step) + + +class Wandblogger(object): + """ + Creating a wandb logger + """ + + def __init__(self, name, config): + self.name = name + self.config = config + self.logger = None + + def __enter__(self): + self.logger = wandb.init(name=self.name, config=self.config) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.logger.join() + + def scalar(self, name, value, step): + self.logger.log(name, value, step=step) + + def histogram(self, name, values, step): + self.logger.log(name, values, step=step) + + def custom_scalars(self, name, scalars, step): + for scalar in scalars: + self.logger.log(name, scalar, step=step) + + def custom_histograms(self, name, histograms, step): + for histogram in histograms: + self.logger.log(name, histogram, step=step) diff --git a/requirements.txt b/requirements.txt index f19d44c..5e41eb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ optax==0.1.3 pytest==6.2.5 typing-extensions==4.1.1 scipy==1.5 +wandb From 13f9dbaea6ab6aa5b9f33744451cf44bb3032a85 Mon Sep 17 00:00:00 2001 From: Harini S I <75417475+harini-si@users.noreply.github.com> Date: Mon, 5 Sep 2022 21:37:31 +0530 Subject: [PATCH 2/3] Update jeta/logger.py Co-authored-by: Vedant Shah --- jeta/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jeta/logger.py b/jeta/logger.py index 17057cb..51df2dc 100644 --- a/jeta/logger.py +++ b/jeta/logger.py @@ -121,7 +121,7 @@ def custom_histograms(self, name, histograms, step): tf.summary.histogram(name, histogram, step=step) -class Wandblogger(object): +class WandbLogger(object): """ Creating a wandb logger """ From 232e7a83447baf650285de54f788057887f3f533 Mon Sep 17 00:00:00 2001 From: S I Harini Date: Tue, 6 Sep 2022 19:56:10 +0530 Subject: [PATCH 3/3] testlog --- jeta/opti_trainer.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/jeta/opti_trainer.py b/jeta/opti_trainer.py index 711c4f8..1966f06 100644 --- a/jeta/opti_trainer.py +++ b/jeta/opti_trainer.py @@ -2,6 +2,7 @@ from typing import Callable, List, Tuple import jax +import logger import optax from flax import struct from flax.training import train_state @@ -41,6 +42,7 @@ def meta_train_step( state: MetaTrainState, tasks, metrics: List[Callable[[ndarray, ndarray], ndarray]] = [], + logger_type: str = "tensorboard", ) -> Tuple[MetaTrainState, ndarray, List[ndarray]]: """Performs a single meta-training step on a batch of tasks. @@ -89,6 +91,18 @@ def batch_meta_train_loss(theta, apply_fn, adapt_fn, loss_fn, tasks): state = state.replace( step=state.step + 1, params=params, opt_state=new_opt_state ) + if logger_type == "tensorboard": + logger.TensorboardLogger.__enter__(state.step) + logger.TensorboardLogger.scalar("loss", loss, state.step) + for i, metric in enumerate(metrics): + logger.TensorboardLogger.scalar( + metric.__name__, metrics_value[i], state.step + ) + if logger_type == "wandb": + logger.WandbLogger.__enter__(state.step) + logger.WandbLogger.scalar("loss", loss, state.step) + for i, metric in enumerate(metrics): + logger.WandbLogger.scalar(metric.__name__, metrics_value[i], state.step) return state, loss, metrics_value