[mlir][Vector][Bigfix] Fix vector transfer to store lowering to insert a proper...
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 12 Oct 2021 13:27:41 +0000 (13:27 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 12 Oct 2021 13:28:12 +0000 (13:28 +0000)
Differential Revision: https://reviews.llvm.org/D111641

mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

index 46865160e6f9ed9bb298070305e195ef2ba815a6..6942aad9e790825faa9aefc43aa01cea324377d9 100644 (file)
@@ -2600,8 +2600,9 @@ struct VectorStoreToMemrefStoreLowering
     auto vecType = storeOp.getVectorType();
     if (vecType.getNumElements() != 1)
       return failure();
+    SmallVector<int64_t> indices(vecType.getRank(), 0);
     Value extracted = rewriter.create<vector::ExtractOp>(
-        storeOp.getLoc(), storeOp.valueToStore(), ArrayRef<int64_t>{1});
+        storeOp.getLoc(), storeOp.valueToStore(), indices);
     rewriter.replaceOpWithNewOp<memref::StoreOp>(
         storeOp, extracted, storeOp.base(), storeOp.indices());
     return success();
index 4139a80527afac035e3f7dd382e994832ab8c50a..3e8706f6fbc5343488a0ff58441a6c3b3fecd17c 100644 (file)
@@ -1,8 +1,9 @@
 // RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @vector_transfer_ops_0d(
-//  CHECK-SAME:   %[[MEM:.*]]: memref<f32>) {
-func @vector_transfer_ops_0d(%M: memref<f32>) {
+//  CHECK-SAME:   %[[MEM:.*]]: memref<f32>
+//  CHECK-SAME:   %[[VV:.*]]: vector<1x1x1xf32>
+func @vector_transfer_ops_0d(%M: memref<f32>, %v: vector<1x1x1xf32>) {
     %f0 = constant 0.0 : f32
 
 //  CHECK-NEXT:   %[[V:.*]] = memref.load %[[MEM]][] : memref<f32>
@@ -13,6 +14,10 @@ func @vector_transfer_ops_0d(%M: memref<f32>) {
     vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} :
       vector<1xf32>, memref<f32>
 
+//  CHECK-NEXT:   %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : vector<1x1x1xf32>
+//  CHECK-NEXT:   memref.store %[[VV]], %[[MEM]][] : memref<f32>
+    vector.store %v, %M[] : memref<f32>, vector<1x1x1xf32>
+
     return
 }