From b8d1f4a423204c9ff2f714151e2d740fb69181fa Mon Sep 17 00:00:00 2001 From: Lara Haidar Date: Tue, 19 Feb 2019 21:06:43 -0800 Subject: [PATCH] ONNX Export Maxpool Indices Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16455 Differential Revision: D14140375 Pulled By: houseroad fbshipit-source-id: 12d02c447e7fe0fae49969d1daf40a87660ed416 --- .../TestOperators.test_maxpool_indices.expect | 130 +++++++++++++++++++++ test/onnx/test_operators.py | 4 + torch/csrc/jit/passes/onnx/peephole.cpp | 15 +++ torch/onnx/symbolic.py | 67 ++++++++--- 4 files changed, 201 insertions(+), 15 deletions(-) create mode 100644 test/onnx/expect/TestOperators.test_maxpool_indices.expect diff --git a/test/onnx/expect/TestOperators.test_maxpool_indices.expect b/test/onnx/expect/TestOperators.test_maxpool_indices.expect new file mode 100644 index 0000000..0e93060 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_maxpool_indices.expect @@ -0,0 +1,130 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "0.4" +graph { + node { + input: "0" + output: "1" + output: "2" + op_type: "MaxPool" + attribute { + name: "kernel_shape" + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 0 + ints: 0 + type: INTS + } + attribute { + name: "strides" + ints: 2 + type: INTS + } + } + node { + input: "0" + output: "3" + output: "4" + op_type: "MaxPool" + attribute { + name: "kernel_shape" + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + type: INTS + } + } + node { + input: "4" + output: "5" + op_type: "Slice" + attribute { + name: "axes" + ints: 2 + type: INTS + } + attribute { + name: "ends" + ints: 1 + type: INTS + } + attribute { + name: "starts" + ints: 0 + type: INTS + } + } + node { + input: "2" + input: "5" + output: "6" + op_type: "Sub" + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 20 + } + dim { + dim_value: 16 + } + dim { + dim_value: 50 + } + } + } + } + } + output { + name: "1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 20 + } + dim { + dim_value: 16 + } + dim { + dim_value: 24 + } + } + } + } + } + output { + name: "6" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 20 + } + dim { + dim_value: 16 + } + dim { + dim_value: 24 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 4e1b6f2..90ac298 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -260,6 +260,10 @@ class TestOperators(TestCase): x = torch.randn(20, 16, 50) self.assertONNX(nn.MaxPool1d(3, stride=2), x) + def test_maxpool_indices(self): + x = torch.randn(20, 16, 50) + self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x) + def test_at_op(self): x = torch.randn(3, 4) diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index d47766c..83f062a 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -587,6 +587,20 @@ static void fuseSplitListUnpack(Block *b) { } } +void removeMaxPoolUnusedOutput(Block* b) { + for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { + auto n = *it; + for (auto* child_block : n->blocks()) { + removeMaxPoolUnusedOutput(child_block); + } + if (strcmp(n->kind().toQualString(), "onnx::MaxPool") == 0) { + if (n->outputs().size() == 2 && n->outputs().at(1)->uses().empty()) { + it->eraseOutput(1); + } + } + } +} + // This optimization does ONNX-specific peephole optimizations. // // At the moment, here are the optimizations it does: @@ -620,6 +634,7 @@ void PeepholeOptimizeONNX(std::shared_ptr& graph) { speculateOps(graph->block()); eraseListConstruct(graph->block()); fuseSplitListUnpack(graph->block()); + removeMaxPoolUnusedOutput(graph->block()); } } // namespace jit diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index e594b1a..0c980c5 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -591,11 +591,30 @@ def max_pool1d_with_indices(g, input, kernel_size, stride, padding, dilation, ce return _unimplemented("max_pool1d_with_indices", "dilation") if stride is None: stride = kernel_size - r = g.op("MaxPool", input, - kernel_shape_i=_single(kernel_size), - pads_i=_single(padding) * 2, - strides_i=_single(stride)) - return r, None + r, indices = g.op("MaxPool", input, outputs=2, + kernel_shape_i=_single(kernel_size), + pads_i=_single(padding) * 2, + strides_i=_single(stride)) + # easy but hacky way to get flattened indices values + # to be used to convert the indices values to non-flattened. + # In ONNX the indices are computed as a flatten 1-D tensor, + # so the values in indices are in [0, N x C x D1 x ... x Dn). + # To convert the indices to the same format used by Pytorch, + # we first execute a maxpool with a kernel and stride of 1 on the same input. + # This will result in a tensor of indices in which each index will have it's own value. + # Using this tensor as a reference, we extract the first index of each axis and substract + # it from each index of this axis in the indices to convert. + # This step will result in a tensor were each dimension has values of indices within + # the dimension it is in. + # For more information : + # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 + _, flattened_indices = g.op("MaxPool", input, outputs=2, + kernel_shape_i=[1], + strides_i=[1]) + # convert indices to have non-flattened indices values + s = g.op("Slice", flattened_indices, axes_i=[2], starts_i=[0], ends_i=[1]) + indices = sub(g, indices, s) + return r, indices @parse_args('v', 'is', 'is', 'is', 'is', 'i') @@ -606,11 +625,20 @@ def max_pool2d_with_indices(g, input, kernel_size, stride, padding, dilation, ce return _unimplemented("max_pool2d_with_indices", "dilation") if not stride: stride = kernel_size - r = g.op("MaxPool", input, - kernel_shape_i=_pair(kernel_size), - pads_i=_pair(padding) * 2, - strides_i=_pair(stride)) - return r, None + r, indices = g.op("MaxPool", input, outputs=2, + kernel_shape_i=_pair(kernel_size), + pads_i=_pair(padding) * 2, + strides_i=_pair(stride)) + # easy but hacky way to get flattened indices values + # to be used to convert the indices values to non-flattened + # See comment in max_pool1d_with_indices for details. + _, flattened_indices = g.op("MaxPool", input, outputs=2, + kernel_shape_i=[1, 1], + strides_i=[1, 1]) + # convert indices to have non-flattened indices values + s = g.op("Slice", flattened_indices, axes_i=[2, 3], starts_i=[0, 0], ends_i=[1, 1]) + indices = sub(g, indices, s) + return r, indices @parse_args('v', 'is', 'is', 'is', 'is', 'i') @@ -621,11 +649,20 @@ def max_pool3d_with_indices(g, input, kernel_size, stride, padding, dilation, ce return _unimplemented("max_pool3d_with_indices", "dilation") if not stride: stride = kernel_size - r = g.op("MaxPool", input, - kernel_shape_i=_triple(kernel_size), - pads_i=_triple(padding) * 2, - strides_i=_triple(stride)) - return r, None + r, indices = g.op("MaxPool", input, outputs=2, + kernel_shape_i=_triple(kernel_size), + pads_i=_triple(padding) * 2, + strides_i=_triple(stride)) + # easy but hacky way to get flattened indices values + # to be used to convert the indices values to non-flattened + # See comment in max_pool1d_with_indices for details. + _, flattened_indices = g.op("MaxPool", input, outputs=2, + kernel_shape_i=[1, 1, 1], + strides_i=[1, 1, 1]) + # convert indices to have non-flattened indices values + s = g.op("Slice", flattened_indices, axes_i=[2, 3, 4], starts_i=[0, 0, 0], ends_i=[1, 1, 1]) + indices = sub(g, indices, s) + return r, indices def _avg_pool(name, tuple_fn): -- 2.7.4