[mlir][vector] Add scalar vector xfer to memref patterns
authorMatthias Springer <springerm@google.com>
Mon, 19 Dec 2022 09:24:46 +0000 (10:24 +0100)
committerMatthias Springer <springerm@google.com>
Mon, 19 Dec 2022 09:27:49 +0000 (10:27 +0100)
These patterns devectorize scalar transfers such as vector<f32> or vector<1xf32>.

Differential Revision: https://reviews.llvm.org/D140215

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

index a4735ae..0028abe 100644 (file)
@@ -142,6 +142,11 @@ void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
 void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
                                              PatternBenefit benefit = 1);
 
+/// Collects patterns that lower scalar vector transfer ops to memref loads and
+/// stores when beneficial.
+void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
+                                                  PatternBenefit benefit = 1);
+
 /// Returns the integer type required for subscripts in the vector dialect.
 IntegerType getVectorSubscriptType(Builder &builder);
 
index cfbf289..6fb1b8c 100644 (file)
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   MLIRMemRefDialect
   MLIRSCFDialect
   MLIRSideEffectInterfaces
+  MLIRTensorDialect
   MLIRTransforms
   MLIRVectorDialect
   MLIRVectorInterfaces
index b59b10c..727a356 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
@@ -556,6 +558,101 @@ class FlattenContiguousRowMajorTransferWritePattern
   }
 };
 
+/// Rewrite extractelement(transfer_read) to memref.load.
+///
+/// Rewrite only if the extractelement op is the single user of the transfer op.
+/// E.g., do not rewrite IR such as:
+/// %0 = vector.transfer_read ... : vector<1024xf32>
+/// %1 = vector.extractelement %0[%a : index] : vector<1024xf32>
+/// %2 = vector.extractelement %0[%b : index] : vector<1024xf32>
+/// Rewriting such IR (replacing one vector load with multiple scalar loads) may
+/// negatively affect performance.
+class FoldScalarExtractOfTransferRead
+    : public OpRewritePattern<vector::ExtractElementOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
+    if (!xferOp)
+      return failure();
+    // xfer result must have a single use. Otherwise, it may be better to
+    // perform a vector load.
+    if (!extractOp.getVector().hasOneUse())
+      return failure();
+    // Mask not supported.
+    if (xferOp.getMask())
+      return failure();
+    // Map not supported.
+    if (!xferOp.getPermutationMap().isMinorIdentity())
+      return failure();
+    // Cannot rewrite if the indices may be out of bounds. The starting point is
+    // always inbounds, so we don't care in case of 0d transfers.
+    if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
+      return failure();
+    // Construct scalar load.
+    SmallVector<Value> newIndices(xferOp.getIndices().begin(),
+                                  xferOp.getIndices().end());
+    if (extractOp.getPosition()) {
+      AffineExpr sym0, sym1;
+      bindSymbols(extractOp.getContext(), sym0, sym1);
+      OpFoldResult ofr = makeComposedFoldedAffineApply(
+          rewriter, extractOp.getLoc(), sym0 + sym1,
+          {newIndices[newIndices.size() - 1], extractOp.getPosition()});
+      if (ofr.is<Value>()) {
+        newIndices[newIndices.size() - 1] = ofr.get<Value>();
+      } else {
+        newIndices[newIndices.size() - 1] =
+            rewriter.create<arith::ConstantIndexOp>(extractOp.getLoc(),
+                                                    *getConstantIntValue(ofr));
+      }
+    }
+    if (xferOp.getSource().getType().isa<MemRefType>()) {
+      rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
+                                                  newIndices);
+    } else {
+      rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+          extractOp, xferOp.getSource(), newIndices);
+    }
+    return success();
+  }
+};
+
+/// Rewrite scalar transfer_write(broadcast) to memref.store.
+class FoldScalarTransferWriteOfBroadcast
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
+                                PatternRewriter &rewriter) const override {
+    // Must be a scalar write.
+    auto vecType = xferOp.getVectorType();
+    if (vecType.getRank() != 0 &&
+        (vecType.getRank() != 1 || vecType.getShape()[0] != 1))
+      return failure();
+    // Mask not supported.
+    if (xferOp.getMask())
+      return failure();
+    // Map not supported.
+    if (!xferOp.getPermutationMap().isMinorIdentity())
+      return failure();
+    // Must be a broadcast of a scalar.
+    auto broadcastOp = xferOp.getVector().getDefiningOp<vector::BroadcastOp>();
+    if (!broadcastOp || broadcastOp.getSource().getType().isa<VectorType>())
+      return failure();
+    // Construct a scalar store.
+    if (xferOp.getSource().getType().isa<MemRefType>()) {
+      rewriter.replaceOpWithNewOp<memref::StoreOp>(
+          xferOp, broadcastOp.getSource(), xferOp.getSource(),
+          xferOp.getIndices());
+    } else {
+      rewriter.replaceOpWithNewOp<tensor::InsertOp>(
+          xferOp, broadcastOp.getSource(), xferOp.getSource(),
+          xferOp.getIndices());
+    }
+    return success();
+  }
+};
 } // namespace
 
 void mlir::vector::transferOpflowOpt(Operation *rootOp) {
@@ -574,6 +671,13 @@ void mlir::vector::transferOpflowOpt(Operation *rootOp) {
   opt.removeDeadOp();
 }
 
+void mlir::vector::populateScalarVectorTransferLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns
+      .add<FoldScalarExtractOfTransferRead, FoldScalarTransferWriteOfBroadcast>(
+          patterns.getContext(), benefit);
+}
+
 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
new file mode 100644 (file)
index 0000000..d34b9c3
--- /dev/null
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s -test-scalar-vector-transfer-lowering -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_0d(
+//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
+//       CHECK:   %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+//       CHECK:   return %[[r]]
+func.func @transfer_read_0d(%m: memref<?x?x?xf32>, %idx: index) -> f32 {
+  %cst = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst : memref<?x?x?xf32>, vector<f32>
+  %1 = vector.extractelement %0[] : vector<f32>
+  return %1 : f32
+}
+
+// -----
+
+//       CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @transfer_read_1d(
+//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[idx2:.*]]: index
+//       CHECK:   %[[added:.*]] = affine.apply #[[$map]]()[%[[idx]], %[[idx2]]]
+//       CHECK:   %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[added]]]
+//       CHECK:   return %[[r]]
+func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) -> f32 {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst {in_bounds = [true]} : memref<?x?x?xf32>, vector<5xf32>
+  %1 = vector.extractelement %0[%idx2 : index] : vector<5xf32>
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_transfer_read_0d(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index
+//       CHECK:   %[[r:.*]] = tensor.extract %[[t]][%[[idx]], %[[idx]], %[[idx]]]
+//       CHECK:   return %[[r]]
+func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
+  %cst = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %t[%idx, %idx, %idx], %cst : tensor<?x?x?xf32>, vector<f32>
+  %1 = vector.extractelement %0[] : vector<f32>
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_0d(
+//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
+//       CHECK:   memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+func.func @transfer_write_0d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
+  %0 = vector.broadcast %f : f32 to vector<f32>
+  vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<f32>, memref<?x?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_1d(
+//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
+//       CHECK:   memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+func.func @transfer_write_1d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
+  %0 = vector.broadcast %f : f32 to vector<1xf32>
+  vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<1xf32>, memref<?x?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_transfer_write_0d(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
+//       CHECK:   %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
+//       CHECK:   return %[[r]]
+func.func @tensor_transfer_write_0d(%t: tensor<?x?x?xf32>, %idx: index, %f: f32) -> tensor<?x?x?xf32> {
+  %0 = vector.broadcast %f : f32 to vector<f32>
+  %1 = vector.transfer_write %0, %t[%idx, %idx, %idx] : vector<f32>, tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
index b033186..48dc95a 100644 (file)
@@ -462,6 +462,33 @@ struct TestVectorTransferFullPartialSplitPatterns
   }
 };
 
+struct TestScalarVectorTransferLoweringPatterns
+    : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestScalarVectorTransferLoweringPatterns)
+
+  StringRef getArgument() const final {
+    return "test-scalar-vector-transfer-lowering";
+  }
+  StringRef getDescription() const final {
+    return "Test lowering of scalar vector transfers to memref loads/stores.";
+  }
+  TestScalarVectorTransferLoweringPatterns() = default;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, memref::MemRefDialect, tensor::TensorDialect,
+                    vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    vector::populateScalarVectorTransferLoweringPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestVectorTransferOpt
     : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
@@ -869,6 +896,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
 
+  PassRegistration<TestScalarVectorTransferLoweringPatterns>();
+
   PassRegistration<TestVectorTransferOpt>();
 
   PassRegistration<TestVectorTransferLoweringPatterns>();