From 4019c6748ee87031018fb8a1735ad90e58eb0122 Mon Sep 17 00:00:00 2001 From: Rishav Chourasia Date: Sun, 14 Apr 2019 18:03:18 +0530 Subject: [PATCH] feat: Implementated average poolin 2d. The supported args cannot be tuple right now. This is WIP. Will extend it to support tuple arguments. Addresses: issue #160 --- elichika/elichika/chainer2onnx.py | 1 + elichika/elichika/functions_buildin.py | 33 ++++++++++++++++ elichika/elichika/parser/core.py | 3 ++ elichika/tests/node/AveragePool2d.py | 53 ++++++++++++++++++++++++++ 4 files changed, 90 insertions(+) create mode 100644 elichika/tests/node/AveragePool2d.py diff --git a/elichika/elichika/chainer2onnx.py b/elichika/elichika/chainer2onnx.py index 1c7974a6..dfa3a058 100644 --- a/elichika/elichika/chainer2onnx.py +++ b/elichika/elichika/chainer2onnx.py @@ -44,6 +44,7 @@ def compile_model(model, inputs) -> 'ONNXModel': oc.chainer_f_converter[F.softmax] = fb.convert_softmax oc.chainer_f_converter[F.pad_sequence] = fb.convert_pad_sequence oc.chainer_f_converter[F.softmax_cross_entropy] = fb.convert_softmax_cross_entropy + oc.chainer_f_converter[F.average_pooling_2d] = fb.convert_average_pool_2d # assign names oc.assigned_names.clear() diff --git a/elichika/elichika/functions_buildin.py b/elichika/elichika/functions_buildin.py index 0ff4f626..67a5d5d7 100644 --- a/elichika/elichika/functions_buildin.py +++ b/elichika/elichika/functions_buildin.py @@ -66,3 +66,36 @@ def convert_softmax_cross_entropy(onnx_graph, node): node.inputs, node.outputs, str(node.lineprop)) + + +def convert_average_pool_2d(onnx_graph, node): + def _pair(x): + if isinstance(x, collections.Iterable): + return x + return (x, x) + + kwargs = {} + ksize = oc.try_get_attribute(node.inputs[1]) + kwargs['kernel_shape'] = _pair(ksize) + + value = oc.try_get_attribute(node.inputs[2]) + if value is not None: + kwargs['strides'] = _pair(value) + else: + kwargs['strides'] = _pair(ksize) + + value = oc.try_get_attribute(node.inputs[3]) + if value is not None: + kwargs['pads'] = _pair(value) * 2 + else: + kwargs['pads'] = _pair(0) + + kwargs['count_include_pad'] = 1 + + onnx_graph.add_node( + "AveragePool", + [node.inputs[0]], + [node.outputs[0]], + name=str(node.lineprop), + **kwargs, + ) diff --git a/elichika/elichika/parser/core.py b/elichika/elichika/parser/core.py index 7a43bf4a..5b0b339b 100644 --- a/elichika/elichika/parser/core.py +++ b/elichika/elichika/parser/core.py @@ -63,6 +63,9 @@ def instance_converter(m, i): functions_builtin.ChainerFunction(F.pad_sequence), None) f_dict.get_field().get_attribute('pad_sequence').revise( values.ValueRef(f_pad_sequence)) + f_average_pooling_2d = values.FuncValue( + functions_builtin.ChainerFunction(F.average_pooling_2d), None) + f_dict.get_field().get_attribute('average_pooling_2d').revise(values.ValueRef(f_average_pooling_2d)) default_module.set_default_value(chainer_functions_module_name, f_dict) # numpy diff --git a/elichika/tests/node/AveragePool2d.py b/elichika/tests/node/AveragePool2d.py new file mode 100644 index 00000000..60747f26 --- /dev/null +++ b/elichika/tests/node/AveragePool2d.py @@ -0,0 +1,53 @@ +# coding: utf-8 + +import chainer +import chainer.functions as F + + +class AvgPool(chainer.Chain): + + def __init__(self): + super(AvgPool, self).__init__() + + def forward(self, x): + y1 = F.average_pooling_2d(x, 1, stride=2) + return y1 + + +class AvgPoolPad(chainer.Chain): + + def __init__(self): + super(AvgPoolPad, self).__init__() + + def forward(self, x): + y1 = F.average_pooling_2d(x, 3, stride=1, pad=2) + return y1 + + +class AvgPoolNoStride(chainer.Chain): + + def __init__(self): + super(AvgPoolNoStride, self).__init__() + + def forward(self, x): + y1 = F.average_pooling_2d(x, 3) + return y1 + + +# ====================================== + +import testtools +import numpy as np + + +def main(): + np.random.seed(123) + x = np.random.rand(2, 20, 15, 17).astype(np.float32) + + testtools.generate_testcase(AvgPool(), [x], subname='default') + testtools.generate_testcase(AvgPoolPad(), [x], subname='withpad') + testtools.generate_testcase(AvgPoolNoStride(), [x], subname='withoutstride') + + +if __name__ == '__main__': + main()