Define a `NoTerminator` traits that allows operations with a single block region...
authorMehdi Amini <joker.eph@gmail.com>
Thu, 11 Mar 2021 23:58:02 +0000 (23:58 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 25 Mar 2021 03:59:03 +0000 (03:59 +0000)
In particular for Graph Regions, the terminator needs is just a
historical artifact of the generalization of MLIR from CFG region.
Operations like Module don't need a terminator, and before Module
migrated to be an operation with region there wasn't any needed.

To validate the feature, the ModuleOp is migrated to use this trait and
the ModuleTerminator operation is deleted.

This patch is likely to break clients, if you're in this case:

- you may iterate on a ModuleOp with `getBody()->without_terminator()`,
  the solution is simple: just remove the ->without_terminator!
- you created a builder with `Builder::atBlockTerminator(module_body)`,
  just use `Builder::atBlockEnd(module_body)` instead.
- you were handling ModuleTerminator: it isn't needed anymore.
- for generic code, a `Block::mayNotHaveTerminator()` may be used.

Differential Revision: https://reviews.llvm.org/D98468

60 files changed:
mlir/docs/LangRef.md
mlir/docs/Traits.md
mlir/docs/Tutorials/Toy/Ch-6.md
mlir/docs/Tutorials/UnderstandingTheIRStructure.md
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/include/mlir/IR/BuiltinOps.h
mlir/include/mlir/IR/BuiltinOps.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/IR/Region.h
mlir/include/mlir/IR/RegionKindInterface.h
mlir/include/mlir/IR/RegionKindInterface.td
mlir/include/mlir/Parser.h
mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Block.cpp
mlir/lib/IR/BuiltinDialect.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/IR/Verifier.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/lib/Transforms/SymbolDCE.cpp
mlir/lib/Transforms/Utils/RegionUtils.cpp
mlir/test/Bindings/Python/context_managers.py
mlir/test/Bindings/Python/dialects.py
mlir/test/Bindings/Python/dialects/builtin.py
mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
mlir/test/Bindings/Python/dialects/linalg/ops.py
mlir/test/Bindings/Python/insertion_point.py
mlir/test/Bindings/Python/ir_operation.py
mlir/test/Bindings/Python/ods_helpers.py
mlir/test/Bindings/Python/pass_manager.py
mlir/test/CAPI/ir.c
mlir/test/CAPI/pass.c
mlir/test/IR/invalid-module-op.mlir
mlir/test/IR/invalid.mlir
mlir/test/IR/module-op.mlir
mlir/test/IR/print-ir-defuse.mlir
mlir/test/IR/print-ir-nesting.mlir
mlir/test/IR/region.mlir
mlir/test/Transforms/test-legalizer-analysis.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/Transforms/TestConvVectorization.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp

index 82cbc973e1fd3d9aa0699a002289da65e15f6ca1..f380987e90aa54694956e8a5ca7ef3c0f2c63859 100644 (file)
@@ -351,13 +351,18 @@ value-id-and-type-list ::= value-id-and-type (`,` value-id-and-type)*
 block-arg-list ::= `(` value-id-and-type-list? `)`
 ```
 
-A *Block* is an ordered list of operations, concluding with a single
-[terminator operation](#terminator-operations). In [SSACFG
+A *Block* is a list of operations. In [SSACFG
 regions](#control-flow-and-ssacfg-regions), each block represents a compiler
 [basic block](https://en.wikipedia.org/wiki/Basic_block) where instructions
 inside the block are executed in order and terminator operations implement
 control flow branches between basic blocks.
 
+A region with a single block may not include a [terminator
+operation](#terminator-operations). The enclosing op can opt-out of this
+requirement with the `NoTerminator` trait. The top-level `ModuleOp` is an
+example of such operation which defined this trait and whose block body does
+not have a terminator.
+
 Blocks in MLIR take a list of block arguments, notated in a function-like
 way. Block arguments are bound to values specified by the semantics of
 individual operations. Block arguments of the entry block of a region are also
index 8a21827578121d04c4a9bdbecf3eff179bd34198..add2c80f1d6d52b378c7d172fd4371e7c9aea163 100644 (file)
@@ -323,13 +323,20 @@ index expression that can express the equivalent of the memory-layout
 specification of the MemRef type. See [the -normalize-memrefs pass].
 (https://mlir.llvm.org/docs/Passes/#-normalize-memrefs-normalize-memrefs)
 
+### Single Block Region
+
+*   `OpTrait::SingleBlock` -- `SingleBlock`
+
+This trait provides APIs and verifiers for operations with regions that have a
+single block.
+
 ### Single Block with Implicit Terminator
 
-*   `OpTrait::SingleBlockImplicitTerminator<typename TerminatorOpType>` :
+*   `OpTrait::SingleBlockImplicitTerminator<typename TerminatorOpType>` --
     `SingleBlockImplicitTerminator<string op>`
 
-This trait provides APIs and verifiers for operations with regions that have a
-single block that must terminate with `TerminatorOpType`.
+This trait implies the `SingleBlock` above, but adds the additional requirement
+that the single block must terminate with `TerminatorOpType`.
 
 ### SymbolTable
 
@@ -344,3 +351,10 @@ This trait is used for operations that define a
 
 This trait provides verification and functionality for operations that are known
 to be [terminators](LangRef.md#terminator-operations).
+
+*   `OpTrait::NoTerminator` -- `NoTerminator`
+
+This trait removes the requirement on regions held by an operation to have
+[terminator operations](LangRef.md#terminator-operations) at the end of a block.
+This requires that these regions have a single block. An example of operation
+using this trait is the top-level `ModuleOp`.
index c54c8d36a2c92c87136240f1a256b68074af5a75..9c5d838c447c13008b22881ff935b9fafa3b71c2 100644 (file)
@@ -63,7 +63,7 @@ everything to the LLVM dialect.
 ```c++
   mlir::ConversionTarget target(getContext());
   target.addLegalDialect<mlir::LLVMDialect>();
-  target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>();
+  target.addLegalOp<mlir::ModuleOp>();
 ```
 
 ### Type Converter
index 69560d29561862b961636ee790f0848562efa74e..3b32d6a409715d939ba33f20005b49fcf7f28a1e 100644 (file)
@@ -110,7 +110,6 @@ llvm-project/mlir/test/IR/print-ir-nesting.mlir`:
     "dialect.innerop6"() : () -> ()
     "dialect.innerop7"() : () -> ()
   }) {"other attribute" = 42 : i64} : () -> ()
-  "module_terminator"() : () -> ()
 }) : () -> ()
 ```
 
@@ -147,7 +146,6 @@ visiting op: 'module' with 0 operands and 0 results
              0 nested regions:
             visiting op: 'dialect.innerop7' with 0 operands and 0 results
              0 nested regions:
-      visiting op: 'module_terminator' with 0 operands and 0 results
        0 nested regions:
 ```
 
index 3fd48c5fd892f9cec7ed26e528440b2591290dec..f334faa29ca8837a35170c7b00cd04ce9235308f 100644 (file)
@@ -174,7 +174,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
   // final target for this lowering. For this lowering, we are only targeting
   // the LLVM dialect.
   LLVMConversionTarget target(getContext());
-  target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
+  target.addLegalOp<ModuleOp>();
 
   // During this lowering, we will also be lowering the MemRef types, that are
   // currently being operated on, to a representation in LLVM. To perform this
index 3fd48c5fd892f9cec7ed26e528440b2591290dec..f334faa29ca8837a35170c7b00cd04ce9235308f 100644 (file)
@@ -174,7 +174,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
   // final target for this lowering. For this lowering, we are only targeting
   // the LLVM dialect.
   LLVMConversionTarget target(getContext());
-  target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
+  target.addLegalOp<ModuleOp>();
 
   // During this lowering, we will also be lowering the MemRef types, that are
   // currently being operated on, to a representation in LLVM. To perform this
index cf43b7cd7305daca99aeb99b2256a9cf3f53b48a..bc341e624fcd47bb9c66d26c3b4026d932fb995a 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "mlir/IR/FunctionSupport.h"
 #include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/RegionKindInterface.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/CastInterfaces.h"
index 4d14b8868e473ec5041fe9a84ea0eefac2ae26a1..f91f4e2e613a305ee0efb1fa858a50aed242975a 100644 (file)
@@ -15,6 +15,7 @@
 #define BUILTIN_OPS
 
 include "mlir/IR/BuiltinDialect.td"
+include "mlir/IR/RegionKindInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/CastInterfaces.td"
@@ -159,17 +160,17 @@ def FuncOp : Builtin_Op<"func", [
 //===----------------------------------------------------------------------===//
 
 def ModuleOp : Builtin_Op<"module", [
-  AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol,
-  SingleBlockImplicitTerminator<"ModuleTerminatorOp">
-]> {
+  AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol] 
+  # GraphRegionNoTerminator.traits> {
   let summary = "A top level container operation";
   let description = [{
     A `module` represents a top-level container operation. It contains a single
-    SSACFG region containing a single block which can contain any
-    operations. Operations within this region cannot implicitly capture values
-    defined outside the module, i.e. Modules are `IsolatedFromAbove`. Modules
-    have an optional symbol name which can be used to refer to them in
-    operations.
+   [graph region](#control-flow-and-ssacfg-regions) containing a single block
+   which can contain any operations and does not have a terminator. Operations
+   within this region cannot implicitly capture values defined outside the module,
+   i.e. Modules are [IsolatedFromAbove](Traits.md#isolatedfromabove). Modules have
+   an optional [symbol name](SymbolsAndSymbolTables.md) which can be used to refer
+   to them in operations.
 
     Example:
 
@@ -213,22 +214,6 @@ def ModuleOp : Builtin_Op<"module", [
   let skipDefaultBuilders = 1;
 }
 
-//===----------------------------------------------------------------------===//
-// ModuleTerminatorOp
-//===----------------------------------------------------------------------===//
-
-def ModuleTerminatorOp : Builtin_Op<"module_terminator", [
-  Terminator, HasParent<"ModuleOp">
-]> {
-  let summary = "A pseudo op that marks the end of a module";
-  let description = [{
-    `module_terminator` is a special terminator operation for the body of a
-    `module`, it has no semantic meaning beyond keeping the body of a `module`
-    well-formed.
-  }];
-  let assemblyFormat = "attr-dict";
-}
-
 //===----------------------------------------------------------------------===//
 // UnrealizedConversionCastOp
 //===----------------------------------------------------------------------===//
index 2785dff7f591f32cbd35f93eed045ef38e749392..3ea9bb41518eeedb770d00826cc42339a1c48bd8 100644 (file)
@@ -1827,10 +1827,16 @@ def ElementwiseMappable {
   ];
 }
 
+// Op's regions have a single block.
+def SingleBlock : NativeOpTrait<"SingleBlock">;
+
 // Op's regions have a single block with the specified terminator.
 class SingleBlockImplicitTerminator<string op>
     : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>;
 
+// Op's regions don't have terminator.
+def NoTerminator : NativeOpTrait<"NoTerminator">;
+
 // Op's parent operation is the provided one.
 class HasParent<string op>
     : ParamNativeOpTrait<"HasParent", op>;
index b101370fca8e291ea40196feabd07f9801950d30..b488dc12a4818ffaf4e7c5b7029cf31bf9db99fe 100644 (file)
@@ -654,6 +654,11 @@ class VariadicResults
 //===----------------------------------------------------------------------===//
 // Terminator Traits
 
+/// This class indicates that the regions associated with this op don't have
+/// terminators.
+template <typename ConcreteType>
+class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {};
+
 /// This class provides the API for ops that are known to be terminators.
 template <typename ConcreteType>
 class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
@@ -757,6 +762,87 @@ class VariadicSuccessors
     : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
 };
 
+//===----------------------------------------------------------------------===//
+// SingleBlock
+
+/// This class provides APIs and verifiers for ops with regions having a single
+/// block.
+template <typename ConcreteType>
+struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
+      Region &region = op->getRegion(i);
+
+      // Empty regions are fine.
+      if (region.empty())
+        continue;
+
+      // Non-empty regions must contain a single basic block.
+      if (!llvm::hasSingleElement(region))
+        return op->emitOpError("expects region #")
+               << i << " to have 0 or 1 blocks";
+
+      if (!ConcreteType::template hasTrait<NoTerminator>()) {
+        Block &block = region.front();
+        if (block.empty())
+          return op->emitOpError() << "expects a non-empty block";
+      }
+    }
+    return success();
+  }
+
+  Block *getBody(unsigned idx = 0) {
+    Region &region = this->getOperation()->getRegion(idx);
+    assert(!region.empty() && "unexpected empty region");
+    return &region.front();
+  }
+  Region &getBodyRegion(unsigned idx = 0) {
+    return this->getOperation()->getRegion(idx);
+  }
+
+  //===------------------------------------------------------------------===//
+  // Single Region Utilities
+  //===------------------------------------------------------------------===//
+
+  /// The following are a set of methods only enabled when the parent
+  /// operation has a single region. Each of these methods take an additional
+  /// template parameter that represents the concrete operation so that we
+  /// can use SFINAE to disable the methods for non-single region operations.
+  template <typename OpT, typename T = void>
+  using enable_if_single_region =
+      typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
+
+  template <typename OpT = ConcreteType>
+  enable_if_single_region<OpT, Block::iterator> begin() {
+    return getBody()->begin();
+  }
+  template <typename OpT = ConcreteType>
+  enable_if_single_region<OpT, Block::iterator> end() {
+    return getBody()->end();
+  }
+  template <typename OpT = ConcreteType>
+  enable_if_single_region<OpT, Operation &> front() {
+    return *begin();
+  }
+
+  /// Insert the operation into the back of the body.
+  template <typename OpT = ConcreteType>
+  enable_if_single_region<OpT> push_back(Operation *op) {
+    insert(Block::iterator(getBody()->end()), op);
+  }
+
+  /// Insert the operation at the given insertion point.
+  template <typename OpT = ConcreteType>
+  enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
+    insert(Block::iterator(insertPt), op);
+  }
+  template <typename OpT = ConcreteType>
+  enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) {
+    getBody()->getOperations().insert(insertPt, op);
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // SingleBlockImplicitTerminator
 
@@ -765,8 +851,9 @@ class VariadicSuccessors
 template <typename TerminatorOpType>
 struct SingleBlockImplicitTerminator {
   template <typename ConcreteType>
-  class Impl : public TraitBase<ConcreteType, Impl> {
+  class Impl : public SingleBlock<ConcreteType> {
   private:
+    using Base = SingleBlock<ConcreteType>;
     /// Builds a terminator operation without relying on OpBuilder APIs to avoid
     /// cyclic header inclusion.
     static Operation *buildTerminator(OpBuilder &builder, Location loc) {
@@ -780,22 +867,14 @@ struct SingleBlockImplicitTerminator {
     using ImplicitTerminatorOpT = TerminatorOpType;
 
     static LogicalResult verifyTrait(Operation *op) {
+      if (failed(Base::verifyTrait(op)))
+        return failure();
       for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
         Region &region = op->getRegion(i);
-
         // Empty regions are fine.
         if (region.empty())
           continue;
-
-        // Non-empty regions must contain a single basic block.
-        if (std::next(region.begin()) != region.end())
-          return op->emitOpError("expects region #")
-                 << i << " to have 0 or 1 blocks";
-
-        Block &block = region.front();
-        if (block.empty())
-          return op->emitOpError() << "expects a non-empty block";
-        Operation &terminator = block.back();
+        Operation &terminator = region.front().back();
         if (isa<TerminatorOpType>(terminator))
           continue;
 
@@ -828,40 +907,15 @@ struct SingleBlockImplicitTerminator {
                                            buildTerminator);
     }
 
-    Block *getBody(unsigned idx = 0) {
-      Region &region = this->getOperation()->getRegion(idx);
-      assert(!region.empty() && "unexpected empty region");
-      return &region.front();
-    }
-    Region &getBodyRegion(unsigned idx = 0) {
-      return this->getOperation()->getRegion(idx);
-    }
-
     //===------------------------------------------------------------------===//
     // Single Region Utilities
     //===------------------------------------------------------------------===//
+    using Base::getBody;
 
-    /// The following are a set of methods only enabled when the parent
-    /// operation has a single region. Each of these methods take an additional
-    /// template parameter that represents the concrete operation so that we
-    /// can use SFINAE to disable the methods for non-single region operations.
     template <typename OpT, typename T = void>
     using enable_if_single_region =
         typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
 
-    template <typename OpT = ConcreteType>
-    enable_if_single_region<OpT, Block::iterator> begin() {
-      return getBody()->begin();
-    }
-    template <typename OpT = ConcreteType>
-    enable_if_single_region<OpT, Block::iterator> end() {
-      return getBody()->end();
-    }
-    template <typename OpT = ConcreteType>
-    enable_if_single_region<OpT, Operation &> front() {
-      return *begin();
-    }
-
     /// Insert the operation into the back of the body, before the terminator.
     template <typename OpT = ConcreteType>
     enable_if_single_region<OpT> push_back(Operation *op) {
@@ -886,6 +940,27 @@ struct SingleBlockImplicitTerminator {
   };
 };
 
+/// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended
+/// to be used with `llvm::is_detected`.
+template <class T>
+using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT;
+
+/// Support to check if an operation has the SingleBlockImplicitTerminator
+/// trait. We can't just use `hasTrait` because this class is templated on a
+/// specific terminator op.
+template <class Op, bool hasTerminator =
+                        llvm::is_detected<has_implicit_terminator_t, Op>::value>
+struct hasSingleBlockImplicitTerminator {
+  static constexpr bool value = std::is_base_of<
+      typename OpTrait::SingleBlockImplicitTerminator<
+          typename Op::ImplicitTerminatorOpT>::template Impl<Op>,
+      Op>::value;
+};
+template <class Op>
+struct hasSingleBlockImplicitTerminator<Op, false> {
+  static constexpr bool value = false;
+};
+
 //===----------------------------------------------------------------------===//
 // Misc Traits
 
index 99561c3a089beb3593e1945ba5cdbc10d78bf6a3..5250899a9b447850244e4ca87d39c3430fc4f58a 100644 (file)
@@ -92,8 +92,13 @@ public:
   virtual void printGenericOp(Operation *op) = 0;
 
   /// Prints a region.
+  /// If 'printEntryBlockArgs' is false, the arguments of the
+  /// block are not printed. If 'printBlockTerminator' is false, the terminator
+  /// operation of the block is not printed. If printEmptyBlock is true, then
+  /// the block header is printed even if the block is empty.
   virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
-                           bool printBlockTerminators = true) = 0;
+                           bool printBlockTerminators = true,
+                           bool printEmptyBlock = false) = 0;
 
   /// Renumber the arguments for the specified region to the same names as the
   /// SSA values in namesToUse.  This may only be used for IsolatedFromAbove
index 8888862dd10c4178f1aaa30e99366ed13d81e828..bc35e2d231a504f2658649572cb4c2176fd43dc4 100644 (file)
@@ -43,6 +43,10 @@ public:
 
   using BlockListType = llvm::iplist<Block>;
   BlockListType &getBlocks() { return blocks; }
+  Block &emplaceBlock() {
+    push_back(new Block);
+    return back();
+  }
 
   // Iteration over the blocks in the region.
   using iterator = BlockListType::iterator;
index c1a1fa8074a71c3b362d5e00109ee2023a57c690..a4b77d65cd65445a6787b441da00af6849221e3a 100644 (file)
@@ -28,6 +28,16 @@ enum class RegionKind {
   Graph,
 };
 
+namespace OpTrait {
+/// A trait that specifies that an operation only defines graph regions.
+template <typename ConcreteType>
+class HasOnlyGraphRegion : public TraitBase<ConcreteType, HasOnlyGraphRegion> {
+public:
+  static RegionKind getRegionKind(unsigned index) { return RegionKind::Graph; }
+  static bool hasSSADominance(unsigned index) { return false; }
+};
+} // namespace OpTrait
+
 } // namespace mlir
 
 #include "mlir/IR/RegionKindInterface.h.inc"
index 1a6f739be172d58bd932550516735fb8fb37a8cc..59b235186455f9d8eee9a22ec5ea9d5518ba4a27 100644 (file)
@@ -50,4 +50,17 @@ def RegionKindInterface : OpInterface<"RegionKindInterface"> {
   ];
 }
 
+def HasOnlyGraphRegion : NativeOpTrait<"HasOnlyGraphRegion">;
+
+// Op's regions that don't need a terminator: requires some other traits
+// so it defines a list that must be concatenated.
+def GraphRegionNoTerminator {
+  list<OpTrait> traits = [
+    NoTerminator,
+    SingleBlock,
+    RegionKindInterface,
+    HasOnlyGraphRegion
+  ];
+}
+
 #endif // MLIR_IR_REGIONKINDINTERFACE
index cec60474d23d2a7426b31cc83583214d6bb97051..907f31824628b0b4c702c6585e84ae01d7f9fa8a 100644 (file)
@@ -25,6 +25,7 @@ class StringRef;
 
 namespace mlir {
 namespace detail {
+
 /// Given a block containing operations that have just been parsed, if the block
 /// contains a single operation of `ContainerOpT` type then remove it from the
 /// block and return it. If the block does not contain just that operation,
@@ -37,12 +38,11 @@ inline OwningOpRef<ContainerOpT> constructContainerOpForParserIfNecessary(
     Block *parsedBlock, MLIRContext *context, Location sourceFileLoc) {
   static_assert(
       ContainerOpT::template hasTrait<OpTrait::OneRegion>() &&
-          std::is_base_of<typename OpTrait::SingleBlockImplicitTerminator<
-                              typename ContainerOpT::ImplicitTerminatorOpT>::
-                              template Impl<ContainerOpT>,
-                          ContainerOpT>::value,
+          (ContainerOpT::template hasTrait<OpTrait::NoTerminator>() ||
+           OpTrait::template hasSingleBlockImplicitTerminator<
+               ContainerOpT>::value),
       "Expected `ContainerOpT` to have a single region with a single "
-      "block that has an implicit terminator");
+      "block that has an implicit terminator or does not require one");
 
   // Check to see if we parsed a single instance of this operation.
   if (llvm::hasSingleElement(*parsedBlock)) {
index dc1d37e766d03ddd78e4a7b5ac6aaac6cb4cf399..6598efe3e082830de64f7ac18472c836bca4caec 100644 (file)
@@ -16,8 +16,6 @@ class ModuleOp:
     super().__init__(self.build_generic(results=[], operands=[], loc=loc,
                                         ip=ip))
     body = self.regions[0].blocks.append()
-    with InsertionPoint(body):
-      Operation.create("module_terminator")
 
   @property
   def body(self):
index 4452dda43f331b3bd11580c15e3e899a829337f5..7a24b75640ec1d4f5c91b46c32a26178b4f23257 100644 (file)
@@ -156,8 +156,8 @@ struct AsyncAPI {
 
 /// Adds Async Runtime C API declarations to the module.
 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
-  auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
-                                                         module.getBody());
+  auto builder =
+      ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
 
   auto addFuncDecl = [&](StringRef name, FunctionType type) {
     if (module.lookupSymbol(name))
@@ -207,8 +207,8 @@ static void addCRuntimeDeclarations(ModuleOp module) {
   using namespace mlir::LLVM;
 
   MLIRContext *ctx = module.getContext();
-  ImplicitLocOpBuilder builder(module.getLoc(),
-                               module.getBody()->getTerminator());
+  auto builder =
+      ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
 
   auto voidTy = LLVMVoidType::get(ctx);
   auto i64 = IntegerType::get(ctx, 64);
@@ -232,15 +232,14 @@ static void addResumeFunction(ModuleOp module) {
     return;
 
   MLIRContext *ctx = module.getContext();
-
-  OpBuilder moduleBuilder(module.getBody()->getTerminator());
-  Location loc = module.getLoc();
+  auto loc = module.getLoc();
+  auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
 
   auto voidTy = LLVM::LLVMVoidType::get(ctx);
   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
 
   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
-      loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
+      kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
   resumeOp.setPrivate();
 
   auto *block = resumeOp.addEntryBlock();
index 81c93987595397a639f9f4a58e6b109c84379249..60a47a9d7befcd71d6c553c449d8795481129d0f 100644 (file)
@@ -342,7 +342,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
   auto function = [&] {
     if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
       return function;
-    return OpBuilder(module.getBody()->getTerminator())
+    return OpBuilder::atBlockEnd(module.getBody())
         .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
   }();
   return builder.create<LLVM::CallOp>(
index 5b62ca455dea305b67c8399759d60856a40e6f1b..ba65865b6095e968558b968ee6c6a22f6e161be1 100644 (file)
@@ -99,7 +99,7 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnOperation() {
 
 LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
     Location loc, gpu::LaunchFuncOp launchOp) {
-  OpBuilder builder(getOperation().getBody()->getTerminator());
+  auto builder = OpBuilder::atBlockEnd(getOperation().getBody());
 
   // Workgroup size is written into the kernel. So to properly modelling
   // vulkan launch, we have to skip local workgroup size configuration here.
index 47968ea458ce9fee5fd2e8019399c866df244ee5..118941539e0c2ba036f3c0ff2a2a8698115a3b0e 100644 (file)
@@ -291,7 +291,7 @@ LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
 
 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
   ModuleOp module = getOperation();
-  OpBuilder builder(module.getBody()->getTerminator());
+  auto builder = OpBuilder::atBlockEnd(module.getBody());
 
   if (!module.lookupSymbol(kSetEntryPoint)) {
     builder.create<LLVM::LLVMFuncOp>(
index f55c5a814bed79dd1fc832ef456b05eb6c333358..2eccac860bf24d28697370ddced0d8453366d72d 100644 (file)
@@ -227,7 +227,7 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
 
   LLVMConversionTarget target(getContext());
   target.addIllegalOp<RangeOp>();
-  target.addLegalOp<ModuleOp, ModuleTerminatorOp, LLVM::DialectCastOp>();
+  target.addLegalOp<ModuleOp, LLVM::DialectCastOp>();
   if (failed(applyPartialConversion(module, target, std::move(patterns))))
     signalPassFailure();
 }
index d91444d42af889a06cdeabe8c7475ccdded6527d..6038a9841cf95dc84d9eb9a4ea4fcf7e507bd6f5 100644 (file)
@@ -35,7 +35,7 @@ void LinalgToSPIRVPass::runOnOperation() {
   populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
 
   // Allow builtin ops.
-  target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
+  target->addLegalOp<ModuleOp>();
   target->addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
     return typeConverter.isSignatureLegal(op.getType()) &&
            typeConverter.isLegal(&op.getBody());
index 36d484fafe66484bdf2c92ab840a6d23e7519d70..5260c503cce92149708eb349168749f6f60996ed 100644 (file)
@@ -216,7 +216,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
   ConversionTarget target(getContext());
   target.addLegalDialect<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
                          StandardOpsDialect>();
-  target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
+  target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
   target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
   RewritePatternSet patterns(&getContext());
   populateLinalgToStandardConversionPatterns(patterns);
index d3fc60a5eb6b3f315be92caff58d2e98518ff96c..172f63ba2268890ae9f0acd4faf62137b199c85c 100644 (file)
@@ -1358,7 +1358,7 @@ public:
   matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
 
-    rewriter.replaceOpWithNewOp<ModuleTerminatorOp>(moduleEndOp);
+    rewriter.eraseOp(moduleEndOp);
     return success();
   }
 };
index f064bb4fc2ad97895a1d64e879987764c9e6dcad..293128c8446bda0bbc543e739aea2dd2b3d4a5d0 100644 (file)
@@ -48,10 +48,8 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
   target.addIllegalDialect<spirv::SPIRVDialect>();
   target.addLegalDialect<LLVM::LLVMDialect>();
 
-  // Set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module`
-  // conversion.
+  // Set `ModuleOp` as legal for `spv.module` conversion.
   target.addLegalOp<ModuleOp>();
-  target.addLegalOp<ModuleTerminatorOp>();
   if (failed(applyPartialConversion(module, target, std::move(patterns))))
     signalPassFailure();
 }
index 2626995b3c93c0bca3bda33d2ca96c2aa5c2eb64..e0342f6162c533c4def0f481a6f46700ad3e16cc 100644 (file)
@@ -675,7 +675,7 @@ void ConvertShapeToStandardPass::runOnOperation() {
   ConversionTarget target(ctx);
   target.addLegalDialect<memref::MemRefDialect, StandardOpsDialect, SCFDialect,
                          tensor::TensorDialect>();
-  target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
+  target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>();
 
   // Setup conversion patterns.
   RewritePatternSet patterns(&ctx);
index 1915b499fbdb13f909d7b56f3ab0bbcf2fd866cf..e170df2948fef69eaac9f02299ac062eb73bd813 100644 (file)
@@ -40,7 +40,7 @@ void LowerVectorToSPIRVPass::runOnOperation() {
   RewritePatternSet patterns(context);
   populateVectorToSPIRVPatterns(typeConverter, patterns);
 
-  target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
+  target->addLegalOp<ModuleOp>();
   target->addLegalOp<FuncOp>();
 
   if (failed(applyFullConversion(module, *target, std::move(patterns))))
index d511b4f8be5d766bf44fed907a548d17fccb5e7a..268c603c64466a588e0a874697d9da503a58b82f 100644 (file)
@@ -199,7 +199,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
   // TODO: Derive outlined function name from the parent FuncOp (support
   // multiple nested async.execute operations).
   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
-  symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator()));
+  symbolTable.insert(func);
 
   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
 
index 21ca1c3a82c2afd442d398ab95a64291416bdc40..1b07ccd9c1eddb1083ebd4dd38642fba9bf3ce5e 100644 (file)
@@ -42,8 +42,7 @@ struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
 
     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
     populateReturnOpTypeConversionPattern(patterns, typeConverter);
-    target.addLegalOp<ModuleOp, ModuleTerminatorOp, memref::TensorLoadOp,
-                      memref::BufferCastOp>();
+    target.addLegalOp<ModuleOp, memref::TensorLoadOp, memref::BufferCastOp>();
 
     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
       return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
index 8cd20c7777adf669be6eefb221a6afda42a49028..bae9c97b1b40f5f3f08d2c483098ed7612998f64 100644 (file)
@@ -418,7 +418,8 @@ private:
 
   /// Print the given region.
   void printRegion(Region &region, bool printEntryBlockArgs,
-                   bool printBlockTerminators) override {
+                   bool printBlockTerminators,
+                   bool printEmptyBlock = false) override {
     if (region.empty())
       return;
 
@@ -2324,7 +2325,7 @@ public:
 
   /// Print the given region.
   void printRegion(Region &region, bool printEntryBlockArgs,
-                   bool printBlockTerminators) override;
+                   bool printBlockTerminators, bool printEmptyBlock) override;
 
   /// Renumber the arguments for the specified region to the same names as the
   /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
@@ -2440,7 +2441,7 @@ void OperationPrinter::printGenericOp(Operation *op) {
     os << " (";
     interleaveComma(op->getRegions(), [&](Region &region) {
       printRegion(region, /*printEntryBlockArgs=*/true,
-                  /*printBlockTerminators=*/true);
+                  /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
     });
     os << ')';
   }
@@ -2541,12 +2542,18 @@ void OperationPrinter::printSuccessorAndUseList(Block *successor,
 }
 
 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
-                                   bool printBlockTerminators) {
+                                   bool printBlockTerminators,
+                                   bool printEmptyBlock) {
   os << " {" << newLine;
   if (!region.empty()) {
     auto *entryBlock = &region.front();
-    print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0,
-          printBlockTerminators);
+    // Force printing the block header if printEmptyBlock is set and the block
+    // is empty or if printEntryBlockArgs is set and there are arguments to
+    // print.
+    bool shouldAlwaysPrintBlockHeader =
+        (printEmptyBlock && entryBlock->empty()) ||
+        (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
+    print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
     for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
       print(&b);
   }
index 07e2e5c007fe6a86c976931db399db970df13164..a24b639f7a45b02f0d87a77c6f273ace27be0370 100644 (file)
@@ -294,6 +294,21 @@ Block *Block::splitBlock(iterator splitBefore) {
   return newBB;
 }
 
+/// Returns true if this block may be valid without terminator. That is if:
+/// - it does not have a parent region.
+/// - Or the parent region have a single block and:
+///    - This region does not have a parent op.
+///    - Or the parent op is unregistered.
+///    - Or the parent op has the NoTerminator trait.
+static bool mayNotHaveTerminator(Block *block) {
+  if (!block->getParent())
+    return true;
+  if (!llvm::hasSingleElement(*block->getParent()))
+    return false;
+  Operation *op = block->getParentOp();
+  return !op || op->mightHaveTrait<OpTrait::NoTerminator>();
+}
+
 //===----------------------------------------------------------------------===//
 // Predecessors
 //===----------------------------------------------------------------------===//
@@ -314,9 +329,11 @@ unsigned PredecessorIterator::getSuccessorIndex() const {
 SuccessorRange::SuccessorRange() : SuccessorRange(nullptr, 0) {}
 
 SuccessorRange::SuccessorRange(Block *block) : SuccessorRange() {
-  if (Operation *term = block->getTerminator())
+  if (!llvm::hasSingleElement(*block->getParent())) {
+    Operation *term = block->getTerminator();
     if ((count = term->getNumSuccessors()))
       base = term->getBlockOperands().data();
+  }
 }
 
 SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() {
index 1035961f51c1c65d0184bc9c139a0f0d8913c1f3..e1706f2c93152612cce7c08c1e29fe63a894d31d 100644 (file)
@@ -209,7 +209,7 @@ FuncOp FuncOp::clone() {
 
 void ModuleOp::build(OpBuilder &builder, OperationState &state,
                      Optional<StringRef> name) {
-  ensureTerminator(*state.addRegion(), builder, state.location);
+  state.addRegion()->emplaceBlock();
   if (name) {
     state.attributes.push_back(builder.getNamedAttr(
         mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)));
index 4620a5bcb381253cda11f927e7f1b38e6b3f6382..8d5ba2e1622443f845b58928bf9b3a92652ef3da 100644 (file)
@@ -161,11 +161,17 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
   // TODO: consider if SymbolTable's constructor should behave the same.
   if (!symbol->getParentOp()) {
     auto &body = symbolTableOp->getRegion(0).front();
-    if (insertPt == Block::iterator() || insertPt == body.end())
-      insertPt = Block::iterator(body.getTerminator());
-
-    assert(insertPt->getParentOp() == symbolTableOp &&
-           "expected insertPt to be in the associated module operation");
+    if (insertPt == Block::iterator()) {
+      insertPt = Block::iterator(body.end());
+    } else {
+      assert((insertPt == body.end() ||
+              insertPt->getParentOp() == symbolTableOp) &&
+             "expected insertPt to be in the associated module operation");
+    }
+    // Insert before the terminator, if any.
+    if (insertPt == Block::iterator(body.end()) && !body.empty() &&
+        std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
+      insertPt = std::prev(body.end());
 
     body.getOperations().insert(insertPt, symbol);
   }
@@ -291,11 +297,14 @@ void SymbolTable::walkSymbolTables(
 Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
                                        StringRef symbol) {
   assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
+  Region &region = symbolTableOp->getRegion(0);
+  if (region.empty())
+    return nullptr;
 
   // Look for a symbol with the given name.
   Identifier symbolNameId = Identifier::get(SymbolTable::getSymbolAttrName(),
                                             symbolTableOp->getContext());
-  for (auto &op : symbolTableOp->getRegion(0).front().without_terminator())
+  for (auto &op : region.front())
     if (getNameIfSymbol(&op, symbolNameId) == symbol)
       return &op;
   return nullptr;
index 6aadab97023fb3e139b7dbfeb33578f7385afd64..c01581d45b4cc1c5fb8db5f2ebc58cb983cd64b6 100644 (file)
@@ -113,17 +113,36 @@ LogicalResult OperationVerifier::verifyRegion(Region &region) {
   return success();
 }
 
+/// Returns true if this block may be valid without terminator. That is if:
+/// - it does not have a parent region.
+/// - Or the parent region have a single block and:
+///    - This region does not have a parent op.
+///    - Or the parent op is unregistered.
+///    - Or the parent op has the NoTerminator trait.
+static bool mayNotHaveTerminator(Block *block) {
+  if (!block->getParent())
+    return true;
+  if (!llvm::hasSingleElement(*block->getParent()))
+    return false;
+  Operation *op = block->getParentOp();
+  return !op || op->mightHaveTrait<OpTrait::NoTerminator>();
+}
+
 LogicalResult OperationVerifier::verifyBlock(Block &block) {
   for (auto arg : block.getArguments())
     if (arg.getOwner() != &block)
       return emitError(block, "block argument not owned by block");
 
   // Verify that this block has a terminator.
-  if (block.empty())
-    return emitError(block, "block with no terminator");
+
+  if (block.empty()) {
+    if (mayNotHaveTerminator(&block))
+      return success();
+    return emitError(block, "empty block: expect at least a terminator");
+  }
 
   // Verify the non-terminator operations separately so that we can verify
-  // they has no successors.
+  // they have no successors.
   for (auto &op : llvm::make_range(block.begin(), std::prev(block.end()))) {
     if (op.getNumSuccessors() != 0)
       return op.emitError(
@@ -137,8 +156,13 @@ LogicalResult OperationVerifier::verifyBlock(Block &block) {
   Operation &terminator = block.back();
   if (failed(verifyOperation(terminator)))
     return failure();
+
+  if (mayNotHaveTerminator(&block))
+    return success();
+
   if (!terminator.mightHaveTrait<OpTrait::IsTerminator>())
-    return block.back().emitError("block with no terminator");
+    return block.back().emitError("block with no terminator, has ")
+           << terminator;
 
   // Verify that this block is not branching to a block of a different
   // region.
@@ -176,13 +200,14 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
   unsigned numRegions = op.getNumRegions();
   for (unsigned i = 0; i < numRegions; i++) {
     Region &region = op.getRegion(i);
+    RegionKind kind =
+        kindInterface ? kindInterface.getRegionKind(i) : RegionKind::SSACFG;
     // Check that Graph Regions only have a single basic block. This is
     // similar to the code in SingleBlockImplicitTerminator, but doesn't
     // require the trait to be specified. This arbitrary limitation is
     // designed to limit the number of cases that have to be handled by
     // transforms and conversions until the concept stabilizes.
-    if (op.isRegistered() && kindInterface &&
-        kindInterface.getRegionKind(i) == RegionKind::Graph) {
+    if (op.isRegistered() && kind == RegionKind::Graph) {
       // Empty regions are fine.
       if (region.empty())
         continue;
index ad80204ac496d75380e67f0a7b1c3d5bc5cc9f69..381338d4226debcda51bbce688db5f0a1cccbd56 100644 (file)
@@ -2121,7 +2121,7 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
       auto &parsedOps = (*topLevelOp)->getRegion(0).front().getOperations();
       auto &destOps = topLevelBlock->getOperations();
       destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()),
-                     parsedOps, parsedOps.begin(), std::prev(parsedOps.end()));
+                     parsedOps, parsedOps.begin(), parsedOps.end());
       return success();
     }
 
index 28b58cffe17001bd4fdfb340863bd333fe6eb981..8044ff62f6427bae6b2b75b3edf54b16fccfc054 100644 (file)
@@ -269,10 +269,11 @@ private:
 
   /// Globals are inserted before the first function, if any.
   Block::iterator getGlobalInsertPt() {
-    auto i = module.getBody()->begin();
-    while (!isa<LLVMFuncOp, ModuleTerminatorOp>(i))
-      ++i;
-    return i;
+    auto it = module.getBody()->begin();
+    auto endIt = module.getBody()->end();
+    while (it != endIt && !isa<LLVMFuncOp>(it))
+      ++it;
+    return it;
   }
 
   /// Functions are always inserted before the module terminator.
index 5d94245489c03ac19c724c821fff30f07cdb1af0..2c65d635d4c4466a046bffcd962003a49a45d2f0 100644 (file)
@@ -61,8 +61,7 @@ void SymbolDCE::runOnOperation() {
     if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
       return;
     for (auto &block : nestedSymbolTable->getRegion(0)) {
-      for (Operation &op :
-           llvm::make_early_inc_range(block.without_terminator())) {
+      for (Operation &op : llvm::make_early_inc_range(block)) {
         if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op))
           op.erase();
       }
@@ -84,7 +83,7 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
   // are known to be live.
   for (auto &block : symbolTableOp->getRegion(0)) {
     // Add all non-symbols or symbols that can't be discarded.
-    for (Operation &op : block.without_terminator()) {
+    for (Operation &op : block) {
       SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
       if (!symbol) {
         worklist.push_back(&op);
index 47635c3bbf4956cd618ac0b7126a8fb5cf527ee6..f0e03c537a9c9bba5b7c6bf827c5ddc36355e996 100644 (file)
@@ -314,6 +314,7 @@ static LogicalResult deleteDeadness(RewriterBase &rewriter,
   for (Region &region : regions) {
     if (region.empty())
       continue;
+    bool hasSingleBlock = llvm::hasSingleElement(region);
 
     // Delete every operation that is not live. Graph regions may have cycles
     // in the use-def graph, so we must explicitly dropAllUses() from each
@@ -321,7 +322,8 @@ static LogicalResult deleteDeadness(RewriterBase &rewriter,
     // guarantees that in SSA CFG regions value uses are removed before defs,
     // which makes dropAllUses() a no-op.
     for (Block *block : llvm::post_order(&region.front())) {
-      eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
+      if (!hasSingleBlock)
+        eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
       for (Operation &childOp :
            llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
         if (!liveMap.wasProvenLive(&childOp)) {
index 33a89381a4164884ab8c11a6624c75ced9f3e8b9..b93fcf70ac4820c68289b036afe530db2927784a 100644 (file)
@@ -62,7 +62,7 @@ run(testLocationEnterExit)
 def testInsertionPointEnterExit():
   ctx1 = Context()
   m = Module.create(Location.unknown(ctx1))
-  ip = InsertionPoint.at_block_terminator(m.body)
+  ip = InsertionPoint(m.body)
 
   with ip:
     assert InsertionPoint.current is ip
index 128f64cab199c79a9d3e412a1470358c99444f36..41f4239e2b660db931bc2f834caa8f180e61b44a 100644 (file)
@@ -77,7 +77,7 @@ def testCustomOpView():
     ctx.allow_unregistered_dialects = True
     m = Module.create()
 
-    with InsertionPoint.at_block_terminator(m.body):
+    with InsertionPoint(m.body):
       f32 = F32Type.get()
       # Create via dialects context collection.
       input1 = createInput()
index 80dea68bae3662e1bd2a2af2886790c98c8f0ee8..1f4847dce892c81a2f42eea3be3ebf0ec599e18d 100644 (file)
@@ -18,7 +18,7 @@ def testFromPyFunc():
     m = builtin.ModuleOp()
     f32 = F32Type.get()
     f64 = F64Type.get()
-    with InsertionPoint.at_block_terminator(m.body):
+    with InsertionPoint(m.body):
       # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
       # CHECK: return %arg0 : f64
       @builtin.FuncOp.from_py_func(f64)
@@ -95,7 +95,7 @@ def testFromPyFuncErrors():
     m = builtin.ModuleOp()
     f32 = F32Type.get()
     f64 = F64Type.get()
-    with InsertionPoint.at_block_terminator(m.body):
+    with InsertionPoint(m.body):
       try:
 
         @builtin.FuncOp.from_py_func(f64, results=[f64])
index 573999c975258ed22bcf3fe1cdc47a58ee5aeb6a..397b32c93c22b087e6ccaa27fc083caf5c1b574e 100644 (file)
@@ -32,7 +32,7 @@ with Context() as ctx, Location.unknown():
   i8 = IntegerType.get_signless(8)
   i16 = IntegerType.get_signless(16)
   i32 = IntegerType.get_signless(32)
-  with InsertionPoint.at_block_terminator(module.body):
+  with InsertionPoint(module.body):
 
     # Note that these all have the same indexing maps. We verify the first and
     # then do more permutation tests on casting and body generation
index 0615dd37bfdded62fdaacf247362301446b844e2..04a6ac8def843038d81b428649564a8ff4319732 100644 (file)
@@ -17,7 +17,7 @@ def testStructuredOpOnTensors():
     module = Module.create()
     f32 = F32Type.get()
     tensor_type = RankedTensorType.get((2, 3, 4), f32)
-    with InsertionPoint.at_block_terminator(module.body):
+    with InsertionPoint(module.body):
       func = builtin.FuncOp(name="matmul_test",
                             type=FunctionType.get(
                                 inputs=[tensor_type, tensor_type],
@@ -40,7 +40,7 @@ def testStructuredOpOnBuffers():
     module = Module.create()
     f32 = F32Type.get()
     memref_type = MemRefType.get((2, 3, 4), f32)
-    with InsertionPoint.at_block_terminator(module.body):
+    with InsertionPoint(module.body):
       func = builtin.FuncOp(name="matmul_test",
                             type=FunctionType.get(
                                 inputs=[memref_type, memref_type, memref_type],
index 1a2de37428cc2b86d30d93736b4fd31a064a768c..2e53aa64b999f8989e9f9b793d35e238a7b4ec51 100644 (file)
@@ -129,8 +129,13 @@ run(test_insert_at_block_terminator_missing)
 def test_insert_at_end_with_terminator_errors():
   with Context() as ctx, Location.unknown():
     ctx.allow_unregistered_dialects = True
-    m = Module.create()  # Module is created with a terminator.
-    with InsertionPoint(m.body):
+    module = Module.parse(r"""
+      func @foo() -> () {
+        return
+      }
+    """)
+    entry_block = module.body.operations[0].regions[0].blocks[0]
+    with InsertionPoint(entry_block):
       try:
         Operation.create("custom.op1", results=[], operands=[])
       except IndexError as e:
index d154e77077efe588de1d3fbe07ad5592b26f7db8..847c1093cd37fcd8d8ad981b2accf7089e49c42d 100644 (file)
@@ -64,7 +64,6 @@ def testTraverseOpRegionBlockIterators():
   # CHECK:         BLOCK 0:
   # CHECK:           OP 0: %0 = "custom.addi"
   # CHECK:           OP 1: return
-  # CHECK:    OP 1: module_terminator
   walk_operations("", op)
 
 run(testTraverseOpRegionBlockIterators)
@@ -101,7 +100,6 @@ def testTraverseOpRegionBlockIndices():
   # CHECK:         BLOCK 0:
   # CHECK:           OP 0: %0 = "custom.addi"
   # CHECK:           OP 1: return
-  # CHECK:    OP 1: module_terminator
   walk_operations("", module.operation)
 
 run(testTraverseOpRegionBlockIndices)
@@ -546,9 +544,9 @@ run(testSingleResultProperty)
 def testPrintInvalidOperation():
   ctx = Context()
   with Location.unknown(ctx):
-    module = Operation.create("module", regions=1)
-    # This block does not have a terminator, it may crash the custom printer.
-    # Verify that we fallback to the generic printer for safety.
+    module = Operation.create("module", regions=2)
+    # This module has two region and is invalid verify that we fallback
+    # to the generic printer for safety.
     block = module.regions[0].blocks.append()
     # CHECK: // Verification failed, printing generic form
     # CHECK: "module"() ( {
index badeac37034f33402fc0c99ee79b123b2de76d72..5aa9bef22a6ba74bc1ea27488a28d75eae3b10fe 100644 (file)
@@ -29,7 +29,7 @@ def testOdsBuildDefaultImplicitRegions():
   with Context() as ctx, Location.unknown():
     ctx.allow_unregistered_dialects = True
     m = Module.create()
-    with InsertionPoint.at_block_terminator(m.body):
+    with InsertionPoint(m.body):
       op = TestFixedRegionsOp.build_generic(results=[], operands=[])
       # CHECK: NUM_REGIONS: 2
       print(f"NUM_REGIONS: {len(op.regions)}")
@@ -84,7 +84,7 @@ def testOdsBuildDefaultNonVariadic():
   with Context() as ctx, Location.unknown():
     ctx.allow_unregistered_dialects = True
     m = Module.create()
-    with InsertionPoint.at_block_terminator(m.body):
+    with InsertionPoint(m.body):
       v0 = add_dummy_value()
       v1 = add_dummy_value()
       t0 = IntegerType.get_signless(8)
@@ -111,7 +111,7 @@ def testOdsBuildDefaultSizedVariadic():
   with Context() as ctx, Location.unknown():
     ctx.allow_unregistered_dialects = True
     m = Module.create()
-    with InsertionPoint.at_block_terminator(m.body):
+    with InsertionPoint(m.body):
       v0 = add_dummy_value()
       v1 = add_dummy_value()
       v2 = add_dummy_value()
@@ -187,7 +187,7 @@ def testOdsBuildDefaultCastError():
   with Context() as ctx, Location.unknown():
     ctx.allow_unregistered_dialects = True
     m = Module.create()
-    with InsertionPoint.at_block_terminator(m.body):
+    with InsertionPoint(m.body):
       v0 = add_dummy_value()
       v1 = add_dummy_value()
       t0 = IntegerType.get_signless(8)
index 61ff64f67be65ab56b80bf8e3b05fae15ff56e00..35e6d980c9f39d591361b8ba1e53d50eb6b1d7c5 100644 (file)
@@ -91,6 +91,5 @@ def testRunPipeline():
 # CHECK: Operations encountered:
 # CHECK: func              , 1
 # CHECK: module            , 1
-# CHECK: module_terminator , 1
 # CHECK: std.return        , 1
 run(testRunPipeline)
index beb73102615ef46eb4fa8f1a9347e71d6a1bac65..40ef39b19d26607f8ac7184035c5864425a52e92 100644 (file)
@@ -293,7 +293,7 @@ int collectStats(MlirOperation operation) {
   fprintf(stderr, "Number of op results: %u\n", stats.numOpResults);
   // clang-format off
   // CHECK-LABEL: @stats
-  // CHECK: Number of operations: 13
+  // CHECK: Number of operations: 12
   // CHECK: Number of attributes: 4
   // CHECK: Number of blocks: 3
   // CHECK: Number of regions: 3
index b7b9e373feb24f1e167d25b06b9f42eecdfc3482..d73aba1c7379c3fda49c610629a1b45a24f9d691 100644 (file)
@@ -42,7 +42,6 @@ void testRunPassOnModule() {
   // Run the print-op-stats pass on the top-level module:
   // CHECK-LABEL: Operations encountered:
   // CHECK: func              , 1
-  // CHECK: module_terminator , 1
   // CHECK: std.addi          , 1
   // CHECK: std.return        , 1
   {
@@ -84,7 +83,6 @@ void testRunPassOnNestedModule() {
 
   // Run the print-op-stats pass on functions under the top-level module:
   // CHECK-LABEL: Operations encountered:
-  // CHECK-NOT: module_terminator
   // CHECK: func              , 1
   // CHECK: std.addi          , 1
   // CHECK: std.return        , 1
@@ -101,7 +99,6 @@ void testRunPassOnNestedModule() {
   }
   // Run the print-op-stats pass on functions under the nested module:
   // CHECK-LABEL: Operations encountered:
-  // CHECK-NOT: module_terminator
   // CHECK: func              , 1
   // CHECK: std.addf          , 1
   // CHECK: std.return        , 1
index 520821a7b0b42d92a44b90c5ed3edd8bf3bb9d65..741a3a9b2dc94dd4af69a2a42630f52b7bd1bb84 100644 (file)
@@ -19,31 +19,12 @@ func @module_op() {
   // expected-error@+1 {{region should have no arguments}}
   module {
   ^bb1(%arg: i32):
-    "module_terminator"() : () -> ()
-  }
-  return
-}
-
-// -----
-
-func @module_op() {
-  // expected-error@below {{expects regions to end with 'module_terminator'}}
-  // expected-note@below {{the absence of terminator implies 'module_terminator'}}
-  module {
-    return
   }
   return
 }
 
 // -----
 
-func @module_op() {
-  // expected-error@+1 {{expects parent op 'module'}}
-  "module_terminator"() : () -> ()
-}
-
-// -----
-
 // expected-error@+1 {{can only contain attributes with dialect-prefixed names}}
 module attributes {attr} {
 }
index 2909416771fc99f50253c7eeb98f054adda109ef..39e72d2cb8daaaed6b86b9f3f86a15eb04a350a3 100644 (file)
@@ -120,7 +120,7 @@ func @block_redef() {
 
 // -----
 
-func @no_terminator() {   // expected-error {{block with no terminator}}
+func @no_terminator() {   // expected-error {{empty block: expect at least a terminator}}
 ^bb40:
   return
 ^bb41:
index b610c0076ac2da57ed43443a43827a7f81021a6b..d99806c92a902be9bcaf50263920a4348fe79247 100644 (file)
@@ -4,16 +4,14 @@
 module {
 }
 
-// CHECK: module {
-// CHECK-NEXT: }
-module {
-  "module_terminator"() : () -> ()
-}
+// -----
 
 // CHECK: module attributes {foo.attr = true} {
 module attributes {foo.attr = true} {
 }
 
+// -----
+
 // CHECK: module {
 module {
   // CHECK-NEXT: "foo.result_op"() : () -> i32
index 78c5804119250a240587a3aa43e87c1532bfa344..55c8494f83c460acec16ab5d11259c57f305a3bc 100644 (file)
@@ -18,8 +18,6 @@
 // CHECK: Has 0 results:
 // CHECK: Visiting op 'dialect.op3' with 0 operands:
 // CHECK: Has 0 results:
-// CHECK: Visiting op 'module_terminator' with 0 operands:
-// CHECK: Has 0 results:
 // CHECK: Visiting op 'module' with 0 operands:
 // CHECK: Has 0 results:
 
index 4682753947550c2ec73a9d5fce4bed5db9e4afbd..92259a6e04561a6718ecf64206476eb268c6eb7a 100644 (file)
@@ -3,7 +3,7 @@
 // CHECK: visiting op: 'module' with 0 operands and 0 results
 // CHECK:  1 nested regions:
 // CHECK:   Region with 1 blocks:
-// CHECK:     Block with 0 arguments, 0 successors, and 3 operations
+// CHECK:     Block with 0 arguments, 0 successors, and 2 operations
 module {
 
 
@@ -52,6 +52,4 @@ module {
     "dialect.innerop7"() : () -> ()
   }) : () -> ()
 
-// CHECK:       visiting op: 'module_terminator' with 0 operands and 0 results
-
 } // module
index 465ae511aad2329c5abb84d4bd05026509a15391..8f9d707b6f18081c3487cf22c420eed6757f8488 100644 (file)
@@ -73,3 +73,11 @@ func @named_region_has_wrong_number_of_blocks() {
     }) : () -> ()
     return
 }
+
+// -----
+
+// Region with single block and not terminator.
+// CHECK: unregistered_without_terminator
+"test.unregistered_without_terminator"() ( {
+  ^bb0:  // no predecessors
+}) : () -> ()
index cd0b936ada78e41f65a434d261c0a41ce03329d3..ca3c72cae6ff906883865bf1cc86bfd15be449a8 100644 (file)
@@ -1,6 +1,5 @@
 // RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s
 // expected-remark@-2 {{op 'module' is legalizable}}
-// expected-remark@-3 {{op 'module_terminator' is legalizable}}
 
 // expected-remark@+1 {{op 'func' is legalizable}}
 func @test(%arg0: f32) {
index c7b4b582d9dd23cbfee3f0c04109eb137501aa54..c0285a3623fa1f87f9c9b63512aa114160b04812 100644 (file)
@@ -33,6 +33,16 @@ void mlir::test::registerTestDialect(DialectRegistry &registry) {
 
 namespace {
 
+/// Testing the correctness of some traits.
+static_assert(
+    llvm::is_detected<OpTrait::has_implicit_terminator_t,
+                      SingleBlockImplicitTerminatorOp>::value,
+    "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
+static_assert(OpTrait::hasSingleBlockImplicitTerminator<
+                  SingleBlockImplicitTerminatorOp>::value,
+              "hasSingleBlockImplicitTerminator does not match "
+              "SingleBlockImplicitTerminatorOp");
+
 // Test support for interacting with the AsmPrinter.
 struct TestOpAsmInterface : public OpAsmDialectInterface {
   using OpAsmDialectInterface::OpAsmDialectInterface;
index ec85b7e38c436c27131476e7bbb1e49ffd843c5d..1cc0c62b8691f5b20564949abd1d59989de0508d 100644 (file)
@@ -573,7 +573,7 @@ struct TestLegalizePatternDriver
 
     // Define the conversion target used for the test.
     ConversionTarget target(getContext());
-    target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
+    target.addLegalOp<ModuleOp>();
     target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
                       TerminatorOp>();
     target
@@ -702,7 +702,7 @@ struct TestRemappedValue
     patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
 
     mlir::ConversionTarget target(getContext());
-    target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
+    target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>();
     // We make OneVResOneVOperandOp1 legal only when it has more that one
     // operand. This will trigger the conversion that will replace one-operand
     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
@@ -969,9 +969,8 @@ struct TestMergeBlocksPatternDriver
     patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
         context);
     ConversionTarget target(*context);
-    target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp,
-                      TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp,
-                      TestReturnOp>();
+    target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
+                      TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
     target.addIllegalOp<ILLegalOpF>();
 
     /// Expect the op to have a single block after legalization.
index 7bf298904780af08e1c97ebdc7d03b3b84bc3055..3ccfa93e0a0d379593cee9c9525be84d74449bc5 100644 (file)
@@ -56,7 +56,7 @@ void TestConvVectorization::runOnOperation() {
   ConversionTarget target(*context);
   target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect,
                          VectorDialect>();
-  target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
+  target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
   target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
 
   SmallVector<RewritePatternSet, 4> stage1Patterns;
index abf77a55004ecc538eb449a4d8a5d3de06c699e2..e5d8db9342cc1b7be4f813518462ce076ca2d762 100644 (file)
@@ -449,6 +449,14 @@ struct OperationFormat {
         llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
           return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
         });
+
+    hasSingleBlockTrait =
+        hasImplicitTermTrait ||
+        llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
+          if (auto *native = dyn_cast<NativeOpTrait>(&trait))
+            return native->getTrait() == "::mlir::OpTrait::SingleBlock";
+          return false;
+        });
   }
 
   /// Generate the operation parser from this format.
@@ -484,6 +492,9 @@ struct OperationFormat {
   /// trait.
   bool hasImplicitTermTrait;
 
+  /// A flag indicating if this operation has the SingleBlock trait.
+  bool hasSingleBlockTrait;
+
   /// A map of buildable types to indices.
   llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
 
@@ -679,6 +690,14 @@ const char *regionListEnsureTerminatorParserCode = R"(
     ensureTerminator(*region, parser.getBuilder(), result.location);
 )";
 
+/// The code snippet used to ensure a list of regions have a block.
+///
+/// {0}: The name of the region list.
+const char *regionListEnsureSingleBlockParserCode = R"(
+  for (auto &region : {0}Regions)
+    if (region.empty()) *{0}Region.emplaceBlock();
+)";
+
 /// The code snippet used to generate a parser call for an optional region.
 ///
 /// {0}: The name of the region.
@@ -705,6 +724,13 @@ const char *regionEnsureTerminatorParserCode = R"(
   ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
 )";
 
+/// The code snippet used to ensure a region has a block.
+///
+/// {0}: The name of the region.
+const char *regionEnsureSingleBlockParserCode = R"(
+  if ({0}Region->empty()) {0}Region->emplaceBlock();
+)";
+
 /// The code snippet used to generate a parser call for a successor list.
 ///
 /// {0}: The name for the successor list.
@@ -1134,6 +1160,9 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
         body << "  if (!" << region->name << "Region->empty()) {\n  ";
         if (hasImplicitTermTrait)
           body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
+        else if (hasSingleBlockTrait)
+          body << llvm::formatv(regionEnsureSingleBlockParserCode,
+                                region->name);
       }
     }
 
@@ -1217,11 +1246,14 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
     bool isVariadic = region->getVar()->isVariadic();
     body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
                           region->getVar()->name);
-    if (hasImplicitTermTrait) {
+    if (hasImplicitTermTrait)
       body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
                                        : regionEnsureTerminatorParserCode,
                             region->getVar()->name);
-    }
+    else if (hasSingleBlockTrait)
+      body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode
+                                       : regionEnsureSingleBlockParserCode,
+                            region->getVar()->name);
 
   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
     bool isVariadic = successor->getVar()->isVariadic();
@@ -1246,6 +1278,8 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
     body << llvm::formatv(regionListParserCode, "full");
     if (hasImplicitTermTrait)
       body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
+    else if (hasSingleBlockTrait)
+      body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full");
 
   } else if (isa<SuccessorsDirective>(element)) {
     body << llvm::formatv(successorListParserCode, "full");