Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set new parameter to save the model to a specific path #319

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def train(opt):
opt.batch_ratio = opt.batch_ratio.split('-')
train_dataset = Batch_Balanced_Dataset(opt)

log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
log = open(f'{opt.save_model_to}/{opt.exp_name}/log_dataset.txt', 'a')
AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
valid_loader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -119,7 +119,7 @@ def train(opt):

""" final options """
# print(opt)
with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
with open(f'{opt.save_model_to}/{opt.exp_name}/opt.txt', 'a') as opt_file:
opt_log = '------------ Options -------------\n'
args = vars(opt)
for k, v in args.items():
Expand Down Expand Up @@ -175,7 +175,7 @@ def train(opt):
if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0'
elapsed_time = time.time() - start_time
# for log
with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log:
with open(f'{opt.save_model_to}/{opt.exp_name}/log_train.txt', 'a') as log:
model.eval()
with torch.no_grad():
valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
Expand All @@ -191,10 +191,10 @@ def train(opt):
# keep best accuracy model (on valid dataset)
if current_accuracy > best_accuracy:
best_accuracy = current_accuracy
torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth')
torch.save(model.state_dict(), f'{opt.save_model_to}/{opt.exp_name}/best_accuracy.pth')
if current_norm_ED > best_norm_ED:
best_norm_ED = current_norm_ED
torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
torch.save(model.state_dict(), f'{opt.save_model_to}/{opt.exp_name}/best_norm_ED.pth')
best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
Expand All @@ -218,7 +218,7 @@ def train(opt):
# save model per 1e+5 iter.
if (iteration + 1) % 1e+5 == 0:
torch.save(
model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth')
model.state_dict(), f'{opt.save_model_to}/{opt.exp_name}/iter_{iteration+1}.pth')

if (iteration + 1) == opt.num_iter:
print('end the training')
Expand All @@ -237,6 +237,7 @@ def train(opt):
parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for')
parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation')
parser.add_argument('--saved_model', default='', help="path to model to continue training")
parser.add_argument('--save_model_to', default='./saved_model', help="path to save your new model")
parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning')
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)')
parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta')
Expand Down Expand Up @@ -281,7 +282,7 @@ def train(opt):
opt.exp_name += f'-Seed{opt.manualSeed}'
# print(opt.exp_name)

os.makedirs(f'./saved_models/{opt.exp_name}', exist_ok=True)
os.makedirs(f'{opt.save_model_to}/{opt.exp_name}', exist_ok=True)

""" vocab / character number configuration """
if opt.sensitive:
Expand Down