Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

logging support added #40

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions jeta/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import logging
import os
import time
from datetime import timedelta

import pandas as pd
import tensorflow as tf
import wandb
Comment on lines +7 to +8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do the lazy import for TensorFlow and wandb? Since these libraries will only be used in Iogging, I think having them as optional dependencies makes sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good idea. We can actually use lazy importing for most of the common dependencies here at a later point. But I don't mind doing this for wandb and tf now

Copy link
Member

@abhi-glitchhg abhi-glitchhg Aug 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for pandas.

Other modules are present in python 3 by default i suppose.

Comment on lines +6 to +8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import pandas as pd
import tensorflow as tf
import wandb



class Locallog:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class Locallog:
class LocalLog:

def __init__(self):
self.start_time = time.time()

def format(self, record):
elapsed_seconds = round(record.created - self.start_time)

prefix = "%s - %s - %s" % (
record.levelname,
time.strftime("%x %X"),
timedelta(seconds=elapsed_seconds),
)
message = record.getMessage()
message = message.replace("\n", "\n" + " " * (len(prefix) + 3))
return "%s - %s" % (prefix, message) if message else ""


def create_logger(filepath, rank):

log_formatter = Locallog()
# create file handler and set level to debug
if filepath is not None:
if rank > 0:
filepath = "%s-%i" % (filepath, rank)
file_handler = logging.FileHandler(filepath, "a")
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(log_formatter)

# create console handler and set level to info
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(log_formatter)

# create logger and set level to debug
logger = logging.getLogger()
logger.handlers = []
logger.setLevel(logging.DEBUG)
logger.propagate = False
if filepath is not None:
logger.addHandler(file_handler)
logger.addHandler(console_handler)

# reset logger elapsed time
def reset_time():
log_formatter.start_time = time.time()

logger.reset_time = reset_time

return logger


class PD_Stats(object):
"""
Log stuff with pandas library
"""

Comment on lines +65 to +66
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
"""
try:
import pandas as pd
except:
raise RuntimeError("Pandas not found. PD_Stats needs to have pandas installed. You can try installing pandas with pip: pip install pandas")

def __init__(self, path, columns):
self.path = path

# reload path stats
if os.path.isfile(self.path):
self.stats = pd.read_pickle(self.path)

# check that columns are the same
assert list(self.stats.columns) == list(columns)

else:
self.stats = pd.DataFrame(columns=columns)

def update(self, row, save=True):
self.stats.loc[len(self.stats.index)] = row

# save the statistics
if save:
self.stats.to_pickle(self.path)


class TensorboardLogger(object):
"""
Logging with tensorboard
"""

Comment on lines +91 to +92
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
"""
try:
import tensorflow as tf
except:
raise RuntimeError("Tensorflow not found. TensorboardLogger needs to have tensorflow installed. You can try installing tensorflow with pip: pip install tensorflow")

def __init__(self, log_dir):
self.log_dir = log_dir
self.writer = None

def __enter__(self):
self.writer = tf.summary.create_file_writer(self.log_dir)
self.writer.set_as_default()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.writer.close()

def scalar(self, name, value, step):
with tf.summary.record_if(True):
tf.summary.scalar(name, value, step=step)

def histogram(self, name, values, step):
with tf.summary.record_if(True):
tf.summary.histogram(name, values, step=step)

def custom_scalars(self, name, scalars, step):
with tf.summary.record_if(True):
for scalar in scalars:
tf.summary.scalar(name, scalar, step=step)

def custom_histograms(self, name, histograms, step):
with tf.summary.record_if(True):
for histogram in histograms:
tf.summary.histogram(name, histogram, step=step)


class WandbLogger(object):
"""
Creating a wandb logger
"""

Comment on lines +127 to +128
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
"""
"""
try:
import wandb
except:
raise RuntimeError("Wandb not found. WandbLogger needs to have wandb installed. You can try installing wandb with pip: pip install wandb")

def __init__(self, name, config):
self.name = name
self.config = config
self.logger = None

def __enter__(self):
self.logger = wandb.init(name=self.name, config=self.config)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.logger.join()

def scalar(self, name, value, step):
self.logger.log(name, value, step=step)

def histogram(self, name, values, step):
self.logger.log(name, values, step=step)

def custom_scalars(self, name, scalars, step):
for scalar in scalars:
self.logger.log(name, scalar, step=step)

def custom_histograms(self, name, histograms, step):
for histogram in histograms:
self.logger.log(name, histogram, step=step)
14 changes: 14 additions & 0 deletions jeta/opti_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, List, Tuple

import jax
import logger
import optax
from flax import struct
from flax.training import train_state
Expand Down Expand Up @@ -41,6 +42,7 @@ def meta_train_step(
state: MetaTrainState,
tasks,
metrics: List[Callable[[ndarray, ndarray], ndarray]] = [],
logger_type: str = "tensorboard",
) -> Tuple[MetaTrainState, ndarray, List[ndarray]]:
"""Performs a single meta-training step on a batch of tasks.

Expand Down Expand Up @@ -89,6 +91,18 @@ def batch_meta_train_loss(theta, apply_fn, adapt_fn, loss_fn, tasks):
state = state.replace(
step=state.step + 1, params=params, opt_state=new_opt_state
)
if logger_type == "tensorboard":
logger.TensorboardLogger.__enter__(state.step)
logger.TensorboardLogger.scalar("loss", loss, state.step)
for i, metric in enumerate(metrics):
logger.TensorboardLogger.scalar(
metric.__name__, metrics_value[i], state.step
)
if logger_type == "wandb":
logger.WandbLogger.__enter__(state.step)
logger.WandbLogger.scalar("loss", loss, state.step)
for i, metric in enumerate(metrics):
logger.WandbLogger.scalar(metric.__name__, metrics_value[i], state.step)

return state, loss, metrics_value

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ optax==0.1.3
pytest==6.2.5
typing-extensions==4.1.1
scipy==1.5
wandb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
wandb