-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.py
46 lines (41 loc) · 1.94 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# GLOBAL PARAMETERS
DATASETS = ['sent140', 'nist', 'shakespeare',
'mnist', 'synthetic', 'cifar10','mqtt']
TRAINERS = {'fedavg': 'FedAvgTrainer',
'fedavg4': 'FedAvg4Trainer',
'fedavg5': 'FedAvg5Trainer',
'fedavg9': 'FedAvg9Trainer',}
OPTIMIZERS = TRAINERS.keys()
BATCH_LIST = [32, 64, 128, 256, 32, 64, 128, 256,32, 64, 128, 256,32, 64, 128, 256,32, 64, 128, 256]
# SERVER_ADDR= 'localhost' # When running in a real distributed setting, change to the server's IP address
SERVER_ADDR= '169.254.231.192' # When running in a real distributed setting, change to the server's IP address
SERVER_PORT = 51000
class ModelConfig(object):
def __init__(self):
pass
def __call__(self, dataset, model):
dataset = dataset.split('_')[0]
if dataset == 'mnist' or dataset == 'nist':
if model == 'logistic' or model == '2nn':
return {'input_shape': 784, 'num_class': 10}
else:
return {'input_shape': (1, 28, 28), 'num_class': 10}
elif dataset == 'cifar10':
return {'input_shape': (3, 32, 32), 'num_class': 10}
elif dataset == 'sent140':
sent140 = {'bag_dnn': {'num_class': 2},
'stacked_lstm': {'seq_len': 25, 'num_class': 2, 'num_hidden': 100},
'stacked_lstm_no_embeddings': {'seq_len': 25, 'num_class': 2, 'num_hidden': 100}
}
return sent140[model]
elif dataset == 'shakespeare':
shakespeare = {'stacked_lstm': {'seq_len': 80, 'emb_dim': 80, 'num_hidden': 256}
}
return shakespeare[model]
elif dataset == 'synthetic':
return {'input_shape': 60, 'num_class': 10}
elif dataset == 'mqtt':
return {'input_shape': 28, 'num_class': 5}
else:
raise ValueError('Not support dataset {}!'.format(dataset))
MODEL_PARAMS = ModelConfig()