[MLIR][PDL] Refactor the positions for multi-root patterns.
authorStanislav Funiak <stano@cerebras.net>
Tue, 4 Jan 2022 02:33:18 +0000 (08:03 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Tue, 4 Jan 2022 02:33:44 +0000 (08:03 +0530)
When the original version of multi-root patterns was reviewed, several improvements were made to the pdl_interp operations during the review process. Specifically, the "get users of a value at the specified operand index" was split up into "get users" and "compare the users' operands with that value". The iterative execution was also cleaned up to `pdl_interp.foreach`. However, the positions in the pdl-to-pdl_interp lowering were not similarly refactored. This introduced several problems, including hard-to-detect bugs in the lowering and duplicate evaluation of `pdl_interp.get_users`.

This diff cleans up the positions. The "upward" `OperationPosition` was split-out into `UsersPosition` and `ForEachPosition`, and the operand comparison was replaced with a simple predicate. In the process, I fixed three bugs:
1. When multiple roots were had the same connector (i.e., a node that they shared with a subtree at the previously visited root), we would generate a single foreach loop rather than one foreach loop for each such root. The reason for this is that such connectors shared the position. The solution for this is to add root index as an id to the newly introduced `ForEachPosition`.
2. Previously, we would use `pdl_interp.get_operands` indiscriminately, whether or not the operand was variadic. We now correctly detect variadic operands and insert `pdl_interp.get_operand` when needed.
3. In certain corner cases, we would trigger the "connector has not been traversed yet" assertion. This was caused by not inserting the values during the upward traversal correctly. This has now been fixed.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D116080

mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir

index 367bbb5..9362a29 100644 (file)
@@ -248,45 +248,43 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
   switch (pos->getKind()) {
   case Predicates::OperationPos: {
     auto *operationPos = cast<OperationPosition>(pos);
-    if (!operationPos->isUpward()) {
+    if (operationPos->isOperandDefiningOp())
       // Standard (downward) traversal which directly follows the defining op.
       value = builder.create<pdl_interp::GetDefiningOpOp>(
           loc, builder.getType<pdl::OperationType>(), parentVal);
-      break;
-    }
+    else
+      // A passthrough operation position.
+      value = parentVal;
+    break;
+  }
+  case Predicates::UsersPos: {
+    auto *usersPos = cast<UsersPosition>(pos);
 
     // The first operation retrieves the representative value of a range.
-    // This applies only when the parent is a range of values.
-    if (parentVal.getType().isa<pdl::RangeType>())
+    // This applies only when the parent is a range of values and we were
+    // requested to use a representative value (e.g., upward traversal).
+    if (parentVal.getType().isa<pdl::RangeType>() &&
+        usersPos->useRepresentative())
       value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
     else
       value = parentVal;
 
     // The second operation retrieves the users.
     value = builder.create<pdl_interp::GetUsersOp>(loc, value);
-
-    // The third operation iterates over them.
+    break;
+  }
+  case Predicates::ForEachPos: {
     assert(!failureBlockStack.empty() && "expected valid failure block");
     auto foreach = builder.create<pdl_interp::ForEachOp>(
-        loc, value, failureBlockStack.back(), /*initLoop=*/true);
+        loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
     value = foreach.getLoopVariable();
 
-    // Create the success and continuation blocks.
-    Block *successBlock = builder.createBlock(&foreach.region());
-    Block *continueBlock = builder.createBlock(successBlock);
+    // Create the continuation block.
+    Block *continueBlock = builder.createBlock(&foreach.region());
     builder.create<pdl_interp::ContinueOp>(loc);
     failureBlockStack.push_back(continueBlock);
 
-    // The fourth operation extracts the operand(s) of the user at the specified
-    // index (which can be None, indicating all operands).
-    builder.setInsertionPointToStart(&foreach.region().front());
-    Value operands = builder.create<pdl_interp::GetOperandsOp>(
-        loc, parentVal.getType(), value, operationPos->getIndex());
-
-    // The fifth operation compares the operands to the parent value / range.
-    builder.create<pdl_interp::AreEqualOp>(loc, parentVal, operands,
-                                           successBlock, continueBlock);
-    currentBlock = successBlock;
+    currentBlock = &foreach.region().front();
     break;
   }
   case Predicates::OperandPos: {
index 07fa5c7..a12f317 100644 (file)
@@ -48,4 +48,6 @@ OperandGroupPosition::OperandGroupPosition(const KeyTy &key) : Base(key) {
 //===----------------------------------------------------------------------===//
 // OperationPosition
 
-constexpr unsigned OperationPosition::kDown;
+bool OperationPosition::isOperandDefiningOp() const {
+  return isa_and_nonnull<OperandPosition, OperandGroupPosition>(parent);
+}
index 266580b..1d72399 100644 (file)
@@ -52,6 +52,8 @@ enum Kind : unsigned {
   TypePos,
   AttributeLiteralPos,
   TypeLiteralPos,
+  UsersPos,
+  ForEachPos,
 
   // Questions, ordered by dependency and decreasing priority.
   IsNotNullQuestion,
@@ -186,6 +188,20 @@ struct AttributeLiteralPosition
 };
 
 //===----------------------------------------------------------------------===//
+// ForEachPosition
+
+/// A position describing an iterative choice of an operation.
+struct ForEachPosition : public PredicateBase<ForEachPosition, Position,
+                                              std::pair<Position *, unsigned>,
+                                              Predicates::ForEachPos> {
+  explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; }
+
+  /// Returns the ID, for differentiating various loops.
+  /// For upward traversals, this is the index of the root.
+  unsigned getID() const { return key.second; }
+};
+
+//===----------------------------------------------------------------------===//
 // OperandPosition
 
 /// A position describing an operand of an operation.
@@ -229,14 +245,11 @@ struct OperandGroupPosition
 
 /// An operation position describes an operation node in the IR. Other position
 /// kinds are formed with respect to an operation position.
-struct OperationPosition
-    : public PredicateBase<OperationPosition, Position,
-                           std::tuple<Position *, Optional<unsigned>, unsigned>,
-                           Predicates::OperationPos> {
-  static constexpr unsigned kDown = std::numeric_limits<unsigned>::max();
-
+struct OperationPosition : public PredicateBase<OperationPosition, Position,
+                                                std::pair<Position *, unsigned>,
+                                                Predicates::OperationPos> {
   explicit OperationPosition(const KeyTy &key) : Base(key) {
-    parent = std::get<0>(key);
+    parent = key.first;
   }
 
   /// Returns a hash suitable for the given keytype.
@@ -246,31 +259,22 @@ struct OperationPosition
 
   /// Gets the root position.
   static OperationPosition *getRoot(StorageUniquer &uniquer) {
-    return Base::get(uniquer, nullptr, kDown, 0);
+    return Base::get(uniquer, nullptr, 0);
   }
 
-  /// Gets an downward operation position with the given parent.
+  /// Gets an operation position with the given parent.
   static OperationPosition *get(StorageUniquer &uniquer, Position *parent) {
-    return Base::get(uniquer, parent, kDown, parent->getOperationDepth() + 1);
-  }
-
-  /// Gets an upward operation position with the given parent and operand.
-  static OperationPosition *get(StorageUniquer &uniquer, Position *parent,
-                                Optional<unsigned> operand) {
-    return Base::get(uniquer, parent, operand, parent->getOperationDepth() + 1);
+    return Base::get(uniquer, parent, parent->getOperationDepth() + 1);
   }
 
-  /// Returns the operand index for an upward operation position.
-  Optional<unsigned> getIndex() const { return std::get<1>(key); }
-
-  /// Returns if this operation position is upward, accepting an input.
-  bool isUpward() const { return getIndex().getValueOr(0) != kDown; }
-
   /// Returns the depth of this position.
-  unsigned getDepth() const { return std::get<2>(key); }
+  unsigned getDepth() const { return key.second; }
 
   /// Returns if this operation position corresponds to the root.
   bool isRoot() const { return getDepth() == 0; }
+
+  /// Returns if this operation represents an operand defining op.
+  bool isOperandDefiningOp() const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -341,6 +345,26 @@ struct TypeLiteralPosition
 };
 
 //===----------------------------------------------------------------------===//
+// UsersPosition
+
+/// A position describing the users of a value or a range of values. The second
+/// value in the key indicates whether we choose users of a representative for
+/// a range (this is true, e.g., in the upward traversals).
+struct UsersPosition
+    : public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>,
+                           Predicates::UsersPos> {
+  explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; }
+
+  /// Returns a hash suitable for the given keytype.
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    return llvm::hash_value(key);
+  }
+
+  /// Indicates whether to compute a range of a representative.
+  bool useRepresentative() const { return key.second; }
+};
+
+//===----------------------------------------------------------------------===//
 // Qualifiers
 //===----------------------------------------------------------------------===//
 
@@ -496,6 +520,7 @@ public:
     // Register the types of Positions with the uniquer.
     registerParametricStorageType<AttributePosition>();
     registerParametricStorageType<AttributeLiteralPosition>();
+    registerParametricStorageType<ForEachPosition>();
     registerParametricStorageType<OperandPosition>();
     registerParametricStorageType<OperandGroupPosition>();
     registerParametricStorageType<OperationPosition>();
@@ -503,6 +528,7 @@ public:
     registerParametricStorageType<ResultGroupPosition>();
     registerParametricStorageType<TypePosition>();
     registerParametricStorageType<TypeLiteralPosition>();
+    registerParametricStorageType<UsersPosition>();
 
     // Register the types of Questions with the uniquer.
     registerParametricStorageType<AttributeAnswer>();
@@ -550,12 +576,10 @@ public:
     return OperationPosition::get(uniquer, p);
   }
 
-  /// Returns the position of operation using the value at the given index.
-  OperationPosition *getUsersOp(Position *p, Optional<unsigned> operand) {
-    assert((isa<OperandPosition, OperandGroupPosition, ResultPosition,
-                ResultGroupPosition>(p)) &&
-           "expected result position");
-    return OperationPosition::get(uniquer, p, operand);
+  /// Returns the operation position equivalent to the given position.
+  OperationPosition *getPassthroughOp(Position *p) {
+    assert((isa<ForEachPosition>(p)) && "expected users position");
+    return OperationPosition::get(uniquer, p);
   }
 
   /// Returns an attribute position for an attribute of the given operation.
@@ -568,6 +592,10 @@ public:
     return AttributeLiteralPosition::get(uniquer, attr);
   }
 
+  Position *getForEach(Position *p, unsigned id) {
+    return ForEachPosition::get(uniquer, p, id);
+  }
+
   /// Returns an operand position for an operand of the given operation.
   Position *getOperand(OperationPosition *p, unsigned operand) {
     return OperandPosition::get(uniquer, p, operand);
@@ -605,6 +633,14 @@ public:
     return TypeLiteralPosition::get(uniquer, attr);
   }
 
+  /// Returns the users of a position using the value at the given operand.
+  UsersPosition *getUsers(Position *p, bool useRepresentative) {
+    assert((isa<OperandPosition, OperandGroupPosition, ResultPosition,
+                ResultGroupPosition>(p)) &&
+           "expected result position");
+    return UsersPosition::get(uniquer, p, useRepresentative);
+  }
+
   //===--------------------------------------------------------------------===//
   // Qualifiers
   //===--------------------------------------------------------------------===//
index 43c57a8..24b2f19 100644 (file)
@@ -158,8 +158,11 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
   // group, we treat it as all of the operands/results of the operation.
   /// Operands.
   if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
-    getTreePredicates(predList, operands.front(), builder, inputs,
-                      builder.getAllOperands(opPos));
+    // Ignore the operands if we are performing an upward traversal (in that
+    // case, they have already been visited).
+    if (opPos->isRoot() || opPos->isOperandDefiningOp())
+      getTreePredicates(predList, operands.front(), builder, inputs,
+                        builder.getAllOperands(opPos));
   } else {
     bool foundVariableLength = false;
     for (const auto &operandIt : llvm::enumerate(operands)) {
@@ -502,23 +505,47 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
          "the pattern contains a candidate root disconnected from the others");
 }
 
+/// Returns true if the operand at the given index needs to be queried using an
+/// operand group, i.e., if it is variadic itself or follows a variadic operand.
+static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
+  OperandRange operands = op.operands();
+  assert(index < operands.size() && "operand index out of range");
+  for (unsigned i = 0; i <= index; ++i)
+    if (operands[i].getType().isa<pdl::RangeType>())
+      return true;
+  return false;
+}
+
 /// Visit a node during upward traversal.
-void visitUpward(std::vector<PositionalPredicate> &predList, OpIndex opIndex,
-                 PredicateBuilder &builder,
-                 DenseMap<Value, Position *> &valueToPosition, Position *&pos,
-                 bool &first) {
+static void visitUpward(std::vector<PositionalPredicate> &predList,
+                        OpIndex opIndex, PredicateBuilder &builder,
+                        DenseMap<Value, Position *> &valueToPosition,
+                        Position *&pos, unsigned rootID) {
   Value value = opIndex.parent;
   TypeSwitch<Operation *>(value.getDefiningOp())
       .Case<pdl::OperationOp>([&](auto operationOp) {
         LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
-        OperationPosition *opPos = builder.getUsersOp(pos, opIndex.index);
 
-        // Guard against traversing back to where we came from.
-        if (first) {
-          Position *parent = pos->getParent();
-          predList.emplace_back(opPos, builder.getNotEqualTo(parent));
-          first = false;
+        // Get users and iterate over them.
+        Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
+        Position *foreachPos = builder.getForEach(usersPos, rootID);
+        OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
+
+        // Compare the operand(s) of the user against the input value(s).
+        Position *operandPos;
+        if (!opIndex.index) {
+          // We are querying all the operands of the operation.
+          operandPos = builder.getAllOperands(opPos);
+        } else if (useOperandGroup(operationOp, *opIndex.index)) {
+          // We are querying an operand group.
+          Type type = operationOp.operands()[*opIndex.index].getType();
+          bool variadic = type.isa<pdl::RangeType>();
+          operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
+        } else {
+          // We are querying an individual operand.
+          operandPos = builder.getOperand(opPos, *opIndex.index);
         }
+        predList.emplace_back(operandPos, builder.getEqualTo(pos));
 
         // Guard against duplicate upward visits. These are not possible,
         // because if this value was already visited, it would have been
@@ -540,6 +567,9 @@ void visitUpward(std::vector<PositionalPredicate> &predList, OpIndex opIndex,
         auto *opPos = dyn_cast<OperationPosition>(pos);
         assert(opPos && "operations and results must be interleaved");
         pos = builder.getResult(opPos, *opIndex.index);
+
+        // Insert the result position in case we have not visited it yet.
+        valueToPosition.try_emplace(value, pos);
       })
       .Case<pdl::ResultsOp>([&](auto resultOp) {
         // Traverse up a group of results.
@@ -550,6 +580,9 @@ void visitUpward(std::vector<PositionalPredicate> &predList, OpIndex opIndex,
           pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
         else
           pos = builder.getAllResults(opPos);
+
+        // Insert the result position in case we have not visited it yet.
+        valueToPosition.try_emplace(value, pos);
       });
 }
 
@@ -568,7 +601,8 @@ static Value buildPredicateList(pdl::PatternOp pattern,
   LLVM_DEBUG({
     llvm::dbgs() << "Graph:\n";
     for (auto &target : graph) {
-      llvm::dbgs() << "  * " << target.first << "\n";
+      llvm::dbgs() << "  * " << target.first.getLoc() << " " << target.first
+                   << "\n";
       for (auto &source : target.second) {
         RootOrderingEntry &entry = source.second;
         llvm::dbgs() << "      <- " << source.first << ": " << entry.cost.first
@@ -601,6 +635,17 @@ static Value buildPredicateList(pdl::PatternOp pattern,
     bestEdges = solver.preOrderTraversal(roots);
   }
 
+  // Print the best solution.
+  LLVM_DEBUG({
+    llvm::dbgs() << "Best tree:\n";
+    for (const std::pair<Value, Value> &edge : bestEdges) {
+      llvm::dbgs() << "  * " << edge.first;
+      if (edge.second)
+        llvm::dbgs() << " <- " << edge.second;
+      llvm::dbgs() << "\n";
+    }
+  });
+
   LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << bestRoot << "\n");
 
@@ -612,9 +657,9 @@ static Value buildPredicateList(pdl::PatternOp pattern,
   // Traverse the selected optimal branching. For all edges in order, traverse
   // up starting from the connector, until the candidate root is reached, and
   // call getTreePredicates at every node along the way.
-  for (const std::pair<Value, Value> &edge : bestEdges) {
-    Value target = edge.first;
-    Value source = edge.second;
+  for (auto it : llvm::enumerate(bestEdges)) {
+    Value target = it.value().first;
+    Value source = it.value().second;
 
     // Check if we already visited the target root. This happens in two cases:
     // 1) the initial root (bestRoot);
@@ -629,14 +674,13 @@ static Value buildPredicateList(pdl::PatternOp pattern,
     LLVM_DEBUG(llvm::dbgs() << "  * Connector: " << connector.getLoc() << "\n");
     DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
     Position *pos = valueToPosition.lookup(connector);
-    assert(pos && "The value has not been traversed yet");
-    bool first = true;
+    assert(pos && "connector has not been traversed yet");
 
     // Traverse from the connector upwards towards the target root.
     for (Value value = connector; value != target;) {
       OpIndex opIndex = parentMap.lookup(value);
       assert(opIndex.parent && "missing parent");
-      visitUpward(predList, opIndex, builder, valueToPosition, pos, first);
+      visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
       value = opIndex.parent;
     }
   }
index 984a317..fd6cfe5 100644 (file)
@@ -423,8 +423,8 @@ module @multi_root {
   // CHECK-DAG: %[[OP1:.*]] = pdl_interp.get_defining_op of %[[VAL1]]
   // CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_users of %[[VAL1]] : !pdl.value
   // CHECK-DAG: pdl_interp.foreach %[[ROOT2:.*]] : !pdl.operation in %[[OPS]]
-  // CHECK-DAG:   %[[OPERANDS:.*]] = pdl_interp.get_operands 0 of %[[ROOT2]]
-  // CHECK-DAG:   pdl_interp.are_equal %[[VAL1]], %[[OPERANDS]] : !pdl.value -> ^{{.*}}, ^[[CONTINUE:.*]]
+  // CHECK-DAG:   %[[OPERANDS:.*]] = pdl_interp.get_operand 0 of %[[ROOT2]]
+  // CHECK-DAG:   pdl_interp.are_equal %[[OPERANDS]], %[[VAL1]] : !pdl.value -> ^{{.*}}, ^[[CONTINUE:.*]]
   // CHECK-DAG:   pdl_interp.continue
   // CHECK-DAG:   %[[VAL2:.*]] = pdl_interp.get_operand 1 of %[[ROOT2]]
   // CHECK-DAG:   %[[OP2:.*]] = pdl_interp.get_defining_op of %[[VAL2]]
@@ -433,7 +433,6 @@ module @multi_root {
   // CHECK-DAG:   pdl_interp.is_not_null %[[VAL1]] : !pdl.value
   // CHECK-DAG:   pdl_interp.is_not_null %[[VAL2]] : !pdl.value
   // CHECK-DAG:   pdl_interp.is_not_null %[[ROOT2]] : !pdl.operation
-  // CHECK-DAG:   pdl_interp.are_equal %[[ROOT2]], %[[ROOT1]] : !pdl.operation -> ^[[CONTINUE]]
 
   pdl.pattern @rewrite_multi_root : benefit(1) {
     %input1 = pdl.operand
@@ -556,7 +555,7 @@ module @variadic_results_at {
   // CHECK-DAG: %[[ROOTS2:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value
   // CHECK-DAG: pdl_interp.foreach %[[ROOT2:.*]] : !pdl.operation in %[[ROOTS2]] {
   // CHECK-DAG:   %[[OPERANDS:.*]] = pdl_interp.get_operands 1 of %[[ROOT2]]
-  // CHECK-DAG:   pdl_interp.are_equal %[[VALS]], %[[OPERANDS]] : !pdl.range<value> -> ^{{.*}}, ^[[CONTINUE:.*]]
+  // CHECK-DAG:   pdl_interp.are_equal %[[OPERANDS]], %[[VALS]] : !pdl.range<value> -> ^{{.*}}, ^[[CONTINUE:.*]]
   // CHECK-DAG:   pdl_interp.is_not_null %[[ROOT2]]
   // CHECK-DAG:   pdl_interp.check_operand_count of %[[ROOT2]] is at_least 1
   // CHECK-DAG:   pdl_interp.check_result_count of %[[ROOT2]] is 0
@@ -612,3 +611,83 @@ module @type_literal {
   }
 }
 
+// -----
+
+// CHECK-LABEL: module @common_connector
+module @common_connector {
+  // Check the correct lowering when multiple roots are using the same
+  // connector.
+
+  // CHECK: func @matcher(%[[ROOTC:.*]]: !pdl.operation)
+  // CHECK-DAG: %[[VAL2:.*]] = pdl_interp.get_operand 0 of %[[ROOTC]]
+  // CHECK-DAG: %[[INTER:.*]] = pdl_interp.get_defining_op of %[[VAL2]] : !pdl.value
+  // CHECK-DAG: pdl_interp.is_not_null %[[INTER]] : !pdl.operation -> ^bb2, ^bb1
+  // CHECK-DAG: %[[VAL1:.*]] = pdl_interp.get_operand 0 of %[[INTER]]
+  // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VAL1]] : !pdl.value
+  // CHECK-DAG: pdl_interp.is_not_null %[[OP]]
+  // CHECK-DAG: %[[VAL0:.*]] = pdl_interp.get_result 0 of %[[OP]]
+  // CHECK-DAG: %[[ROOTS:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value
+  // CHECK-DAG: pdl_interp.foreach %[[ROOTA:.*]] : !pdl.operation in %[[ROOTS]] {
+  // CHECK-DAG:   pdl_interp.is_not_null %[[ROOTA]] : !pdl.operation -> ^{{.*}}, ^[[CONTA:.*]]
+  // CHECK-DAG:   pdl_interp.continue
+  // CHECK-DAG:   pdl_interp.foreach %[[ROOTB:.*]] : !pdl.operation in %[[ROOTS]] {
+  // CHECK-DAG:     pdl_interp.is_not_null %[[ROOTB]] : !pdl.operation -> ^{{.*}}, ^[[CONTB:.*]]
+  // CHECK-DAG:     %[[ROOTA_OP:.*]] = pdl_interp.get_operand 0 of %[[ROOTA]]
+  // CHECK-DAG:     pdl_interp.are_equal %[[ROOTA_OP]], %[[VAL0]] : !pdl.value
+  // CHECK-DAG:     %[[ROOTB_OP:.*]] = pdl_interp.get_operand 0 of %[[ROOTB]]
+  // CHECK-DAG:     pdl_interp.are_equal %[[ROOTB_OP]], %[[VAL0]] : !pdl.value
+  // CHECK-DAG    } -> ^[[CONTA:.*]]
+  pdl.pattern @common_connector : benefit(1) {
+      %type = pdl.type
+      %op = pdl.operation -> (%type, %type : !pdl.type, !pdl.type)
+      %val0 = pdl.result 0 of %op
+      %val1 = pdl.result 1 of %op
+      %rootA = pdl.operation (%val0 : !pdl.value)
+      %rootB = pdl.operation (%val0 : !pdl.value)
+      %inter = pdl.operation (%val1 : !pdl.value) -> (%type : !pdl.type)
+      %val2 = pdl.result 0 of %inter
+      %rootC = pdl.operation (%val2 : !pdl.value)
+      pdl.rewrite with "rewriter"(%rootA, %rootB, %rootC : !pdl.operation, !pdl.operation, !pdl.operation)
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module @common_connector_range
+module @common_connector_range {
+  // Check the correct lowering when multiple roots are using the same
+  // connector range.
+
+  // CHECK: func @matcher(%[[ROOTC:.*]]: !pdl.operation)
+  // CHECK-DAG: %[[VALS2:.*]] = pdl_interp.get_operands of %[[ROOTC]] : !pdl.range<value>
+  // CHECK-DAG: %[[INTER:.*]] = pdl_interp.get_defining_op of %[[VALS2]] : !pdl.range<value>
+  // CHECK-DAG: pdl_interp.is_not_null %[[INTER]] : !pdl.operation -> ^bb2, ^bb1
+  // CHECK-DAG: %[[VALS1:.*]] = pdl_interp.get_operands of %[[INTER]] : !pdl.range<value>
+  // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VALS1]] : !pdl.range<value>
+  // CHECK-DAG: pdl_interp.is_not_null %[[OP]]
+  // CHECK-DAG: %[[VALS0:.*]] = pdl_interp.get_results 0 of %[[OP]]
+  // CHECK-DAG: %[[VAL0:.*]] = pdl_interp.extract 0 of %[[VALS0]] : !pdl.value
+  // CHECK-DAG: %[[ROOTS:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value
+  // CHECK-DAG: pdl_interp.foreach %[[ROOTA:.*]] : !pdl.operation in %[[ROOTS]] {
+  // CHECK-DAG:   pdl_interp.is_not_null %[[ROOTA]] : !pdl.operation -> ^{{.*}}, ^[[CONTA:.*]]
+  // CHECK-DAG:   pdl_interp.continue
+  // CHECK-DAG:   pdl_interp.foreach %[[ROOTB:.*]] : !pdl.operation in %[[ROOTS]] {
+  // CHECK-DAG:     pdl_interp.is_not_null %[[ROOTB]] : !pdl.operation -> ^{{.*}}, ^[[CONTB:.*]]
+  // CHECK-DAG:     %[[ROOTA_OPS:.*]] = pdl_interp.get_operands of %[[ROOTA]]
+  // CHECK-DAG:     pdl_interp.are_equal %[[ROOTA_OPS]], %[[VALS0]] : !pdl.range<value>
+  // CHECK-DAG:     %[[ROOTB_OPS:.*]] = pdl_interp.get_operands of %[[ROOTB]]
+  // CHECK-DAG:     pdl_interp.are_equal %[[ROOTB_OPS]], %[[VALS0]] : !pdl.range<value>
+  // CHECK-DAG    } -> ^[[CONTA:.*]]
+  pdl.pattern @common_connector_range : benefit(1) {
+    %types = pdl.types
+    %op = pdl.operation -> (%types, %types : !pdl.range<type>, !pdl.range<type>)
+    %vals0 = pdl.results 0 of %op -> !pdl.range<value>
+    %vals1 = pdl.results 1 of %op -> !pdl.range<value>
+    %rootA = pdl.operation (%vals0 : !pdl.range<value>)
+    %rootB = pdl.operation (%vals0 : !pdl.range<value>)
+    %inter = pdl.operation (%vals1 : !pdl.range<value>) -> (%types : !pdl.range<type>)
+    %vals2 = pdl.results of %inter
+    %rootC = pdl.operation (%vals2 : !pdl.range<value>)
+    pdl.rewrite with "rewriter"(%rootA, %rootB, %rootC : !pdl.operation, !pdl.operation, !pdl.operation)
+  }
+}