return _unimplemented("max_pool1d_with_indices", "dilation")
if stride is None:
stride = kernel_size
- r = g.op("MaxPool", input,
- kernel_shape_i=_single(kernel_size),
- pads_i=_single(padding) * 2,
- strides_i=_single(stride))
- return r, None
+ r, indices = g.op("MaxPool", input, outputs=2,
+ kernel_shape_i=_single(kernel_size),
+ pads_i=_single(padding) * 2,
+ strides_i=_single(stride))
+ # easy but hacky way to get flattened indices values
+ # to be used to convert the indices values to non-flattened.
+ # In ONNX the indices are computed as a flatten 1-D tensor,
+ # so the values in indices are in [0, N x C x D1 x ... x Dn).
+ # To convert the indices to the same format used by Pytorch,
+ # we first execute a maxpool with a kernel and stride of 1 on the same input.
+ # This will result in a tensor of indices in which each index will have it's own value.
+ # Using this tensor as a reference, we extract the first index of each axis and substract
+ # it from each index of this axis in the indices to convert.
+ # This step will result in a tensor were each dimension has values of indices within
+ # the dimension it is in.
+ # For more information :
+ # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
+ _, flattened_indices = g.op("MaxPool", input, outputs=2,
+ kernel_shape_i=[1],
+ strides_i=[1])
+ # convert indices to have non-flattened indices values
+ s = g.op("Slice", flattened_indices, axes_i=[2], starts_i=[0], ends_i=[1])
+ indices = sub(g, indices, s)
+ return r, indices
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
return _unimplemented("max_pool2d_with_indices", "dilation")
if not stride:
stride = kernel_size
- r = g.op("MaxPool", input,
- kernel_shape_i=_pair(kernel_size),
- pads_i=_pair(padding) * 2,
- strides_i=_pair(stride))
- return r, None
+ r, indices = g.op("MaxPool", input, outputs=2,
+ kernel_shape_i=_pair(kernel_size),
+ pads_i=_pair(padding) * 2,
+ strides_i=_pair(stride))
+ # easy but hacky way to get flattened indices values
+ # to be used to convert the indices values to non-flattened
+ # See comment in max_pool1d_with_indices for details.
+ _, flattened_indices = g.op("MaxPool", input, outputs=2,
+ kernel_shape_i=[1, 1],
+ strides_i=[1, 1])
+ # convert indices to have non-flattened indices values
+ s = g.op("Slice", flattened_indices, axes_i=[2, 3], starts_i=[0, 0], ends_i=[1, 1])
+ indices = sub(g, indices, s)
+ return r, indices
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
return _unimplemented("max_pool3d_with_indices", "dilation")
if not stride:
stride = kernel_size
- r = g.op("MaxPool", input,
- kernel_shape_i=_triple(kernel_size),
- pads_i=_triple(padding) * 2,
- strides_i=_triple(stride))
- return r, None
+ r, indices = g.op("MaxPool", input, outputs=2,
+ kernel_shape_i=_triple(kernel_size),
+ pads_i=_triple(padding) * 2,
+ strides_i=_triple(stride))
+ # easy but hacky way to get flattened indices values
+ # to be used to convert the indices values to non-flattened
+ # See comment in max_pool1d_with_indices for details.
+ _, flattened_indices = g.op("MaxPool", input, outputs=2,
+ kernel_shape_i=[1, 1, 1],
+ strides_i=[1, 1, 1])
+ # convert indices to have non-flattened indices values
+ s = g.op("Slice", flattened_indices, axes_i=[2, 3, 4], starts_i=[0, 0, 0], ends_i=[1, 1, 1])
+ indices = sub(g, indices, s)
+ return r, indices
def _avg_pool(name, tuple_fn):