[mlir][interfaces] Add helpers for detecting recursive regions
authorMatthias Springer <springerm@google.com>
Tue, 19 Apr 2022 07:12:40 +0000 (16:12 +0900)
committerMatthias Springer <springerm@google.com>
Tue, 19 Apr 2022 07:13:32 +0000 (16:13 +0900)
Add helper functions to check if an op may be executed multiple times based on RegionBranchOpInterface.

Differential Revision: https://reviews.llvm.org/D123789

mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
mlir/lib/Interfaces/ControlFlowInterfaces.cpp
mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp

index ff4304c..e9be31e 100644 (file)
@@ -216,6 +216,16 @@ private:
 /// RegionBranchOpInterface.
 bool insideMutuallyExclusiveRegions(Operation *a, Operation *b);
 
+/// Return the first enclosing region of the given op that may be executed
+/// repetitively as per RegionBranchOpInterface or `nullptr` if no such region
+/// exists.
+Region *getEnclosingRepetitiveRegion(Operation *op);
+
+/// Return the first enclosing region of the given Value that may be executed
+/// repetitively as per RegionBranchOpInterface or `nullptr` if no such region
+/// exists.
+Region *getEnclosingRepetitiveRegion(Value value);
+
 //===----------------------------------------------------------------------===//
 // RegionBranchTerminatorOpInterface
 //===----------------------------------------------------------------------===//
index ac805ea..198a38c 100644 (file)
@@ -211,6 +211,11 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
        SmallVector<Attribute, 2> nullAttrs(getOperation()->getNumOperands());
        getSuccessorRegions(index, nullAttrs, regions);
     }
+
+    /// Return `true` if control flow originating from the given region may
+    /// eventually branch back to the same region. (Maybe after passing through
+    /// other regions.)
+    bool isRepetitiveRegion(unsigned index);
   }];
 }
 
index 69ed30a..2ed3a9f 100644 (file)
@@ -309,6 +309,57 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
   return false;
 }
 
+bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
+  SmallVector<bool> visited(getOperation()->getNumRegions(), false);
+  visited[index] = true;
+
+  // Retrieve all successors of the region and enqueue them in the worklist.
+  SmallVector<unsigned> worklist;
+  auto enqueueAllSuccessors = [&](unsigned index) {
+    SmallVector<RegionSuccessor> successors;
+    this->getSuccessorRegions(index, successors);
+    for (RegionSuccessor successor : successors)
+      if (!successor.isParent())
+        worklist.push_back(successor.getSuccessor()->getRegionNumber());
+  };
+  enqueueAllSuccessors(index);
+
+  // Process all regions in the worklist via DFS.
+  while (!worklist.empty()) {
+    unsigned nextRegion = worklist.pop_back_val();
+    if (nextRegion == index)
+      return true;
+    if (visited[nextRegion])
+      continue;
+    visited[nextRegion] = true;
+    enqueueAllSuccessors(nextRegion);
+  }
+
+  return false;
+}
+
+Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
+  while (Region *region = op->getParentRegion()) {
+    op = region->getParentOp();
+    if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
+      if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+        return region;
+  }
+  return nullptr;
+}
+
+Region *mlir::getEnclosingRepetitiveRegion(Value value) {
+  Region *region = value.getParentRegion();
+  while (region) {
+    Operation *op = region->getParentOp();
+    if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
+      if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+        return region;
+    region = op->getParentRegion();
+  }
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // RegionBranchTerminatorOpInterface
 //===----------------------------------------------------------------------===//
index 4f47316..5f43321 100644 (file)
@@ -42,6 +42,29 @@ struct MutuallyExclusiveRegionsOp
                            SmallVectorImpl<RegionSuccessor> &regions) {}
 };
 
+/// All regions of this op call each other in a large circle.
+struct LoopRegionsOp
+    : public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> {
+  using Op::Op;
+  static const unsigned kNumRegions = 3;
+
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+  static StringRef getOperationName() { return "cftest.loop_regions_op"; }
+
+  void getSuccessorRegions(Optional<unsigned> index,
+                           ArrayRef<Attribute> operands,
+                           SmallVectorImpl<RegionSuccessor> &regions) {
+    if (index) {
+      if (*index == 1)
+        // This region also branches back to the parent.
+        regions.push_back(RegionSuccessor());
+      regions.push_back(
+          RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
+    }
+  }
+};
+
 /// Regions are executed sequentially.
 struct SequentialRegionsOp
     : public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
@@ -65,7 +88,8 @@ struct SequentialRegionsOp
 struct CFTestDialect : Dialect {
   explicit CFTestDialect(MLIRContext *ctx)
       : Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
-    addOperations<DummyOp, MutuallyExclusiveRegionsOp, SequentialRegionsOp>();
+    addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
+                  SequentialRegionsOp>();
   }
   static StringRef getDialectNamespace() { return "cftest"; }
 };
@@ -142,3 +166,52 @@ TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
   EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
   EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
 }
+
+TEST(RegionBranchOpInterface, RecursiveRegions) {
+  const char *ir = R"MLIR(
+"cftest.loop_regions_op"() (
+      {"cftest.dummy_op"() : () -> ()},  // op1
+      {"cftest.dummy_op"() : () -> ()},  // op2
+      {"cftest.dummy_op"() : () -> ()}   // op3
+  ) : () -> ()
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<CFTestDialect>();
+  MLIRContext ctx(registry);
+
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+  Operation *testOp = &module->getBody()->getOperations().front();
+  auto regionOp = cast<RegionBranchOpInterface>(testOp);
+  Operation *op1 = &testOp->getRegion(0).front().front();
+  Operation *op2 = &testOp->getRegion(1).front().front();
+  Operation *op3 = &testOp->getRegion(2).front().front();
+
+  EXPECT_TRUE(regionOp.isRepetitiveRegion(0));
+  EXPECT_TRUE(regionOp.isRepetitiveRegion(1));
+  EXPECT_TRUE(regionOp.isRepetitiveRegion(2));
+  EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr);
+  EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr);
+  EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr);
+}
+
+TEST(RegionBranchOpInterface, NotRecursiveRegions) {
+  const char *ir = R"MLIR(
+"cftest.sequential_regions_op"() (
+      {"cftest.dummy_op"() : () -> ()},  // op1
+      {"cftest.dummy_op"() : () -> ()}   // op2
+  ) : () -> ()
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<CFTestDialect>();
+  MLIRContext ctx(registry);
+
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+  Operation *testOp = &module->getBody()->getOperations().front();
+  Operation *op1 = &testOp->getRegion(0).front().front();
+  Operation *op2 = &testOp->getRegion(1).front().front();
+
+  EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr);
+  EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr);
+}