Skip to content

Commit

Permalink
feat: Implementated average poolin 2d.
Browse files Browse the repository at this point in the history
The supported args cannot be tuple right now. This is WIP. Will extend
it to support tuple arguments.

Addresses: issue pfnet-research#160
  • Loading branch information
Rishav1 committed Apr 14, 2019
1 parent dd2bc00 commit 4019c67
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 0 deletions.
1 change: 1 addition & 0 deletions elichika/elichika/chainer2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def compile_model(model, inputs) -> 'ONNXModel':
oc.chainer_f_converter[F.softmax] = fb.convert_softmax
oc.chainer_f_converter[F.pad_sequence] = fb.convert_pad_sequence
oc.chainer_f_converter[F.softmax_cross_entropy] = fb.convert_softmax_cross_entropy
oc.chainer_f_converter[F.average_pooling_2d] = fb.convert_average_pool_2d

# assign names
oc.assigned_names.clear()
Expand Down
33 changes: 33 additions & 0 deletions elichika/elichika/functions_buildin.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,36 @@ def convert_softmax_cross_entropy(onnx_graph, node):
node.inputs,
node.outputs,
str(node.lineprop))


def convert_average_pool_2d(onnx_graph, node):
def _pair(x):
if isinstance(x, collections.Iterable):
return x
return (x, x)

kwargs = {}
ksize = oc.try_get_attribute(node.inputs[1])
kwargs['kernel_shape'] = _pair(ksize)

value = oc.try_get_attribute(node.inputs[2])
if value is not None:
kwargs['strides'] = _pair(value)
else:
kwargs['strides'] = _pair(ksize)

value = oc.try_get_attribute(node.inputs[3])
if value is not None:
kwargs['pads'] = _pair(value) * 2
else:
kwargs['pads'] = _pair(0)

kwargs['count_include_pad'] = 1

onnx_graph.add_node(
"AveragePool",
[node.inputs[0]],
[node.outputs[0]],
name=str(node.lineprop),
**kwargs,
)
3 changes: 3 additions & 0 deletions elichika/elichika/parser/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def instance_converter(m, i):
functions_builtin.ChainerFunction(F.pad_sequence), None)
f_dict.get_field().get_attribute('pad_sequence').revise(
values.ValueRef(f_pad_sequence))
f_average_pooling_2d = values.FuncValue(
functions_builtin.ChainerFunction(F.average_pooling_2d), None)
f_dict.get_field().get_attribute('average_pooling_2d').revise(values.ValueRef(f_average_pooling_2d))
default_module.set_default_value(chainer_functions_module_name, f_dict)

# numpy
Expand Down
53 changes: 53 additions & 0 deletions elichika/tests/node/AveragePool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# coding: utf-8

import chainer
import chainer.functions as F


class AvgPool(chainer.Chain):

def __init__(self):
super(AvgPool, self).__init__()

def forward(self, x):
y1 = F.average_pooling_2d(x, 1, stride=2)
return y1


class AvgPoolPad(chainer.Chain):

def __init__(self):
super(AvgPoolPad, self).__init__()

def forward(self, x):
y1 = F.average_pooling_2d(x, 3, stride=1, pad=2)
return y1


class AvgPoolNoStride(chainer.Chain):

def __init__(self):
super(AvgPoolNoStride, self).__init__()

def forward(self, x):
y1 = F.average_pooling_2d(x, 3)
return y1


# ======================================

import testtools
import numpy as np


def main():
np.random.seed(123)
x = np.random.rand(2, 20, 15, 17).astype(np.float32)

testtools.generate_testcase(AvgPool(), [x], subname='default')
testtools.generate_testcase(AvgPoolPad(), [x], subname='withpad')
testtools.generate_testcase(AvgPoolNoStride(), [x], subname='withoutstride')


if __name__ == '__main__':
main()

0 comments on commit 4019c67

Please sign in to comment.