diff --git a/chainer_compiler/ch2o/funcs.py b/chainer_compiler/ch2o/funcs.py index 03046320..994d6508 100644 --- a/chainer_compiler/ch2o/funcs.py +++ b/chainer_compiler/ch2o/funcs.py @@ -65,7 +65,7 @@ def call_impl(self, env, x, ksize, stride, pad, cover_all, return_indices): 'MaxPool', inputs=[x.to_tensor(env).name], kernel_shape=kernel_shape, - chainer_cover_all=cover_all.to_bool(), + ceil_mode=int(cover_all.to_bool()), **kwargs) diff --git a/chainer_compiler/elichika/functions_builtin.py b/chainer_compiler/elichika/functions_builtin.py index 37652855..a5cb2a44 100644 --- a/chainer_compiler/elichika/functions_builtin.py +++ b/chainer_compiler/elichika/functions_builtin.py @@ -40,7 +40,7 @@ def get_onnx_dtype(dtype): class BaseConverter(object): def __init__(self): self.expected_args = () - + def parse_args(self, onnx_graph, node): assert hasattr(self, 'expected_args'), 'BaseConverter subclass must have `expected_args`' parser = oc.NodeParse() @@ -52,7 +52,7 @@ def parse_args(self, onnx_graph, node): def __call__(self, onnx_graph, node): raise NotImplementedError - + class ConverterChainerMathMisc(BaseConverter): def __init__(self, operator, arg_name = 'x'): self.arg_name = arg_name @@ -284,7 +284,7 @@ def __call__(self, onnx_graph, node): str(node.lineprop), min=parser.get('x_min'), max=parser.get('x_max')) - + class ConverterSum(BaseConverter): def __init__(self): self.expected_args = ( @@ -482,7 +482,7 @@ def __call__(self, onnx_graph, node): [parser.get('x')], node.outputs, name=str(node.lineprop), - chainer_cover_all=parser.get('cover_all'), + ceil_mode=int(parser.get('cover_all')), **kwargs) diff --git a/compiler/chxvm/simple_node_emitter.cc b/compiler/chxvm/simple_node_emitter.cc index cb6ff532..0924540e 100644 --- a/compiler/chxvm/simple_node_emitter.cc +++ b/compiler/chxvm/simple_node_emitter.cc @@ -321,10 +321,10 @@ void EmitSimpleNode(const Node& node, const ValueIdManager& id_manager, ChxVMPro CHECK_EQ(3UL, node.outputs().size()); CHECK(node.output(1)->IsNull()); } - EMIT(MaxPool, out(0), oout(2), in(0), node.kernel_shape(), strides(), pads(), node.chainer_cover_all(), auto_pad()); + EMIT(MaxPool, out(0), oout(2), in(0), node.kernel_shape(), strides(), pads(), node.ceil_mode(), auto_pad()); } else if (node.op_type() == Node::kChainerMaxPoolGrad) { CHECK_EQ("NOTSET", node.auto_pad()) << "auto_pad is not supported for MaxPool"; - EMIT(MaxPoolGrad, out(0), in(0), in(1), node.kernel_shape(), node.chainer_cover_all()); + EMIT(MaxPoolGrad, out(0), in(0), in(1), node.kernel_shape(), node.ceil_mode()); } else if (node.op_type() == Node::kChainerROIMaxPool2D) { EMIT(ROIMaxPool2D, out(0), in(0), in(1), in(2), node.output_shape(), node.spatial_scale()); } else if (node.op_type() == Node::kChainerROIAveragePool2D) { @@ -348,9 +348,11 @@ void EmitSimpleNode(const Node& node, const ValueIdManager& id_manager, ChxVMPro } else if (node.op_type() == Node::kAveragePool) { CHECK_EQ("NOTSET", node.auto_pad()) << "auto_pad is not supported for AveragePool"; CHECK_EQ(1UL, node.inputs().size()); + CHECK_EQ(0, node.ceil_mode()) << "ceil_mode for AveragePool is not supported yet"; EMIT(AveragePool, out(0), oout(1), in(0), node.kernel_shape(), strides(), pads(), node.count_include_pad()); } else if (node.op_type() == Node::kChainerAveragePoolGrad) { CHECK_EQ("NOTSET", node.auto_pad()) << "auto_pad is not supported for AveragePool"; + CHECK_EQ(0, node.ceil_mode()) << "ceil_mode for AveragePool is not supported yet"; EMIT(AveragePoolGrad, out(0), in(0), in(1), node.kernel_shape(), node.count_include_pad()); } else if (node.op_type() == Node::kChainerPadBatchSize) { EMIT(PadBatchSize, out(0), in(0), node.size()); diff --git a/compiler/custom_onnx_ops.cc b/compiler/custom_onnx_ops.cc index 49a85059..92071afe 100644 --- a/compiler/custom_onnx_ops.cc +++ b/compiler/custom_onnx_ops.cc @@ -92,159 +92,6 @@ ONNX_CHAINER_OPERATOR_SET_SCHEMA( namespace { -// From onnx/onnx/defs/nn/defs.cc -void convPoolTypeAndShapeInference(InferenceContext& ctx, bool use_dilation, bool require_kernel_shape) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (ctx.getNumOutputs() > 1) { - // MaxPool with two outputs case. - auto output_type = ctx.getOutputType(1); - if (output_type->value_case() == TypeProto::kTensorType || output_type->value_case() == TypeProto::VALUE_NOT_SET) { - output_type->mutable_tensor_type()->set_elem_type(TensorProto::INT64); - } - } - - // we need the first input shape for this inference. - if (!hasNInputShapes(ctx, 1)) { - return; - } - - // if kernel shape is an input (and not attribute) - // we need the shape of the second input. - if (!require_kernel_shape && !hasNInputShapes(ctx, 2)) { - return; - } - - // don't bother with legacy auto_pad for now - if (ctx.getAttribute("auto_pad")) { - return; - } - - auto input_shape = ctx.getInputType(0)->tensor_type().shape(); - if (input_shape.dim_size() < 2) { - fail_shape_inference("Input tensor must have atleast 2 dimensions"); - } - - // first dim is the batch axis and the next is the number of channels. - size_t n_input_dims = static_cast(input_shape.dim_size() - 2); - - // Pooling operations don't support dilation, only Conv. For - // simplicity of the code, we just treat them as having all-1s - // dilation. - std::vector dilations; - if (use_dilation && getRepeatedAttribute(ctx, "dilations", dilations)) { - if (dilations.size() != n_input_dims) { - fail_shape_inference("Attribute dilations has incorrect size"); - } - } else { - dilations.assign(n_input_dims, 1); - } - - std::vector pads; - if (getRepeatedAttribute(ctx, "pads", pads)) { - if (pads.size() != n_input_dims * 2) { - fail_shape_inference("Attribute pads has incorrect size"); - } - } else { - pads.assign(n_input_dims * 2, 0); - } - - std::vector strides; - if (getRepeatedAttribute(ctx, "strides", strides)) { - if (strides.size() != n_input_dims) { - fail_shape_inference("Attribute strides has incorrect size"); - } - } else { - strides.assign(n_input_dims, 1); - } - - std::vector kernel_shape; - if (getRepeatedAttribute(ctx, "kernel_shape", kernel_shape)) { - if (kernel_shape.size() != n_input_dims) { - fail_shape_inference("Attribute kernel_shape has incorrect size"); - } - } else if (require_kernel_shape) { - fail_shape_inference("Attribute kernel_shape must be specified"); - } else { - auto second_input_shape = ctx.getInputType(1)->tensor_type().shape(); - for (int i = 2; i < second_input_shape.dim_size(); ++i) { - if (!second_input_shape.dim(i).has_dim_value()) { - return; - } - kernel_shape.push_back(second_input_shape.dim(i).dim_value()); - } - } - - auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); - - if (require_kernel_shape) { - // add the first two dimensions from the input. - *output_shape->add_dim() = input_shape.dim(0); - *output_shape->add_dim() = input_shape.dim(1); - } else { - *output_shape->add_dim() = input_shape.dim(0); - auto& second_input_shape = getInputShape(ctx, 1); - if (second_input_shape.dim_size() < 1) { - fail_shape_inference("Second input tensor has wrong dimension"); - } - *output_shape->add_dim() = second_input_shape.dim(0); - } - - // EDIT(hamaji): Check if `chainer_cover_all` is set. - const bool cover_all = getAttribute(ctx, "chainer_cover_all", 0); - - int kernel_shape_size = static_cast(kernel_shape.size()); - for (int i = 0; i < kernel_shape_size; ++i) { - auto newdim = output_shape->add_dim(); - if (!input_shape.dim(2 + i).has_dim_value()) { - continue; - } - // how big is the input, including padding - int64_t effective_input_size = input_shape.dim(2 + i).dim_value(); - effective_input_size += pads[i]; - effective_input_size += pads[i + kernel_shape_size]; - - int64_t effective_kernel_size = kernel_shape[i]; - // accounting for dilation, how big is the kernel in this dimension - effective_kernel_size = (effective_kernel_size - 1) * dilations[i] + 1; - - // how many times we can move the kernel from it's initial position, based - // on the stride - int64_t strided_kernel_positions = (effective_input_size - effective_kernel_size) / strides[i]; - // EDIT(hamaji): Adjustment for `chainer_cover_all`. - if (cover_all && (effective_input_size - effective_kernel_size) % strides[i]) { - ++strided_kernel_positions; - } - - // add in the initial position - newdim->set_dim_value(1 + strided_kernel_positions); - } - - if (ctx.getNumOutputs() > 1) { - // MaxPool with two outputs case. - auto second_output_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape(); - second_output_shape->CopyFrom(*output_shape); - } -} - -} // namespace - -ONNX_WORKAROUND_OPERATOR_SET_SCHEMA( - MaxPool, - 11, - OpSchema() - .SetDoc("TBD") - .Input(0, "X", "Input tensor", "T") - .Output(0, "Y", "Output tensor", "T") - .Output(1, "Indices", "Indices tensor", "I", OpSchema::Optional) - .TypeConstraint( - "T", - {"tensor(float)", "tensor(float16)", "tensor(double)"}, - "Constrain input and output types to signed numeric tensors.") - .TypeConstraint("I", {"tensor(int64)"}, "Constrain index tensor to int64") - .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { convPoolTypeAndShapeInference(ctx, false, true); })); - -namespace { - void InferROI(InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); @@ -612,7 +459,6 @@ class Custom_OpSet_Onnx_ver9 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); - fn(GetOpSchema()); fn(GetOpSchema()); } }; diff --git a/compiler/fusion_ngraph.cc b/compiler/fusion_ngraph.cc index fa734803..62d7b87b 100644 --- a/compiler/fusion_ngraph.cc +++ b/compiler/fusion_ngraph.cc @@ -127,7 +127,7 @@ void FuseNGraphOperations(Graph* graph) { return false; } } else if (node.op_type() == Node::kMaxPool) { - if (node.chainer_cover_all()) { + if (node.ceil_mode()) { return false; } } else if ( diff --git a/compiler/gen_node.py b/compiler/gen_node.py index d1f8950a..1db0687a 100644 --- a/compiler/gen_node.py +++ b/compiler/gen_node.py @@ -197,9 +197,10 @@ def __init__(self, op_type, num_inputs, num_outputs, domain='', **kwargs): kernel_shape=Required([int]), pads=[int], storage_order=0, - strides=[int]) + strides=[int], + ceil_mode=0) # Extension: the third output is for backward context. -NodeDef('MaxPool', 1, (1, 2, 3), chainer_cover_all=False, **pool_attrs) +NodeDef('MaxPool', 1, (1, 2, 3), **pool_attrs) # Extension: the second output is for backward context. NodeDef('AveragePool', 1, (1, 2), count_include_pad=False, **pool_attrs) NodeDef('GlobalMaxPool', 1, 1) @@ -266,7 +267,7 @@ def __init__(self, op_type, num_inputs, num_outputs, domain='', **kwargs): # For experimental ops. NodeDef('ChainerDoSomething', None, None, function_name=Required(str)) -NodeDef('ChainerMaxPoolGrad', 2, 1, chainer_cover_all=False, **pool_attrs) +NodeDef('ChainerMaxPoolGrad', 2, 1, **pool_attrs) NodeDef('ChainerAveragePoolGrad', 2, 1, count_include_pad=False, **pool_attrs) NodeDef('ChainerResizeGrad', 2, 1) NodeDef('ChainerBatchNormalizationGrad', 2, 3) diff --git a/compiler/gradient_ops.cc b/compiler/gradient_ops.cc index 9c382379..b54e6be8 100644 --- a/compiler/gradient_ops.cc +++ b/compiler/gradient_ops.cc @@ -509,7 +509,7 @@ void MaxPoolGradFn(GradientOpContext* gc) { ->set_pads(node->pads()) ->set_storage_order(node->storage_order()) ->set_strides(node->strides()) - ->set_chainer_cover_all(node->chainer_cover_all()); + ->set_ceil_mode(node->ceil_mode()); } void AveragePoolGradFn(GradientOpContext* gc) { diff --git a/compiler/simplifier.cc b/compiler/simplifier.cc index 03359b25..d689b1c1 100644 --- a/compiler/simplifier.cc +++ b/compiler/simplifier.cc @@ -351,7 +351,7 @@ bool ReplaceMaxPool(Graph* graph, Node* node) { Value* padded = PadForPool(&gb, node, -std::numeric_limits::infinity()); gb.Op(Node::kMaxPool, {padded}, node->output(0)) ->producer() - ->set_chainer_cover_all(node->chainer_cover_all()) + ->set_ceil_mode(node->ceil_mode()) ->set_auto_pad(node->auto_pad()) ->set_kernel_shape(node->kernel_shape()) ->set_storage_order(node->storage_order()) diff --git a/scripts/gen_extra_test.py b/scripts/gen_extra_test.py index 1fa68349..976b2de5 100644 --- a/scripts/gen_extra_test.py +++ b/scripts/gen_extra_test.py @@ -875,7 +875,6 @@ def gen_incomplete_transpose_test(test_name): def gen_maxpool_cover_all_test(test_name): - # A custom attribute for Chainer/ChainerX's `cover_all` parameter. gb = onnx_script.GraphBuilder(test_name) input = np.random.random((1, 3, 7, 7)) @@ -889,14 +888,14 @@ def gen_maxpool_cover_all_test(test_name): outputs=['not_cover_all']), F.max_pooling_2d(input, ksize=3, stride=2, cover_all=False)) gb.output(gb.MaxPool([input_v], kernel_shape=[3, 3], strides=[2, 2], - chainer_cover_all=True, + ceil_mode=1, outputs=['cover_all']), F.max_pooling_2d(input, ksize=3, stride=2, cover_all=True)) gb.output(gb.MaxPool([dynamic_v], kernel_shape=[3, 3], strides=[2, 2], outputs=['not_cover_all_dynamic']), F.max_pooling_2d(input, ksize=3, stride=2, cover_all=False)) gb.output(gb.MaxPool([dynamic_v], kernel_shape=[3, 3], strides=[2, 2], - chainer_cover_all=True, + ceil_mode=1, outputs=['cover_all_dynamic']), F.max_pooling_2d(input, ksize=3, stride=2, cover_all=True)) diff --git a/scripts/runtests.py b/scripts/runtests.py index a52d9721..c6aa321f 100755 --- a/scripts/runtests.py +++ b/scripts/runtests.py @@ -216,9 +216,11 @@ TestCase(NODE_TEST, 'test_constant_pad'), # TODO(hamaji): auto_pad is not supported. TestCase(NODE_TEST, 'test_maxpool_1d_default', fail=fail_1d_conv_pool), + TestCase(NODE_TEST, 'test_maxpool_2d_ceil'), TestCase(NODE_TEST, 'test_maxpool_2d_default'), TestCase(NODE_TEST, 'test_maxpool_2d_pads'), TestCase(NODE_TEST, 'test_maxpool_2d_precomputed_pads'), + TestCase(NODE_TEST, 'test_maxpool_2d_precomputed_same_upper'), TestCase(NODE_TEST, 'test_maxpool_2d_precomputed_strides'), TestCase(NODE_TEST, 'test_maxpool_2d_strides'), TestCase(NODE_TEST, 'test_maxpool_3d_default'),