//===----------------------------------------------------------------------===//
namespace {
+
+// Merge multiple `shape.assuming_all` operations together.
+//
+// %0 = shape.assuming_all %w0, %w1
+// %1 = shape.assuming_all %w2, %0
+//
+// to:
+//
+// %0 = shape.assuming_all %w0, %w2, %w2
+struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
+ using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AssumingAllOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> operands;
+
+ for (Value operand : op.getInputs()) {
+ if (auto assume_all = operand.getDefiningOp<AssumingAllOp>())
+ operands.append(assume_all.operand_begin(), assume_all->operand_end());
+ else
+ operands.push_back(operand);
+ }
+
+ // We didn't find any other `assuming_all` ops to merge with.
+ if (operands.size() == op.getNumOperands())
+ return failure();
+
+ // Replace with a new `assuming_all` operation with merged constraints.
+ rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
+ return success();
+ }
+};
+
struct AssumingAllToCstrEqCanonicalization
: public OpRewritePattern<AssumingAllOp> {
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
+ patterns.add<MergeAssumingAllOps, AssumingAllOneOp,
+ AssumingAllToCstrEqCanonicalization,
RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
}
return
}
+// -----
+
+// merge assuming_all operations
+// CHECK-LABEL: func @f
+func @f() {
+ // CHECK-NEXT: %[[W0:.*]] = "test.source"
+ // CHECK-NEXT: %[[W1:.*]] = "test.source"
+ // CHECK-NEXT: %[[W2:.*]] = "test.source"
+ // CHECK-NEXT: shape.assuming_all %[[W0]], %[[W1]], %[[W2]]
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %0 = "test.source"() : () -> !shape.witness
+ %1 = "test.source"() : () -> !shape.witness
+ %2 = "test.source"() : () -> !shape.witness
+ %3 = shape.assuming_all %0, %1
+ %4 = shape.assuming_all %3, %2
+ "consume.witness"(%4) : (!shape.witness) -> ()
+ return
+}
+
// -----
// `assuming_all` with all `cstr_eq` and shared operands can be collapsed.
// CHECK-LABEL: func @assuming_all_to_cstr_eq