-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
74 lines (72 loc) · 3.19 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
from sacred import Experiment
ex = Experiment("PSTL", save_git_info=False)
@ex.config
def my_config():
############################## setting ##############################
version = "ntu60_xsub_joint"
dataset = "NTU60_occ" # ntu60 / ntu120
split = "xsub"
view = "joint" # joint / motion / bone
save_lp = False
save_finetune = False
save_semi = False
pretrain_epoch = 150
ft_epoch = 150
lp_epoch = 150
pretrain_lr = 5e-3
lp_lr = 0.01
ft_lr = 5e-3
label_percent = 0.1
weight_decay = 1e-5
hidden_size = 256
label_num = 60
############################## ST-GCN ###############################
in_channels = 3
hidden_channels = 16
hidden_dim = 256
dropout = 0.5
graph_args = {
"layout" : 'ntu-rgb+d',
"strategy" : 'spatial'
}
edge_importance_weighting = True
############################ down stream ############################
weight_path = '/cvhci/temp/ychen2/data_occ_frame50/OPSTL/kmeans+knn/OPSTL_'+version+'_frame50_epoch_150_pretrain.pt'
train_mode = 'pretrain'
# train_mode = 'finetune'
# train_mode = 'pretrain'
# train_mode = 'semi'
log_path = '/cvhci/temp/ychen2/data_occ_frame50/OPSTL/kmeans+knn/'+version+'_'+train_mode+'.log'
################################ GPU ################################
gpus = "0"
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
########################## Skeleton Setting #########################
batch_size = 128
channel_num = 3
person_num = 2
joint_num = 25
max_frame = 50
# train_list = '/cvhci/temp/ychen2/data_occ_frame50/'+dataset+'_frame50_relative/'+split+'/train_position.npy'
# test_list = '/cvhci/temp/ychen2/data_occ_frame50/'+dataset+'_frame50_relative/'+split+'/val_position.npy'
train_list = '/cvhci/temp/ychen2/data_occ_frame50/OPSTL/kmeans+knn/'+dataset+'_completed_frame50/'+split+'/train_position.npy'
test_list = '/cvhci/temp/ychen2/data_occ_frame50/OPSTL/kmeans+knn/'+dataset+'_completed_frame50/'+split+'/val_position.npy'
train_label = '/lsdf/users/ychen/data_occ/NTU-RGB-D-60-occ/'+split+'/train_label.pkl'
test_label = '/lsdf/users/ychen/data_occ/NTU-RGB-D-60-occ/'+split+'/val_label.pkl'
joints_list = '/lsdf/users/ychen/data_occ/NTU-RGB-D-60-occ/'+split+'/train_missing_joints_distribution.pkl'
########################### complete joints ###########################
original_train_data_nan = '/lsdf/users/ychen/data_occ/NTU-RGB-D-60-occ/'+split+'/train_data_nan.npy'
original_test_data_nan = '/lsdf/users/ychen/data_occ/NTU-RGB-D-60-occ/'+split+'/val_data_nan.npy'
missing_joints_train = '/lsdf/users/ychen/data_occ/NTU-RGB-D-60-occ/'+split+'/train_missing_joints.pkl'
missing_joints_test = '/lsdf/users/ychen/data_occ/NTU-RGB-D-60-occ/'+split+'/val_missing_joints.pkl'
output_path = '/cvhci/temp/ychen2/data_occ_frame50/OPSTL/kmeans+knn'
cluster_num = 60
n_neighbors = 5
########################### Data Augmentation #########################
temperal_padding_ratio = 6
shear_amp = 1
mask_joint = 9
mask_frame = 10
############################ Barlow Twins #############################
pj_size = 6144
lambd = 2e-4