[mlir][Tensor] Add rewrites to extract slices through `tensor.collape_shape`
authorChristopher Bate <cbate@nvidia.com>
Thu, 8 Sep 2022 21:21:57 +0000 (15:21 -0600)
committerChristopher Bate <cbate@nvidia.com>
Fri, 9 Sep 2022 03:58:21 +0000 (21:58 -0600)
commitf4a478cd017818ad6381a8aa1a7e3d29fd263ef9
treef0c652fd9a17dae9435a5a356c7eb2b046db1626
parent7fa1d743d073b4af6acb0a34b6324edf1d92f518
[mlir][Tensor] Add rewrites to extract slices through `tensor.collape_shape`

This change adds a set of utilities to replace the result of a
`tensor.collapse_shape -> tensor.extract_slice` chain with the
equivalent result formed by aggregating slices of the
`tensor.collapse_shape` source. In general, it is not possible to
commute `extract_slice` and `collapse_shape` if linearized dimensions
are sliced. The i-th dimension of the `tensor.collapse_shape`
result is a "linearized sliced dimension" if:

1) Reassociation indices of tensor.collapse_shape in the i'th position
   is greater than size 1 (multiple dimensions of the input are collapsed)
2) The i-th dimension is sliced by `tensor.extract_slice`.

We can work around this by stitching together the result of
`tensor.extract_slice` by iterating over any linearized sliced dimensions.
This is equivalent to "tiling" the linearized-and-sliced dimensions of
the `tensor.collapse_shape` operation in order to manifest the result
tile (the result of the `tensor.extract_slice`). The user of the
utilities must provide the mechanism to create the tiling (e.g. a loop).
In the tests, it is demonstrated how to apply the utilities using either
`scf.for` or `scf.foreach_thread`.

The below example illustrates the pattern using `scf.for`:

```
%0 = linalg.generic ... -> tensor<3x7x11x10xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
%2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
```

We can construct %2 by generating the following IR:

```
%dest = linalg.init_tensor() : tensor<10x10xf32>
%2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
   // Step 1: Map this output idx (%iv) to a multi-index for the input (%3):
   %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
   %3:3 = arith.delinearize_index %iv into (3, 7, 11)
   // Step 2: Extract the slice from the input
   %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
         tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
   %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
         tensor<1x1x1x10xf32> into tensor<1x10xf32>
   // Step 3: Insert the slice into the destination
   %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
         tensor<1x10xf32> into tensor<10x10xf32>
   scf.yield %6 : tensor<10x10xf32>
}
```

The pattern was discussed in the RFC here: https://discourse.llvm.org/t/rfc-tensor-extracting-slices-from-tensor-collapse-shape/64034

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129699
14 files changed:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h [new file with mode: 0644]
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp [new file with mode: 0644]
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Tensor/CMakeLists.txt
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel