[mlir] Add canonicalizer to merge shape.assuming_all ops
authorEugene Zhulenev <ezhulenev@google.com>
Fri, 4 Feb 2022 19:31:59 +0000 (11:31 -0800)
committerEugene Zhulenev <ezhulenev@google.com>
Fri, 4 Feb 2022 23:27:37 +0000 (15:27 -0800)
Depends On D119021

Reviewed By: frgossen

Differential Revision: https://reviews.llvm.org/D119025

mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir

index 661f621f1cee1c2d77e529e9ee60d00264507587..ecf9ade0b05c5e6e3aaf27e40d51ba085255e655 100644 (file)
@@ -460,6 +460,39 @@ LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
 //===----------------------------------------------------------------------===//
 
 namespace {
+
+// Merge multiple `shape.assuming_all` operations together.
+//
+//   %0 = shape.assuming_all %w0, %w1
+//   %1 = shape.assuming_all %w2, %0
+//
+// to:
+//
+//   %0 = shape.assuming_all %w0, %w2, %w2
+struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
+  using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AssumingAllOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value> operands;
+
+    for (Value operand : op.getInputs()) {
+      if (auto assume_all = operand.getDefiningOp<AssumingAllOp>())
+        operands.append(assume_all.operand_begin(), assume_all->operand_end());
+      else
+        operands.push_back(operand);
+    }
+
+    // We didn't find any other `assuming_all` ops to merge with.
+    if (operands.size() == op.getNumOperands())
+      return failure();
+
+    // Replace with a new `assuming_all` operation with merged constraints.
+    rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
+    return success();
+  }
+};
+
 struct AssumingAllToCstrEqCanonicalization
     : public OpRewritePattern<AssumingAllOp> {
   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
@@ -506,7 +539,8 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
 
 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                 MLIRContext *context) {
-  patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
+  patterns.add<MergeAssumingAllOps, AssumingAllOneOp,
+               AssumingAllToCstrEqCanonicalization,
                RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
 }
 
index 75b92640a998d8f0c9bf12748bce4d79b65a64fb..425f8fe71a42ff8f0b9b5db74ed2c7032d3f9061 100644 (file)
@@ -463,6 +463,26 @@ func @cstr_require_no_fold(%arg0: i1) {
   return
 }
 
+// -----
+
+// merge assuming_all operations
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: %[[W0:.*]] = "test.source"
+  // CHECK-NEXT: %[[W1:.*]] = "test.source"
+  // CHECK-NEXT: %[[W2:.*]] = "test.source"
+  // CHECK-NEXT: shape.assuming_all %[[W0]], %[[W1]], %[[W2]]
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %0 = "test.source"() : () -> !shape.witness
+  %1 = "test.source"() : () -> !shape.witness
+  %2 = "test.source"() : () -> !shape.witness
+  %3 = shape.assuming_all %0, %1
+  %4 = shape.assuming_all %3, %2
+  "consume.witness"(%4) : (!shape.witness) -> ()
+  return
+}
+
 // -----
 // `assuming_all` with all `cstr_eq` and shared operands can be collapsed.
 // CHECK-LABEL: func @assuming_all_to_cstr_eq