dispatch max_pools with no indices, expose max_pools to torch namespace (#19449)
authorWanchao Liang <wanchaol@users.noreply.github.com>
Tue, 23 Apr 2019 18:16:28 +0000 (11:16 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 23 Apr 2019 18:20:05 +0000 (11:20 -0700)
Summary:
in functional interfaces we do boolean dispatch, but all to max_pool\*d_with_indices. This change it to emit max_pool\*d op instead when it's not necessary to expose with_indices ops to different backends (for jit).

It also bind max_pool\*d to the torch namespace, which is the same behavior with avg_pool\*d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19449

Differential Revision: D15016839

Pulled By: wanchaol

fbshipit-source-id: f77cd5f0bcd6d8534c1296d89b061023a8288a2c

aten/src/ATen/native/native_functions.yaml
test/onnx/expect/TestOperators.test_maxpool.expect
tools/autograd/gen_python_functions.py
torch/nn/functional.py
torch/onnx/symbolic.py

index 7f016b4..7b4e911 100644 (file)
 
 - func: max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor
 
-- func: max_pool2d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor
+- func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
 
-- func: max_pool3d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor
+- func: max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor
 
 # FIXME: These could be combined as optional<ScalarType> but for https://github.com/pytorch/pytorch/issues/6593.
 - func: mean(Tensor self, *, ScalarType dtype) -> Tensor
index 1f359a6..5a04a8f 100644 (file)
@@ -3,7 +3,7 @@ producer_name: "pytorch"
 producer_version: "1.1"
 graph {
   node {
-    input: "input"
+    input: "0"
     output: "1"
     op_type: "MaxPool"
     attribute {
@@ -25,7 +25,7 @@ graph {
   }
   name: "torch-jit-export"
   input {
-    name: "input"
+    name: "0"
     type {
       tensor_type {
         elem_type: 1
index 821058e..9b4eb1e 100644 (file)
@@ -29,8 +29,7 @@ SKIP_PYTHON_BINDINGS = [
     'arange.*', 'range.*', '_solve.*', '_getri.*', '_inverse.*',
     '_cholesky.*', '_triangular_solve.*',
     'slice', 'randint(_out)?',
-    'item', '_local_scalar_dense',
-    'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
+    'item', '_local_scalar_dense', 'linear', 'to',
     'copy_sparse_to_sparse_',
 ]
 
index 7c4a8f3..db4fc16 100644 (file)
@@ -455,8 +455,10 @@ def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
 def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
                 ceil_mode=False, return_indices=False):
     # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor  # noqa
-    return max_pool1d_with_indices(
-        input, kernel_size, stride, padding, dilation, ceil_mode)[0]
+    if stride is None:
+        stride = torch.jit.annotate(List[int], [])
+    return torch.max_pool1d(
+        input, kernel_size, stride, padding, dilation, ceil_mode)
 
 max_pool1d = torch._jit_internal.boolean_dispatch(
     arg_name='return_indices',
@@ -486,8 +488,10 @@ def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation
 def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
                 ceil_mode=False, return_indices=False):
     # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor  # noqa
-    return max_pool2d_with_indices(
-        input, kernel_size, stride, padding, dilation, ceil_mode)[0]
+    if stride is None:
+        stride = torch.jit.annotate(List[int], [])
+    return torch.max_pool2d(
+        input, kernel_size, stride, padding, dilation, ceil_mode)
 
 max_pool2d = torch._jit_internal.boolean_dispatch(
     arg_name='return_indices',
@@ -518,8 +522,10 @@ def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
 def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
                 ceil_mode=False, return_indices=False):
     # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor  # noqa
-    return max_pool3d_with_indices(
-        input, kernel_size, stride, padding, dilation, ceil_mode)[0]
+    if stride is None:
+        stride = torch.jit.annotate(List[int], [])
+    return torch.max_pool3d(
+        input, kernel_size, stride, padding, dilation, ceil_mode)
 
 max_pool3d = torch._jit_internal.boolean_dispatch(
     arg_name='return_indices',
index 8f87391..da07135 100644 (file)
@@ -667,104 +667,63 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding):
     return padding_ceil
 
 
-@parse_args('v', 'is', 'is', 'is', 'is', 'i')
-def max_pool1d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
-    if ceil_mode and input.type().kind() != "CompleteTensorType":
-        return _unimplemented("max_pool1d_with_indices", "input size not accesible")
-    if set(_single(dilation)) != {1}:
-        return _unimplemented("max_pool1d_with_indices", "dilation")
-    if stride is None:
-        stride = kernel_size
-    padding = tuple(_single(padding))
-    if ceil_mode:
-        padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
-        padding = padding + tuple(numpy.add(padding_ceil, padding))
-    else:
-        padding = padding * 2
-    r, indices = g.op("MaxPool", input, outputs=2,
-                      kernel_shape_i=_single(kernel_size),
-                      pads_i=padding,
-                      strides_i=_single(stride))
-    # easy but hacky way to get flattened indices values
-    # to be used to convert the indices values to non-flattened.
-    # In ONNX the indices are computed as a flatten 1-D tensor,
-    # so the values in indices are in [0, N x C x D1 x ... x Dn).
-    # To convert the indices to the same format used by Pytorch,
-    # we first execute a maxpool with a kernel and stride of 1 on the same input.
-    # This will result in a tensor of indices in which each index will have it's own value.
-    # Using this tensor as a reference, we extract the first index of each axis and substract
-    # it from each index of this axis in the indices to convert.
-    # This step will result in a tensor were each dimension has values of indices within
-    # the dimension it is in.
-    # For more information :
-    # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
-    _, flattened_indices = g.op("MaxPool", input, outputs=2,
-                                kernel_shape_i=[1],
-                                strides_i=[1])
-    # convert indices to have non-flattened indices values
-    s = _slice_op(g, flattened_indices, axes=[2], starts=[0], ends=[1])
-    indices = sub(g, indices, s)
-    return r, indices
-
-
-@parse_args('v', 'is', 'is', 'is', 'is', 'i')
-def max_pool2d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
-    if ceil_mode and input.type().kind() != "CompleteTensorType":
-        return _unimplemented("max_pool2d_with_indices", "input size not accesible")
-    if set(_pair(dilation)) != {1}:
-        return _unimplemented("max_pool2d_with_indices", "dilation")
-    if not stride:
-        stride = kernel_size
-    padding = tuple(_pair(padding))
-    if ceil_mode:
-        padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
-        padding = padding + tuple(numpy.add(padding_ceil, padding))
-    else:
-        padding = padding * 2
-    r, indices = g.op("MaxPool", input, outputs=2,
-                      kernel_shape_i=_pair(kernel_size),
-                      pads_i=padding,
-                      strides_i=_pair(stride))
-    # easy but hacky way to get flattened indices values
-    # to be used to convert the indices values to non-flattened
-    # See comment in max_pool1d_with_indices for details.
-    _, flattened_indices = g.op("MaxPool", input, outputs=2,
-                                kernel_shape_i=[1, 1],
-                                strides_i=[1, 1])
-    # convert indices to have non-flattened indices values
-    s = _slice_op(g, flattened_indices, axes=[2, 3], starts=[0, 0], ends=[1, 1])
-    indices = sub(g, indices, s)
-    return r, indices
-
-
-@parse_args('v', 'is', 'is', 'is', 'is', 'i')
-def max_pool3d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
-    if ceil_mode and input.type().kind() != "CompleteTensorType":
-        return _unimplemented("max_pool3d_with_indices", "input size not accesible")
-    if set(_triple(dilation)) != {1}:
-        return _unimplemented("max_pool3d_with_indices", "dilation")
-    if not stride:
-        stride = kernel_size
-    padding = tuple(_triple(padding))
-    if ceil_mode:
-        padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
-        padding = padding + tuple(numpy.add(padding_ceil, padding))
-    else:
-        padding = padding * 2
-    r, indices = g.op("MaxPool", input, outputs=2,
-                      kernel_shape_i=_triple(kernel_size),
-                      pads_i=padding,
-                      strides_i=_triple(stride))
-    # easy but hacky way to get flattened indices values
-    # to be used to convert the indices values to non-flattened
-    # See comment in max_pool1d_with_indices for details.
-    _, flattened_indices = g.op("MaxPool", input, outputs=2,
-                                kernel_shape_i=[1, 1, 1],
-                                strides_i=[1, 1, 1])
-    # convert indices to have non-flattened indices values
-    s = _slice_op(g, flattened_indices, axes=[2, 3, 4], starts=[0, 0, 0], ends=[1, 1, 1])
-    indices = sub(g, indices, s)
-    return r, indices
+def _max_pool(name, tuple_fn, ndims, return_indices):
+    @parse_args('v', 'is', 'is', 'is', 'is', 'i')
+    def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
+        if ceil_mode and input.type().kind() != "CompleteTensorType":
+            return _unimplemented(name, "input size not accesible")
+        if set(tuple_fn(dilation)) != {1}:
+            return _unimplemented(name, "dilation")
+        if not stride:
+            stride = kernel_size
+        padding = tuple(tuple_fn(padding))
+        if ceil_mode:
+            padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
+            padding = padding + tuple(numpy.add(padding_ceil, padding))
+        else:
+            padding = padding * 2
+        # easy but hacky way to get flattened indices values
+        # to be used to convert the indices values to non-flattened.
+        # In ONNX the indices are computed as a flatten 1-D tensor,
+        # so the values in indices are in [0, N x C x D1 x ... x Dn).
+        # To convert the indices to the same format used by Pytorch,
+        # we first execute a maxpool with a kernel and stride of 1 on the same input.
+        # This will result in a tensor of indices in which each index will have it's own value.
+        # Using this tensor as a reference, we extract the first index of each axis and substract
+        # it from each index of this axis in the indices to convert.
+        # This step will result in a tensor were each dimension has values of indices within
+        # the dimension it is in.
+        # For more information :
+        # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
+        if return_indices:
+            r, indices = g.op("MaxPool", input, outputs=2,
+                              kernel_shape_i=tuple_fn(kernel_size),
+                              pads_i=padding,
+                              strides_i=tuple_fn(stride))
+            _, flattened_indices = g.op("MaxPool", input, outputs=2,
+                                        kernel_shape_i=[1 for _ in range(ndims)],
+                                        strides_i=[1 for _ in range(ndims)])
+            # convert indices to have non-flattened indices values
+            s = _slice_op(g, flattened_indices, axes=[2 + i for i in range(ndims)],
+                          starts=tuple_fn(0), ends=tuple_fn(1))
+            indices = sub(g, indices, s)
+            return r, indices
+        else:
+            r = g.op("MaxPool", input, outputs=1,
+                     kernel_shape_i=tuple_fn(kernel_size),
+                     pads_i=padding,
+                     strides_i=tuple_fn(stride))
+            return r
+
+    return symbolic_fn
+
+
+max_pool1d = _max_pool("max_pool1d", _single, 1, return_indices=False)
+max_pool2d = _max_pool("max_pool2d", _pair, 2, return_indices=False)
+max_pool3d = _max_pool("max_pool3d", _triple, 3, return_indices=False)
+max_pool1d_with_indices = _max_pool("max_pool1d_with_indices", _single, 1, return_indices=True)
+max_pool2d_with_indices = _max_pool("max_pool2d_with_indices", _pair, 2, return_indices=True)
+max_pool3d_with_indices = _max_pool("max_pool3d_with_indices", _triple, 3, return_indices=True)
 
 
 def _avg_pool(name, tuple_fn):