Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
// smaller rank.
- InsertStridedSliceOp insertStridedSliceOp =
- rewriter.create<InsertStridedSliceOp>(
- loc, extractedSource, extractedDest,
- getI64SubArray(op.offsets(), /* dropFront=*/1),
- getI64SubArray(op.strides(), /* dropFront=*/1));
- // Call matchAndRewrite recursively from within the pattern. This
- // circumvents the current limitation that a given pattern cannot
- // be called multiple times by the PatternRewrite infrastructure (to
- // avoid infinite recursion, but in this case, infinite recursion
- // cannot happen because the rank is strictly decreasing).
- // TODO(rriddle, nicolasvasilache) Implement something like a hook for
- // a potential function that must decrease and allow the same pattern
- // multiple times.
- auto success = matchAndRewrite(insertStridedSliceOp, rewriter);
- (void)success;
- assert(succeeded(success) && "Unexpected failure");
- extractedSource = insertStridedSliceOp;
+ extractedSource = rewriter.create<InsertStridedSliceOp>(
+ loc, extractedSource, extractedDest,
+ getI64SubArray(op.offsets(), /* dropFront=*/1),
+ getI64SubArray(op.strides(), /* dropFront=*/1));
}
// 4. Insert the extractedSource into the res vector.
res = insertOne(rewriter, loc, extractedSource, res, off);
rewriter.replaceOp(op, res);
return success();
}
+ /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
+ /// bounded as the rank is strictly decreasing.
+ bool hasBoundedRewriteRecursion() const final { return true; }
};
class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
off += stride, ++idx) {
Value extracted = extractOne(rewriter, loc, op.vector(), off);
if (op.offsets().getValue().size() > 1) {
- StridedSliceOp stridedSliceOp = rewriter.create<StridedSliceOp>(
+ extracted = rewriter.create<StridedSliceOp>(
loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
getI64SubArray(op.sizes(), /* dropFront=*/1),
getI64SubArray(op.strides(), /* dropFront=*/1));
- // Call matchAndRewrite recursively from within the pattern. This
- // circumvents the current limitation that a given pattern cannot
- // be called multiple times by the PatternRewrite infrastructure (to
- // avoid infinite recursion, but in this case, infinite recursion
- // cannot happen because the rank is strictly decreasing).
- // TODO(rriddle, nicolasvasilache) Implement something like a hook for
- // a potential function that must decrease and allow the same pattern
- // multiple times.
- auto success = matchAndRewrite(stridedSliceOp, rewriter);
- (void)success;
- assert(succeeded(success) && "Unexpected failure");
- extracted = stridedSliceOp;
}
res = insertOne(rewriter, loc, extracted, res, idx);
}
rewriter.replaceOp(op, {res});
return success();
}
+ /// This pattern creates recursive StridedSliceOp, but the recursion is
+ /// bounded as the rank is strictly decreasing.
+ bool hasBoundedRewriteRecursion() const final { return true; }
};
} // namespace
return success();
}
};
+
+//===----------------------------------------------------------------------===//
+// Recursive Rewrite Testing
+/// This pattern is applied to the same operation multiple times, but has a
+/// bounded recursion.
+struct TestBoundedRecursiveRewrite
+ : public OpRewritePattern<TestRecursiveRewriteOp> {
+ using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
+ PatternRewriter &rewriter) const final {
+ // Decrement the depth of the op in-place.
+ rewriter.updateRootInPlace(op, [&] {
+ op.setAttr("depth",
+ rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1));
+ });
+ return success();
+ }
+
+ /// The conversion target handles bounding the recursion of this pattern.
+ bool hasBoundedRewriteRecursion() const final { return true; }
+};
} // namespace
namespace {
TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType,
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
- TestNonRootReplacement>(&getContext());
+ TestNonRootReplacement, TestBoundedRecursiveRewrite>(&getContext());
patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter);
op->getAttrOfType<UnitAttr>("test.recursively_legal"));
});
+ // Mark the bound recursion operation as dynamically legal.
+ target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
+ [](TestRecursiveRewriteOp op) { return op.depth() == 0; });
+
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
(void)applyPartialConversion(getOperation(), target, patterns,