def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa
- return max_pool1d_with_indices(
- input, kernel_size, stride, padding, dilation, ceil_mode)[0]
+ if stride is None:
+ stride = torch.jit.annotate(List[int], [])
+ return torch.max_pool1d(
+ input, kernel_size, stride, padding, dilation, ceil_mode)
max_pool1d = torch._jit_internal.boolean_dispatch(
arg_name='return_indices',
def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa
- return max_pool2d_with_indices(
- input, kernel_size, stride, padding, dilation, ceil_mode)[0]
+ if stride is None:
+ stride = torch.jit.annotate(List[int], [])
+ return torch.max_pool2d(
+ input, kernel_size, stride, padding, dilation, ceil_mode)
max_pool2d = torch._jit_internal.boolean_dispatch(
arg_name='return_indices',
def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa
- return max_pool3d_with_indices(
- input, kernel_size, stride, padding, dilation, ceil_mode)[0]
+ if stride is None:
+ stride = torch.jit.annotate(List[int], [])
+ return torch.max_pool3d(
+ input, kernel_size, stride, padding, dilation, ceil_mode)
max_pool3d = torch._jit_internal.boolean_dispatch(
arg_name='return_indices',
return padding_ceil
-@parse_args('v', 'is', 'is', 'is', 'is', 'i')
-def max_pool1d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
- if ceil_mode and input.type().kind() != "CompleteTensorType":
- return _unimplemented("max_pool1d_with_indices", "input size not accesible")
- if set(_single(dilation)) != {1}:
- return _unimplemented("max_pool1d_with_indices", "dilation")
- if stride is None:
- stride = kernel_size
- padding = tuple(_single(padding))
- if ceil_mode:
- padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
- padding = padding + tuple(numpy.add(padding_ceil, padding))
- else:
- padding = padding * 2
- r, indices = g.op("MaxPool", input, outputs=2,
- kernel_shape_i=_single(kernel_size),
- pads_i=padding,
- strides_i=_single(stride))
- # easy but hacky way to get flattened indices values
- # to be used to convert the indices values to non-flattened.
- # In ONNX the indices are computed as a flatten 1-D tensor,
- # so the values in indices are in [0, N x C x D1 x ... x Dn).
- # To convert the indices to the same format used by Pytorch,
- # we first execute a maxpool with a kernel and stride of 1 on the same input.
- # This will result in a tensor of indices in which each index will have it's own value.
- # Using this tensor as a reference, we extract the first index of each axis and substract
- # it from each index of this axis in the indices to convert.
- # This step will result in a tensor were each dimension has values of indices within
- # the dimension it is in.
- # For more information :
- # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
- _, flattened_indices = g.op("MaxPool", input, outputs=2,
- kernel_shape_i=[1],
- strides_i=[1])
- # convert indices to have non-flattened indices values
- s = _slice_op(g, flattened_indices, axes=[2], starts=[0], ends=[1])
- indices = sub(g, indices, s)
- return r, indices
-
-
-@parse_args('v', 'is', 'is', 'is', 'is', 'i')
-def max_pool2d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
- if ceil_mode and input.type().kind() != "CompleteTensorType":
- return _unimplemented("max_pool2d_with_indices", "input size not accesible")
- if set(_pair(dilation)) != {1}:
- return _unimplemented("max_pool2d_with_indices", "dilation")
- if not stride:
- stride = kernel_size
- padding = tuple(_pair(padding))
- if ceil_mode:
- padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
- padding = padding + tuple(numpy.add(padding_ceil, padding))
- else:
- padding = padding * 2
- r, indices = g.op("MaxPool", input, outputs=2,
- kernel_shape_i=_pair(kernel_size),
- pads_i=padding,
- strides_i=_pair(stride))
- # easy but hacky way to get flattened indices values
- # to be used to convert the indices values to non-flattened
- # See comment in max_pool1d_with_indices for details.
- _, flattened_indices = g.op("MaxPool", input, outputs=2,
- kernel_shape_i=[1, 1],
- strides_i=[1, 1])
- # convert indices to have non-flattened indices values
- s = _slice_op(g, flattened_indices, axes=[2, 3], starts=[0, 0], ends=[1, 1])
- indices = sub(g, indices, s)
- return r, indices
-
-
-@parse_args('v', 'is', 'is', 'is', 'is', 'i')
-def max_pool3d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
- if ceil_mode and input.type().kind() != "CompleteTensorType":
- return _unimplemented("max_pool3d_with_indices", "input size not accesible")
- if set(_triple(dilation)) != {1}:
- return _unimplemented("max_pool3d_with_indices", "dilation")
- if not stride:
- stride = kernel_size
- padding = tuple(_triple(padding))
- if ceil_mode:
- padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
- padding = padding + tuple(numpy.add(padding_ceil, padding))
- else:
- padding = padding * 2
- r, indices = g.op("MaxPool", input, outputs=2,
- kernel_shape_i=_triple(kernel_size),
- pads_i=padding,
- strides_i=_triple(stride))
- # easy but hacky way to get flattened indices values
- # to be used to convert the indices values to non-flattened
- # See comment in max_pool1d_with_indices for details.
- _, flattened_indices = g.op("MaxPool", input, outputs=2,
- kernel_shape_i=[1, 1, 1],
- strides_i=[1, 1, 1])
- # convert indices to have non-flattened indices values
- s = _slice_op(g, flattened_indices, axes=[2, 3, 4], starts=[0, 0, 0], ends=[1, 1, 1])
- indices = sub(g, indices, s)
- return r, indices
+def _max_pool(name, tuple_fn, ndims, return_indices):
+ @parse_args('v', 'is', 'is', 'is', 'is', 'i')
+ def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
+ if ceil_mode and input.type().kind() != "CompleteTensorType":
+ return _unimplemented(name, "input size not accesible")
+ if set(tuple_fn(dilation)) != {1}:
+ return _unimplemented(name, "dilation")
+ if not stride:
+ stride = kernel_size
+ padding = tuple(tuple_fn(padding))
+ if ceil_mode:
+ padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
+ padding = padding + tuple(numpy.add(padding_ceil, padding))
+ else:
+ padding = padding * 2
+ # easy but hacky way to get flattened indices values
+ # to be used to convert the indices values to non-flattened.
+ # In ONNX the indices are computed as a flatten 1-D tensor,
+ # so the values in indices are in [0, N x C x D1 x ... x Dn).
+ # To convert the indices to the same format used by Pytorch,
+ # we first execute a maxpool with a kernel and stride of 1 on the same input.
+ # This will result in a tensor of indices in which each index will have it's own value.
+ # Using this tensor as a reference, we extract the first index of each axis and substract
+ # it from each index of this axis in the indices to convert.
+ # This step will result in a tensor were each dimension has values of indices within
+ # the dimension it is in.
+ # For more information :
+ # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
+ if return_indices:
+ r, indices = g.op("MaxPool", input, outputs=2,
+ kernel_shape_i=tuple_fn(kernel_size),
+ pads_i=padding,
+ strides_i=tuple_fn(stride))
+ _, flattened_indices = g.op("MaxPool", input, outputs=2,
+ kernel_shape_i=[1 for _ in range(ndims)],
+ strides_i=[1 for _ in range(ndims)])
+ # convert indices to have non-flattened indices values
+ s = _slice_op(g, flattened_indices, axes=[2 + i for i in range(ndims)],
+ starts=tuple_fn(0), ends=tuple_fn(1))
+ indices = sub(g, indices, s)
+ return r, indices
+ else:
+ r = g.op("MaxPool", input, outputs=1,
+ kernel_shape_i=tuple_fn(kernel_size),
+ pads_i=padding,
+ strides_i=tuple_fn(stride))
+ return r
+
+ return symbolic_fn
+
+
+max_pool1d = _max_pool("max_pool1d", _single, 1, return_indices=False)
+max_pool2d = _max_pool("max_pool2d", _pair, 2, return_indices=False)
+max_pool3d = _max_pool("max_pool3d", _triple, 3, return_indices=False)
+max_pool1d_with_indices = _max_pool("max_pool1d_with_indices", _single, 1, return_indices=True)
+max_pool2d_with_indices = _max_pool("max_pool2d_with_indices", _pair, 2, return_indices=True)
+max_pool3d_with_indices = _max_pool("max_pool3d_with_indices", _triple, 3, return_indices=True)
def _avg_pool(name, tuple_fn):