/// Attribute name used for labeling transfer ops during progressive lowering.
static const char kPassLabel[] = "__vector_to_scf_lowering__";
-/// Lower to 1D transfer ops. Target-specific lowering will lower those.
-static const int64_t kTargetRank = 1;
+/// Patterns that inherit from this struct have access to
+/// VectorTransferToSCFOptions.
+template <typename OpTy>
+struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
+ explicit VectorToSCFPattern(MLIRContext *context,
+ VectorTransferToSCFOptions opt)
+ : OpRewritePattern<OpTy>(context), options(opt) {}
+
+ VectorTransferToSCFOptions options;
+};
/// Given a MemRefType with VectorType element type, unpack one dimension from
/// the VectorType into the MemRefType.
/// Add the pass label to a vector transfer op if its rank is not the target
/// rank.
template <typename OpTy>
-static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp) {
- if (newXferOp.getVectorType().getRank() > kTargetRank)
+static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp,
+ unsigned targetRank) {
+ if (newXferOp.getVectorType().getRank() > targetRank)
newXferOp->setAttr(kPassLabel, builder.getUnitAttr());
}
/// Note: The loop and type cast are generated in TransferOpConversion.
/// The original TransferReadOp and store op are deleted in `cleanup`.
/// Note: The `mask` operand is set in TransferOpConversion.
- static TransferReadOp rewriteOp(OpBuilder &builder, TransferReadOp xferOp,
- Value buffer, Value iv) {
+ static TransferReadOp rewriteOp(OpBuilder &builder,
+ VectorTransferToSCFOptions options,
+ TransferReadOp xferOp, Value buffer,
+ Value iv) {
SmallVector<Value, 8> storeIndices;
getBufferIndices(xferOp, storeIndices);
storeIndices.push_back(iv);
.value;
maybeApplyPassLabel(builder,
- dyn_cast<TransferReadOp>(newXfer.getDefiningOp()));
+ dyn_cast<TransferReadOp>(newXfer.getDefiningOp()),
+ options.targetRank);
memref_store(newXfer, buffer, storeIndices);
return newXfer.getDefiningOp<TransferReadOp>();
/// to memory.
///
/// Note: For more details, see comments on Strategy<TransferReadOp>.
- static TransferWriteOp rewriteOp(OpBuilder &builder, TransferWriteOp xferOp,
- Value buffer, Value iv) {
+ static TransferWriteOp rewriteOp(OpBuilder &builder,
+ VectorTransferToSCFOptions options,
+ TransferWriteOp xferOp, Value buffer,
+ Value iv) {
SmallVector<Value, 8> loadIndices;
getBufferIndices(xferOp, loadIndices);
loadIndices.push_back(iv);
AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(),
inBoundsAttr);
- maybeApplyPassLabel(builder, newXfer.op);
+ maybeApplyPassLabel(builder, newXfer.op, options.targetRank);
return newXfer;
}
};
template <typename OpTy>
-LogicalResult checkPrepareXferOp(OpTy xferOp) {
+LogicalResult checkPrepareXferOp(OpTy xferOp, unsigned targetRank) {
if (xferOp->hasAttr(kPassLabel))
return failure();
- if (xferOp.getVectorType().getRank() <= kTargetRank)
+ if (xferOp.getVectorType().getRank() <= targetRank)
return failure();
return success();
}
/// ```
///
/// Note: A second temporary buffer may be allocated for the `mask` operand.
-struct PrepareTransferReadConversion : public OpRewritePattern<TransferReadOp> {
- using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+struct PrepareTransferReadConversion
+ : public VectorToSCFPattern<TransferReadOp> {
+ using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
- if (checkPrepareXferOp(xferOp).failed())
+ if (checkPrepareXferOp(xferOp, options.targetRank).failed())
return failure();
ScopedContext scope(rewriter, xferOp.getLoc());
///
/// Note: A second temporary buffer may be allocated for the `mask` operand.
struct PrepareTransferWriteConversion
- : public OpRewritePattern<TransferWriteOp> {
- using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
+ : public VectorToSCFPattern<TransferWriteOp> {
+ using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
LogicalResult matchAndRewrite(TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
- if (checkPrepareXferOp(xferOp).failed())
+ if (checkPrepareXferOp(xferOp, options.targetRank).failed())
return failure();
ScopedContext scope(rewriter, xferOp.getLoc());
/// out-of-bounds, generate an if-check and handle both cases separately.
/// 3. Clean up according to the corresponding Strategy<OpTy>.
template <typename OpTy>
-struct TransferOpConversion : public OpRewritePattern<OpTy> {
- using OpRewritePattern<OpTy>::OpRewritePattern;
+struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
+ using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
LogicalResult matchAndRewrite(OpTy xferOp,
PatternRewriter &rewriter) const override {
/*inBoundsCase=*/
[&](OpBuilder &b, Location /*loc*/) {
// Create new transfer op.
- OpTy newXfer =
- Strategy<OpTy>::rewriteOp(b, xferOp, castedDataBuffer, iv);
+ OpTy newXfer = Strategy<OpTy>::rewriteOp(
+ b, this->options, xferOp, castedDataBuffer, iv);
// If old transfer op has a mask: Set mask on new transfer op.
// Special case: If the mask of the old transfer op is 1D and
/// Note: As an optimization, if the result of the original TransferReadOp
/// was directly inserted into another vector, no new %v_init vector is created.
/// Instead, the new TransferReadOp results are inserted into that vector.
-struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
- using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+struct UnrollTransferReadConversion
+ : public VectorToSCFPattern<TransferReadOp> {
+ using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
/// Return the vector into which the newly created TransferReadOp results
/// are inserted.
/// accesses, and broadcasts and transposes in permutation maps.
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
- if (xferOp.getVectorType().getRank() <= kTargetRank)
+ if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
ScopedContext scope(rewriter, xferOp.getLoc());
/// doing so, `a` may become dead, and the number of ExtractOps generated during
/// recursive application of this pattern will be minimal.
struct UnrollTransferWriteConversion
- : public OpRewritePattern<TransferWriteOp> {
- using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
+ : public VectorToSCFPattern<TransferWriteOp> {
+ using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
/// Return the vector from which newly generated ExtracOps will extract.
Value getDataVector(TransferWriteOp xferOp) const {
/// accesses, and broadcasts and transposes in permutation maps.
LogicalResult matchAndRewrite(TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
- if (xferOp.getVectorType().getRank() <= kTargetRank)
+ if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
ScopedContext scope(rewriter, xferOp.getLoc());
/// }
/// ```
template <typename OpTy>
-struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
- using OpRewritePattern<OpTy>::OpRewritePattern;
+struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
+ using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
LogicalResult matchAndRewrite(OpTy xferOp,
PatternRewriter &rewriter) const override {
RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
if (options.unroll) {
patterns.add<UnrollTransferReadConversion, UnrollTransferWriteConversion>(
- patterns.getContext());
+ patterns.getContext(), options);
} else {
patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
TransferOpConversion<TransferReadOp>,
- TransferOpConversion<TransferWriteOp>>(patterns.getContext());
+ TransferOpConversion<TransferWriteOp>>(patterns.getContext(),
+ options);
}
- if (kTargetRank == 1) {
+ if (options.targetRank == 1) {
patterns.add<TransferOp1dConversion<TransferReadOp>,
- TransferOp1dConversion<TransferWriteOp>>(
- patterns.getContext());
+ TransferOp1dConversion<TransferWriteOp>>(patterns.getContext(),
+ options);
}
}
ConvertVectorToSCFPass() = default;
ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
this->fullUnroll = options.unroll;
+ this->targetRank = options.targetRank;
}
void runOnFunction() override {
+ VectorTransferToSCFOptions options;
+ options.setUnroll(fullUnroll);
+ options.setTargetRank(targetRank);
+
RewritePatternSet patterns(getFunction().getContext());
- populateVectorToSCFConversionPatterns(
- patterns, VectorTransferToSCFOptions().setUnroll(fullUnroll));
+ populateVectorToSCFConversionPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};