From 621d7cca3751f934f991e34fe0e26187c33314f4 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 5 Mar 2020 12:40:23 -0800 Subject: [PATCH] [mlir] Add a new BranchOpInterface to allow for opaquely interfacing with branching terminator operations. This interface contains the necessary components to provide the same builtin behavior that terminators have. This will be used in future revisions to remove many of the hardcoded constraints placed on successors and successor operands. The interface initially contains three methods: ```c++ // Return a set of values corresponding to the operands for successor 'index', or None if the operands do not correspond to materialized values. Optional getSuccessorOperands(unsigned index); // Return true if this terminator can have it's successor operands erased. bool canEraseSuccessorOperand(); // Erase the operand of a successor. This is only valid to call if 'canEraseSuccessorOperand' returns true. void eraseSuccessorOperand(unsigned succIdx, unsigned opIdx); ``` Differential Revision: https://reviews.llvm.org/D75314 --- mlir/include/mlir/Analysis/CMakeLists.txt | 5 + mlir/include/mlir/Analysis/ControlFlowInterfaces.h | 43 +++++++++ .../include/mlir/Analysis/ControlFlowInterfaces.td | 85 +++++++++++++++++ mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h | 1 + mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 11 ++- .../mlir/Dialect/SPIRV/SPIRVControlFlowOps.td | 10 +- mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h | 1 + mlir/include/mlir/Dialect/StandardOps/IR/Ops.h | 1 + mlir/include/mlir/Dialect/StandardOps/IR/Ops.td | 7 +- mlir/include/mlir/IR/OperationSupport.h | 4 + mlir/lib/Analysis/CMakeLists.txt | 4 + mlir/lib/Analysis/ControlFlowInterfaces.cpp | 101 +++++++++++++++++++++ mlir/lib/Dialect/LLVMIR/CMakeLists.txt | 2 +- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 31 ++++++- mlir/lib/Dialect/SPIRV/CMakeLists.txt | 1 + mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 20 ++++ mlir/lib/Dialect/StandardOps/CMakeLists.txt | 1 + mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 16 +++- mlir/lib/IR/Operation.cpp | 26 +----- mlir/lib/IR/OperationSupport.cpp | 7 ++ mlir/test/lib/TestDialect/CMakeLists.txt | 1 + mlir/test/lib/TestDialect/TestDialect.cpp | 11 +++ mlir/test/lib/TestDialect/TestDialect.h | 1 + mlir/test/lib/TestDialect/TestOps.td | 9 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 4 +- 25 files changed, 364 insertions(+), 39 deletions(-) create mode 100644 mlir/include/mlir/Analysis/ControlFlowInterfaces.h create mode 100644 mlir/include/mlir/Analysis/ControlFlowInterfaces.td create mode 100644 mlir/lib/Analysis/ControlFlowInterfaces.cpp diff --git a/mlir/include/mlir/Analysis/CMakeLists.txt b/mlir/include/mlir/Analysis/CMakeLists.txt index 3d9a7ed..cc8c493 100644 --- a/mlir/include/mlir/Analysis/CMakeLists.txt +++ b/mlir/include/mlir/Analysis/CMakeLists.txt @@ -3,6 +3,11 @@ mlir_tablegen(CallInterfaces.h.inc -gen-op-interface-decls) mlir_tablegen(CallInterfaces.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(MLIRCallOpInterfacesIncGen) +set(LLVM_TARGET_DEFINITIONS ControlFlowInterfaces.td) +mlir_tablegen(ControlFlowInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ControlFlowInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRControlFlowInterfacesIncGen) + set(LLVM_TARGET_DEFINITIONS InferTypeOpInterface.td) mlir_tablegen(InferTypeOpInterface.h.inc -gen-op-interface-decls) mlir_tablegen(InferTypeOpInterface.cpp.inc -gen-op-interface-defs) diff --git a/mlir/include/mlir/Analysis/ControlFlowInterfaces.h b/mlir/include/mlir/Analysis/ControlFlowInterfaces.h new file mode 100644 index 0000000..87f8204 --- /dev/null +++ b/mlir/include/mlir/Analysis/ControlFlowInterfaces.h @@ -0,0 +1,43 @@ +//===- ControlFlowInterfaces.h - ControlFlow Interfaces ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the branch interfaces defined in +// `ControlFlowInterfaces.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H +#define MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +class BranchOpInterface; + +namespace detail { +/// Erase an operand from a branch operation that is used as a successor +/// operand. `operandIndex` is the operand within `operands` to be erased. +void eraseBranchSuccessorOperand(OperandRange operands, unsigned operandIndex, + Operation *op); + +/// Return the `BlockArgument` corresponding to operand `operandIndex` in some +/// successor if `operandIndex` is within the range of `operands`, or None if +/// `operandIndex` isn't a successor operand index. +Optional +getBranchSuccessorArgument(Optional operands, + unsigned operandIndex, Block *successor); + +/// Verify that the given operands match those of the given successor block. +LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, + Optional operands); +} // namespace detail + +#include "mlir/Analysis/ControlFlowInterfaces.h.inc" +} // end namespace mlir + +#endif // MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H diff --git a/mlir/include/mlir/Analysis/ControlFlowInterfaces.td b/mlir/include/mlir/Analysis/ControlFlowInterfaces.td new file mode 100644 index 0000000..b34b7a9 --- /dev/null +++ b/mlir/include/mlir/Analysis/ControlFlowInterfaces.td @@ -0,0 +1,85 @@ +//===-- ControlFlowInterfaces.td - ControlFlow Interfaces --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces that can be used to define information +// about control flow operations, e.g. branches. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_CONTROLFLOWINTERFACES +#define MLIR_ANALYSIS_CONTROLFLOWINTERFACES + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// BranchOpInterface +//===----------------------------------------------------------------------===// + +def BranchOpInterface : OpInterface<"BranchOpInterface"> { + let description = [{ + This interface provides information for branching terminator operations, + i.e. terminator operations with successors. + }]; + let methods = [ + InterfaceMethod<[{ + Returns a set of values that correspond to the arguments to the + successor at the given index. Returns None if the operands to the + successor are non-materialized values, i.e. they are internal to the + operation. + }], + "Optional", "getSuccessorOperands", (ins "unsigned":$index) + >, + InterfaceMethod<[{ + Return true if this operation can erase an operand to a successor block. + }], + "bool", "canEraseSuccessorOperand" + >, + InterfaceMethod<[{ + Erase the operand at `operandIndex` from the `index`-th successor. This + should only be called if `canEraseSuccessorOperand` returns true. + }], + "void", "eraseSuccessorOperand", + (ins "unsigned":$index, "unsigned":$operandIndex), [{}], + /*defaultImplementation=*/[{ + ConcreteOp *op = static_cast(this); + Optional operands = op->getSuccessorOperands(index); + assert(operands && "unable to query operands for successor"); + detail::eraseBranchSuccessorOperand(*operands, operandIndex, *op); + }] + >, + InterfaceMethod<[{ + Returns the `BlockArgument` corresponding to operand `operandIndex` in + some successor, or None if `operandIndex` isn't a successor operand + index. + }], + "Optional", "getSuccessorBlockArgument", + (ins "unsigned":$operandIndex), [{ + Operation *opaqueOp = op; + for (unsigned i = 0, e = opaqueOp->getNumSuccessors(); i != e; ++i) { + if (Optional arg = detail::getBranchSuccessorArgument( + op.getSuccessorOperands(i), operandIndex, + opaqueOp->getSuccessor(i))) + return arg; + } + return llvm::None; + }] + > + ]; + + let verify = [{ + auto concreteOp = cast($_op); + for (unsigned i = 0, e = $_op->getNumSuccessors(); i != e; ++i) { + Optional operands = concreteOp.getSuccessorOperands(i); + if (failed(detail::verifyBranchSuccessorOperands($_op, i, operands))) + return failure(); + } + return success(); + }]; +} + +#endif // MLIR_ANALYSIS_CONTROLFLOWINTERFACES diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index d1378b8..a8b0be0 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ #define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ +#include "mlir/Analysis/ControlFlowInterfaces.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index b458540..2ab9041 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -14,6 +14,7 @@ #define LLVMIR_OPS include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Analysis/ControlFlowInterfaces.td" class LLVM_Builder { string llvmBuilder = builder; @@ -315,7 +316,9 @@ def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt">; def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">; // Call-related operations. -def LLVM_InvokeOp : LLVM_Op<"invoke", [Terminator]>, +def LLVM_InvokeOp : LLVM_Op<"invoke", [ + DeclareOpInterfaceMethods, Terminator + ]>, Arguments<(ins OptionalAttr:$callee, Variadic)>, Results<(outs Variadic)> { @@ -458,11 +461,13 @@ def LLVM_FreezeOp : LLVM_OneResultOp<"freeze", [SameOperandsAndResultType]>, } // Terminators. -def LLVM_BrOp : LLVM_TerminatorOp<"br", []> { +def LLVM_BrOp : LLVM_TerminatorOp<"br", + [DeclareOpInterfaceMethods]> { let successors = (successor AnySuccessor:$dest); let assemblyFormat = "$dest attr-dict"; } -def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> { +def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", + [DeclareOpInterfaceMethods]> { let arguments = (ins LLVMI1:$condition); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); let assemblyFormat = "$condition `,` successors attr-dict"; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index 03884af..5ef825af 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -16,10 +16,13 @@ include "mlir/Dialect/SPIRV/SPIRVBase.td" include "mlir/Analysis/CallInterfaces.td" +include "mlir/Analysis/ControlFlowInterfaces.td" // ----- -def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> { +def SPV_BranchOp : SPV_Op<"Branch",[ + DeclareOpInterfaceMethods, InFunctionScope, + Terminator]> { let summary = "Unconditional branch to target block."; let description = [{ @@ -75,8 +78,9 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> { // ----- -def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", - [InFunctionScope, Terminator]> { +def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [ + DeclareOpInterfaceMethods, InFunctionScope, + Terminator]> { let summary = [{ If Condition is true, branch to true block, otherwise branch to false block. diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h index ea541c0..e223f17 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPIRV_SPIRVOPS_H_ #define MLIR_DIALECT_SPIRV_SPIRVOPS_H_ +#include "mlir/Analysis/ControlFlowInterfaces.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Function.h" #include "llvm/Support/PointerLikeTypeTraits.h" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h index 1e19c02..1082276 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -15,6 +15,7 @@ #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H #include "mlir/Analysis/CallInterfaces.h" +#include "mlir/Analysis/ControlFlowInterfaces.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index 2fe0365..e44f8ff 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -14,6 +14,7 @@ #define STANDARD_OPS include "mlir/Analysis/CallInterfaces.td" +include "mlir/Analysis/ControlFlowInterfaces.td" include "mlir/IR/OpAsmInterface.td" def Std_Dialect : Dialect { @@ -331,7 +332,8 @@ def AtomicRMWOp : Std_Op<"atomic_rmw", [ // BranchOp //===----------------------------------------------------------------------===// -def BranchOp : Std_Op<"br", [Terminator]> { +def BranchOp : Std_Op<"br", + [DeclareOpInterfaceMethods, Terminator]> { let summary = "branch operation"; let description = [{ The "br" operation represents a branch operation in a function. @@ -668,7 +670,8 @@ def CmpIOp : Std_Op<"cmpi", // CondBranchOp //===----------------------------------------------------------------------===// -def CondBranchOp : Std_Op<"cond_br", [Terminator]> { +def CondBranchOp : Std_Op<"cond_br", + [DeclareOpInterfaceMethods, Terminator]> { let summary = "conditional branch operation"; let description = [{ The "cond_br" operation represents a conditional branch operation in a diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 96545d6..7735dd1 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -639,6 +639,10 @@ public: type_range getTypes() const { return {begin(), end()}; } auto getType() const { return getTypes(); } + /// Return the operand index of the first element of this range. The range + /// must not be empty. + unsigned getBeginOperandIndex() const; + private: /// See `detail::indexed_accessor_range_base` for details. static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) { diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt index 9eccde5..44ec43c 100644 --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES AffineAnalysis.cpp AffineStructures.cpp CallGraph.cpp + ControlFlowInterfaces.cpp Dominance.cpp InferTypeOpInterface.cpp Liveness.cpp @@ -14,6 +15,7 @@ set(LLVM_OPTIONAL_SOURCES add_llvm_library(MLIRAnalysis CallGraph.cpp + ControlFlowInterfaces.cpp InferTypeOpInterface.cpp Liveness.cpp SliceAnalysis.cpp @@ -26,6 +28,7 @@ add_llvm_library(MLIRAnalysis add_dependencies(MLIRAnalysis MLIRAffineOps MLIRCallOpInterfacesIncGen + MLIRControlFlowInterfacesIncGen MLIRTypeInferOpInterfaceIncGen MLIRLoopOps ) @@ -45,6 +48,7 @@ add_llvm_library(MLIRLoopAnalysis add_dependencies(MLIRLoopAnalysis MLIRAffineOps MLIRCallOpInterfacesIncGen + MLIRControlFlowInterfacesIncGen MLIRTypeInferOpInterfaceIncGen MLIRLoopOps ) diff --git a/mlir/lib/Analysis/ControlFlowInterfaces.cpp b/mlir/lib/Analysis/ControlFlowInterfaces.cpp new file mode 100644 index 0000000..7d98f29 --- /dev/null +++ b/mlir/lib/Analysis/ControlFlowInterfaces.cpp @@ -0,0 +1,101 @@ +//===- ControlFlowInterfaces.h - ControlFlow Interfaces -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/ControlFlowInterfaces.h" +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ControlFlowInterfaces +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/ControlFlowInterfaces.cpp.inc" + +//===----------------------------------------------------------------------===// +// BranchOpInterface +//===----------------------------------------------------------------------===// + +/// Erase an operand from a branch operation that is used as a successor +/// operand. 'operandIndex' is the operand within 'operands' to be erased. +void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands, + unsigned operandIndex, + Operation *op) { + assert(operandIndex < operands.size() && + "invalid index for successor operands"); + + // Erase the operand from the operation. + size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex; + op->eraseOperand(fullOperandIndex); + + // If this operation has an OperandSegmentSizeAttr, keep it up to date. + auto operandSegmentAttr = + op->getAttrOfType("operand_segment_sizes"); + if (!operandSegmentAttr) + return; + + // Find the segment containing the full operand index and decrement it. + // TODO: This seems like a general utility that could be added somewhere. + SmallVector values(operandSegmentAttr.getValues()); + unsigned currentSize = 0; + for (unsigned i = 0, e = values.size(); i != e; ++i) { + currentSize += values[i]; + if (fullOperandIndex < currentSize) { + --values[i]; + break; + } + } + op->setAttr("operand_segment_sizes", + DenseIntElementsAttr::get(operandSegmentAttr.getType(), values)); +} + +/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some +/// successor if 'operandIndex' is within the range of 'operands', or None if +/// `operandIndex` isn't a successor operand index. +Optional mlir::detail::getBranchSuccessorArgument( + Optional operands, unsigned operandIndex, Block *successor) { + // Check that the operands are valid. + if (!operands || operands->empty()) + return llvm::None; + + // Check to ensure that this operand is within the range. + unsigned operandsStart = operands->getBeginOperandIndex(); + if (operandIndex < operandsStart || + operandIndex >= (operandsStart + operands->size())) + return llvm::None; + + // Index the successor. + unsigned argIndex = operandIndex - operandsStart; + return successor->getArgument(argIndex); +} + +/// Verify that the given operands match those of the given successor block. +LogicalResult +mlir::detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, + Optional operands) { + if (!operands) + return success(); + + // Check the count. + unsigned operandCount = operands->size(); + Block *destBB = op->getSuccessor(succNo); + if (operandCount != destBB->getNumArguments()) + return op->emitError() << "branch has " << operandCount + << " operands for successor #" << succNo + << ", but target block has " + << destBB->getNumArguments(); + + // Check the types. + auto operandIt = operands->begin(); + for (unsigned i = 0; i != operandCount; ++i, ++operandIt) { + if ((*operandIt).getType() != destBB->getArgument(i).getType()) + return op->emitError() << "type mismatch for bb argument #" << i + << " of successor #" << succNo; + } + return success(); +} diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index 8eafbd8..97b70ca 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -4,7 +4,7 @@ add_mlir_dialect_library(MLIRLLVMIR ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR ) -add_dependencies(MLIRLLVMIR MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen MLIROpenMP LLVMFrontendOpenMP LLVMAsmParser LLVMCore LLVMSupport) +add_dependencies(MLIRLLVMIR MLIRControlFlowInterfacesIncGen MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen MLIROpenMP LLVMFrontendOpenMP LLVMAsmParser LLVMCore LLVMSupport) target_link_libraries(MLIRLLVMIR LLVMAsmParser LLVMCore LLVMSupport LLVMFrontendOpenMP MLIROpenMP MLIRIR) add_mlir_dialect_library(MLIRNVVMIR diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 7536914..567ddee 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -154,6 +154,28 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) { } //===----------------------------------------------------------------------===// +// LLVM::BrOp +//===----------------------------------------------------------------------===// + +Optional BrOp::getSuccessorOperands(unsigned index) { + assert(index == 0 && "invalid successor index"); + return getOperands(); +} + +bool BrOp::canEraseSuccessorOperand() { return true; } + +//===----------------------------------------------------------------------===// +// LLVM::CondBrOp +//===----------------------------------------------------------------------===// + +Optional CondBrOp::getSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return index == 0 ? trueDestOperands() : falseDestOperands(); +} + +bool CondBrOp::canEraseSuccessorOperand() { return true; } + +//===----------------------------------------------------------------------===// // Printing/parsing for LLVM::LoadOp. //===----------------------------------------------------------------------===// @@ -229,9 +251,16 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { } ///===----------------------------------------------------------------------===// -/// Verifying/Printing/Parsing for LLVM::InvokeOp. +/// LLVM::InvokeOp ///===----------------------------------------------------------------------===// +Optional InvokeOp::getSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return index == 0 ? normalDestOperands() : unwindDestOperands(); +} + +bool InvokeOp::canEraseSuccessorOperand() { return true; } + static LogicalResult verify(InvokeOp op) { if (op.getNumResults() > 1) return op.emitOpError("must have 0 or 1 result"); diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt index ad1bb4d..7e77f3e 100644 --- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRSPIRV ) add_dependencies(MLIRSPIRV + MLIRControlFlowInterfacesIncGen MLIRSPIRVAvailabilityIncGen MLIRSPIRVCanonicalizationIncGen MLIRSPIRVEnumAvailabilityIncGen diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 6a63867..907f8f8 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -943,9 +943,29 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) { } //===----------------------------------------------------------------------===// +// spv.BranchOp +//===----------------------------------------------------------------------===// + +Optional spirv::BranchOp::getSuccessorOperands(unsigned index) { + assert(index == 0 && "invalid successor index"); + return getOperands(); +} + +bool spirv::BranchOp::canEraseSuccessorOperand() { return true; } + +//===----------------------------------------------------------------------===// // spv.BranchConditionalOp //===----------------------------------------------------------------------===// +Optional +spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) { + assert(index < 2 && "invalid successor index"); + return index == kTrueIndex ? getTrueBlockArguments() + : getFalseBlockArguments(); +} + +bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; } + static ParseResult parseBranchConditionalOp(OpAsmParser &parser, OperationState &state) { auto &builder = parser.getBuilder(); diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt index c8af470..9b8ffcd 100644 --- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRStandardOps add_dependencies(MLIRStandardOps MLIRCallOpInterfacesIncGen + MLIRControlFlowInterfacesIncGen MLIREDSC MLIRIR MLIRStandardOpsIncGen diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 1059e66..6cb1f21 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -482,7 +482,7 @@ Block *BranchOp::getDest() { return getSuccessor(); } void BranchOp::setDest(Block *block) { return setSuccessor(block); } void BranchOp::eraseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(0, index); + getOperation()->eraseOperand(index); } void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, @@ -490,6 +490,13 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +Optional BranchOp::getSuccessorOperands(unsigned index) { + assert(index == 0 && "invalid successor index"); + return getOperands(); +} + +bool BranchOp::canEraseSuccessorOperand() { return true; } + //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// @@ -749,6 +756,13 @@ void CondBranchOp::getCanonicalizationPatterns( results.insert(context); } +Optional CondBranchOp::getSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return index == trueIndex ? getTrueOperands() : getFalseOperands(); +} + +bool CondBranchOp::canEraseSuccessorOperand() { return true; } + //===----------------------------------------------------------------------===// // Constant*Op //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index bfd4b40..2af425d 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -950,37 +950,13 @@ LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { return success(); } -static LogicalResult verifySuccessor(Operation *op, unsigned succNo) { - Operation::operand_range operands = op->getSuccessorOperands(succNo); - unsigned operandCount = op->getNumSuccessorOperands(succNo); - Block *destBB = op->getSuccessor(succNo); - if (operandCount != destBB->getNumArguments()) - return op->emitError() << "branch has " << operandCount - << " operands for successor #" << succNo - << ", but target block has " - << destBB->getNumArguments(); - - auto operandIt = operands.begin(); - for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) { - if ((*operandIt).getType() != destBB->getArgument(i).getType()) - return op->emitError() << "type mismatch for bb argument #" << i - << " of successor #" << succNo; - } - - return success(); -} - static LogicalResult verifyTerminatorSuccessors(Operation *op) { auto *parent = op->getParentRegion(); // Verify that the operands lines up with the BB arguments in the successor. - for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { - auto *succ = op->getSuccessor(i); + for (Block *succ : op->getSuccessors()) if (succ->getParent() != parent) return op->emitError("reference to block defined in another region"); - if (failed(verifySuccessor(op, i))) - return failure(); - } return success(); } diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index 107fc48..25859a5 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -183,6 +183,13 @@ Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) { OperandRange::OperandRange(Operation *op) : OperandRange(op->getOpOperands().data(), op->getNumOperands()) {} +/// Return the operand index of the first element of this range. The range +/// must not be empty. +unsigned OperandRange::getBeginOperandIndex() const { + assert(!empty() && "range must not be empty"); + return base->getOperandNumber(); +} + //===----------------------------------------------------------------------===// // ResultRange diff --git a/mlir/test/lib/TestDialect/CMakeLists.txt b/mlir/test/lib/TestDialect/CMakeLists.txt index 15459b9..d815009 100644 --- a/mlir/test/lib/TestDialect/CMakeLists.txt +++ b/mlir/test/lib/TestDialect/CMakeLists.txt @@ -16,6 +16,7 @@ add_llvm_library(MLIRTestDialect TestPatterns.cpp ) add_dependencies(MLIRTestDialect + MLIRControlFlowInterfacesIncGen MLIRTestOpsIncGen MLIRTypeInferOpInterfaceIncGen ) diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index 3ded7b9..649b547 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -164,6 +164,17 @@ TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, } //===----------------------------------------------------------------------===// +// TestBranchOp +//===----------------------------------------------------------------------===// + +Optional TestBranchOp::getSuccessorOperands(unsigned index) { + assert(index == 0 && "invalid successor index"); + return getOperands(); +} + +bool TestBranchOp::canEraseSuccessorOperand() { return true; } + +//===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/TestDialect/TestDialect.h b/mlir/test/lib/TestDialect/TestDialect.h index 3003276..8228f31 100644 --- a/mlir/test/lib/TestDialect/TestDialect.h +++ b/mlir/test/lib/TestDialect/TestDialect.h @@ -15,6 +15,7 @@ #define MLIR_TESTDIALECT_H #include "mlir/Analysis/CallInterfaces.h" +#include "mlir/Analysis/ControlFlowInterfaces.h" #include "mlir/Analysis/InferTypeOpInterface.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 26205e7..5ee4a46 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -11,6 +11,7 @@ include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/Analysis/ControlFlowInterfaces.td" include "mlir/Analysis/CallInterfaces.td" include "mlir/Analysis/InferTypeOpInterface.td" @@ -446,6 +447,11 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty ]> { let arguments = (ins AnyTensor, AnyTensor); let results = (outs AnyTensor); + + let extraClassDeclaration = [{ + LogicalResult reifyReturnTypeShapes(OpBuilder &builder, + SmallVectorImpl &shapes); + }]; } def IsNotScalar : Constraint>; @@ -454,7 +460,8 @@ def UpdateAttr : Pat<(I32ElementsAttrOp $attr), (I32ElementsAttrOp ConstantAttr), [(IsNotScalar $attr)]>; -def TestBranchOp : TEST_Op<"br", [Terminator]> { +def TestBranchOp : TEST_Op<"br", + [DeclareOpInterfaceMethods, Terminator]> { let successors = (successor AnySuccessor:$target); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index ebd82f9..7aa51bd 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1155,8 +1155,8 @@ void OpEmitter::genOpInterfaceMethods() { continue; auto interface = opTrait->getOpInterface(); for (auto method : interface.getMethods()) { - // Don't declare if the method has a body. - if (method.getBody()) + // Don't declare if the method has a body or a default implementation. + if (method.getBody() || method.getDefaultImplementation()) continue; std::string args; llvm::raw_string_ostream os(args); -- 2.7.4