forked from facebookresearch/DomainBed
-
Notifications
You must be signed in to change notification settings - Fork 1
/
model_selection.py
executable file
·141 lines (125 loc) · 5.04 KB
/
model_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import itertools
import numpy as np
def get_test_records(records):
"""Given records with a common test env, get the test records (i.e. the
records with *only* that single test env and no other test envs)"""
return records.filter(lambda r: len(r['args']['test_envs']) == 1)
class SelectionMethod:
"""Abstract class whose subclasses implement strategies for model
selection across hparams and timesteps."""
def __init__(self):
raise TypeError
@classmethod
def run_acc(self, run_records):
"""
Given records from a run, return a {val_acc, test_acc} dict representing
the best val-acc and corresponding test-acc for that run.
"""
raise NotImplementedError
@classmethod
def hparams_accs(self, records):
"""
Given all records from a single (dataset, algorithm, test env) pair,
return a sorted list of (run_acc, records) tuples.
"""
return (records.group('args.hparams_seed')
.map(lambda _, run_records:
(
self.run_acc(run_records),
run_records
)
).filter(lambda x: x[0] is not None)
.sorted(key=lambda x: x[0]['val_acc'])[::-1]
)
@classmethod
def sweep_acc(self, records):
"""
Given all records from a single (dataset, algorithm, test env) pair,
return the mean test acc of the k runs with the top val accs.
"""
_hparams_accs = self.hparams_accs(records)
if len(_hparams_accs):
return _hparams_accs[0][0]['test_acc']
else:
return None
class OracleSelectionMethod(SelectionMethod):
"""Like Selection method which picks argmax(test_out_acc) across all hparams
and checkpoints, but instead of taking the argmax over all
checkpoints, we pick the last checkpoint, i.e. no early stopping."""
name = "test-domain validation set (oracle)"
@classmethod
def run_acc(self, run_records):
run_records = run_records.filter(lambda r:
len(r['args']['test_envs']) == 1)
if not len(run_records):
return None
test_env = run_records[0]['args']['test_envs'][0]
test_out_acc_key = 'env{}_out_acc'.format(test_env)
test_in_acc_key = 'env{}_in_acc'.format(test_env)
chosen_record = run_records.sorted(lambda r: r['step'])[-1]
return {
'val_acc': chosen_record[test_out_acc_key],
'test_acc': chosen_record[test_in_acc_key]
}
class IIDAccuracySelectionMethod(SelectionMethod):
"""Picks argmax(mean(env_out_acc for env in train_envs))"""
name = "training-domain validation set"
@classmethod
def _step_acc(self, record):
"""Given a single record, return a {val_acc, test_acc} dict."""
test_env = record['args']['test_envs'][0]
val_env_keys = []
for i in itertools.count():
if f'env{i}_out_acc' not in record:
break
if i != test_env:
val_env_keys.append(f'env{i}_out_acc')
test_in_acc_key = 'env{}_in_acc'.format(test_env)
return {
'val_acc': np.mean([record[key] for key in val_env_keys]),
'test_acc': record[test_in_acc_key]
}
@classmethod
def run_acc(self, run_records):
test_records = get_test_records(run_records)
if not len(test_records):
return None
return test_records.map(self._step_acc).argmax('val_acc')
class LeaveOneOutSelectionMethod(SelectionMethod):
"""Picks (hparams, step) by leave-one-out cross validation."""
name = "leave-one-domain-out cross-validation"
@classmethod
def _step_acc(self, records):
"""Return the {val_acc, test_acc} for a group of records corresponding
to a single step."""
test_records = get_test_records(records)
if len(test_records) != 1:
return None
test_env = test_records[0]['args']['test_envs'][0]
n_envs = 0
for i in itertools.count():
if f'env{i}_out_acc' not in records[0]:
break
n_envs += 1
val_accs = np.zeros(n_envs) - 1
for r in records.filter(lambda r: len(r['args']['test_envs']) == 2):
val_env = (set(r['args']['test_envs']) - set([test_env])).pop()
val_accs[val_env] = r['env{}_in_acc'.format(val_env)]
val_accs = list(val_accs[:test_env]) + list(val_accs[test_env+1:])
if any([v==-1 for v in val_accs]):
return None
val_acc = np.sum(val_accs) / (n_envs-1)
return {
'val_acc': val_acc,
'test_acc': test_records[0]['env{}_in_acc'.format(test_env)]
}
@classmethod
def run_acc(self, records):
step_accs = records.group('step').map(lambda step, step_records:
self._step_acc(step_records)
).filter_not_none()
if len(step_accs):
return step_accs.argmax('val_acc')
else:
return None