ONNX Export Maxpool Indices
authorLara Haidar <haidar.lara@gmail.com>
Wed, 20 Feb 2019 05:06:43 +0000 (21:06 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 20 Feb 2019 05:10:14 +0000 (21:10 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16455

Differential Revision: D14140375

Pulled By: houseroad

fbshipit-source-id: 12d02c447e7fe0fae49969d1daf40a87660ed416

test/onnx/expect/TestOperators.test_maxpool_indices.expect [new file with mode: 0644]
test/onnx/test_operators.py
torch/csrc/jit/passes/onnx/peephole.cpp
torch/onnx/symbolic.py

diff --git a/test/onnx/expect/TestOperators.test_maxpool_indices.expect b/test/onnx/expect/TestOperators.test_maxpool_indices.expect
new file mode 100644 (file)
index 0000000..0e93060
--- /dev/null
@@ -0,0 +1,130 @@
+ir_version: 4
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    input: "0"
+    output: "1"
+    output: "2"
+    op_type: "MaxPool"
+    attribute {
+      name: "kernel_shape"
+      ints: 3
+      type: INTS
+    }
+    attribute {
+      name: "pads"
+      ints: 0
+      ints: 0
+      type: INTS
+    }
+    attribute {
+      name: "strides"
+      ints: 2
+      type: INTS
+    }
+  }
+  node {
+    input: "0"
+    output: "3"
+    output: "4"
+    op_type: "MaxPool"
+    attribute {
+      name: "kernel_shape"
+      ints: 1
+      type: INTS
+    }
+    attribute {
+      name: "strides"
+      ints: 1
+      type: INTS
+    }
+  }
+  node {
+    input: "4"
+    output: "5"
+    op_type: "Slice"
+    attribute {
+      name: "axes"
+      ints: 2
+      type: INTS
+    }
+    attribute {
+      name: "ends"
+      ints: 1
+      type: INTS
+    }
+    attribute {
+      name: "starts"
+      ints: 0
+      type: INTS
+    }
+  }
+  node {
+    input: "2"
+    input: "5"
+    output: "6"
+    op_type: "Sub"
+  }
+  name: "torch-jit-export"
+  input {
+    name: "0"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 20
+          }
+          dim {
+            dim_value: 16
+          }
+          dim {
+            dim_value: 50
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "1"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 20
+          }
+          dim {
+            dim_value: 16
+          }
+          dim {
+            dim_value: 24
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "6"
+    type {
+      tensor_type {
+        elem_type: 7
+        shape {
+          dim {
+            dim_value: 20
+          }
+          dim {
+            dim_value: 16
+          }
+          dim {
+            dim_value: 24
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
index 4e1b6f2..90ac298 100644 (file)
@@ -260,6 +260,10 @@ class TestOperators(TestCase):
         x = torch.randn(20, 16, 50)
         self.assertONNX(nn.MaxPool1d(3, stride=2), x)
 
+    def test_maxpool_indices(self):
+        x = torch.randn(20, 16, 50)
+        self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x)
+
     def test_at_op(self):
         x = torch.randn(3, 4)
 
index d47766c..83f062a 100644 (file)
@@ -587,6 +587,20 @@ static void fuseSplitListUnpack(Block *b) {
   }
 }
 
+void removeMaxPoolUnusedOutput(Block* b) {
+    for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
+    auto n = *it;
+    for (auto* child_block : n->blocks()) {
+      removeMaxPoolUnusedOutput(child_block);
+    }
+    if (strcmp(n->kind().toQualString(), "onnx::MaxPool") == 0) {
+      if (n->outputs().size() == 2 && n->outputs().at(1)->uses().empty()) {
+        it->eraseOutput(1);
+      }
+    }
+  }
+}
+
 // This optimization does ONNX-specific peephole optimizations.
 //
 // At the moment, here are the optimizations it does:
@@ -620,6 +634,7 @@ void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph) {
   speculateOps(graph->block());
   eraseListConstruct(graph->block());
   fuseSplitListUnpack(graph->block());
+  removeMaxPoolUnusedOutput(graph->block());
 }
 
 } // namespace jit
index e594b1a..0c980c5 100644 (file)
@@ -591,11 +591,30 @@ def max_pool1d_with_indices(g, input, kernel_size, stride, padding, dilation, ce
         return _unimplemented("max_pool1d_with_indices", "dilation")
     if stride is None:
         stride = kernel_size
-    r = g.op("MaxPool", input,
-             kernel_shape_i=_single(kernel_size),
-             pads_i=_single(padding) * 2,
-             strides_i=_single(stride))
-    return r, None
+    r, indices = g.op("MaxPool", input, outputs=2,
+                      kernel_shape_i=_single(kernel_size),
+                      pads_i=_single(padding) * 2,
+                      strides_i=_single(stride))
+    # easy but hacky way to get flattened indices values
+    # to be used to convert the indices values to non-flattened.
+    # In ONNX the indices are computed as a flatten 1-D tensor,
+    # so the values in indices are in [0, N x C x D1 x ... x Dn).
+    # To convert the indices to the same format used by Pytorch,
+    # we first execute a maxpool with a kernel and stride of 1 on the same input.
+    # This will result in a tensor of indices in which each index will have it's own value.
+    # Using this tensor as a reference, we extract the first index of each axis and substract
+    # it from each index of this axis in the indices to convert.
+    # This step will result in a tensor were each dimension has values of indices within
+    # the dimension it is in.
+    # For more information :
+    # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
+    _, flattened_indices = g.op("MaxPool", input, outputs=2,
+                                kernel_shape_i=[1],
+                                strides_i=[1])
+    # convert indices to have non-flattened indices values
+    s = g.op("Slice", flattened_indices, axes_i=[2], starts_i=[0], ends_i=[1])
+    indices = sub(g, indices, s)
+    return r, indices
 
 
 @parse_args('v', 'is', 'is', 'is', 'is', 'i')
@@ -606,11 +625,20 @@ def max_pool2d_with_indices(g, input, kernel_size, stride, padding, dilation, ce
         return _unimplemented("max_pool2d_with_indices", "dilation")
     if not stride:
         stride = kernel_size
-    r = g.op("MaxPool", input,
-             kernel_shape_i=_pair(kernel_size),
-             pads_i=_pair(padding) * 2,
-             strides_i=_pair(stride))
-    return r, None
+    r, indices = g.op("MaxPool", input, outputs=2,
+                      kernel_shape_i=_pair(kernel_size),
+                      pads_i=_pair(padding) * 2,
+                      strides_i=_pair(stride))
+    # easy but hacky way to get flattened indices values
+    # to be used to convert the indices values to non-flattened
+    # See comment in max_pool1d_with_indices for details.
+    _, flattened_indices = g.op("MaxPool", input, outputs=2,
+                                kernel_shape_i=[1, 1],
+                                strides_i=[1, 1])
+    # convert indices to have non-flattened indices values
+    s = g.op("Slice", flattened_indices, axes_i=[2, 3], starts_i=[0, 0], ends_i=[1, 1])
+    indices = sub(g, indices, s)
+    return r, indices
 
 
 @parse_args('v', 'is', 'is', 'is', 'is', 'i')
@@ -621,11 +649,20 @@ def max_pool3d_with_indices(g, input, kernel_size, stride, padding, dilation, ce
         return _unimplemented("max_pool3d_with_indices", "dilation")
     if not stride:
         stride = kernel_size
-    r = g.op("MaxPool", input,
-             kernel_shape_i=_triple(kernel_size),
-             pads_i=_triple(padding) * 2,
-             strides_i=_triple(stride))
-    return r, None
+    r, indices = g.op("MaxPool", input, outputs=2,
+                      kernel_shape_i=_triple(kernel_size),
+                      pads_i=_triple(padding) * 2,
+                      strides_i=_triple(stride))
+    # easy but hacky way to get flattened indices values
+    # to be used to convert the indices values to non-flattened
+    # See comment in max_pool1d_with_indices for details.
+    _, flattened_indices = g.op("MaxPool", input, outputs=2,
+                                kernel_shape_i=[1, 1, 1],
+                                strides_i=[1, 1, 1])
+    # convert indices to have non-flattened indices values
+    s = g.op("Slice", flattened_indices, axes_i=[2, 3, 4], starts_i=[0, 0, 0], ends_i=[1, 1, 1])
+    indices = sub(g, indices, s)
+    return r, indices
 
 
 def _avg_pool(name, tuple_fn):