From 1d465931e448cab7607e907dd3030a2f994c9ad9 Mon Sep 17 00:00:00 2001 From: hans Date: Thu, 13 Jun 2024 14:33:22 +0900 Subject: [PATCH] feat: Improve model checkpoint loading --- models.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/models.py b/models.py index 84bbb03d..b7abf6ac 100644 --- a/models.py +++ b/models.py @@ -696,12 +696,21 @@ def build_model(args, text_aligner, pitch_extractor, bert): def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]): state = torch.load(path, map_location='cpu') params = state['net'] + for key in model: if key in params and key not in ignore_modules: + try: + model[key].load_state_dict(params[key], strict=True) + except: + from collections import OrderedDict + state_dict = params[key] + new_state_dict = OrderedDict() + print(f'{key} key length: {len(model[key].state_dict().keys())}, state_dict key length: {len(state_dict.keys())}') + for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()): + new_state_dict[k_m] = v_c + model[key].load_state_dict(new_state_dict, strict=True) print('%s loaded' % key) - model[key].load_state_dict(params[key], strict=False) - _ = [model[key].eval() for key in model] - + if not load_only_params: epoch = state["epoch"] iters = state["iters"]