[mlir] Add a new RewritePattern::hasBoundedRewriteRecursion hook.
authorRiver Riddle <riddleriver@gmail.com>
Thu, 9 Apr 2020 19:38:52 +0000 (12:38 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 9 Apr 2020 19:42:28 +0000 (12:42 -0700)
Summary: Some pattern rewriters, like dialect conversion, prohibit the unbounded recursion(or reapplication) of patterns on generated IR. Most patterns are not written with recursive application in mind, so will generally explode the stack if uncaught. This revision adds a hook to RewritePattern, `hasBoundedRewriteRecursion`, to signal that the pattern can safely be applied to the generated IR of a previous application of the same pattern. This allows for establishing a contract between the pattern and rewriter that the pattern knows and can handle the potential recursive application.

Differential Revision: https://reviews.llvm.org/D77782

mlir/include/mlir/IR/PatternMatch.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Transforms/test-legalizer.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp

index ef3c9fa..457a4b1 100644 (file)
@@ -131,6 +131,12 @@ public:
     return failure();
   }
 
+  /// Returns true if this pattern is known to result in recursive application,
+  /// i.e. this pattern may generate IR that also matches this pattern, but is
+  /// known to bound the recursion. This signals to a rewriter that it is safe
+  /// to apply this pattern recursively to generated IR.
+  virtual bool hasBoundedRewriteRecursion() const { return false; }
+
   /// Return a list of operations that may be generated when rewriting an
   /// operation instance with this pattern.
   ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
index 2fc9b22..eb4bf3b 100644 (file)
@@ -789,23 +789,10 @@ public:
         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
         // smaller rank.
-        InsertStridedSliceOp insertStridedSliceOp =
-            rewriter.create<InsertStridedSliceOp>(
-                loc, extractedSource, extractedDest,
-                getI64SubArray(op.offsets(), /* dropFront=*/1),
-                getI64SubArray(op.strides(), /* dropFront=*/1));
-        // Call matchAndRewrite recursively from within the pattern. This
-        // circumvents the current limitation that a given pattern cannot
-        // be called multiple times by the PatternRewrite infrastructure (to
-        // avoid infinite recursion, but in this case, infinite recursion
-        // cannot happen because the rank is strictly decreasing).
-        // TODO(rriddle, nicolasvasilache) Implement something like a hook for
-        // a potential function that must decrease and allow the same pattern
-        // multiple times.
-        auto success = matchAndRewrite(insertStridedSliceOp, rewriter);
-        (void)success;
-        assert(succeeded(success) && "Unexpected failure");
-        extractedSource = insertStridedSliceOp;
+        extractedSource = rewriter.create<InsertStridedSliceOp>(
+            loc, extractedSource, extractedDest,
+            getI64SubArray(op.offsets(), /* dropFront=*/1),
+            getI64SubArray(op.strides(), /* dropFront=*/1));
       }
       // 4. Insert the extractedSource into the res vector.
       res = insertOne(rewriter, loc, extractedSource, res, off);
@@ -814,6 +801,9 @@ public:
     rewriter.replaceOp(op, res);
     return success();
   }
+  /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
+  /// bounded as the rank is strictly decreasing.
+  bool hasBoundedRewriteRecursion() const final { return true; }
 };
 
 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
@@ -1068,28 +1058,19 @@ public:
          off += stride, ++idx) {
       Value extracted = extractOne(rewriter, loc, op.vector(), off);
       if (op.offsets().getValue().size() > 1) {
-        StridedSliceOp stridedSliceOp = rewriter.create<StridedSliceOp>(
+        extracted = rewriter.create<StridedSliceOp>(
             loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
             getI64SubArray(op.sizes(), /* dropFront=*/1),
             getI64SubArray(op.strides(), /* dropFront=*/1));
-        // Call matchAndRewrite recursively from within the pattern. This
-        // circumvents the current limitation that a given pattern cannot
-        // be called multiple times by the PatternRewrite infrastructure (to
-        // avoid infinite recursion, but in this case, infinite recursion
-        // cannot happen because the rank is strictly decreasing).
-        // TODO(rriddle, nicolasvasilache) Implement something like a hook for
-        // a potential function that must decrease and allow the same pattern
-        // multiple times.
-        auto success = matchAndRewrite(stridedSliceOp, rewriter);
-        (void)success;
-        assert(succeeded(success) && "Unexpected failure");
-        extracted = stridedSliceOp;
       }
       res = insertOne(rewriter, loc, extracted, res, idx);
     }
     rewriter.replaceOp(op, {res});
     return success();
   }
+  /// This pattern creates recursive StridedSliceOp, but the recursion is
+  /// bounded as the rank is strictly decreasing.
+  bool hasBoundedRewriteRecursion() const final { return true; }
 };
 
 } // namespace
index ab6924f..e153e5c 100644 (file)
@@ -1256,10 +1256,9 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
   });
 
   // Ensure that we don't cycle by not allowing the same pattern to be
-  // applied twice in the same recursion stack.
-  // TODO(riverriddle) We could eventually converge, but that requires more
-  // complicated analysis.
-  if (!appliedPatterns.insert(pattern).second) {
+  // applied twice in the same recursion stack if it is not known to be safe.
+  if (!pattern->hasBoundedRewriteRecursion() &&
+      !appliedPatterns.insert(pattern).second) {
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern was already applied"));
     return failure();
   }
index 3305e01..557908d 100644 (file)
@@ -143,6 +143,13 @@ func @create_block() {
   return
 }
 
+// CHECK-LABEL: @bounded_recursion
+func @bounded_recursion() {
+  // CHECK: test.recursive_rewrite 0
+  test.recursive_rewrite 3
+  return
+}
+
 // -----
 
 func @fail_to_convert_illegal_op() -> i32 {
index 8859d50..8eedd1f 100644 (file)
@@ -1061,6 +1061,12 @@ def TestRewriteOp : TEST_Op<"rewrite">,
   Arguments<(ins AnyType)>, Results<(outs AnyType)>;
 def : Pat<(TestRewriteOp $input), (replaceWithValue $input)>;
 
+// Check that patterns can specify bounded recursion when rewriting.
+def TestRecursiveRewriteOp : TEST_Op<"recursive_rewrite"> {
+  let arguments = (ins I64Attr:$depth);
+  let assemblyFormat = "$depth attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // Test Type Legalization
 //===----------------------------------------------------------------------===//
index 39b3fc1..90b34d9 100644 (file)
@@ -360,6 +360,28 @@ struct TestNonRootReplacement : public RewritePattern {
     return success();
   }
 };
+
+//===----------------------------------------------------------------------===//
+// Recursive Rewrite Testing
+/// This pattern is applied to the same operation multiple times, but has a
+/// bounded recursion.
+struct TestBoundedRecursiveRewrite
+    : public OpRewritePattern<TestRecursiveRewriteOp> {
+  using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
+                                PatternRewriter &rewriter) const final {
+    // Decrement the depth of the op in-place.
+    rewriter.updateRootInPlace(op, [&] {
+      op.setAttr("depth",
+                 rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1));
+    });
+    return success();
+  }
+
+  /// The conversion target handles bounding the recursion of this pattern.
+  bool hasBoundedRewriteRecursion() const final { return true; }
+};
 } // namespace
 
 namespace {
@@ -414,7 +436,7 @@ struct TestLegalizePatternDriver
         TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType,
         TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
         TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
-        TestNonRootReplacement>(&getContext());
+        TestNonRootReplacement, TestBoundedRecursiveRewrite>(&getContext());
     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
                                               converter);
@@ -449,6 +471,10 @@ struct TestLegalizePatternDriver
           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
     });
 
+    // Mark the bound recursion operation as dynamically legal.
+    target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
+        [](TestRecursiveRewriteOp op) { return op.depth() == 0; });
+
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
       (void)applyPartialConversion(getOperation(), target, patterns,