diff --git a/clinicadl/utils/callbacks/callbacks.py b/clinicadl/utils/callbacks/callbacks.py index 09b647a89..7415249d4 100644 --- a/clinicadl/utils/callbacks/callbacks.py +++ b/clinicadl/utils/callbacks/callbacks.py @@ -122,3 +122,100 @@ def on_train_begin(self, parameters, **kwargs): def on_train_end(self, parameters, **kwargs): self.tracker.stop() + + +class Tracker(Callback): + def on_train_begin(self, parameters, **kwargs): + if parameters["track_exp"] == "wandb": + from clinicadl.utils.tracking_exp import WandB_handler + + self.run = WandB_handler( + kwargs["split"], parameters, kwargs["maps_path"].name + ) + + if parameters["track_exp"] == "mlflow": + from clinicadl.utils.tracking_exp import Mlflow_handler + + self.run = Mlflow_handler( + kwargs["split"], parameters, kwargs["maps_path"].name + ) + + def on_epoch_end(self, parameters, **kwargs): + if parameters["track_exp"] == "wandb": + self.run.log_metrics( + self.run._wandb, + parameters["track_exp"], + parameters["network_task"], + kwargs["metrics_train"], + kwargs["metrics_valid"], + ) + + if parameters["track_exp"] == "mlflow": + self.run.log_metrics( + self.run._mlflow, + parameters["track_exp"], + parameters["network_task"], + kwargs["metrics_train"], + kwargs["metrics_valid"], + ) + + def on_train_end(self, parameters, **kwargs): + if parameters["track_exp"] == "mlflow": + self.run._mlflow.end_run() + + if parameters["track_exp"] == "wandb": + self.run._wandb.finish() + + +class LoggerCallback(Callback): + def on_train_begin(self, parameters, **kwargs): + logger.info( + f"Criterion for {parameters['network_task']} is {(kwargs['criterion'])}" + ) + logger.debug(f"Optimizer used for training is {kwargs['optimizer']}") + + def on_epoch_begin(self, parameters, **kwargs): + logger.info(f"Beginning epoch {kwargs['epoch']}.") + + def on_epoch_end(self, parameters, **kwargs): + logger.info( + f"{kwargs['mode']} level training loss is {kwargs['metrics_train']['loss']} " + f"at the end of iteration {kwargs['i']}" + ) + logger.info( + f"{kwargs['mode']} level validation loss is {kwargs['metrics_valid']['loss']} " + f"at the end of iteration {kwargs['i']}" + ) + + def on_train_end(self, parameters, **kwargs): + logger.info("tests") + + +# class ProfilerHandler(Callback): +# def on_train_begin(self, parameters, **kwargs): +# if self.profiler: +# from contextlib import nullcontext +# from datetime import datetime +# from clinicadl.utils.maps_manager.cluster.profiler import ( +# ProfilerActivity, +# profile, +# schedule, +# tensorboard_trace_handler, +# ) + +# time = datetime.now().strftime("%H:%M:%S") +# filename = [self.maps_path / "profiler" / f"clinicadl_{time}"] +# dist.broadcast_object_list(filename, src=0) +# profiler = profile( +# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], +# schedule=schedule(wait=2, warmup=2, active=30, repeat=1), +# on_trace_ready=tensorboard_trace_handler(filename[0]), +# profile_memory=True, +# record_shapes=False, +# with_stack=False, +# with_flops=False, +# ) +# else: +# profiler = nullcontext() +# profiler.step = lambda *args, **kwargs: None +# return profiler diff --git a/clinicadl/utils/maps_manager/maps_manager.py b/clinicadl/utils/maps_manager/maps_manager.py index 4876823b9..bd1f4859a 100644 --- a/clinicadl/utils/maps_manager/maps_manager.py +++ b/clinicadl/utils/maps_manager/maps_manager.py @@ -1085,7 +1085,6 @@ def _train( resume (bool): If True the job is resumed from the checkpoint. """ self._init_callbacks() - self.callback_handler.on_train_begin(self.parameters) model, beginning_epoch = self._init_model( split=split, resume=resume, @@ -1096,10 +1095,14 @@ def _train( model = DDP(model) criterion = self.task_manager.get_criterion(self.loss) - logger.info(f"Criterion for {self.network_task} is {criterion}") - optimizer = self._init_optimizer(model, split=split, resume=resume) - logger.debug(f"Optimizer used for training is {optimizer}") + self.callback_handler.on_train_begin( + self.parameters, + criterion=criterion, + optimizer=optimizer, + split=split, + maps_path=self.maps_path, + ) model.train() train_loader.dataset.train() @@ -1134,18 +1137,8 @@ def _train( scaler = GradScaler(enabled=self.amp) profiler = self._init_profiler() - if self.parameters["track_exp"] == "wandb": - from clinicadl.utils.tracking_exp import WandB_handler - - run = WandB_handler(split, self.parameters, self.maps_path.name) - - if self.parameters["track_exp"] == "mlflow": - from clinicadl.utils.tracking_exp import Mlflow_handler - - run = Mlflow_handler(split, self.parameters, self.maps_path.name) - while epoch < self.epochs and not early_stopping.step(metrics_valid["loss"]): - logger.info(f"Beginning epoch {epoch}.") + # self.callback_handler.on_epoch_begin(self.parameters, epoch = epoch) if isinstance(train_loader.sampler, DistributedSampler): # It should always be true for a random sampler. But just in case @@ -1245,63 +1238,14 @@ def _train( model.train() train_loader.dataset.train() - if cluster.master: - log_writer.step( - epoch, i, metrics_train, metrics_valid, len(train_loader) - ) - logger.info( - f"{self.mode} level training loss is {metrics_train['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.mode} level validation loss is {metrics_valid['loss']} " - f"at the end of iteration {i}" - ) - - if self.track_exp == "wandb": - run.log_metrics( - run._wandb, - self.track_exp, - self.network_task, - metrics_train, - metrics_valid, - ) - - if self.track_exp == "mlflow": - run.log_metrics( - run._mlflow, - self.track_exp, - self.network_task, - metrics_train, - metrics_valid, - ) + self.callback_handler.on_epoch_end( + self.parameters, + metrics_train=metrics_train, + metrics_valid=metrics_valid, + mode=self.mode, + i=i, + ) - # log_writer.step(epoch, i, metrics_train, metrics_valid, len(train_loader)) - # logger.info( - # f"{self.mode} level training loss is {metrics_train['loss']} " - # f"at the end of iteration {i}" - # ) - # logger.info( - # f"{self.mode} level validation loss is {metrics_valid['loss']} " - # f"at the end of iteration {i}" - # ) - # if self.track_exp == "wandb": - # run.log_metrics( - # run._wandb, - # self.track_exp, - # self.network_task, - # metrics_train, - # metrics_valid, - # ) - - # if self.track_exp == "mlflow": - # run.log_metrics( - # run._mlflow, - # self.track_exp, - # self.network_task, - # metrics_train, - # metrics_valid, - # ) if cluster.master: # Save checkpoints and best models best_dict = retain_best.step(metrics_valid) @@ -1334,12 +1278,6 @@ def _train( epoch += 1 - if self.parameters["track_exp"] == "mlflow": - run._mlflow.end_run() - - if self.parameters["track_exp"] == "wandb": - run._wandb.finish() - del model self._test_loader( train_loader, @@ -1377,7 +1315,8 @@ def _train( nb_images=1, network=network, ) - self.callback_handler.on_train_end(self.parameters) + + self.callback_handler.on_train_end(parameters=self.parameters) def _train_ssdann( self, @@ -3181,7 +3120,11 @@ def get_interpretation( return map_pt def _init_callbacks(self): - from clinicadl.utils.callbacks.callbacks import Callback, CallbacksHandler + from clinicadl.utils.callbacks.callbacks import ( + Callback, + CallbacksHandler, + LoggerCallback, + ) # if self.callbacks is None: # self.callbacks = [Callback()] @@ -3193,5 +3136,10 @@ def _init_callbacks(self): self.callback_handler.add_callback(CodeCarbonTracker()) - # self.callback_handler.add_callback(ProgressBarCallback()) + if self.parameters["track_exp"]: + from clinicadl.utils.callbacks.callbacks import Tracker + + self.callback_handler.add_callback(Tracker) + + self.callback_handler.add_callback(LoggerCallback()) # self.callback_handler.add_callback(MetricConsolePrinterCallback())