[mlir][vector] Add fold for ExtractStridedSlice(non-splat ConstantOp)
authorJakub Kuderski <kubak@google.com>
Fri, 25 Nov 2022 18:41:24 +0000 (13:41 -0500)
committerJakub Kuderski <kubak@google.com>
Fri, 25 Nov 2022 18:42:56 +0000 (13:42 -0500)
This allows us to better canonicalize/clean-up code created by the Wide
Integer Emulation pass.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D138606

mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index b71c2a0..af81c15 100644 (file)
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/ADT/bit.h"
+
+#include <cassert>
 #include <numeric>
 
 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc"
@@ -2680,28 +2684,117 @@ public:
 };
 
 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
-class StridedSliceConstantFolder final
+class StridedSliceSplatConstantFolder final
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
                                 PatternRewriter &rewriter) const override {
-    // Return if 'extractStridedSliceOp' operand is not defined by a
+    // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
     // ConstantOp.
-    auto constantOp =
-        extractStridedSliceOp.getVector().getDefiningOp<arith::ConstantOp>();
-    if (!constantOp)
+    Value sourceVector = extractStridedSliceOp.getVector();
+    Attribute vectorCst;
+    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
       return failure();
-    auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
-    if (!dense)
+
+    auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
+    if (!splat)
+      return failure();
+
+    auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
+                                          splat.getSplatValue<Attribute>());
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
+                                                   newAttr);
+    return success();
+  }
+};
+
+// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
+// ConstantOp.
+class StridedSliceNonSplatConstantFolder final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
+                                PatternRewriter &rewriter) const override {
+    // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
+    // ConstantOp.
+    Value sourceVector = extractStridedSliceOp.getVector();
+    Attribute vectorCst;
+    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
+      return failure();
+
+    // The splat case is handled by `StridedSliceSplatConstantFolder`.
+    auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
+    if (!dense || dense.isSplat())
+      return failure();
+
+    // TODO: Handle non-unit strides when they become available.
+    if (extractStridedSliceOp.hasNonUnitStrides())
       return failure();
-    auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
-                                          dense.getSplatValue<Attribute>());
+
+    auto sourceVecTy = sourceVector.getType().cast<VectorType>();
+    ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
+    SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
+
+    VectorType sliceVecTy = extractStridedSliceOp.getType();
+    ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
+    int64_t sliceRank = sliceVecTy.getRank();
+
+    // Expand offsets and sizes to match the vector rank.
+    SmallVector<int64_t, 4> offsets(sliceRank, 0);
+    llvm::copy(getI64SubArray(extractStridedSliceOp.getOffsets()),
+               offsets.begin());
+
+    SmallVector<int64_t, 4> sizes(sourceShape.begin(), sourceShape.end());
+    llvm::copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
+
+    // Calcualte the slice elements by enumerating all slice positions and
+    // linearizing them. The enumeration order is lexicographic which yields a
+    // sequence of monotonically increasing linearized position indices.
+    auto denseValuesBegin = dense.value_begin<Attribute>();
+    SmallVector<Attribute> sliceValues;
+    sliceValues.reserve(sliceVecTy.getNumElements());
+    SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
+    do {
+      int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
+      assert(linearizedPosition < sourceVecTy.getNumElements() &&
+             "Invalid index");
+      sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
+    } while (succeeded(incPosition(currSlicePosition, sliceShape, offsets)));
+
+    assert(static_cast<int64_t>(sliceValues.size()) ==
+               sliceVecTy.getNumElements() &&
+           "Invalid number of slice elements");
+    auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
     rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
                                                    newAttr);
     return success();
   }
+
+private:
+  // Calculate the next `position` in the n-D vector of size `shape`,
+  // applying an offset `offsets`. Modifies the `position` in place.
+  // Returns a failure when `position` becomes the end position.
+  static LogicalResult incPosition(MutableArrayRef<int64_t> position,
+                                   ArrayRef<int64_t> shape,
+                                   ArrayRef<int64_t> offsets) {
+    assert(position.size() == shape.size());
+    assert(position.size() == offsets.size());
+    for (auto [posInDim, dimSize, offsetInDim] :
+         llvm::reverse(llvm::zip(position, shape, offsets))) {
+      ++posInDim;
+      if (posInDim < dimSize + offsetInDim)
+        return success();
+
+      // Carry the overflow to the next loop iteration.
+      posInDim = offsetInDim;
+    }
+
+    return failure();
+  }
 };
 
 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
@@ -2770,8 +2863,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
-  results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
-              StridedSliceBroadcast, StridedSliceSplat>(context);
+  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
+              StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
+              StridedSliceSplat>(context);
 }
 
 //===----------------------------------------------------------------------===//
index eb1fb24..19a06af 100644 (file)
@@ -1533,6 +1533,63 @@ func.func @extract_splat_vector_3d_constant() -> (vector<2xi32>, vector<2xi32>,
 
 // -----
 
+// CHECK-LABEL: func.func @extract_strided_slice_1d_constant
+//   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xi32>
+//   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[1, 2]> : vector<2xi32>
+//   CHECK-DAG: %[[CCST:.*]] = arith.constant dense<2> : vector<1xi32>
+//  CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<3xi32>, vector<2xi32>, vector<1xi32>
+func.func @extract_strided_slice_1d_constant() -> (vector<3xi32>, vector<2xi32>, vector<1xi32>) {
+  %cst = arith.constant dense<[0, 1, 2]> : vector<3xi32>
+  %a = vector.extract_strided_slice %cst
+   {offsets = [0], sizes = [3], strides = [1]} : vector<3xi32> to vector<3xi32>
+  %b = vector.extract_strided_slice %cst
+   {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32>
+  %c = vector.extract_strided_slice %cst
+   {offsets = [2], sizes = [1], strides = [1]} : vector<3xi32> to vector<1xi32>
+  return %a, %b, %c : vector<3xi32>, vector<2xi32>, vector<1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_strided_slice_2d_constant
+//   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<1x1xi32>
+//   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[4, 5\]\]}}> : vector<1x2xi32>
+//   CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[1, 2\], \[4, 5\]\]}}> : vector<2x2xi32>
+//  CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<1x1xi32>, vector<1x2xi32>, vector<2x2xi32>
+func.func @extract_strided_slice_2d_constant() -> (vector<1x1xi32>, vector<1x2xi32>, vector<2x2xi32>) {
+  %cst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32>
+  %a = vector.extract_strided_slice %cst
+   {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x3xi32> to vector<1x1xi32>
+  %b = vector.extract_strided_slice %cst
+   {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32>
+  %c = vector.extract_strided_slice %cst
+   {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]} : vector<2x3xi32> to vector<2x2xi32>
+  return %a, %b, %c : vector<1x1xi32>, vector<1x2xi32>, vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_strided_slice_3d_constant
+//   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<{{\[\[\[8, 9\], \[10, 11\]\]\]}}> : vector<1x2x2xi32>
+//   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[\[2, 3\]\]\]}}> : vector<1x1x2xi32>
+//   CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[\[6, 7\]\], \[\[10, 11\]\]\]}}> : vector<2x1x2xi32>
+//   CHECK-DAG: %[[DCST:.*]] = arith.constant dense<11> : vector<1x1x1xi32>
+//  CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]]
+func.func @extract_strided_slice_3d_constant() -> (vector<1x2x2xi32>, vector<1x1x2xi32>, vector<2x1x2xi32>, vector<1x1x1xi32>) {
+  %cst = arith.constant dense<[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]]> : vector<3x2x2xi32>
+  %a = vector.extract_strided_slice %cst
+   {offsets = [2], sizes = [1], strides = [1]} : vector<3x2x2xi32> to vector<1x2x2xi32>
+  %b = vector.extract_strided_slice %cst
+   {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<3x2x2xi32> to vector<1x1x2xi32>
+  %c = vector.extract_strided_slice %cst
+   {offsets = [1, 1, 0], sizes = [2, 1, 2], strides = [1, 1, 1]} : vector<3x2x2xi32> to vector<2x1x2xi32>
+  %d = vector.extract_strided_slice %cst
+   {offsets = [2, 1, 1], sizes = [1, 1, 1], strides = [1, 1, 1]} : vector<3x2x2xi32> to vector<1x1x1xi32>
+  return %a, %b, %c, %d : vector<1x2x2xi32>, vector<1x1x2xi32>, vector<2x1x2xi32>, vector<1x1x1xi32>
+}
+
+// -----
+
 // CHECK-LABEL: extract_extract_strided
 //  CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16>
 //       CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16>