[mlir][Vector] Add support for lowering 0-d transfers to load/store.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 12 Oct 2021 12:26:30 +0000 (12:26 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 12 Oct 2021 12:35:19 +0000 (12:35 +0000)
Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D111603

mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-lowering.mlir

index c76c43afbed3fc123e9c5d9fcf5b7e85d2208f8c..46865160e6f9ed9bb298070305e195ef2ba815a6 100644 (file)
@@ -2590,6 +2590,24 @@ struct VectorLoadToMemrefLoadLowering
   }
 };
 
+/// Replace a scalar vector.store with a memref.store.
+struct VectorStoreToMemrefStoreLowering
+    : public OpRewritePattern<vector::StoreOp> {
+  using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    auto vecType = storeOp.getVectorType();
+    if (vecType.getNumElements() != 1)
+      return failure();
+    Value extracted = rewriter.create<vector::ExtractOp>(
+        storeOp.getLoc(), storeOp.valueToStore(), ArrayRef<int64_t>{1});
+    rewriter.replaceOpWithNewOp<memref::StoreOp>(
+        storeOp, extracted, storeOp.base(), storeOp.indices());
+    return success();
+  }
+};
+
 /// Progressive lowering of transfer_write. This pattern supports lowering of
 /// `vector.transfer_write` to `vector.store` if all of the following hold:
 /// - Stride of most minor memref dimension must be 1.
@@ -2611,7 +2629,7 @@ struct TransferWriteToVectorStoreLowering
       return failure();
     // Permutations are handled by VectorToSCF or
     // populateVectorTransferPermutationMapLoweringPatterns.
-    if (!write.permutation_map().isMinorIdentity())
+    if (!write.isZeroD() && !write.permutation_map().isMinorIdentity())
       return failure();
     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
     if (!memRefType)
@@ -2766,6 +2784,9 @@ struct TransferWritePermutationLowering
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp op,
                                 PatternRewriter &rewriter) const override {
+    if (op.isZeroD())
+      return failure();
+
     SmallVector<unsigned> permutation;
     AffineMap map = op.permutation_map();
     if (map.isMinorIdentity())
@@ -3581,7 +3602,9 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
   patterns.add<TransferReadToVectorLoadLowering,
                TransferWriteToVectorStoreLowering>(patterns.getContext(),
                                                    maxTransferRank);
-  patterns.add<VectorLoadToMemrefLoadLowering>(patterns.getContext());
+  patterns
+      .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
+          patterns.getContext());
 }
 
 void mlir::vector::populateVectorUnrollPatterns(
index 825d7468dc952c039031a7f5ea8e35dc35821afd..c2db8a501d6bc165d8c9a4b23927d5379761b71f 100644 (file)
@@ -1,5 +1,23 @@
 // RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s
 
+// CHECK-LABEL: func @vector_transfer_ops_0d(
+//  CHECK-SAME:   %[[MEM:.*]]: memref<f32>) {
+func @vector_transfer_ops_0d(%M: memref<f32>) {
+    %f0 = constant 0.0 : f32
+
+//  CHECK-NEXT:   %[[V:.*]] = memref.load %[[MEM]][] : memref<f32>
+    %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} :
+      memref<f32>, vector<1xf32>
+
+//  CHECK-NEXT:   memref.store %[[V]], %[[MEM]][] : memref<f32>
+    vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} :
+      vector<1xf32>, memref<f32>
+  
+    return
+}
+
+// -----
+
 // transfer_read/write are lowered to vector.load/store
 // CHECK-LABEL:   func @transfer_to_load(
 // CHECK-SAME:                                %[[MEM:.*]]: memref<8x8xf32>,