return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode)
return _impl
+def _maxpool_2d_with_indices():
+ def _impl(inputs, input_types):
+ # returns dummy indices too
+ return _maxpool_2d()(inputs, input_types), None
+ return _impl
+
def _maxpool_1d():
def _impl(inputs, input_types):
data = inputs[0]
"aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(),
"aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(),
"aten::max_pool2d" : _maxpool_2d(),
- "aten::max_pool2d_with_indices" : _maxpool_2d(),
+ "aten::max_pool2d_with_indices" : _maxpool_2d_with_indices(),
"aten::max_pool1d" : _maxpool_1d(),
"aten::max_pool3d" : _maxpool_3d(),
"aten::hardtanh" : _hardtanh(),
out_names = _get_output_names(op_node)
outputs.update(zip(out_names, relay_out))
else:
+ assert op_node.outputsSize() == 1
outputs[node_name] = relay_out
return [_wrap_const(outputs[ret_name])
stride=2).eval(),
input_data)
+ class MaxPool2DWithIndices(Module):
+ def __init__(self):
+ super(MaxPool2DWithIndices, self).__init__()
+ self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1], return_indices=True)
+
+ def forward(self, *args):
+ output, indices = self.pool(args[0])
+ return output
+
+ verify_model(MaxPool2DWithIndices().float().eval(), input_data=input_data)
+
def test_forward_maxpool1d():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10]