-
Notifications
You must be signed in to change notification settings - Fork 1
/
val_results.py
97 lines (62 loc) · 2.72 KB
/
val_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 17 13:52:40 2019
@author: john
"""
import torch
import os
import argparse
from utils import AudioDataset, PostProcess
import json
#torch.set_default_tensor_type('torch.cuda.FloatTensor')
lookup_dict = json.load(open('./lookup.json'))
postprocessor = PostProcess(lookup_dict)
from torch.utils.data import DataLoader
def infer(model,val_loader) :
total_lev_dist = 0
n_val = len(val_loader.dataset)
for data in val_loader :
audio = data[0]
targets = data[1].cpu().numpy().astype('int')
output = model(audio)
outmax = torch.argmax(output,dim=1).cpu().numpy()
for i, vec in enumerate(outmax):
original = postprocessor.target2kana(targets[i])
predicted = postprocessor.target2kana(vec,refine = True)
lev_dist = postprocessor.levenshtein(original,predicted)
total_lev_dist += lev_dist/len(original)
print(original, predicted,lev_dist,round(lev_dist/len(original),2))
av_lev = total_lev_dist/n_val
print('average Levenshtein distance',av_lev)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--model_path',
default = None, type = str)
parser.add_argument('-n', '--gpu_number',
default = 0, type = int)
args = parser.parse_args()
model_path = args.model_path
gpu_nbr = args.gpu_number
torch.cuda.set_device(gpu_nbr)
info_dict = json.load(open(os.path.join(model_path,'info_dict.json')))
# audio_list_train = info_dict['audio_list_train']
# target_list_train = info_dict['target_list_train']
sets = info_dict['sets']
list_dict = info_dict['list_dict']
for set_ in sets :
set_dict = list_dict[set_['path']]
audio_list_val = set_dict['audio_list_val']
target_dir = info_dict['target_dir']
n_audio_max = info_dict['n_audio_max']
n_target_max = info_dict['n_target_max']
if len(audio_list_val) != 0 :
print(set_['path'])
dataset_val = AudioDataset(audio_list_val,target_dir,n_audio_max,n_target_max,random_pad = False,change_speed=False)
val_loader = DataLoader(dataset_val, batch_size=8,shuffle=False)
model = torch.load(os.path.join(model_path,'model.pt'),map_location='cuda:' + str(gpu_nbr) )
model.to(gpu_nbr)
model = model.eval()
infer(model,val_loader)
print('='*20)
print()