diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py index f06f9ca2c..525db9903 100644 --- a/basicsr/models/base_model.py +++ b/basicsr/models/base_model.py @@ -192,6 +192,12 @@ def update_learning_rate(self, current_iter, warmup_iter=-1): def get_current_learning_rate(self): return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + def get_save_path(self, net_label, current_iter): + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + return os.path.join(self.opt['path']['models'], save_filename) + @master_only def save_network(self, net, net_label, current_iter, param_key='params'): """Save networks. @@ -203,10 +209,7 @@ def save_network(self, net, net_label, current_iter, param_key='params'): param_key (str | list[str]): The parameter key(s) to save network. Default: 'params'. """ - if current_iter == -1: - current_iter = 'latest' - save_filename = f'{net_label}_{current_iter}.pth' - save_path = os.path.join(self.opt['path']['models'], save_filename) + save_path = self.get_save_path(net_label, current_iter) net = net if isinstance(net, list) else [net] param_key = param_key if isinstance(param_key, list) else [param_key] diff --git a/basicsr/train.py b/basicsr/train.py index f63149c64..4509244ac 100644 --- a/basicsr/train.py +++ b/basicsr/train.py @@ -10,14 +10,13 @@ from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher from basicsr.models import build_model from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str, - init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir) + init_tb_logger, wandb_enabled, init_wandb_logger, log_artifact, make_exp_dirs, mkdir_and_rename, scandir) from basicsr.utils.options import copy_opt_file, dict2str, parse_options def init_tb_loggers(opt): # initialize wandb logger before tensorboard logger to allow proper sync - if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') - is not None) and ('debug' not in opt['name']): + if wandb_enabled(opt): assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb') init_wandb_logger(opt) tb_logger = None @@ -184,6 +183,7 @@ def train_pipeline(root_path): if current_iter % opt['logger']['save_checkpoint_freq'] == 0: logger.info('Saving models and training states.') model.save(epoch, current_iter) + log_artifact(opt, model, 'net_g', current_iter) # validation if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0): diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py index 902b293cd..482b77a2e 100644 --- a/basicsr/utils/__init__.py +++ b/basicsr/utils/__init__.py @@ -2,7 +2,7 @@ from .file_client import FileClient from .img_process_util import USMSharp, usm_sharp from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img -from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger +from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger, wandb_enabled, log_artifact from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt __all__ = [ @@ -19,6 +19,8 @@ 'AvgTimer', 'init_tb_logger', 'init_wandb_logger', + 'wandb_enabled', + 'log_artifact', 'get_root_logger', 'get_env_info', # misc.py diff --git a/basicsr/utils/logger.py b/basicsr/utils/logger.py index 73553dc66..7f547aa88 100644 --- a/basicsr/utils/logger.py +++ b/basicsr/utils/logger.py @@ -143,6 +143,32 @@ def init_wandb_logger(opt): logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') +def wandb_enabled(opt): + if 'debug' in opt['name']: return False + if opt['logger'].get('wandb') is None: return False + if opt['logger']['wandb'].get('project') is None: return False + return True + + +@master_only +def log_artifact(opt, model, net_label, current_step): + if not wandb_enabled(opt): + return + + if not opt['logger']['wandb'].get('log_model', False): + return + + import wandb + + # Prepend run id to artifact name so it is attributed to the current run + name = opt['name'] + name = f'{wandb.run.id}_{name}' + artifact = wandb.Artifact(name, type='model') + save_path = model.get_save_path(net_label, current_step) + artifact.add_file(save_path) + wandb.run.log_artifact(artifact) + + def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): """Get the root logger.