[MLIR] Canonicalize `shape.assuming` op to yield only inner values
authorFrederik Gossen <frgossen@google.com>
Tue, 23 Mar 2021 11:00:16 +0000 (12:00 +0100)
committerFrederik Gossen <frgossen@google.com>
Tue, 23 Mar 2021 11:34:50 +0000 (12:34 +0100)
Differential Revision: https://reviews.llvm.org/D99156

mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir

index d2a10a9..0feac87 100644 (file)
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Traits.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -268,12 +269,72 @@ struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
     return success();
   }
 };
+
+// Results of an assuming op that are defined outside its body are available
+// indepentently of the assuming op. There is no need to yield such values. This
+// canonicalization replaces such results with their definition.
+struct AssumingBypassIndependentResult : public OpRewritePattern<AssumingOp> {
+  using OpRewritePattern<AssumingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AssumingOp op,
+                                PatternRewriter &rewriter) const override {
+    Block *body = op.getBody();
+    auto yieldOp = llvm::dyn_cast<AssumingYieldOp>(body->getTerminator());
+    if (!yieldOp)
+      return failure();
+
+    // See if there is at least one result that can bypass the assuming op.
+    auto isDefinedInBody = [&](Value val) {
+      Operation *def = val.getDefiningOp();
+      return def && op->isAncestor(def);
+    };
+    if (llvm::all_of(yieldOp.operands(), isDefinedInBody))
+      return failure();
+
+    SmallVector<Value, 2> replacementValues;
+    auto newAssumingOp = rewriter.create<shape::AssumingOp>(
+        op.getLoc(), op.witness(), [&](OpBuilder &b, Location loc) {
+          // Copy body.
+          BlockAndValueMapping mapping;
+          for (auto &nested : body->without_terminator())
+            b.clone(nested, mapping);
+
+          // Collect new yielded values.
+          SmallVector<Value, 2> mappedResults;
+          for (auto result : yieldOp.getOperands()) {
+            if (isDefinedInBody(result)) {
+              // This value is a result of the assuming op. We can obtain the
+              // replacement value only after the new op is fully constructed.
+              mappedResults.push_back(mapping.lookup(result));
+              replacementValues.push_back(nullptr);
+            } else {
+              // When defined outside of the assuming block, we can use it
+              // direclty. There is no need to yield the value from within the
+              // block.
+              replacementValues.push_back(result);
+            }
+          }
+          return mappedResults;
+        });
+
+    // Use the assuming op's results for the missing replacement values, which
+    // could not bypass the op.
+    auto src = newAssumingOp.getResults().begin();
+    for (auto &dst : replacementValues) {
+      if (dst)
+        continue;
+      dst = *src++;
+    }
+
+    rewriter.replaceOp(op, replacementValues);
+    return success();
+  }
+};
 } // namespace
 
 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                              MLIRContext *context) {
-  // If taking a passing witness, inline region.
-  patterns.add<AssumingWithTrue>(context);
+  patterns.add<AssumingBypassIndependentResult, AssumingWithTrue>(context);
 }
 
 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
index 39f17e9..3c4a282 100644 (file)
@@ -1144,3 +1144,28 @@ func @broadcast_on_single_operand(%a : tensor<3xindex>) {
   "use"(%0) : (tensor<?xindex>) -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: @bypass_assmunig
+// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>)
+func @bypass_assmunig(%arg : tensor<2x3xf32>)
+    -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) {
+  // CHECK: %[[OUTER:.*]] = "some.tensor"
+  // CHECK: %[[WITNESS:.*]] = "some.witness"
+  // CHECK: %[[YIELDED:.*]] = shape.assuming %[[WITNESS]] -> (tensor<2x3xf32>) {
+  // CHECK:   %[[INNER:.*]] = "some.tensor"
+  // CHECK:   shape.assuming_yield %[[INNER]] : tensor<2x3xf32>
+  // CHECK: }
+  // CHECK: return %[[YIELDED]], %[[OUTER]], %[[ARG]]
+  %outer = "some.tensor"() : () -> tensor<2x3xf32>
+  %witness = "some.witness"() : () -> !shape.witness
+  %results:3 = shape.assuming %witness
+      -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) {
+    %inner = "some.tensor"() : () -> tensor<2x3xf32>
+    shape.assuming_yield %inner, %outer, %arg
+        : tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>
+  }
+  return %results#0, %results#1, %results#2
+      : tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>
+}