From c7103810bde9e300f0a272f0dc55eb324f5415f2 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 15 Dec 2021 06:42:36 +0000 Subject: [PATCH] [mlir][scf] Add getNumRegionInvocations to IfOp Implements the RegionBranchOpInterface method getNumRegionInvocations to `scf::IfOp` so that, when the condition is constant, the number of region executions can be analyzed by `NumberOfExecutions`. Reviewed By: jpienaar, ftynse Differential Revision: https://reviews.llvm.org/D115087 --- mlir/include/mlir/Dialect/SCF/SCFOps.td | 6 +++ mlir/lib/Dialect/SCF/SCF.cpp | 19 ++++++++- mlir/unittests/Dialect/CMakeLists.txt | 1 + mlir/unittests/Dialect/SCF/CMakeLists.txt | 10 +++++ mlir/unittests/Dialect/SCF/SCFOps.cpp | 67 +++++++++++++++++++++++++++++++ 5 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 mlir/unittests/Dialect/SCF/CMakeLists.txt create mode 100644 mlir/unittests/Dialect/SCF/SCFOps.cpp diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td index 89293bd..29e36cb 100644 --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -403,6 +403,12 @@ def IfOp : SCF_Op<"if", YieldOp thenYield(); Block* elseBlock(); YieldOp elseYield(); + + /// If the condition is a constant, returns 1 for the executed block and 0 + /// for the other. Otherwise, returns `kUnknownNumRegionInvocations` for + /// both successors. + void getNumRegionInvocations(ArrayRef operands, + SmallVectorImpl &countPerRegion); }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp index c7f1436..cea88d1 100644 --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -474,7 +474,7 @@ void ForOp::getNumRegionInvocations(ArrayRef operands, // Loop bounds are not known statically. if (!lb || !ub || !step || step.getValue().getSExtValue() == 0) { - countPerRegion[0] = -1; + countPerRegion[0] = kUnknownNumRegionInvocations; return; } @@ -1181,6 +1181,23 @@ void IfOp::getSuccessorRegions(Optional index, regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion)); } +/// If the condition is a constant, returns 1 for the executed block and 0 for +/// the other. Otherwise, returns `kUnknownNumRegionInvocations` for both +/// successors. +void IfOp::getNumRegionInvocations(ArrayRef operands, + SmallVectorImpl &countPerRegion) { + if (auto condAttr = operands.front().dyn_cast_or_null()) { + // If the condition is true, `then` is executed once and `else` zero times, + // and vice-versa. + bool cond = condAttr.getValue().isOneValue(); + countPerRegion.assign(1, cond ? 1 : 0); + countPerRegion.push_back(cond ? 0 : 1); + } else { + // Non-constant condition: unknown invocations for both successors. + countPerRegion.assign(2, kUnknownNumRegionInvocations); + } +} + namespace { // Pattern to remove unused IfOp results. struct RemoveUnusedResults : public OpRewritePattern { diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index f37f578..91aec50 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -7,6 +7,7 @@ target_link_libraries(MLIRDialectTests MLIRDialect) add_subdirectory(Quant) +add_subdirectory(SCF) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) add_subdirectory(Utils) diff --git a/mlir/unittests/Dialect/SCF/CMakeLists.txt b/mlir/unittests/Dialect/SCF/CMakeLists.txt new file mode 100644 index 0000000..81e05d3 --- /dev/null +++ b/mlir/unittests/Dialect/SCF/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_unittest(MLIRSCFTests + SCFOps.cpp + ) +target_link_libraries(MLIRSCFTests + PRIVATE + MLIRIR + MLIRParser + MLIRSCF + MLIRStandard + ) diff --git a/mlir/unittests/Dialect/SCF/SCFOps.cpp b/mlir/unittests/Dialect/SCF/SCFOps.cpp new file mode 100644 index 0000000..099bd27 --- /dev/null +++ b/mlir/unittests/Dialect/SCF/SCFOps.cpp @@ -0,0 +1,67 @@ +//===- SCFOps.cpp - SCF Op Unit Tests -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Parser.h" +#include "gtest/gtest.h" + +using namespace mlir; + +namespace { +class SCFOpsTest : public testing::Test { +public: + SCFOpsTest() { + context.getOrLoadDialect(); + context.getOrLoadDialect(); + } + +protected: + MLIRContext context; +}; + +TEST_F(SCFOpsTest, IfOpNumRegionInvocations) { + const char *const code = R"mlir( +func @test(%cond : i1) -> () { + scf.if %cond { + scf.yield + } else { + scf.yield + } + return +} +)mlir"; + Builder builder(&context); + + auto module = parseSourceString(code, &context); + ASSERT_TRUE(module); + scf::IfOp op; + module->walk([&](scf::IfOp ifOp) { op = ifOp; }); + ASSERT_TRUE(op); + + SmallVector countPerRegion; + op.getNumRegionInvocations({Attribute()}, countPerRegion); + EXPECT_EQ(countPerRegion.size(), 2u); + EXPECT_EQ(countPerRegion[0], kUnknownNumRegionInvocations); + EXPECT_EQ(countPerRegion[1], kUnknownNumRegionInvocations); + + countPerRegion.clear(); + op.getNumRegionInvocations( + {builder.getIntegerAttr(builder.getI1Type(), true)}, countPerRegion); + EXPECT_EQ(countPerRegion.size(), 2u); + EXPECT_EQ(countPerRegion[0], 1); + EXPECT_EQ(countPerRegion[1], 0); + + countPerRegion.clear(); + op.getNumRegionInvocations( + {builder.getIntegerAttr(builder.getI1Type(), false)}, countPerRegion); + EXPECT_EQ(countPerRegion.size(), 2u); + EXPECT_EQ(countPerRegion[0], 0); + EXPECT_EQ(countPerRegion[1], 1); +} +} // end anonymous namespace -- 2.7.4