These patterns devectorize scalar transfers such as vector<f32> or vector<1xf32>.
Differential Revision: https://reviews.llvm.org/D140215
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Collects patterns that lower scalar vector transfer ops to memref loads and
+/// stores when beneficial.
+void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
MLIRMemRefDialect
MLIRSCFDialect
MLIRSideEffectInterfaces
+ MLIRTensorDialect
MLIRTransforms
MLIRVectorDialect
MLIRVectorInterfaces
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
}
};
+/// Rewrite extractelement(transfer_read) to memref.load.
+///
+/// Rewrite only if the extractelement op is the single user of the transfer op.
+/// E.g., do not rewrite IR such as:
+/// %0 = vector.transfer_read ... : vector<1024xf32>
+/// %1 = vector.extractelement %0[%a : index] : vector<1024xf32>
+/// %2 = vector.extractelement %0[%b : index] : vector<1024xf32>
+/// Rewriting such IR (replacing one vector load with multiple scalar loads) may
+/// negatively affect performance.
+class FoldScalarExtractOfTransferRead
+ : public OpRewritePattern<vector::ExtractElementOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
+ if (!xferOp)
+ return failure();
+ // xfer result must have a single use. Otherwise, it may be better to
+ // perform a vector load.
+ if (!extractOp.getVector().hasOneUse())
+ return failure();
+ // Mask not supported.
+ if (xferOp.getMask())
+ return failure();
+ // Map not supported.
+ if (!xferOp.getPermutationMap().isMinorIdentity())
+ return failure();
+ // Cannot rewrite if the indices may be out of bounds. The starting point is
+ // always inbounds, so we don't care in case of 0d transfers.
+ if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
+ return failure();
+ // Construct scalar load.
+ SmallVector<Value> newIndices(xferOp.getIndices().begin(),
+ xferOp.getIndices().end());
+ if (extractOp.getPosition()) {
+ AffineExpr sym0, sym1;
+ bindSymbols(extractOp.getContext(), sym0, sym1);
+ OpFoldResult ofr = makeComposedFoldedAffineApply(
+ rewriter, extractOp.getLoc(), sym0 + sym1,
+ {newIndices[newIndices.size() - 1], extractOp.getPosition()});
+ if (ofr.is<Value>()) {
+ newIndices[newIndices.size() - 1] = ofr.get<Value>();
+ } else {
+ newIndices[newIndices.size() - 1] =
+ rewriter.create<arith::ConstantIndexOp>(extractOp.getLoc(),
+ *getConstantIntValue(ofr));
+ }
+ }
+ if (xferOp.getSource().getType().isa<MemRefType>()) {
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
+ newIndices);
+ } else {
+ rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+ extractOp, xferOp.getSource(), newIndices);
+ }
+ return success();
+ }
+};
+
+/// Rewrite scalar transfer_write(broadcast) to memref.store.
+class FoldScalarTransferWriteOfBroadcast
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
+ PatternRewriter &rewriter) const override {
+ // Must be a scalar write.
+ auto vecType = xferOp.getVectorType();
+ if (vecType.getRank() != 0 &&
+ (vecType.getRank() != 1 || vecType.getShape()[0] != 1))
+ return failure();
+ // Mask not supported.
+ if (xferOp.getMask())
+ return failure();
+ // Map not supported.
+ if (!xferOp.getPermutationMap().isMinorIdentity())
+ return failure();
+ // Must be a broadcast of a scalar.
+ auto broadcastOp = xferOp.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcastOp || broadcastOp.getSource().getType().isa<VectorType>())
+ return failure();
+ // Construct a scalar store.
+ if (xferOp.getSource().getType().isa<MemRefType>()) {
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ xferOp, broadcastOp.getSource(), xferOp.getSource(),
+ xferOp.getIndices());
+ } else {
+ rewriter.replaceOpWithNewOp<tensor::InsertOp>(
+ xferOp, broadcastOp.getSource(), xferOp.getSource(),
+ xferOp.getIndices());
+ }
+ return success();
+ }
+};
} // namespace
void mlir::vector::transferOpflowOpt(Operation *rootOp) {
opt.removeDeadOp();
}
+void mlir::vector::populateScalarVectorTransferLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns
+ .add<FoldScalarExtractOfTransferRead, FoldScalarTransferWriteOfBroadcast>(
+ patterns.getContext(), benefit);
+}
+
void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
--- /dev/null
+// RUN: mlir-opt %s -test-scalar-vector-transfer-lowering -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_0d(
+// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
+// CHECK: %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: return %[[r]]
+func.func @transfer_read_0d(%m: memref<?x?x?xf32>, %idx: index) -> f32 {
+ %cst = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst : memref<?x?x?xf32>, vector<f32>
+ %1 = vector.extractelement %0[] : vector<f32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @transfer_read_1d(
+// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[idx2:.*]]: index
+// CHECK: %[[added:.*]] = affine.apply #[[$map]]()[%[[idx]], %[[idx2]]]
+// CHECK: %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[added]]]
+// CHECK: return %[[r]]
+func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) -> f32 {
+ %cst = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst {in_bounds = [true]} : memref<?x?x?xf32>, vector<5xf32>
+ %1 = vector.extractelement %0[%idx2 : index] : vector<5xf32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_transfer_read_0d(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index
+// CHECK: %[[r:.*]] = tensor.extract %[[t]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: return %[[r]]
+func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
+ %cst = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %t[%idx, %idx, %idx], %cst : tensor<?x?x?xf32>, vector<f32>
+ %1 = vector.extractelement %0[] : vector<f32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_0d(
+// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
+// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+func.func @transfer_write_0d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
+ %0 = vector.broadcast %f : f32 to vector<f32>
+ vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<f32>, memref<?x?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_1d(
+// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
+// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+func.func @transfer_write_1d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
+ %0 = vector.broadcast %f : f32 to vector<1xf32>
+ vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<1xf32>, memref<?x?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_transfer_write_0d(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
+// CHECK: %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: return %[[r]]
+func.func @tensor_transfer_write_0d(%t: tensor<?x?x?xf32>, %idx: index, %f: f32) -> tensor<?x?x?xf32> {
+ %0 = vector.broadcast %f : f32 to vector<f32>
+ %1 = vector.transfer_write %0, %t[%idx, %idx, %idx] : vector<f32>, tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
}
};
+struct TestScalarVectorTransferLoweringPatterns
+ : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestScalarVectorTransferLoweringPatterns)
+
+ StringRef getArgument() const final {
+ return "test-scalar-vector-transfer-lowering";
+ }
+ StringRef getDescription() const final {
+ return "Test lowering of scalar vector transfers to memref loads/stores.";
+ }
+ TestScalarVectorTransferLoweringPatterns() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect, memref::MemRefDialect, tensor::TensorDialect,
+ vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ vector::populateScalarVectorTransferLoweringPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestVectorTransferOpt
: public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
+ PassRegistration<TestScalarVectorTransferLoweringPatterns>();
+
PassRegistration<TestVectorTransferOpt>();
PassRegistration<TestVectorTransferLoweringPatterns>();