-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_training_data.py
executable file
·42 lines (36 loc) · 2.14 KB
/
generate_training_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import read_bvh
import numpy as np
from os import listdir
import argparse
import os
def generate_traindata_from_bvh(src_bvh_folder, tar_traindata_folder, representation):
print ("Generating training data for "+ src_bvh_folder)
if (os.path.exists(tar_traindata_folder+'/train/')==False):
os.makedirs(tar_traindata_folder+'/train/')
if (os.path.exists(tar_traindata_folder+'/test/')==False):
os.makedirs(tar_traindata_folder+'/test/')
bvh_dances_names=listdir(src_bvh_folder)
bvh_dances_names.sort()
test_idx = int(len(bvh_dances_names)*0.7) -1
for idx, bvh_dance_name in enumerate(bvh_dances_names):
name_len=len(bvh_dance_name)
if(name_len>4):
if(bvh_dance_name[name_len-4: name_len]==".bvh"):
print ("Processing "+bvh_dance_name)
dance=read_bvh.get_train_data(src_bvh_folder+bvh_dance_name, representation)
if idx <= test_idx: np.save(tar_traindata_folder+'/train/'+bvh_dance_name+".npy", dance)
else: np.save(tar_traindata_folder+'/test/'+bvh_dance_name+".npy", dance)
if __name__ == '__main__' :
parser = argparse.ArgumentParser(description='ACLSTM-Train')
parser.add_argument('in_folder', default=None, help='Path to the folder containig the bvh to be processed.')
parser.add_argument('out_folder', default=None, help='Path to the folder where to output theprocessed bvh.')
parser.add_argument('--representation', default=None, help='Which representation to use for the conversion. [positional, euler, 6d, quaternions]')
args = parser.parse_args()
gen_info_path = f'{args.out_folder}/info.txt'
os.makedirs(args.out_folder, exist_ok=True)
with open(gen_info_path, "w") as f:
f.write(f'{args}')
generate_traindata_from_bvh(args.in_folder, args.out_folder, args.representation)
#generate_traindata_from_bvh("../train_data_bvh/indian/","../train_data_xyz/indian/")
#generate_traindata_from_bvh("../train_data_bvh/salsa/","../train_data_xyz/salsa/")
#generate_traindata_from_bvh("../train_data_bvh/martial/","../train_data_xyz/martial/")