Skip to content

Commit

Permalink
Add others callbacks (aramis-lab#497)
Browse files Browse the repository at this point in the history
* add logger and mlflow and wandb callbacks
  • Loading branch information
camillebrianceau committed Mar 21, 2024
1 parent d2732ff commit 36c3512
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 80 deletions.
97 changes: 97 additions & 0 deletions clinicadl/utils/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,100 @@ def on_train_begin(self, parameters, **kwargs):

def on_train_end(self, parameters, **kwargs):
self.tracker.stop()


class Tracker(Callback):
def on_train_begin(self, parameters, **kwargs):
if parameters["track_exp"] == "wandb":
from clinicadl.utils.tracking_exp import WandB_handler

self.run = WandB_handler(
kwargs["split"], parameters, kwargs["maps_path"].name
)

if parameters["track_exp"] == "mlflow":
from clinicadl.utils.tracking_exp import Mlflow_handler

self.run = Mlflow_handler(
kwargs["split"], parameters, kwargs["maps_path"].name
)

def on_epoch_end(self, parameters, **kwargs):
if parameters["track_exp"] == "wandb":
self.run.log_metrics(
self.run._wandb,
parameters["track_exp"],
parameters["network_task"],
kwargs["metrics_train"],
kwargs["metrics_valid"],
)

if parameters["track_exp"] == "mlflow":
self.run.log_metrics(
self.run._mlflow,
parameters["track_exp"],
parameters["network_task"],
kwargs["metrics_train"],
kwargs["metrics_valid"],
)

def on_train_end(self, parameters, **kwargs):
if parameters["track_exp"] == "mlflow":
self.run._mlflow.end_run()

if parameters["track_exp"] == "wandb":
self.run._wandb.finish()


class LoggerCallback(Callback):
def on_train_begin(self, parameters, **kwargs):
logger.info(
f"Criterion for {parameters['network_task']} is {(kwargs['criterion'])}"
)
logger.debug(f"Optimizer used for training is {kwargs['optimizer']}")

def on_epoch_begin(self, parameters, **kwargs):
logger.info(f"Beginning epoch {kwargs['epoch']}.")

def on_epoch_end(self, parameters, **kwargs):
logger.info(
f"{kwargs['mode']} level training loss is {kwargs['metrics_train']['loss']} "
f"at the end of iteration {kwargs['i']}"
)
logger.info(
f"{kwargs['mode']} level validation loss is {kwargs['metrics_valid']['loss']} "
f"at the end of iteration {kwargs['i']}"
)

def on_train_end(self, parameters, **kwargs):
logger.info("tests")


# class ProfilerHandler(Callback):
# def on_train_begin(self, parameters, **kwargs):
# if self.profiler:
# from contextlib import nullcontext
# from datetime import datetime
# from clinicadl.utils.maps_manager.cluster.profiler import (
# ProfilerActivity,
# profile,
# schedule,
# tensorboard_trace_handler,
# )

# time = datetime.now().strftime("%H:%M:%S")
# filename = [self.maps_path / "profiler" / f"clinicadl_{time}"]
# dist.broadcast_object_list(filename, src=0)
# profiler = profile(
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# schedule=schedule(wait=2, warmup=2, active=30, repeat=1),
# on_trace_ready=tensorboard_trace_handler(filename[0]),
# profile_memory=True,
# record_shapes=False,
# with_stack=False,
# with_flops=False,
# )
# else:
# profiler = nullcontext()
# profiler.step = lambda *args, **kwargs: None
# return profiler
108 changes: 28 additions & 80 deletions clinicadl/utils/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,6 @@ def _train(
resume (bool): If True the job is resumed from the checkpoint.
"""
self._init_callbacks()
self.callback_handler.on_train_begin(self.parameters)
model, beginning_epoch = self._init_model(
split=split,
resume=resume,
Expand All @@ -1096,10 +1095,14 @@ def _train(
model = DDP(model)
criterion = self.task_manager.get_criterion(self.loss)

logger.info(f"Criterion for {self.network_task} is {criterion}")

optimizer = self._init_optimizer(model, split=split, resume=resume)
logger.debug(f"Optimizer used for training is {optimizer}")
self.callback_handler.on_train_begin(
self.parameters,
criterion=criterion,
optimizer=optimizer,
split=split,
maps_path=self.maps_path,
)

model.train()
train_loader.dataset.train()
Expand Down Expand Up @@ -1134,18 +1137,8 @@ def _train(
scaler = GradScaler(enabled=self.amp)
profiler = self._init_profiler()

if self.parameters["track_exp"] == "wandb":
from clinicadl.utils.tracking_exp import WandB_handler

run = WandB_handler(split, self.parameters, self.maps_path.name)

if self.parameters["track_exp"] == "mlflow":
from clinicadl.utils.tracking_exp import Mlflow_handler

run = Mlflow_handler(split, self.parameters, self.maps_path.name)

while epoch < self.epochs and not early_stopping.step(metrics_valid["loss"]):
logger.info(f"Beginning epoch {epoch}.")
# self.callback_handler.on_epoch_begin(self.parameters, epoch = epoch)

if isinstance(train_loader.sampler, DistributedSampler):
# It should always be true for a random sampler. But just in case
Expand Down Expand Up @@ -1245,63 +1238,14 @@ def _train(
model.train()
train_loader.dataset.train()

if cluster.master:
log_writer.step(
epoch, i, metrics_train, metrics_valid, len(train_loader)
)
logger.info(
f"{self.mode} level training loss is {metrics_train['loss']} "
f"at the end of iteration {i}"
)
logger.info(
f"{self.mode} level validation loss is {metrics_valid['loss']} "
f"at the end of iteration {i}"
)

if self.track_exp == "wandb":
run.log_metrics(
run._wandb,
self.track_exp,
self.network_task,
metrics_train,
metrics_valid,
)

if self.track_exp == "mlflow":
run.log_metrics(
run._mlflow,
self.track_exp,
self.network_task,
metrics_train,
metrics_valid,
)
self.callback_handler.on_epoch_end(
self.parameters,
metrics_train=metrics_train,
metrics_valid=metrics_valid,
mode=self.mode,
i=i,
)

# log_writer.step(epoch, i, metrics_train, metrics_valid, len(train_loader))
# logger.info(
# f"{self.mode} level training loss is {metrics_train['loss']} "
# f"at the end of iteration {i}"
# )
# logger.info(
# f"{self.mode} level validation loss is {metrics_valid['loss']} "
# f"at the end of iteration {i}"
# )
# if self.track_exp == "wandb":
# run.log_metrics(
# run._wandb,
# self.track_exp,
# self.network_task,
# metrics_train,
# metrics_valid,
# )

# if self.track_exp == "mlflow":
# run.log_metrics(
# run._mlflow,
# self.track_exp,
# self.network_task,
# metrics_train,
# metrics_valid,
# )
if cluster.master:
# Save checkpoints and best models
best_dict = retain_best.step(metrics_valid)
Expand Down Expand Up @@ -1334,12 +1278,6 @@ def _train(

epoch += 1

if self.parameters["track_exp"] == "mlflow":
run._mlflow.end_run()

if self.parameters["track_exp"] == "wandb":
run._wandb.finish()

del model
self._test_loader(
train_loader,
Expand Down Expand Up @@ -1377,7 +1315,8 @@ def _train(
nb_images=1,
network=network,
)
self.callback_handler.on_train_end(self.parameters)

self.callback_handler.on_train_end(parameters=self.parameters)

def _train_ssdann(
self,
Expand Down Expand Up @@ -3181,7 +3120,11 @@ def get_interpretation(
return map_pt

def _init_callbacks(self):
from clinicadl.utils.callbacks.callbacks import Callback, CallbacksHandler
from clinicadl.utils.callbacks.callbacks import (
Callback,
CallbacksHandler,
LoggerCallback,
)

# if self.callbacks is None:
# self.callbacks = [Callback()]
Expand All @@ -3193,5 +3136,10 @@ def _init_callbacks(self):

self.callback_handler.add_callback(CodeCarbonTracker())

# self.callback_handler.add_callback(ProgressBarCallback())
if self.parameters["track_exp"]:
from clinicadl.utils.callbacks.callbacks import Tracker

self.callback_handler.add_callback(Tracker)

self.callback_handler.add_callback(LoggerCallback())
# self.callback_handler.add_callback(MetricConsolePrinterCallback())

0 comments on commit 36c3512

Please sign in to comment.