[mlir][vector] Convert extract_strided_slice to extract & insert chain
authorLei Zhang <antiagainst@google.com>
Thu, 10 Nov 2022 00:37:19 +0000 (19:37 -0500)
committerLei Zhang <antiagainst@google.com>
Thu, 10 Nov 2022 00:42:07 +0000 (19:42 -0500)
This is useful for breaking down extract_strided_slice and potentially
cancel with other extract / insert ops before or after.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D137471

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

index e7169b6..1575661 100644 (file)
@@ -285,6 +285,17 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 void populateVectorInsertExtractStridedSliceDecompositionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+/// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice
+/// ops into a chain of Extract ops to extract each element from the source, and
+/// then a chain of Insert ops to insert to the target vector.
+///
+/// If `controlFn` is not nullptr, the pattern will only be invoked on ops that
+/// `controlFn` returns true. Otherwise runs on ops.
+void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+    RewritePatternSet &patterns,
+    std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
+    PatternBenefit benefit = 1);
+
 /// Populate `patterns` with the following patterns.
 ///
 /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
index ad6cf85..313a3f9 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
 
 using namespace mlir;
 using namespace mlir::vector;
@@ -231,6 +232,53 @@ public:
   }
 };
 
+/// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops
+/// to extract each element from the source, and then a chain of Insert ops
+/// to insert to the target vector.
+class Convert1DExtractStridedSliceIntoExtractInsertChain final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  Convert1DExtractStridedSliceIntoExtractInsertChain(
+      MLIRContext *context,
+      std::function<bool(ExtractStridedSliceOp)> controlFn,
+      PatternBenefit benefit)
+      : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    if (controlFn && !controlFn(op))
+      return failure();
+
+    // Only handle 1-D cases.
+    if (op.getOffsets().getValue().size() != 1)
+      return failure();
+
+    int64_t offset =
+        op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
+    int64_t size =
+        op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
+    int64_t stride =
+        op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
+
+    Location loc = op.getLoc();
+    SmallVector<Value> elements;
+    elements.reserve(size);
+    for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
+      elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
+
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(op.getType()));
+    for (int64_t i = 0; i < size; ++i)
+      result = rewriter.create<InsertOp>(loc, elements[i], result, i);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+private:
+  std::function<bool(ExtractStridedSliceOp)> controlFn;
+};
+
 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
@@ -285,14 +333,22 @@ public:
   }
 };
 
-void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<DecomposeDifferentRankInsertStridedSlice,
                DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
 }
 
+void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+    RewritePatternSet &patterns,
+    std::function<bool(ExtractStridedSliceOp)> controlFn,
+    PatternBenefit benefit) {
+  patterns.add<Convert1DExtractStridedSliceIntoExtractInsertChain>(
+      patterns.getContext(), std::move(controlFn), benefit);
+}
+
 /// Populate the given list with patterns that convert from Vector to LLVM.
-void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
+void vector::populateVectorInsertExtractStridedSliceTransforms(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,
                                                                benefit);
diff --git a/mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir b/mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir
new file mode 100644 (file)
index 0000000..ca14dee
--- /dev/null
@@ -0,0 +1,29 @@
+// RUN: mlir-opt -split-input-file -test-vector-extract-strided-slice-lowering %s | FileCheck %s
+
+// CHECK-LABEL: func.func @extract_strided_slice_1D
+//  CHECK-SAME: (%[[INPUT:.+]]: vector<8xf16>)
+func.func @extract_strided_slice_1D(%input: vector<8xf16>) -> vector<4xf16> {
+  %0 = vector.extract_strided_slice %input {offsets = [1], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+  return %0: vector<4xf16>
+}
+
+// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf16>
+// CHECK: %[[E0:.+]] = vector.extract %[[INPUT]][1] : vector<8xf16>
+// CHECK: %[[E1:.+]] = vector.extract %[[INPUT]][2] : vector<8xf16>
+// CHECK: %[[E2:.+]] = vector.extract %[[INPUT]][3] : vector<8xf16>
+// CHECK: %[[E3:.+]] = vector.extract %[[INPUT]][4] : vector<8xf16>
+// CHECK: %[[I0:.+]] = vector.insert %[[E0]], %[[INIT]] [0] : f16 into vector<4xf16>
+// CHECK: %[[I1:.+]] = vector.insert %[[E1]], %[[I0]] [1] : f16 into vector<4xf16>
+// CHECK: %[[I2:.+]] = vector.insert %[[E2]], %[[I1]] [2] : f16 into vector<4xf16>
+// CHECK: %[[I3:.+]] = vector.insert %[[E3]], %[[I2]] [3] : f16 into vector<4xf16>
+// CHECK: return %[[I3]]
+
+
+// -----
+
+// CHECK-LABEL: func.func @extract_strided_slice_2D
+func.func @extract_strided_slice_2D(%input: vector<1x8xf16>) -> vector<1x4xf16> {
+  // CHECK: vector.extract_strided_slice
+  %0 = vector.extract_strided_slice %input {offsets = [0, 1], sizes = [1, 4], strides = [1, 1]} : vector<1x8xf16> to vector<1x4xf16>
+  return %0: vector<1x4xf16>
+}
index de29fc2..4f44d43 100644 (file)
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
@@ -785,6 +786,26 @@ struct TestVectorDistribution
   }
 };
 
+struct TestVectorExtractStridedSliceLowering
+    : public PassWrapper<TestVectorExtractStridedSliceLowering,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorExtractStridedSliceLowering)
+
+  StringRef getArgument() const final {
+    return "test-vector-extract-strided-slice-lowering";
+  }
+  StringRef getDescription() const final {
+    return "Test lowering patterns that converts vector.extract_strided_slice "
+           "into a chain of vector.extract and vector.insert ops";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -819,6 +840,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestVectorScanLowering>();
 
   PassRegistration<TestVectorDistribution>();
+
+  PassRegistration<TestVectorExtractStridedSliceLowering>();
 }
 } // namespace test
 } // namespace mlir