[mlir][tensor] ExtractSliceFromReshape: handle collapsing of unit dim edge cases
authorChristopher Bate <cbate@nvidia.com>
Sun, 16 Oct 2022 19:57:18 +0000 (13:57 -0600)
committerChristopher Bate <cbate@nvidia.com>
Sat, 22 Oct 2022 19:29:34 +0000 (13:29 -0600)
commit446981bdb64d0ae24ac77b8ba07f3ee3808c3936
tree09291079efb99464165fa99aa50e9b50fa75bb98
parentb24a9f0cef88760ae9383d445541b513bcc66018
[mlir][tensor] ExtractSliceFromReshape: handle collapsing of unit dim edge cases

Prior to this change, the "ExtractSliceFromReshape" pattern would transform

```
%collapsed = tensor.collapse_shape %input [[0, 1], [2]]
                : tensor<1x11x100xf32> into tensor<11x100xf32>
%slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1]
                : tensor<11x100xf32> to tensor<?x100xf32>
```

into a loop that iterated over the range `%size - %offt`, that pieces
together multiple sub-slices of `%input` along the first dimension. This
is correct but obviously inefficient. The technical condition is that
collapsing at-most-one non-unit dimension of `%src` will not result in a
subsequent slice along the corresponding dimension of `%collapsed`
mapping across discontinuities in the index space of `%src`. Thus, the
definition of a "linearized dimension" (from the perspective of
`tensor.collapse_shape`) is updated to reflect this condition.

The transform will now generate

```
%slice = tensor.extract_slice %input [0, %offt, 0][1, %size, 100] [1, 1]
            : tensor<1x11x100xf32> to tensor<1x?x100xf32>
%result = tensor.collapse_shape [[0, 1], [2]]
            : tensor<1x?x100xf32> to tensor<?x100xf32>
```

which can be further canonicalized.

Additional tests are added to check this family of edge cases.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D135726
mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp