[mlir][MemRef] Simplify extract_strided_metadata(expand_shape)
authorQuentin Colombet <quentin.colombet@gmail.com>
Sat, 10 Sep 2022 00:23:13 +0000 (00:23 +0000)
committerQuentin Colombet <quentin.colombet@gmail.com>
Thu, 22 Sep 2022 19:07:09 +0000 (19:07 +0000)
commitd0aeb74e8869e5db23c079b98c5e1f325aeeeefe
tree073fbc04d120a52b0509654357a024b6f4aae7d6
parent9dc0b1674841243110f95cdc9c9d97bcf30c1544
[mlir][MemRef] Simplify extract_strided_metadata(expand_shape)

Add a pattern to the pass that simplifies
extract_strided_metadata(other_op(memref)).

The new pattern gets rid of the expand_shape operation while
materializing its effects on the sizes, and the strides of
the base object.

In other words, this simplification replaces:
```
baseBuffer, offset, sizes, strides =
             extract_strided_metadata(expand_shape(memref))
```

With

```
baseBuffer, offset, baseSizes, baseStrides =
    extract_strided_metadata(memref)
sizes#reassIdx =
    baseSizes#reassDim / product(expandShapeSizes#j,
                                 for j in group excluding
                                   reassIdx)
strides#reassIdx =
    baseStrides#reassDim * product(expandShapeSizes#j,
                                   for j in
                                     reassIdx+1..
                                       reassIdx+group.size-1)
```

Where `reassIdx` is a reassociation index for the group at
`reassDim` and `expandShapeSizes#j` is either:
- The constant size at dimension j, derived directly from the
  result type of the expand_shape op, or
- An affine expression: baseSizes#reassDim / product of all
  constant sizes in expandShapeSizes.

Note: baseBuffer and offset are unaffected by the expand_shape
operation.

Differential Revision: https://reviews.llvm.org/D133625
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir