-
Notifications
You must be signed in to change notification settings - Fork 9
/
run_swda.py
executable file
·71 lines (57 loc) · 2.35 KB
/
run_swda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
# finetune all layers, chunk size 40
# finetune 2 layers, chunk size 240
# finetune 1 layer, chunk size 300
# finetune 1 layer, chunksize 160, batchsize=2
if __name__ == '__main__':
corpus = 'swda'
mode = ['train', 'inference'][0]
batch_size = 2
batch_size_val = 2 # don't change this number
emb_batch = 0
epochs = 100 # default 100
gpu = ['0,1', '', '0,1,2,3', '0,1,2,3,4,5,6,7'][0] # default 0,1
lr = 1e-4 # default 1e-4
nlayer = 2 # default 1
chunk_size = 196
dropout = 0.5 # default 0.5
nfinetune = 1 # default 2
nclass = 43
speaker_info = ['none', 'emb_cls'][1] # whether to use speaker embeddings or not
topic_info = ['none', 'emb_cls'][0] # whether to use topic embeddings or not
os.makedirs(f'results_{corpus}', exist_ok=True)
file_name = f"results_{corpus}/{corpus}_chunk={chunk_size}_nlayer={nlayer}.txt"
if nfinetune != 2:
file_name = file_name[:-4] + f'_nfinetune={nfinetune}.txt'
if speaker_info != 'none':
file_name = file_name[:-4] + f'_speaker={speaker_info}.txt'
if topic_info != 'none':
file_name = file_name[:-4] + f'_tinfo={topic_info}.txt'
if lr != 1e-4:
file_name = file_name[:-4] + f'_lr={lr}.txt'
if dropout != 0.5:
file_name = file_name[:-4] + f'_dropout={dropout}.txt'
if not gpu:
n_gpu = 0
else:
n_gpu = len(gpu.split(','))
if n_gpu != 2:
file_name = file_name[:-4] + f'_ngpu={n_gpu}.txt'
# bash_file = ''
# with open(f'run_{corpus}.sh') as f:
# for line in f:
# str_line = list(line)
# if 'gpu' in line:
# pos = line.index('=')
# str_line[pos+1:pos+2] = str(n_gpu)
# bash_file += ''.join(str_line)
# with open(f'run_{corpus}.sh', 'w') as f:
# f.write(bash_file)
command = f"python -u engine.py --corpus={corpus} --mode={mode} --gpu={gpu} --batch_size={batch_size} " \
f"--batch_size_val={batch_size_val} --epochs={epochs} " \
f"--lr={lr} --nlayer={nlayer} --chunk_size={chunk_size} --dropout={dropout} " \
f"--nfinetune={nfinetune} --speaker_info={speaker_info} " \
f"--topic_info={topic_info} --nclass={nclass} --emb_batch={emb_batch}" \
# f" > {file_name}"
print(command)
os.system(command)