[PYTORCH] Support max_pool2d_with_indices (#5549)
authorTrevor Morris <trevmorr@amazon.com>
Wed, 13 May 2020 06:21:16 +0000 (23:21 -0700)
committerGitHub <noreply@github.com>
Wed, 13 May 2020 06:21:16 +0000 (15:21 +0900)
* Use real output name instead of node_name

* Add pytorch max_pool2d_with_indices converter.

* Add test for maxpool2d with indices

* Add explicit assert for single output

* Only consume output (not indices) from max pool 2d with indices

* undo change

python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index d95a912..c7eccf7 100644 (file)
@@ -628,6 +628,12 @@ def _maxpool_2d():
         return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode)
     return _impl
 
+def _maxpool_2d_with_indices():
+    def _impl(inputs, input_types):
+        # returns dummy indices too
+        return _maxpool_2d()(inputs, input_types), None
+    return _impl
+
 def _maxpool_1d():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1654,7 +1660,7 @@ def _get_convert_map(prelude):
         "aten::adaptive_avg_pool2d"             : _adaptive_avg_pool_2d(),
         "aten::adaptive_max_pool2d"             : _adaptive_max_pool_2d(),
         "aten::max_pool2d"                      : _maxpool_2d(),
-        "aten::max_pool2d_with_indices"         : _maxpool_2d(),
+        "aten::max_pool2d_with_indices"         : _maxpool_2d_with_indices(),
         "aten::max_pool1d"                      : _maxpool_1d(),
         "aten::max_pool3d"                      : _maxpool_3d(),
         "aten::hardtanh"                        : _hardtanh(),
@@ -2252,6 +2258,7 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude):
                 out_names = _get_output_names(op_node)
                 outputs.update(zip(out_names, relay_out))
             else:
+                assert op_node.outputsSize() == 1
                 outputs[node_name] = relay_out
 
     return [_wrap_const(outputs[ret_name])
index 82a027f..3d9d22b 100644 (file)
@@ -534,6 +534,17 @@ def test_forward_maxpool2d():
                                     stride=2).eval(),
                  input_data)
 
+    class MaxPool2DWithIndices(Module):
+        def __init__(self):
+            super(MaxPool2DWithIndices, self).__init__()
+            self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1], return_indices=True)
+
+        def forward(self, *args):
+            output, indices = self.pool(args[0])
+            return output
+
+    verify_model(MaxPool2DWithIndices().float().eval(), input_data=input_data)
+
 def test_forward_maxpool1d():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10]