VectorTransferPermutationMapLoweringPatterns can be enabled via a pass option. These additional patterns lower permutation maps to minor identity maps with broadcasting, if possible, allowing for more efficient vector load/stores. The option is deactivated by default.
Differential Revision: https://reviews.llvm.org/D102593
"Perform full unrolling when converting vector transfers to SCF">,
Option<"targetRank", "target-rank", "unsigned", /*default=*/"1",
"Target vector rank to which transfer ops should be lowered">,
+ Option<"lowerPermutationMaps", "lower-permutation-maps", "bool",
+ /*default=*/"false", "Replace permutation maps with vector "
+ "transposes/broadcasts before lowering transfer ops">
];
}
struct VectorTransferToSCFOptions {
bool unroll = false;
unsigned targetRank = 1;
+ bool lowerPermutationMaps = false;
VectorTransferToSCFOptions &setUnroll(bool u) {
unroll = u;
targetRank = r;
return *this;
}
+
+ VectorTransferToSCFOptions &setLowerPermutationMaps(bool l) {
+ lowerPermutationMaps = l;
+ return *this;
+ }
};
/// Collect a set of patterns to convert from the Vector dialect to SCF + std.
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
/// Collect a set of transfer read/write lowering patterns.
///
/// These patterns lower transfer ops to simpler ops like `vector.load`,
-/// `vector.store` and `vector.broadcast`.
+/// `vector.store` and `vector.broadcast`. Includes all patterns of
+/// populateVectorTransferPermutationMapLoweringPatterns.
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
+/// Collect a set of transfer read/write lowering patterns that simplify the
+/// permutation map (e.g., converting it to a minor identity map) by inserting
+/// broadcasts and transposes.
+void populateVectorTransferPermutationMapLoweringPatterns(
+ RewritePatternSet &patterns);
+
/// These patterns materialize masks for various vector ops such as transfers.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
bool enableIndexOptimizations);
if (xferOp.mask()) {
auto maskType = MemRefType::get({}, xferOp.mask().getType());
auto maskBuffer = memref_alloca(maskType).value;
+ b.setInsertionPoint(xferOp);
memref_store(xferOp.mask(), maskBuffer);
result.maskBuffer = memref_load(maskBuffer);
}
};
template <typename OpTy>
-LogicalResult checkPrepareXferOp(OpTy xferOp, unsigned targetRank) {
+LogicalResult checkPrepareXferOp(OpTy xferOp,
+ VectorTransferToSCFOptions options) {
if (xferOp->hasAttr(kPassLabel))
return failure();
- if (xferOp.getVectorType().getRank() <= targetRank)
+ if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
return success();
}
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
- if (checkPrepareXferOp(xferOp, options.targetRank).failed())
+ if (checkPrepareXferOp(xferOp, options).failed())
return failure();
ScopedContext scope(rewriter, xferOp.getLoc());
LogicalResult matchAndRewrite(TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
- if (checkPrepareXferOp(xferOp, options.targetRank).failed())
+ if (checkPrepareXferOp(xferOp, options).failed())
return failure();
ScopedContext scope(rewriter, xferOp.getLoc());
ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
this->fullUnroll = options.unroll;
this->targetRank = options.targetRank;
+ this->lowerPermutationMaps = options.lowerPermutationMaps;
}
void runOnFunction() override {
VectorTransferToSCFOptions options;
- options.setUnroll(fullUnroll);
- options.setTargetRank(targetRank);
+ options.unroll = fullUnroll;
+ options.targetRank = targetRank;
+ options.lowerPermutationMaps = lowerPermutationMaps;
+
+ // Lower permutation maps first.
+ if (lowerPermutationMaps) {
+ RewritePatternSet lowerTransferPatterns(getFunction().getContext());
+ mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
+ lowerTransferPatterns);
+ (void)applyPatternsAndFoldGreedily(getFunction(),
+ std::move(lowerTransferPatterns));
+ }
RewritePatternSet patterns(getFunction().getContext());
populateVectorToSCFConversionPatterns(patterns, options);
/// - The op has no mask.
struct TransferReadToVectorLoadLowering
: public OpRewritePattern<vector::TransferReadOp> {
- TransferReadToVectorLoadLowering(MLIRContext *context)
- : OpRewritePattern<vector::TransferReadOp>(context) {}
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
SmallVector<unsigned, 4> broadcastedDims;
/// - The op has no mask.
struct TransferWriteToVectorStoreLowering
: public OpRewritePattern<vector::TransferWriteOp> {
- TransferWriteToVectorStoreLowering(MLIRContext *context)
- : OpRewritePattern<vector::TransferWriteOp>(context) {}
+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
// TODO: Support non-minor-identity maps
if (permutationMap.isIdentity())
return failure();
+ permutationMap = map.getPermutationMap(permutation, op.getContext());
// Caluclate the map of the new read by applying the inverse permutation.
permutationMap = inversePermutation(permutationMap);
AffineMap newMap = permutationMap.compose(map);
patterns.getContext());
}
+void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<TransferReadPermutationLowering,
+ TransferWritePermutationLowering, TransferOpReduceRank>(
+ patterns.getContext());
+}
+
void mlir::vector::populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns) {
- patterns
- .add<TransferReadToVectorLoadLowering, TransferWriteToVectorStoreLowering,
- TransferReadPermutationLowering, TransferWritePermutationLowering,
- TransferOpReduceRank>(
- patterns.getContext());
+ patterns.add<TransferReadToVectorLoadLowering,
+ TransferWriteToVectorStoreLowering>(patterns.getContext());
+ populateVectorTransferPermutationMapLoweringPatterns(patterns);
}
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
--- /dev/null
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -split-input-file | FileCheck %s
+
+// Ensure that the permutation map is lowered (by inserting a transpose op)
+// before lowering the vector.transfer_read.
+
+// CHECK-LABEL: func @transfer_read_2d_mask_transposed(
+// CHECK-DAG: %[[PADDING:.*]] = constant dense<-4.200000e+01> : vector<9xf32>
+// CHECK-DAG: %[[MASK:.*]] = constant dense<{{.*}}> : vector<9x4xi1>
+// CHECK: %[[MASK_MEM:.*]] = memref.alloca() : memref<vector<4x9xi1>>
+// CHECK: %[[MASK_T:.*]] = vector.transpose %[[MASK]], [1, 0] : vector<9x4xi1> to vector<4x9xi1>
+// CHECK: memref.store %[[MASK_T]], %[[MASK_MEM]][] : memref<vector<4x9xi1>>
+// CHECK: %[[MASK_CASTED:.*]] = vector.type_cast %[[MASK_MEM]] : memref<vector<4x9xi1>> to memref<4xvector<9xi1>>
+// CHECK: scf.for {{.*}} {
+// CHECK: scf.if {{.*}} {
+// CHECK: %[[MASK_LOADED:.*]] = memref.load %[[MASK_CASTED]][%{{.*}}] : memref<4xvector<9xi1>>
+// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}, %{{.*}}, %[[MASK_LOADED]] : memref<?x?xf32>, vector<9xf32>
+// CHECK: memref.store %[[READ]], %{{.*}} : memref<4xvector<9xf32>>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[RESULT:.*]] = memref.load %{{.*}} : memref<vector<4x9xf32>>
+// CHECK: %[[RESULT_T:.*]] = vector.transpose %[[RESULT]], [1, 0] : vector<4x9xf32> to vector<9x4xf32>
+// CHECK: return %[[RESULT_T]] : vector<9x4xf32>
+
+// Vector load with mask + transpose.
+func @transfer_read_2d_mask_transposed(
+ %A : memref<?x?xf32>, %base1: index, %base2: index) -> (vector<9x4xf32>) {
+ %fm42 = constant -42.0: f32
+ %mask = constant dense<[[1, 0, 1, 0], [0, 0, 1, 0],
+ [1, 1, 1, 1], [0, 1, 1, 0],
+ [1, 1, 1, 1], [1, 1, 1, 1],
+ [1, 1, 1, 1], [0, 0, 0, 0],
+ [1, 1, 1, 1]]> : vector<9x4xi1>
+ %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
+ memref<?x?xf32>, vector<9x4xf32>
+ return %f : vector<9x4xf32>
+}
+++ /dev/null
-// Run test with and without test-vector-transfer-lowering-patterns.
-
-// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
-// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
-// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s
-
-// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
-// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
-// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s
-
-
-memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ],
- [10., 11., 12., 13.],
- [20., 21., 22., 23.]]>
-
-// Vector load with transpose.
-func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
- %fm42 = constant -42.0: f32
- %f = vector.transfer_read %A[%base1, %base2], %fm42
- {in_bounds = [true, false], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
- memref<?x?xf32>, vector<3x9xf32>
- vector.print %f: vector<3x9xf32>
- return
-}
-
-func @entry() {
- %c0 = constant 0: index
- %c1 = constant 1: index
- %c2 = constant 2: index
- %c3 = constant 3: index
- %0 = memref.get_global @gv : memref<3x4xf32>
- %A = memref.cast %0 : memref<3x4xf32> to memref<?x?xf32>
-
- // 1. Read 2D vector from 2D memref with transpose.
- call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
- // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( 20, 0, -42, -42, -42, -42, -42, -42, -42 ) )
-
- return
-}
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
-// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
-// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
-// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s