auto vecType = storeOp.getVectorType();
if (vecType.getNumElements() != 1)
return failure();
+ SmallVector<int64_t> indices(vecType.getRank(), 0);
Value extracted = rewriter.create<vector::ExtractOp>(
- storeOp.getLoc(), storeOp.valueToStore(), ArrayRef<int64_t>{1});
+ storeOp.getLoc(), storeOp.valueToStore(), indices);
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, extracted, storeOp.base(), storeOp.indices());
return success();
// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s
// CHECK-LABEL: func @vector_transfer_ops_0d(
-// CHECK-SAME: %[[MEM:.*]]: memref<f32>) {
-func @vector_transfer_ops_0d(%M: memref<f32>) {
+// CHECK-SAME: %[[MEM:.*]]: memref<f32>
+// CHECK-SAME: %[[VV:.*]]: vector<1x1x1xf32>
+func @vector_transfer_ops_0d(%M: memref<f32>, %v: vector<1x1x1xf32>) {
%f0 = constant 0.0 : f32
// CHECK-NEXT: %[[V:.*]] = memref.load %[[MEM]][] : memref<f32>
vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} :
vector<1xf32>, memref<f32>
+// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : vector<1x1x1xf32>
+// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref<f32>
+ vector.store %v, %M[] : memref<f32>, vector<1x1x1xf32>
+
return
}