}
};
+/// Replace a scalar vector.store with a memref.store.
+struct VectorStoreToMemrefStoreLowering
+ : public OpRewritePattern<vector::StoreOp> {
+ using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ auto vecType = storeOp.getVectorType();
+ if (vecType.getNumElements() != 1)
+ return failure();
+ Value extracted = rewriter.create<vector::ExtractOp>(
+ storeOp.getLoc(), storeOp.valueToStore(), ArrayRef<int64_t>{1});
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ storeOp, extracted, storeOp.base(), storeOp.indices());
+ return success();
+ }
+};
+
/// Progressive lowering of transfer_write. This pattern supports lowering of
/// `vector.transfer_write` to `vector.store` if all of the following hold:
/// - Stride of most minor memref dimension must be 1.
return failure();
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
- if (!write.permutation_map().isMinorIdentity())
+ if (!write.isZeroD() && !write.permutation_map().isMinorIdentity())
return failure();
auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
if (!memRefType)
LogicalResult matchAndRewrite(vector::TransferWriteOp op,
PatternRewriter &rewriter) const override {
+ if (op.isZeroD())
+ return failure();
+
SmallVector<unsigned> permutation;
AffineMap map = op.permutation_map();
if (map.isMinorIdentity())
patterns.add<TransferReadToVectorLoadLowering,
TransferWriteToVectorStoreLowering>(patterns.getContext(),
maxTransferRank);
- patterns.add<VectorLoadToMemrefLoadLowering>(patterns.getContext());
+ patterns
+ .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
+ patterns.getContext());
}
void mlir::vector::populateVectorUnrollPatterns(
// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s
+// CHECK-LABEL: func @vector_transfer_ops_0d(
+// CHECK-SAME: %[[MEM:.*]]: memref<f32>) {
+func @vector_transfer_ops_0d(%M: memref<f32>) {
+ %f0 = constant 0.0 : f32
+
+// CHECK-NEXT: %[[V:.*]] = memref.load %[[MEM]][] : memref<f32>
+ %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} :
+ memref<f32>, vector<1xf32>
+
+// CHECK-NEXT: memref.store %[[V]], %[[MEM]][] : memref<f32>
+ vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} :
+ vector<1xf32>, memref<f32>
+
+ return
+}
+
+// -----
+
// transfer_read/write are lowered to vector.load/store
// CHECK-LABEL: func @transfer_to_load(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,