return false;
}
+bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
+ SmallVector<bool> visited(getOperation()->getNumRegions(), false);
+ visited[index] = true;
+
+ // Retrieve all successors of the region and enqueue them in the worklist.
+ SmallVector<unsigned> worklist;
+ auto enqueueAllSuccessors = [&](unsigned index) {
+ SmallVector<RegionSuccessor> successors;
+ this->getSuccessorRegions(index, successors);
+ for (RegionSuccessor successor : successors)
+ if (!successor.isParent())
+ worklist.push_back(successor.getSuccessor()->getRegionNumber());
+ };
+ enqueueAllSuccessors(index);
+
+ // Process all regions in the worklist via DFS.
+ while (!worklist.empty()) {
+ unsigned nextRegion = worklist.pop_back_val();
+ if (nextRegion == index)
+ return true;
+ if (visited[nextRegion])
+ continue;
+ visited[nextRegion] = true;
+ enqueueAllSuccessors(nextRegion);
+ }
+
+ return false;
+}
+
+Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
+ while (Region *region = op->getParentRegion()) {
+ op = region->getParentOp();
+ if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
+ if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+ return region;
+ }
+ return nullptr;
+}
+
+Region *mlir::getEnclosingRepetitiveRegion(Value value) {
+ Region *region = value.getParentRegion();
+ while (region) {
+ Operation *op = region->getParentOp();
+ if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
+ if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+ return region;
+ region = op->getParentRegion();
+ }
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// RegionBranchTerminatorOpInterface
//===----------------------------------------------------------------------===//
SmallVectorImpl<RegionSuccessor> ®ions) {}
};
+/// All regions of this op call each other in a large circle.
+struct LoopRegionsOp
+ : public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> {
+ using Op::Op;
+ static const unsigned kNumRegions = 3;
+
+ static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+ static StringRef getOperationName() { return "cftest.loop_regions_op"; }
+
+ void getSuccessorRegions(Optional<unsigned> index,
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (index) {
+ if (*index == 1)
+ // This region also branches back to the parent.
+ regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
+ }
+ }
+};
+
/// Regions are executed sequentially.
struct SequentialRegionsOp
: public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
struct CFTestDialect : Dialect {
explicit CFTestDialect(MLIRContext *ctx)
: Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
- addOperations<DummyOp, MutuallyExclusiveRegionsOp, SequentialRegionsOp>();
+ addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
+ SequentialRegionsOp>();
}
static StringRef getDialectNamespace() { return "cftest"; }
};
EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
}
+
+TEST(RegionBranchOpInterface, RecursiveRegions) {
+ const char *ir = R"MLIR(
+"cftest.loop_regions_op"() (
+ {"cftest.dummy_op"() : () -> ()}, // op1
+ {"cftest.dummy_op"() : () -> ()}, // op2
+ {"cftest.dummy_op"() : () -> ()} // op3
+ ) : () -> ()
+ )MLIR";
+
+ DialectRegistry registry;
+ registry.insert<CFTestDialect>();
+ MLIRContext ctx(registry);
+
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+ Operation *testOp = &module->getBody()->getOperations().front();
+ auto regionOp = cast<RegionBranchOpInterface>(testOp);
+ Operation *op1 = &testOp->getRegion(0).front().front();
+ Operation *op2 = &testOp->getRegion(1).front().front();
+ Operation *op3 = &testOp->getRegion(2).front().front();
+
+ EXPECT_TRUE(regionOp.isRepetitiveRegion(0));
+ EXPECT_TRUE(regionOp.isRepetitiveRegion(1));
+ EXPECT_TRUE(regionOp.isRepetitiveRegion(2));
+ EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr);
+ EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr);
+ EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr);
+}
+
+TEST(RegionBranchOpInterface, NotRecursiveRegions) {
+ const char *ir = R"MLIR(
+"cftest.sequential_regions_op"() (
+ {"cftest.dummy_op"() : () -> ()}, // op1
+ {"cftest.dummy_op"() : () -> ()} // op2
+ ) : () -> ()
+ )MLIR";
+
+ DialectRegistry registry;
+ registry.insert<CFTestDialect>();
+ MLIRContext ctx(registry);
+
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+ Operation *testOp = &module->getBody()->getOperations().front();
+ Operation *op1 = &testOp->getRegion(0).front().front();
+ Operation *op2 = &testOp->getRegion(1).front().front();
+
+ EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr);
+ EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr);
+}