forked from TorchEnsemble-Community/Ensemble-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_base.py
288 lines (239 loc) · 8.92 KB
/
_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import abc
import copy
import torch
import logging
import warnings
import numpy as np
import torch.nn as nn
from . import _constants as const
from .utils.io import split_data_target
from .utils.logging import get_tb_logger
def torchensemble_model_doc(header="", item="model"):
"""
A decorator on obtaining documentation for different methods in the
ensemble. This decorator is modified from `sklearn.py` in XGBoost.
Parameters
----------
header: string
Introduction to the decorated class or method.
item : string
Type of the docstring item.
"""
def get_doc(item):
"""Return the selected item."""
__doc = {
"model": const.__model_doc,
"seq_model": const.__seq_model_doc,
"fit": const.__fit_doc,
"predict": const.__predict_doc,
"set_optimizer": const.__set_optimizer_doc,
"set_scheduler": const.__set_scheduler_doc,
"set_criterion": const.__set_criterion_doc,
"classifier_forward": const.__classification_forward_doc,
"classifier_evaluate": const.__classification_evaluate_doc,
"regressor_forward": const.__regression_forward_doc,
"regressor_evaluate": const.__regression_evaluate_doc,
}
return __doc[item]
def adddoc(cls):
doc = [header + "\n\n"]
doc.extend(get_doc(item))
cls.__doc__ = "".join(doc)
return cls
return adddoc
class BaseModule(nn.Module):
"""Base class for all ensembles.
WARNING: This class cannot be used directly.
Please use the derived classes instead.
"""
def __init__(
self,
estimator,
n_estimators,
estimator_args=None,
cuda=True,
n_jobs=None,
logger=None,
):
super(BaseModule, self).__init__()
self.base_estimator_ = estimator
self.n_estimators = n_estimators
self.estimator_args = estimator_args
if estimator_args and not isinstance(estimator, type):
msg = (
"The input `estimator_args` will have no effect since"
" `estimator` is already an object after instantiation."
)
warnings.warn(msg, RuntimeWarning)
self.device = torch.device("cuda" if cuda else "cpu")
self.n_jobs = n_jobs
self.logger = logging.getLogger() if not logger else logger
self.tb_logger = get_tb_logger()
self.estimators_ = nn.ModuleList()
self.use_scheduler_ = False
def __len__(self):
"""
Return the number of base estimators in the ensemble. The real number
of base estimators may not match `self.n_estimators` because of the
early stopping stage in several ensembles such as Gradient Boosting.
"""
return len(self.estimators_)
def __getitem__(self, index):
"""Return the `index`-th base estimator in the ensemble."""
return self.estimators_[index]
@abc.abstractmethod
def _decide_n_outputs(self, train_loader):
"""Decide the number of outputs according to the `train_loader`."""
def _make_estimator(self):
"""Make and configure a copy of `self.base_estimator_`."""
# Call `deepcopy` to make a base estimator
if not isinstance(self.base_estimator_, type):
estimator = copy.deepcopy(self.base_estimator_)
# Call `__init__` to make a base estimator
else:
# Without params
if self.estimator_args is None:
estimator = self.base_estimator_()
# With params
else:
estimator = self.base_estimator_(**self.estimator_args)
return estimator.to(self.device)
def _validate_parameters(self, epochs, log_interval):
"""Validate hyper-parameters on training the ensemble."""
if not epochs > 0:
msg = (
"The number of training epochs should be strictly positive"
", but got {} instead."
)
self.logger.error(msg.format(epochs))
raise ValueError(msg.format(epochs))
if not log_interval > 0:
msg = (
"The number of batches to wait before printing the"
" training status should be strictly positive, but got {}"
" instead."
)
self.logger.error(msg.format(log_interval))
raise ValueError(msg.format(log_interval))
def set_criterion(self, criterion):
"""Set the training criterion."""
self._criterion = criterion
def set_optimizer(self, optimizer_name, **kwargs):
"""Set the parameter optimizer."""
self.optimizer_name = optimizer_name
self.optimizer_args = kwargs
def set_scheduler(self, scheduler_name, **kwargs):
"""Set the learning rate scheduler."""
self.scheduler_name = scheduler_name
self.scheduler_args = kwargs
self.use_scheduler_ = True
@abc.abstractmethod
def forward(self, *x):
"""
Implementation on the data forwarding in the ensemble. Notice
that the input ``x`` should be a data batch instead of a standalone
data loader that contains many data batches.
"""
@abc.abstractmethod
def fit(
self,
train_loader,
epochs=100,
log_interval=100,
test_loader=None,
save_model=True,
save_dir=None,
):
"""
Implementation on the training stage of the ensemble.
"""
@torch.no_grad()
def predict(self, *x):
"""Docstrings decorated by downstream ensembles."""
self.eval()
# Copy data
x_device = []
for data in x:
if isinstance(data, torch.Tensor):
x_device.append(data.to(self.device))
elif isinstance(data, np.ndarray):
x_device.append(torch.Tensor(data).to(self.device))
else:
msg = (
"The type of input X should be one of {{torch.Tensor,"
" np.ndarray}}."
)
raise ValueError(msg)
pred = self.forward(*x_device)
pred = pred.cpu()
return pred
class BaseClassifier(BaseModule):
"""Base class for all ensemble classifiers.
WARNING: This class cannot be used directly.
Please use the derived classes instead.
"""
def _decide_n_outputs(self, train_loader):
"""
Decide the number of outputs according to the `train_loader`.
The number of outputs equals the number of distinct classes for
classifiers.
"""
if hasattr(train_loader.dataset, "classes"):
n_outputs = len(train_loader.dataset.classes)
# Infer `n_outputs` from the dataloader
else:
labels = []
for _, elem in enumerate(train_loader):
_, target = split_data_target(elem, self.device)
labels.append(target)
labels = torch.unique(torch.cat(labels))
n_outputs = labels.size(0)
return n_outputs
@torch.no_grad()
def evaluate(self, test_loader, return_loss=False):
"""Docstrings decorated by downstream models."""
self.eval()
correct = 0
total = 0
loss = 0.0
for _, elem in enumerate(test_loader):
idx, data, target = split_data_target(elem, self.device)
output = self.forward(*data)
_, predicted = torch.max(output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
loss += self._criterion(output, target)
acc = 100 * correct / total
loss /= len(test_loader)
if return_loss:
return acc, float(loss)
return acc
class BaseRegressor(BaseModule):
"""Base class for all ensemble regressors.
WARNING: This class cannot be used directly.
Please use the derived classes instead.
"""
def _decide_n_outputs(self, train_loader):
"""
Decide the number of outputs according to the `train_loader`.
The number of outputs equals the number of target variables for
regressors (e.g., `1` in univariate regression).
"""
for _, elem in enumerate(train_loader):
_, target = split_data_target(elem, self.device)
if len(target.size()) == 1:
n_outputs = 1 # univariate regression
else:
n_outputs = target.size(1) # multivariate regression
break
return n_outputs
@torch.no_grad()
def evaluate(self, test_loader):
"""Docstrings decorated by downstream ensembles."""
self.eval()
loss = 0.0
for _, elem in enumerate(test_loader):
data, target = split_data_target(elem, self.device)
output = self.forward(*data)
loss += self._criterion(output, target)
return float(loss) / len(test_loader)