From 0f4ba02db3985051adac07a87ca9da549c0eb8ad Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 19 Apr 2022 16:12:40 +0900 Subject: [PATCH] [mlir][interfaces] Add helpers for detecting recursive regions Add helper functions to check if an op may be executed multiple times based on RegionBranchOpInterface. Differential Revision: https://reviews.llvm.org/D123789 --- .../mlir/Interfaces/ControlFlowInterfaces.h | 10 +++ .../mlir/Interfaces/ControlFlowInterfaces.td | 5 ++ mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 51 +++++++++++++++ .../Interfaces/ControlFlowInterfacesTest.cpp | 75 +++++++++++++++++++++- 4 files changed, 140 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h index ff4304c..e9be31e 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -216,6 +216,16 @@ private: /// RegionBranchOpInterface. bool insideMutuallyExclusiveRegions(Operation *a, Operation *b); +/// Return the first enclosing region of the given op that may be executed +/// repetitively as per RegionBranchOpInterface or `nullptr` if no such region +/// exists. +Region *getEnclosingRepetitiveRegion(Operation *op); + +/// Return the first enclosing region of the given Value that may be executed +/// repetitively as per RegionBranchOpInterface or `nullptr` if no such region +/// exists. +Region *getEnclosingRepetitiveRegion(Value value); + //===----------------------------------------------------------------------===// // RegionBranchTerminatorOpInterface //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index ac805ea..198a38c 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -211,6 +211,11 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { SmallVector nullAttrs(getOperation()->getNumOperands()); getSuccessorRegions(index, nullAttrs, regions); } + + /// Return `true` if control flow originating from the given region may + /// eventually branch back to the same region. (Maybe after passing through + /// other regions.) + bool isRepetitiveRegion(unsigned index); }]; } diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index 69ed30a..2ed3a9f 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -309,6 +309,57 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { return false; } +bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { + SmallVector visited(getOperation()->getNumRegions(), false); + visited[index] = true; + + // Retrieve all successors of the region and enqueue them in the worklist. + SmallVector worklist; + auto enqueueAllSuccessors = [&](unsigned index) { + SmallVector successors; + this->getSuccessorRegions(index, successors); + for (RegionSuccessor successor : successors) + if (!successor.isParent()) + worklist.push_back(successor.getSuccessor()->getRegionNumber()); + }; + enqueueAllSuccessors(index); + + // Process all regions in the worklist via DFS. + while (!worklist.empty()) { + unsigned nextRegion = worklist.pop_back_val(); + if (nextRegion == index) + return true; + if (visited[nextRegion]) + continue; + visited[nextRegion] = true; + enqueueAllSuccessors(nextRegion); + } + + return false; +} + +Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { + while (Region *region = op->getParentRegion()) { + op = region->getParentOp(); + if (auto branchOp = dyn_cast(op)) + if (branchOp.isRepetitiveRegion(region->getRegionNumber())) + return region; + } + return nullptr; +} + +Region *mlir::getEnclosingRepetitiveRegion(Value value) { + Region *region = value.getParentRegion(); + while (region) { + Operation *op = region->getParentOp(); + if (auto branchOp = dyn_cast(op)) + if (branchOp.isRepetitiveRegion(region->getRegionNumber())) + return region; + region = op->getParentRegion(); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // RegionBranchTerminatorOpInterface //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp index 4f47316..5f43321 100644 --- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp @@ -42,6 +42,29 @@ struct MutuallyExclusiveRegionsOp SmallVectorImpl ®ions) {} }; +/// All regions of this op call each other in a large circle. +struct LoopRegionsOp + : public Op { + using Op::Op; + static const unsigned kNumRegions = 3; + + static ArrayRef getAttributeNames() { return {}; } + + static StringRef getOperationName() { return "cftest.loop_regions_op"; } + + void getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + if (index) { + if (*index == 1) + // This region also branches back to the parent. + regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions))); + } + } +}; + /// Regions are executed sequentially. struct SequentialRegionsOp : public Op { @@ -65,7 +88,8 @@ struct SequentialRegionsOp struct CFTestDialect : Dialect { explicit CFTestDialect(MLIRContext *ctx) : Dialect(getDialectNamespace(), ctx, TypeID::get()) { - addOperations(); + addOperations(); } static StringRef getDialectNamespace() { return "cftest"; } }; @@ -142,3 +166,52 @@ TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) { EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2)); EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3)); } + +TEST(RegionBranchOpInterface, RecursiveRegions) { + const char *ir = R"MLIR( +"cftest.loop_regions_op"() ( + {"cftest.dummy_op"() : () -> ()}, // op1 + {"cftest.dummy_op"() : () -> ()}, // op2 + {"cftest.dummy_op"() : () -> ()} // op3 + ) : () -> () + )MLIR"; + + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(registry); + + OwningOpRef module = parseSourceString(ir, &ctx); + Operation *testOp = &module->getBody()->getOperations().front(); + auto regionOp = cast(testOp); + Operation *op1 = &testOp->getRegion(0).front().front(); + Operation *op2 = &testOp->getRegion(1).front().front(); + Operation *op3 = &testOp->getRegion(2).front().front(); + + EXPECT_TRUE(regionOp.isRepetitiveRegion(0)); + EXPECT_TRUE(regionOp.isRepetitiveRegion(1)); + EXPECT_TRUE(regionOp.isRepetitiveRegion(2)); + EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr); + EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr); + EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr); +} + +TEST(RegionBranchOpInterface, NotRecursiveRegions) { + const char *ir = R"MLIR( +"cftest.sequential_regions_op"() ( + {"cftest.dummy_op"() : () -> ()}, // op1 + {"cftest.dummy_op"() : () -> ()} // op2 + ) : () -> () + )MLIR"; + + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(registry); + + OwningOpRef module = parseSourceString(ir, &ctx); + Operation *testOp = &module->getBody()->getOperations().front(); + Operation *op1 = &testOp->getRegion(0).front().front(); + Operation *op2 = &testOp->getRegion(1).front().front(); + + EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr); + EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr); +} -- 2.7.4