From 8a9961123a36ddd5aabfb3b37d0fdcd0ab763199 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 12 May 2020 23:21:16 -0700 Subject: [PATCH] [PYTORCH] Support max_pool2d_with_indices (#5549) * Use real output name instead of node_name * Add pytorch max_pool2d_with_indices converter. * Add test for maxpool2d with indices * Add explicit assert for single output * Only consume output (not indices) from max pool 2d with indices * undo change --- python/tvm/relay/frontend/pytorch.py | 9 ++++++++- tests/python/frontend/pytorch/test_forward.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index d95a912..c7eccf7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -628,6 +628,12 @@ def _maxpool_2d(): return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode) return _impl +def _maxpool_2d_with_indices(): + def _impl(inputs, input_types): + # returns dummy indices too + return _maxpool_2d()(inputs, input_types), None + return _impl + def _maxpool_1d(): def _impl(inputs, input_types): data = inputs[0] @@ -1654,7 +1660,7 @@ def _get_convert_map(prelude): "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(), "aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(), "aten::max_pool2d" : _maxpool_2d(), - "aten::max_pool2d_with_indices" : _maxpool_2d(), + "aten::max_pool2d_with_indices" : _maxpool_2d_with_indices(), "aten::max_pool1d" : _maxpool_1d(), "aten::max_pool3d" : _maxpool_3d(), "aten::hardtanh" : _hardtanh(), @@ -2252,6 +2258,7 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude): out_names = _get_output_names(op_node) outputs.update(zip(out_names, relay_out)) else: + assert op_node.outputsSize() == 1 outputs[node_name] = relay_out return [_wrap_const(outputs[ret_name]) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 82a027f..3d9d22b 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -534,6 +534,17 @@ def test_forward_maxpool2d(): stride=2).eval(), input_data) + class MaxPool2DWithIndices(Module): + def __init__(self): + super(MaxPool2DWithIndices, self).__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1], return_indices=True) + + def forward(self, *args): + output, indices = self.pool(args[0]) + return output + + verify_model(MaxPool2DWithIndices().float().eval(), input_data=input_data) + def test_forward_maxpool1d(): torch.set_grad_enabled(False) input_shape = [1, 3, 10] -- 2.7.4