[mlir] OperationFolder: fix crash in creation of single-result-ops with in-place...
authorAlex Zinenko <zinenko@google.com>
Wed, 6 May 2020 15:39:23 +0000 (17:39 +0200)
committerAlex Zinenko <zinenko@google.com>
Wed, 6 May 2020 18:40:32 +0000 (20:40 +0200)
When the folding is performed in place, the `::fold` function does not populate
its `results` argument to indicate that. (In the folding hook for single-result
operations, the result of the original operation is expected to be returned,
but it is then ignored by the wrapper.) `OperationFolder::create` would
erronously rely on the _operation_ having zero results instead of on the
_folding_ producing zero new results to populate the list of results with those
of the original operation. This would lead to a crash for single-result ops
with in-place folds where the first result is accessed uncondtionally because
the list of results was not properly populated. Use the list of values produced
by the folding instead.

Differential Revision: https://reviews.llvm.org/D79497

mlir/include/mlir/Transforms/FoldUtils.h
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp

index f8c678d..d427f0b 100644 (file)
@@ -77,14 +77,14 @@ public:
   void create(OpBuilder &builder, SmallVectorImpl<Value> &results,
               Location location, Args &&... args) {
     // The op needs to be inserted only if the fold (below) fails, or the number
-    // of results of the op is zero (which is treated as an in-place
-    // fold). Using create methods of the builder will insert the op, so not
-    // using it here.
+    // of results produced by the successful folding is zero (which is treated
+    // as an in-place fold). Using create methods of the builder will insert the
+    // op, so not using it here.
     OperationState state(location, OpTy::getOperationName());
     OpTy::build(builder, state, std::forward<Args>(args)...);
     Operation *op = Operation::create(state);
 
-    if (failed(tryToFold(builder, op, results)) || op->getNumResults() == 0) {
+    if (failed(tryToFold(builder, op, results)) || results.empty()) {
       builder.insert(op);
       results.assign(op->result_begin(), op->result_end());
       return;
index 1a40f99..fb7acde 100644 (file)
@@ -323,6 +323,15 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold(
   return success();
 }
 
+OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 1);
+  if (operands.front()) {
+    setAttr("attr", operands.front());
+    return getResult();
+  }
+  return {};
+}
+
 LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
     MLIRContext *, Optional<Location> location, ValueRange operands,
     ArrayRef<NamedAttribute> attributes, RegionRange regions,
index f9140f2..3e49a1d 100644 (file)
@@ -734,6 +734,17 @@ def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
   let results = (outs I32);
 }
 
+def TestOpInPlaceFoldAnchor : TEST_Op<"op_in_place_fold_anchor"> {
+  let arguments = (ins I32);
+  let results = (outs I32);
+}
+
+def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
+  let arguments = (ins I32:$op, I32Attr:$attr);
+  let results = (outs I32);
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Test Patterns (Symbol Binding)
 
index deb1cf5..7c91d5f 100644 (file)
@@ -8,9 +8,11 @@
 
 #include "TestDialect.h"
 #include "mlir/Conversion/StandardToStandard/StandardToStandard.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/FoldUtils.h"
 
 using namespace mlir;
 
@@ -39,13 +41,36 @@ namespace {
 //===----------------------------------------------------------------------===//
 
 namespace {
+struct FoldingPattern : public RewritePattern {
+public:
+  FoldingPattern(MLIRContext *context)
+      : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
+                       /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Exercice OperationFolder API for a single-result operation that is folded
+    // upon construction. The operation being created through the folder has an
+    // in-place folder, and it should be still present in the output.
+    // Furthermore, the folder should not crash when attempting to recover the
+    // (unchanged) opeation result.
+    OperationFolder folder(op->getContext());
+    Value result = folder.create<TestOpInPlaceFold>(
+        rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
+        rewriter.getI32IntegerAttr(0));
+    assert(result);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
   void runOnFunction() override {
     mlir::OwningRewritePatternList patterns;
     populateWithGenerated(&getContext(), &patterns);
 
     // Verify named pattern is generated with expected name.
-    patterns.insert<TestNamedPatternRule>(&getContext());
+    patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext());
 
     applyPatternsAndFoldGreedily(getFunction(), patterns);
   }