-
Notifications
You must be signed in to change notification settings - Fork 23
/
chainer_compiler.py
348 lines (292 loc) · 12.2 KB
/
chainer_compiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import chainer
import chainerx
import os
import sys
import tempfile
try:
from chainer_compiler import _chainer_compiler_core
except ImportError:
# When testing the module without the installation of chainer_compiler via
# pip, `_chainer_compiler_core.so` is not accessible through
# `chainer_compiler` package.
# `_chainer_compiler_core.so` should be imported directly from
# `build/chainer_compiler_cc`.
# TODO(mkusumoto): Seek more sophisticated way to import the .so file.
try:
root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.join(root, 'build/chainer_compiler_cc'))
import _chainer_compiler_core
except ImportError:
# We need to allow this failure for build time (e.g., elichika
# testgen) import where the shared object is not ready yet.
pass
try:
import cupy
except ImportError:
cupy = None
def _is_array(v):
return not isinstance(v, (list, tuple, range, dict))
def _flatten(xs):
if _is_array(xs):
return [xs]
o = []
for x in xs:
if _is_array(x):
o.append(x)
else:
o.extend(_flatten(x))
return o
def _flatten_structured(xs, tmpl):
o = []
for x, t in zip(xs, tmpl):
if _is_array(t):
assert _is_array(x)
o.append(x)
else:
assert not _is_array(x), '%s vs %s' % (x, t)
if len(x) == len(t):
o.extend(_flatten_structured(x, t))
elif len(x) == 0:
o.extend([None] * len(t))
else:
raise RuntimeError('%s vs %s' % (x, t))
return o
def _unflatten(xs, tmpl, i=0):
o = []
for t in tmpl:
if _is_array(t):
o.append(xs[i])
i += 1
else:
no, i = _unflatten(xs, t, i)
o.append(no)
return type(tmpl)(o), i
def _from_var(v, device):
if v.is_array():
return device.send(v.array())
return [_from_var(x, device) for x in v.sequence()]
class RunCompiledModel(chainer.function_node.FunctionNode):
def __init__(self, compiled_model, input_tmpl, runtime_kwargs):
self.fwd_input_names = compiled_model.fwd_input_names
self.fwd_output_names = compiled_model.fwd_output_names
self.bwd_input_names = compiled_model.bwd_input_names
self.bwd_output_names = compiled_model.bwd_output_names
self.param_names = compiled_model.param_names
self.fwd = compiled_model.fwd
self.bwd = compiled_model.bwd
self.num_outputs = len(compiled_model.orig_output_names)
self.input_tmpl = input_tmpl
self.num_inputs = len(_flatten(input_tmpl))
self.chainerx_device_name = None
self.runtime_kwargs = runtime_kwargs
def _to_var(self, v):
if _is_array(v):
if isinstance(v, chainer.Variable):
v = v.array
v = chainer.backend.to_chx(v)
if self.chainerx_device_name is None:
self.chainerx_device_name = v.device
else:
assert self.chainerx_device_name == v.device
return _chainer_compiler_core.value(v)
return _chainer_compiler_core.value([self._to_var(a) for a in v])
def forward(self, args):
flat_inputs = args[:self.num_inputs]
param_values = args[self.num_inputs:]
device = chainer.backend.get_device_from_array(*flat_inputs)
inputs, i = _unflatten(flat_inputs, self.input_tmpl)
assert i == len(flat_inputs)
entire_inputs = {}
assert len(self.fwd_input_names) == len(inputs)
for name, value in zip(self.fwd_input_names, inputs):
entire_inputs[name] = self._to_var(value)
assert len(self.param_names) == len(param_values)
for name, value in zip(self.param_names, param_values):
entire_inputs[name] = self._to_var(value)
with chainer.using_device(self.chainerx_device_name):
outputs = self.fwd.run(entire_inputs, **self.runtime_kwargs)
outputs_and_retained = []
for name in self.fwd_output_names:
outputs_and_retained.append(outputs[name])
self.retained = outputs_and_retained[self.num_outputs:]
# TODO(hamaji): Do not hold actual arrays.
self.nested_outputs = []
for output in outputs_and_retained[:self.num_outputs]:
self.nested_outputs.append(_from_var(output, device))
flat_outputs = _flatten(self.nested_outputs)
return tuple(flat_outputs)
def unflatten_outputs(self, flat_outputs):
outputs, _ = _unflatten(flat_outputs, self.nested_outputs)
return outputs
def backward(self, indexes, flat_gys):
device = chainer.backend.get_device_from_array(flat_gys[0].array)
gys, _ = _unflatten(flat_gys, self.nested_outputs)
gys = [self._to_var(gy) for gy in gys]
values = gys + self.retained
del self.retained
del self.nested_outputs
inputs = {}
assert len(self.bwd_input_names) == len(values)
for name, value in zip(self.bwd_input_names, values):
inputs[name] = value
state = self.bwd.prepare(inputs, **self.runtime_kwargs)
del inputs
del values
with chainer.using_device(self.chainerx_device_name):
outputs = self.bwd.run(state)
gxs = []
assert len(self.input_tmpl) == len(self.fwd_input_names)
for name, tmpl in zip(self.fwd_input_names, self.input_tmpl):
grad_name = 'grad_out@' + name
if grad_name in outputs:
gx = _from_var(outputs[grad_name], device)
if _is_array(tmpl):
gxs.append(gx)
else:
assert len(gx) == len(tmpl)
gxs.extend(_flatten_structured(gx, tmpl))
else:
gxs.extend([None] * len(_flatten(tmpl)))
for name in self.param_names:
grad_name = 'grad_out@' + name
if grad_name in outputs:
gx = _from_var(outputs[grad_name], device)
gxs.append(gx)
else:
gxs.extend([None])
gxs = tuple(None if gx is None else chainer.Variable(gx) for gx in gxs)
return gxs
def export(model, inputs, filename=None, translator='onnx_chainer'):
if translator == 'ch2o':
from chainer_compiler import ch2o
xmodel = ch2o.compile_model(model, inputs)
if filename is None:
f = tempfile.NamedTemporaryFile(delete=False)
else:
f = open(filename, 'wb')
f.write(xmodel.SerializeToString())
f.close()
del xmodel
elif translator == 'onnx_chainer':
import onnx_chainer
if filename is None:
f = tempfile.NamedTemporaryFile(delete=False)
else:
f = open(filename, 'wb')
onnx_chainer.export(model, inputs, filename=f)
f.close()
else:
raise NotImplementedError('Unsupported translator:',
translator)
return f.name
class CompiledModel(chainer.Chain):
def __init__(self, model, onnx_file, used_translator, dump_onnx=False,
computation_order=None,
compiler_kwargs=None,
runtime_kwargs=None,
quiet_period=0):
super(CompiledModel, self).__init__()
with self.init_scope():
self.mc = model
self.used_translator = used_translator
self.dump_onnx = dump_onnx
self.computation_order = computation_order
self.compiler_kwargs = compiler_kwargs
self.runtime_kwargs = runtime_kwargs
self.quiet_period = quiet_period
self.num_iterations = 0
self.param_names = None
self.param_values = None
# Propagate device from `model` before compiling it.
self.to_device(model.device)
self.compile(onnx_file)
def compile(self, onnx_file):
if self.compiler_kwargs is not None:
_chainer_compiler_core.configure(**self.compiler_kwargs)
graph = _chainer_compiler_core.load(onnx_file)
self.orig_output_names = graph.output_names()
if self.computation_order is None:
fwd_graph, bwd_graph = graph.backward_to(
graph.input_names() + graph.param_names())
skip_scheduling = False
else:
fwd_graph, bwd_graph = graph.backward_to_with_order(
self.computation_order)
skip_scheduling = True
if self.dump_onnx:
sys.stderr.write('=== vvv forward vvv ===\n' +
fwd_graph.dump() +
'\n=== ^^^ forward ^^^ ===\n')
sys.stderr.write('=== vvv backward vvv ===\n' +
bwd_graph.dump() +
'\n=== ^^^ backward ^^^ ===\n')
# TODO(hamaji): Revive shape inference.
compiler_kwargs = {'skip_inference': True}
if self.compiler_kwargs is not None:
compiler_kwargs.update(self.compiler_kwargs)
_chainer_compiler_core.configure(**compiler_kwargs)
assert graph.input_names() == fwd_graph.input_names()
self.fwd_input_names = fwd_graph.input_names()
self.fwd_output_names = fwd_graph.output_names()
self.bwd_input_names = bwd_graph.input_names()
self.bwd_output_names = bwd_graph.output_names()
self.fwd = fwd_graph.compile(skip_scheduling)
self.bwd = bwd_graph.compile(skip_scheduling)
self.param_names = fwd_graph.param_names()
if self.used_translator == 'ch2o':
convert_rule = lambda key: key # noqa
elif self.used_translator == 'onnx_chainer':
convert_rule = lambda key: 'param' + key.replace('/', '_') # noqa
params = {convert_rule(key): value for key, value
in self.mc.namedparams()}
# Since avg_mean and avg_var in BatchNormalization are not parameters
# in chainer link, we need an additional handling.
for link_name, link in self.mc.namedlinks():
if not isinstance(link, chainer.links.BatchNormalization):
continue
for avg_name in ['avg_mean', 'avg_var']:
key = convert_rule(link_name + '/' + avg_name)
assert key not in params
params[key] = getattr(link, avg_name)
self.param_values = []
fwd_chxvm_vars = fwd_graph.params()
for name in self.param_names:
if name in params:
self.param_values.append(params[name])
elif name in fwd_chxvm_vars:
# Retrieve the initial value from ONNX initializer
# TODO(hamaji): Emit `Constant` in onnx-chainer so we will not
# need this branch.
array = fwd_chxvm_vars[name].array()
array = self.device.send(array)
self.param_values.append(array)
else:
raise NotImplementedError('Initial value is uknown: ' + name)
def forward(self, *args):
inputs = list(args)
flat_inputs = _flatten(inputs)
runtime_kwargs = {}
if (self.runtime_kwargs is not None and
self.num_iterations % (self.quiet_period + 1) == 0):
runtime_kwargs.update(self.runtime_kwargs)
self.num_iterations += 1
runner = RunCompiledModel(self, inputs, runtime_kwargs)
outputs = runner.apply(flat_inputs + self.param_values)
outputs = runner.unflatten_outputs(outputs)
outputs = outputs[:len(self.orig_output_names)]
if len(outputs) == 1:
outputs = outputs[0]
return outputs
def compile(model, inputs, translator='ch2o', **kwargs):
# Run translator internally
onnx_file = export(model, inputs, filename=None, translator=translator)
compiled_model = CompiledModel(model, onnx_file, translator, **kwargs)
return compiled_model
def compile_onnx(model, onnx_file, used_translator, **kwargs):
return CompiledModel(model, onnx_file, used_translator, **kwargs)
def use_unified_memory_allocator():
cupy.cuda.set_allocator(cupy.cuda.memory.malloc_managed)
def use_chainerx_shared_allocator():
if cupy is None:
return
chainerx._cuda.cupy_share_allocator()