[mlir] Add a new BranchOpInterface to allow for opaquely interfacing with branching...
authorRiver Riddle <riddleriver@gmail.com>
Thu, 5 Mar 2020 20:40:23 +0000 (12:40 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 5 Mar 2020 20:50:35 +0000 (12:50 -0800)
This interface contains the necessary components to provide the same builtin behavior that terminators have. This will be used in future revisions to remove many of the hardcoded constraints placed on successors and successor operands. The interface initially contains three methods:

```c++
// Return a set of values corresponding to the operands for successor 'index', or None if the operands do not correspond to materialized values.
Optional<OperandRange> getSuccessorOperands(unsigned index);

// Return true if this terminator can have it's successor operands erased.
bool canEraseSuccessorOperand();

// Erase the operand of a successor. This is only valid to call if 'canEraseSuccessorOperand' returns true.
void eraseSuccessorOperand(unsigned succIdx, unsigned opIdx);
```

Differential Revision: https://reviews.llvm.org/D75314

25 files changed:
mlir/include/mlir/Analysis/CMakeLists.txt
mlir/include/mlir/Analysis/ControlFlowInterfaces.h [new file with mode: 0644]
mlir/include/mlir/Analysis/ControlFlowInterfaces.td [new file with mode: 0644]
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/Analysis/CMakeLists.txt
mlir/lib/Analysis/ControlFlowInterfaces.cpp [new file with mode: 0644]
mlir/lib/Dialect/LLVMIR/CMakeLists.txt
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/SPIRV/CMakeLists.txt
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/StandardOps/CMakeLists.txt
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/test/lib/TestDialect/CMakeLists.txt
mlir/test/lib/TestDialect/TestDialect.cpp
mlir/test/lib/TestDialect/TestDialect.h
mlir/test/lib/TestDialect/TestOps.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index 3d9a7ed..cc8c493 100644 (file)
@@ -3,6 +3,11 @@ mlir_tablegen(CallInterfaces.h.inc -gen-op-interface-decls)
 mlir_tablegen(CallInterfaces.cpp.inc -gen-op-interface-defs)
 add_public_tablegen_target(MLIRCallOpInterfacesIncGen)
 
+set(LLVM_TARGET_DEFINITIONS ControlFlowInterfaces.td)
+mlir_tablegen(ControlFlowInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(ControlFlowInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRControlFlowInterfacesIncGen)
+
 set(LLVM_TARGET_DEFINITIONS InferTypeOpInterface.td)
 mlir_tablegen(InferTypeOpInterface.h.inc -gen-op-interface-decls)
 mlir_tablegen(InferTypeOpInterface.cpp.inc -gen-op-interface-defs)
diff --git a/mlir/include/mlir/Analysis/ControlFlowInterfaces.h b/mlir/include/mlir/Analysis/ControlFlowInterfaces.h
new file mode 100644 (file)
index 0000000..87f8204
--- /dev/null
@@ -0,0 +1,43 @@
+//===- ControlFlowInterfaces.h - ControlFlow Interfaces ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definitions of the branch interfaces defined in
+// `ControlFlowInterfaces.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H
+#define MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class BranchOpInterface;
+
+namespace detail {
+/// Erase an operand from a branch operation that is used as a successor
+/// operand. `operandIndex` is the operand within `operands` to be erased.
+void eraseBranchSuccessorOperand(OperandRange operands, unsigned operandIndex,
+                                 Operation *op);
+
+/// Return the `BlockArgument` corresponding to operand `operandIndex` in some
+/// successor if `operandIndex` is within the range of `operands`, or None if
+/// `operandIndex` isn't a successor operand index.
+Optional<BlockArgument>
+getBranchSuccessorArgument(Optional<OperandRange> operands,
+                           unsigned operandIndex, Block *successor);
+
+/// Verify that the given operands match those of the given successor block.
+LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
+                                            Optional<OperandRange> operands);
+} // namespace detail
+
+#include "mlir/Analysis/ControlFlowInterfaces.h.inc"
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H
diff --git a/mlir/include/mlir/Analysis/ControlFlowInterfaces.td b/mlir/include/mlir/Analysis/ControlFlowInterfaces.td
new file mode 100644 (file)
index 0000000..b34b7a9
--- /dev/null
@@ -0,0 +1,85 @@
+//===-- ControlFlowInterfaces.td - ControlFlow Interfaces --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a set of interfaces that can be used to define information
+// about control flow operations, e.g. branches.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_CONTROLFLOWINTERFACES
+#define MLIR_ANALYSIS_CONTROLFLOWINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// BranchOpInterface
+//===----------------------------------------------------------------------===//
+
+def BranchOpInterface : OpInterface<"BranchOpInterface"> {
+  let description = [{
+    This interface provides information for branching terminator operations,
+    i.e. terminator operations with successors.
+  }];
+  let methods = [
+    InterfaceMethod<[{
+        Returns a set of values that correspond to the arguments to the
+        successor at the given index. Returns None if the operands to the
+        successor are non-materialized values, i.e. they are internal to the
+        operation.
+      }],
+      "Optional<OperandRange>", "getSuccessorOperands", (ins "unsigned":$index)
+    >,
+    InterfaceMethod<[{
+        Return true if this operation can erase an operand to a successor block.
+      }],
+      "bool", "canEraseSuccessorOperand"
+    >,
+    InterfaceMethod<[{
+        Erase the operand at `operandIndex` from the `index`-th successor. This
+        should only be called if `canEraseSuccessorOperand` returns true.
+      }],
+      "void", "eraseSuccessorOperand",
+      (ins "unsigned":$index, "unsigned":$operandIndex), [{}],
+      /*defaultImplementation=*/[{
+        ConcreteOp *op = static_cast<ConcreteOp *>(this);
+        Optional<OperandRange> operands = op->getSuccessorOperands(index);
+        assert(operands && "unable to query operands for successor");
+        detail::eraseBranchSuccessorOperand(*operands, operandIndex, *op);
+      }]
+    >,
+    InterfaceMethod<[{
+        Returns the `BlockArgument` corresponding to operand `operandIndex` in
+        some successor, or None if `operandIndex` isn't a successor operand
+        index.
+      }],
+      "Optional<BlockArgument>", "getSuccessorBlockArgument",
+      (ins "unsigned":$operandIndex), [{
+        Operation *opaqueOp = op;
+        for (unsigned i = 0, e = opaqueOp->getNumSuccessors(); i != e; ++i) {
+          if (Optional<BlockArgument> arg = detail::getBranchSuccessorArgument(
+                op.getSuccessorOperands(i), operandIndex,
+                opaqueOp->getSuccessor(i)))
+            return arg;
+        }
+        return llvm::None;
+      }]
+    >
+  ];
+
+  let verify = [{
+    auto concreteOp = cast<ConcreteOpType>($_op);
+    for (unsigned i = 0, e = $_op->getNumSuccessors(); i != e; ++i) {
+      Optional<OperandRange> operands = concreteOp.getSuccessorOperands(i);
+      if (failed(detail::verifyBranchSuccessorOperands($_op, i, operands)))
+        return failure();
+    }
+    return success();
+  }];
+}
+
+#endif // MLIR_ANALYSIS_CONTROLFLOWINTERFACES
index d1378b8..a8b0be0 100644 (file)
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_
 #define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_
 
+#include "mlir/Analysis/ControlFlowInterfaces.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/OpDefinition.h"
index b458540..2ab9041 100644 (file)
@@ -14,6 +14,7 @@
 #define LLVMIR_OPS
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Analysis/ControlFlowInterfaces.td"
 
 class LLVM_Builder<string builder> {
   string llvmBuilder = builder;
@@ -315,7 +316,9 @@ def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt">;
 def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">;
 
 // Call-related operations.
-def LLVM_InvokeOp : LLVM_Op<"invoke", [Terminator]>,
+def LLVM_InvokeOp : LLVM_Op<"invoke", [
+                      DeclareOpInterfaceMethods<BranchOpInterface>, Terminator
+                    ]>,
                     Arguments<(ins OptionalAttr<FlatSymbolRefAttr>:$callee,
                                Variadic<LLVM_Type>)>,
                     Results<(outs Variadic<LLVM_Type>)> {
@@ -458,11 +461,13 @@ def LLVM_FreezeOp : LLVM_OneResultOp<"freeze", [SameOperandsAndResultType]>,
 }
 
 // Terminators.
-def LLVM_BrOp : LLVM_TerminatorOp<"br", []> {
+def LLVM_BrOp : LLVM_TerminatorOp<"br",
+    [DeclareOpInterfaceMethods<BranchOpInterface>]> {
   let successors = (successor AnySuccessor:$dest);
   let assemblyFormat = "$dest attr-dict";
 }
-def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> {
+def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
+    [DeclareOpInterfaceMethods<BranchOpInterface>]> {
   let arguments = (ins LLVMI1:$condition);
   let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
   let assemblyFormat = "$condition `,` successors attr-dict";
index 03884af..5ef825a 100644 (file)
 
 include "mlir/Dialect/SPIRV/SPIRVBase.td"
 include "mlir/Analysis/CallInterfaces.td"
+include "mlir/Analysis/ControlFlowInterfaces.td"
 
 // -----
 
-def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
+def SPV_BranchOp : SPV_Op<"Branch",[
+    DeclareOpInterfaceMethods<BranchOpInterface>, InFunctionScope,
+    Terminator]> {
   let summary = "Unconditional branch to target block.";
 
   let description = [{
@@ -75,8 +78,9 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
 
 // -----
 
-def SPV_BranchConditionalOp : SPV_Op<"BranchConditional",
-                                     [InFunctionScope, Terminator]> {
+def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [
+    DeclareOpInterfaceMethods<BranchOpInterface>, InFunctionScope,
+    Terminator]> {
   let summary = [{
     If Condition is true, branch to true block, otherwise branch to false
     block.
index ea541c0..e223f17 100644 (file)
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SPIRV_SPIRVOPS_H_
 #define MLIR_DIALECT_SPIRV_SPIRVOPS_H_
 
+#include "mlir/Analysis/ControlFlowInterfaces.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Function.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
index 1e19c02..1082276 100644 (file)
@@ -15,6 +15,7 @@
 #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
 
 #include "mlir/Analysis/CallInterfaces.h"
+#include "mlir/Analysis/ControlFlowInterfaces.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpImplementation.h"
index 2fe0365..e44f8ff 100644 (file)
@@ -14,6 +14,7 @@
 #define STANDARD_OPS
 
 include "mlir/Analysis/CallInterfaces.td"
+include "mlir/Analysis/ControlFlowInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 
 def Std_Dialect : Dialect {
@@ -331,7 +332,8 @@ def AtomicRMWOp : Std_Op<"atomic_rmw", [
 // BranchOp
 //===----------------------------------------------------------------------===//
 
-def BranchOp : Std_Op<"br", [Terminator]> {
+def BranchOp : Std_Op<"br",
+    [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator]> {
   let summary = "branch operation";
   let description = [{
     The "br" operation represents a branch operation in a function.
@@ -668,7 +670,8 @@ def CmpIOp : Std_Op<"cmpi",
 // CondBranchOp
 //===----------------------------------------------------------------------===//
 
-def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
+def CondBranchOp : Std_Op<"cond_br",
+    [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator]> {
   let summary = "conditional branch operation";
   let description = [{
     The "cond_br" operation represents a conditional branch operation in a
index 96545d6..7735dd1 100644 (file)
@@ -639,6 +639,10 @@ public:
   type_range getTypes() const { return {begin(), end()}; }
   auto getType() const { return getTypes(); }
 
+  /// Return the operand index of the first element of this range. The range
+  /// must not be empty.
+  unsigned getBeginOperandIndex() const;
+
 private:
   /// See `detail::indexed_accessor_range_base` for details.
   static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) {
index 9eccde5..44ec43c 100644 (file)
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
   AffineAnalysis.cpp
   AffineStructures.cpp
   CallGraph.cpp
+  ControlFlowInterfaces.cpp
   Dominance.cpp
   InferTypeOpInterface.cpp
   Liveness.cpp
@@ -14,6 +15,7 @@ set(LLVM_OPTIONAL_SOURCES
 
 add_llvm_library(MLIRAnalysis
   CallGraph.cpp
+  ControlFlowInterfaces.cpp
   InferTypeOpInterface.cpp
   Liveness.cpp
   SliceAnalysis.cpp
@@ -26,6 +28,7 @@ add_llvm_library(MLIRAnalysis
 add_dependencies(MLIRAnalysis
   MLIRAffineOps
   MLIRCallOpInterfacesIncGen
+  MLIRControlFlowInterfacesIncGen
   MLIRTypeInferOpInterfaceIncGen
   MLIRLoopOps
   )
@@ -45,6 +48,7 @@ add_llvm_library(MLIRLoopAnalysis
 add_dependencies(MLIRLoopAnalysis
   MLIRAffineOps
   MLIRCallOpInterfacesIncGen
+  MLIRControlFlowInterfacesIncGen
   MLIRTypeInferOpInterfaceIncGen
   MLIRLoopOps
   )
diff --git a/mlir/lib/Analysis/ControlFlowInterfaces.cpp b/mlir/lib/Analysis/ControlFlowInterfaces.cpp
new file mode 100644 (file)
index 0000000..7d98f29
--- /dev/null
@@ -0,0 +1,101 @@
+//===- ControlFlowInterfaces.h - ControlFlow Interfaces -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/ControlFlowInterfaces.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// ControlFlowInterfaces
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/ControlFlowInterfaces.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// BranchOpInterface
+//===----------------------------------------------------------------------===//
+
+/// Erase an operand from a branch operation that is used as a successor
+/// operand. 'operandIndex' is the operand within 'operands' to be erased.
+void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands,
+                                               unsigned operandIndex,
+                                               Operation *op) {
+  assert(operandIndex < operands.size() &&
+         "invalid index for successor operands");
+
+  // Erase the operand from the operation.
+  size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex;
+  op->eraseOperand(fullOperandIndex);
+
+  // If this operation has an OperandSegmentSizeAttr, keep it up to date.
+  auto operandSegmentAttr =
+      op->getAttrOfType<DenseElementsAttr>("operand_segment_sizes");
+  if (!operandSegmentAttr)
+    return;
+
+  // Find the segment containing the full operand index and decrement it.
+  // TODO: This seems like a general utility that could be added somewhere.
+  SmallVector<int32_t, 4> values(operandSegmentAttr.getValues<int32_t>());
+  unsigned currentSize = 0;
+  for (unsigned i = 0, e = values.size(); i != e; ++i) {
+    currentSize += values[i];
+    if (fullOperandIndex < currentSize) {
+      --values[i];
+      break;
+    }
+  }
+  op->setAttr("operand_segment_sizes",
+              DenseIntElementsAttr::get(operandSegmentAttr.getType(), values));
+}
+
+/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
+/// successor if 'operandIndex' is within the range of 'operands', or None if
+/// `operandIndex` isn't a successor operand index.
+Optional<BlockArgument> mlir::detail::getBranchSuccessorArgument(
+    Optional<OperandRange> operands, unsigned operandIndex, Block *successor) {
+  // Check that the operands are valid.
+  if (!operands || operands->empty())
+    return llvm::None;
+
+  // Check to ensure that this operand is within the range.
+  unsigned operandsStart = operands->getBeginOperandIndex();
+  if (operandIndex < operandsStart ||
+      operandIndex >= (operandsStart + operands->size()))
+    return llvm::None;
+
+  // Index the successor.
+  unsigned argIndex = operandIndex - operandsStart;
+  return successor->getArgument(argIndex);
+}
+
+/// Verify that the given operands match those of the given successor block.
+LogicalResult
+mlir::detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
+                                            Optional<OperandRange> operands) {
+  if (!operands)
+    return success();
+
+  // Check the count.
+  unsigned operandCount = operands->size();
+  Block *destBB = op->getSuccessor(succNo);
+  if (operandCount != destBB->getNumArguments())
+    return op->emitError() << "branch has " << operandCount
+                           << " operands for successor #" << succNo
+                           << ", but target block has "
+                           << destBB->getNumArguments();
+
+  // Check the types.
+  auto operandIt = operands->begin();
+  for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
+    if ((*operandIt).getType() != destBB->getArgument(i).getType())
+      return op->emitError() << "type mismatch for bb argument #" << i
+                             << " of successor #" << succNo;
+  }
+  return success();
+}
index 8eafbd8..97b70ca 100644 (file)
@@ -4,7 +4,7 @@ add_mlir_dialect_library(MLIRLLVMIR
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
   )
-add_dependencies(MLIRLLVMIR MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen MLIROpenMP LLVMFrontendOpenMP LLVMAsmParser LLVMCore LLVMSupport)
+add_dependencies(MLIRLLVMIR MLIRControlFlowInterfacesIncGen MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen MLIROpenMP LLVMFrontendOpenMP LLVMAsmParser LLVMCore LLVMSupport)
 target_link_libraries(MLIRLLVMIR LLVMAsmParser LLVMCore LLVMSupport LLVMFrontendOpenMP MLIROpenMP MLIRIR)
 
 add_mlir_dialect_library(MLIRNVVMIR
index 7536914..567ddee 100644 (file)
@@ -154,6 +154,28 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
 }
 
 //===----------------------------------------------------------------------===//
+// LLVM::BrOp
+//===----------------------------------------------------------------------===//
+
+Optional<OperandRange> BrOp::getSuccessorOperands(unsigned index) {
+  assert(index == 0 && "invalid successor index");
+  return getOperands();
+}
+
+bool BrOp::canEraseSuccessorOperand() { return true; }
+
+//===----------------------------------------------------------------------===//
+// LLVM::CondBrOp
+//===----------------------------------------------------------------------===//
+
+Optional<OperandRange> CondBrOp::getSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  return index == 0 ? trueDestOperands() : falseDestOperands();
+}
+
+bool CondBrOp::canEraseSuccessorOperand() { return true; }
+
+//===----------------------------------------------------------------------===//
 // Printing/parsing for LLVM::LoadOp.
 //===----------------------------------------------------------------------===//
 
@@ -229,9 +251,16 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
 }
 
 ///===----------------------------------------------------------------------===//
-/// Verifying/Printing/Parsing for LLVM::InvokeOp.
+/// LLVM::InvokeOp
 ///===----------------------------------------------------------------------===//
 
+Optional<OperandRange> InvokeOp::getSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  return index == 0 ? normalDestOperands() : unwindDestOperands();
+}
+
+bool InvokeOp::canEraseSuccessorOperand() { return true; }
+
 static LogicalResult verify(InvokeOp op) {
   if (op.getNumResults() > 1)
     return op.emitOpError("must have 0 or 1 result");
index ad1bb4d..7e77f3e 100644 (file)
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRSPIRV
   )
 
 add_dependencies(MLIRSPIRV
+  MLIRControlFlowInterfacesIncGen
   MLIRSPIRVAvailabilityIncGen
   MLIRSPIRVCanonicalizationIncGen
   MLIRSPIRVEnumAvailabilityIncGen
index 6a63867..907f8f8 100644 (file)
@@ -943,9 +943,29 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
 }
 
 //===----------------------------------------------------------------------===//
+// spv.BranchOp
+//===----------------------------------------------------------------------===//
+
+Optional<OperandRange> spirv::BranchOp::getSuccessorOperands(unsigned index) {
+  assert(index == 0 && "invalid successor index");
+  return getOperands();
+}
+
+bool spirv::BranchOp::canEraseSuccessorOperand() { return true; }
+
+//===----------------------------------------------------------------------===//
 // spv.BranchConditionalOp
 //===----------------------------------------------------------------------===//
 
+Optional<OperandRange>
+spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
+  assert(index < 2 && "invalid successor index");
+  return index == kTrueIndex ? getTrueBlockArguments()
+                             : getFalseBlockArguments();
+}
+
+bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; }
+
 static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
                                             OperationState &state) {
   auto &builder = parser.getBuilder();
index c8af470..9b8ffcd 100644 (file)
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRStandardOps
 add_dependencies(MLIRStandardOps
 
   MLIRCallOpInterfacesIncGen
+  MLIRControlFlowInterfacesIncGen
   MLIREDSC
   MLIRIR
   MLIRStandardOpsIncGen
index 1059e66..6cb1f21 100644 (file)
@@ -482,7 +482,7 @@ Block *BranchOp::getDest() { return getSuccessor(); }
 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
 
 void BranchOp::eraseOperand(unsigned index) {
-  getOperation()->eraseSuccessorOperand(0, index);
+  getOperation()->eraseOperand(index);
 }
 
 void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
@@ -490,6 +490,13 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<SimplifyBrToBlockWithSinglePred>(context);
 }
 
+Optional<OperandRange> BranchOp::getSuccessorOperands(unsigned index) {
+  assert(index == 0 && "invalid successor index");
+  return getOperands();
+}
+
+bool BranchOp::canEraseSuccessorOperand() { return true; }
+
 //===----------------------------------------------------------------------===//
 // CallOp
 //===----------------------------------------------------------------------===//
@@ -749,6 +756,13 @@ void CondBranchOp::getCanonicalizationPatterns(
   results.insert<SimplifyConstCondBranchPred>(context);
 }
 
+Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  return index == trueIndex ? getTrueOperands() : getFalseOperands();
+}
+
+bool CondBranchOp::canEraseSuccessorOperand() { return true; }
+
 //===----------------------------------------------------------------------===//
 // Constant*Op
 //===----------------------------------------------------------------------===//
index bfd4b40..2af425d 100644 (file)
@@ -950,37 +950,13 @@ LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
   return success();
 }
 
-static LogicalResult verifySuccessor(Operation *op, unsigned succNo) {
-  Operation::operand_range operands = op->getSuccessorOperands(succNo);
-  unsigned operandCount = op->getNumSuccessorOperands(succNo);
-  Block *destBB = op->getSuccessor(succNo);
-  if (operandCount != destBB->getNumArguments())
-    return op->emitError() << "branch has " << operandCount
-                           << " operands for successor #" << succNo
-                           << ", but target block has "
-                           << destBB->getNumArguments();
-
-  auto operandIt = operands.begin();
-  for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) {
-    if ((*operandIt).getType() != destBB->getArgument(i).getType())
-      return op->emitError() << "type mismatch for bb argument #" << i
-                             << " of successor #" << succNo;
-  }
-
-  return success();
-}
-
 static LogicalResult verifyTerminatorSuccessors(Operation *op) {
   auto *parent = op->getParentRegion();
 
   // Verify that the operands lines up with the BB arguments in the successor.
-  for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
-    auto *succ = op->getSuccessor(i);
+  for (Block *succ : op->getSuccessors())
     if (succ->getParent() != parent)
       return op->emitError("reference to block defined in another region");
-    if (failed(verifySuccessor(op, i)))
-      return failure();
-  }
   return success();
 }
 
index 107fc48..25859a5 100644 (file)
@@ -183,6 +183,13 @@ Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
 OperandRange::OperandRange(Operation *op)
     : OperandRange(op->getOpOperands().data(), op->getNumOperands()) {}
 
+/// Return the operand index of the first element of this range. The range
+/// must not be empty.
+unsigned OperandRange::getBeginOperandIndex() const {
+  assert(!empty() && "range must not be empty");
+  return base->getOperandNumber();
+}
+
 //===----------------------------------------------------------------------===//
 // ResultRange
 
index 15459b9..d815009 100644 (file)
@@ -16,6 +16,7 @@ add_llvm_library(MLIRTestDialect
   TestPatterns.cpp
 )
 add_dependencies(MLIRTestDialect
+  MLIRControlFlowInterfacesIncGen
   MLIRTestOpsIncGen
   MLIRTypeInferOpInterfaceIncGen
 )
index 3ded7b9..649b547 100644 (file)
@@ -164,6 +164,17 @@ TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
 }
 
 //===----------------------------------------------------------------------===//
+// TestBranchOp
+//===----------------------------------------------------------------------===//
+
+Optional<OperandRange> TestBranchOp::getSuccessorOperands(unsigned index) {
+  assert(index == 0 && "invalid successor index");
+  return getOperands();
+}
+
+bool TestBranchOp::canEraseSuccessorOperand() { return true; }
+
+//===----------------------------------------------------------------------===//
 // Test IsolatedRegionOp - parse passthrough region arguments.
 //===----------------------------------------------------------------------===//
 
index 3003276..8228f31 100644 (file)
@@ -15,6 +15,7 @@
 #define MLIR_TESTDIALECT_H
 
 #include "mlir/Analysis/CallInterfaces.h"
+#include "mlir/Analysis/ControlFlowInterfaces.h"
 #include "mlir/Analysis/InferTypeOpInterface.h"
 #include "mlir/Dialect/Traits.h"
 #include "mlir/IR/Dialect.h"
index 26205e7..5ee4a46 100644 (file)
@@ -11,6 +11,7 @@
 
 include "mlir/IR/OpBase.td"
 include "mlir/IR/OpAsmInterface.td"
+include "mlir/Analysis/ControlFlowInterfaces.td"
 include "mlir/Analysis/CallInterfaces.td"
 include "mlir/Analysis/InferTypeOpInterface.td"
 
@@ -446,6 +447,11 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty
    ]> {
   let arguments = (ins AnyTensor, AnyTensor);
   let results = (outs AnyTensor);
+
+  let extraClassDeclaration = [{
+    LogicalResult reifyReturnTypeShapes(OpBuilder &builder,
+                                        SmallVectorImpl<Value> &shapes);
+  }];
 }
 
 def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
@@ -454,7 +460,8 @@ def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
                      (I32ElementsAttrOp ConstantAttr<I32ElementsAttr, "0">),
                      [(IsNotScalar $attr)]>;
 
-def TestBranchOp : TEST_Op<"br", [Terminator]> {
+def TestBranchOp : TEST_Op<"br",
+    [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator]> {
   let successors = (successor AnySuccessor:$target);
 }
 
index ebd82f9..7aa51bd 100644 (file)
@@ -1155,8 +1155,8 @@ void OpEmitter::genOpInterfaceMethods() {
       continue;
     auto interface = opTrait->getOpInterface();
     for (auto method : interface.getMethods()) {
-      // Don't declare if the method has a body.
-      if (method.getBody())
+      // Don't declare if the method has a body or a default implementation.
+      if (method.getBody() || method.getDefaultImplementation())
         continue;
       std::string args;
       llvm::raw_string_ostream os(args);