Generalize implicit terminator into an OpTrait
authorAlex Zinenko <zinenko@google.com>
Fri, 19 Jul 2019 13:35:10 +0000 (06:35 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 19 Jul 2019 18:40:51 +0000 (11:40 -0700)
Several groups of operations in different dialects (e.g. AffineForOp,
AffineIfOp; loop::ForOp, loop::IfOp) share the requirement for their regions to
contain 0 or 1 block, and for blocks to always have a specific terminator type.
Furthermore, this terminator may be omitted from the custom syntax.  Generalize
this behavior into OpTrait::SingleBlockImplicitTerminator, parameterized by the
terminator operation type.  This trait provides the verifier that checks the
presence of the terminator, and utility functions adding the terminator in case
of absence.

PiperOrigin-RevId: 258957180

mlir/include/mlir/AffineOps/AffineOps.h
mlir/include/mlir/AffineOps/AffineOps.td
mlir/include/mlir/Dialect/LoopOps/LoopOps.h
mlir/include/mlir/Dialect/LoopOps/LoopOps.td
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/AffineOps/AffineOps.cpp
mlir/lib/Dialect/LoopOps/LoopOps.cpp
mlir/test/IR/invalid-ops.mlir

index cb82361..59f7fc7 100644 (file)
@@ -32,6 +32,7 @@
 namespace mlir {
 class AffineBound;
 class AffineValueMap;
+class AffineTerminatorOp;
 class FlatAffineConstraints;
 class OpBuilder;
 
index 66286ec..a23d1a3 100644 (file)
@@ -47,7 +47,9 @@ class Affine_Op<string mnemonic, list<OpTrait> traits = []> :
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
-def AffineForOp : Affine_Op<"for"> {
+def AffineForOp : Affine_Op<"for",
+    [NativeOpTrait<"SingleBlockImplicitTerminator<AffineTerminatorOp>::Impl">]>
+  {
   let summary = "for operation";
   let description = [{
     The "affine.for" operation represents an affine loop nest, defining an SSA
@@ -181,7 +183,9 @@ def AffineForOp : Affine_Op<"for"> {
   let hasCanonicalizer = 1;
 }
 
-def AffineIfOp : Affine_Op<"if"> {
+def AffineIfOp : Affine_Op<"if",
+    [NativeOpTrait<"SingleBlockImplicitTerminator<AffineTerminatorOp>::Impl">]>
+  {
   let summary = "if-then-else operation";
   let description = [{
     The "if" operation represents an if-then-else construct for conditionally
index ac6189b..90cc0b7 100644 (file)
@@ -30,6 +30,8 @@
 namespace mlir {
 namespace loop {
 
+class TerminatorOp;
+
 class LoopOpsDialect : public Dialect {
 public:
   LoopOpsDialect(MLIRContext *context);
index eb937de..58ffbc3 100644 (file)
@@ -47,7 +47,8 @@ class Loop_Op<string mnemonic, list<OpTrait> traits = []> :
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
-def ForOp : Loop_Op<"for"> {
+def ForOp : Loop_Op<"for",
+      [NativeOpTrait<"SingleBlockImplicitTerminator<TerminatorOp>::Impl">]> {
   let summary = "for operation";
   let description = [{
     The "loop.for" operation represents a loop nest taking 3 SSA value as
@@ -90,7 +91,8 @@ def ForOp : Loop_Op<"for"> {
   }];
 }
 
-def IfOp : Loop_Op<"if"> {
+def IfOp : Loop_Op<"if",
+      [NativeOpTrait<"SingleBlockImplicitTerminator<TerminatorOp>::Impl">]> {
   let summary = "if-then-else operation";
   let description = [{
     The "loop.if" operation represents an if-then-else construct for
index c275d03..2f33475 100644 (file)
@@ -54,6 +54,27 @@ public:
   explicit operator bool() const { return failed(*this); }
 };
 
+// These functions are out-of-line utilities, which avoids them being template
+// instantiated/duplicated.
+namespace impl {
+/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
+/// region's only block if it does not have a terminator already. If the region
+/// is empty, insert a new block first. `buildTerminatorOp` should return the
+/// terminator operation to insert.
+void ensureRegionTerminator(
+    Region &region, Location loc,
+    llvm::function_ref<Operation *()> buildTerminatorOp);
+/// Templated version that fills the generates the provided operation type.
+template <typename OpTy>
+void ensureRegionTerminator(Region &region, Builder &builder, Location loc) {
+  ensureRegionTerminator(region, loc, [&] {
+    OperationState state(loc, OpTy::getOperationName());
+    OpTy::build(&builder, &state);
+    return Operation::create(state);
+  });
+}
+} // namespace impl
+
 /// This is the concrete base class that holds the operation pointer and has
 /// non-generic methods that only depend on State (to avoid having them
 /// instantiated on template types that don't affect them.
@@ -773,6 +794,54 @@ public:
   }
 };
 
+/// This class provides APIs and verifiers for ops with regions having a single
+/// block that must terminate with `TerminatorOpType`.
+template <typename TerminatorOpType> struct SingleBlockImplicitTerminator {
+  template <typename ConcreteType>
+  class Impl : public TraitBase<ConcreteType, Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
+        Region &region = op->getRegion(i);
+
+        // Empty regions are fine.
+        if (region.empty())
+          continue;
+
+        // Non-empty regions must contain a single basic block.
+        if (std::next(region.begin()) != region.end())
+          return op->emitOpError("expects region #")
+                 << i << " to have 0 or 1 blocks";
+
+        // The block must terminate with TerminatorOpType.  If the block is
+        // empty, silently fail, the general block well-formedness verifier
+        // should complain instead.
+        Block &block = region.front();
+        if (block.empty())
+          return failure();
+        if (isa<TerminatorOpType>(block.back()))
+          continue;
+
+        return op->emitOpError("expects regions to end with '" +
+                               TerminatorOpType::getOperationName() + "'")
+                   .attachNote()
+               << "in custom textual format, the absence of terminator implies "
+                  "'"
+               << TerminatorOpType::getOperationName() << '\'';
+      }
+
+      return success();
+    }
+
+    /// Ensure that the given region has the terminator required by this trait.
+    static void ensureTerminator(Region &region, Builder &builder,
+                                 Location loc) {
+      ::mlir::impl::template ensureRegionTerminator<TerminatorOpType>(
+          region, builder, loc);
+    }
+  };
+};
+
 } // end namespace OpTrait
 
 //===----------------------------------------------------------------------===//
@@ -934,27 +1003,6 @@ ParseResult parseCastOp(OpAsmParser *parser, OperationState *result);
 void printCastOp(Operation *op, OpAsmPrinter *p);
 Value *foldCastOp(Operation *op);
 } // namespace impl
-
-// These functions are out-of-line utilities, which avoids them being template
-// instantiated/duplicated.
-namespace impl {
-/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
-/// region's only block if it does not have a terminator already. If the region
-/// is empty, insert a new block first. `buildTerminatorOp` should return the
-/// terminator operation to insert.
-void ensureRegionTerminator(
-    Region &region, Location loc,
-    llvm::function_ref<Operation *()> buildTerminatorOp);
-/// Templated version that fills the generates the provided operation type.
-template <typename OpTy>
-void ensureRegionTerminator(Region &region, Builder &builder, Location loc) {
-  ensureRegionTerminator(region, loc, [&] {
-    OperationState state(loc, OpTy::getOperationName());
-    OpTy::build(&builder, &state);
-    return Operation::create(state);
-  });
-}
-} // namespace impl
 } // end namespace mlir
 
 #endif
index dbc3e93..e730ba5 100644 (file)
@@ -965,26 +965,6 @@ void AffineDmaWaitOp::getCanonicalizationPatterns(
 // AffineForOp
 //===----------------------------------------------------------------------===//
 
-// Check that if a "block" has a terminator, it is an `AffineTerminatorOp`.
-static LogicalResult checkHasAffineTerminator(OpState &op, Block &block) {
-  if (block.empty() || isa<AffineTerminatorOp>(block.back()))
-    return success();
-
-  return op.emitOpError("expects regions to end with '" +
-                        AffineTerminatorOp::getOperationName() + "'")
-             .attachNote()
-         << "in custom textual format, the absence of terminator implies '"
-         << AffineTerminatorOp::getOperationName() << "'";
-}
-
-// Insert `affine.terminator` at the end of the region's only block if it does
-// not have a terminator already.  If the region is empty, insert a new block
-// first.
-static void ensureAffineTerminator(Region &region, Builder &builder,
-                                   Location loc) {
-  impl::ensureRegionTerminator<AffineTerminatorOp>(region, builder, loc);
-}
-
 void AffineForOp::build(Builder *builder, OperationState *result,
                         ArrayRef<Value *> lbOperands, AffineMap lbMap,
                         ArrayRef<Value *> ubOperands, AffineMap ubMap,
@@ -1017,7 +997,7 @@ void AffineForOp::build(Builder *builder, OperationState *result,
   Block *body = new Block();
   body->addArgument(IndexType::get(builder->getContext()));
   bodyRegion->push_back(body);
-  ensureAffineTerminator(*bodyRegion, *builder, result->location);
+  ensureTerminator(*bodyRegion, *builder, result->location);
 
   // Set the operands list as resizable so that we can freely modify the bounds.
   result->setOperandListToResizable();
@@ -1031,12 +1011,6 @@ void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb,
 }
 
 static LogicalResult verify(AffineForOp op) {
-  auto &bodyRegion = op.region();
-
-  // The body region must contain a single basic block.
-  if (bodyRegion.empty() || std::next(bodyRegion.begin()) != bodyRegion.end())
-    return op.emitOpError("expected body region to have a single block");
-
   // Check that the body defines as single block argument for the induction
   // variable.
   auto *body = op.getBody();
@@ -1046,9 +1020,6 @@ static LogicalResult verify(AffineForOp op) {
         "expected body to have a single index argument for the "
         "induction variable");
 
-  if (failed(checkHasAffineTerminator(op, *body)))
-    return failure();
-
   // Verify that there are enough operands for the bounds.
   AffineMap lowerBoundMap = op.getLowerBoundMap(),
             upperBoundMap = op.getUpperBoundMap();
@@ -1198,7 +1169,7 @@ ParseResult parseAffineForOp(OpAsmParser *parser, OperationState *result) {
   if (parser->parseRegion(*body, inductionVariable, builder.getIndexType()))
     return failure();
 
-  ensureAffineTerminator(*body, builder, result->location);
+  AffineForOp::ensureTerminator(*body, builder, result->location);
 
   // Parse the optional attribute list.
   if (parser->parseOptionalAttributeDict(result->attributes))
@@ -1486,17 +1457,6 @@ static LogicalResult verify(AffineIfOp op) {
 
   // Verify that the entry of each child region does not have arguments.
   for (auto &region : op.getOperation()->getRegions()) {
-    if (region.empty())
-      continue;
-
-    // TODO(riverriddle) We currently do not allow multiple blocks in child
-    // regions.
-    if (std::next(region.begin()) != region.end())
-      return op.emitOpError(
-          "expects only one block per 'then' or 'else' regions");
-    if (failed(checkHasAffineTerminator(op, region.front())))
-      return failure();
-
     for (auto &b : region)
       if (b.getNumArguments() != 0)
         return op.emitOpError(
@@ -1534,13 +1494,15 @@ ParseResult parseAffineIfOp(OpAsmParser *parser, OperationState *result) {
   // Parse the 'then' region.
   if (parser->parseRegion(*thenRegion, {}, {}))
     return failure();
-  ensureAffineTerminator(*thenRegion, parser->getBuilder(), result->location);
+  AffineIfOp::ensureTerminator(*thenRegion, parser->getBuilder(),
+                               result->location);
 
   // If we find an 'else' keyword then parse the 'else' region.
   if (!parser->parseOptionalKeyword("else")) {
     if (parser->parseRegion(*elseRegion, {}, {}))
       return failure();
-    ensureAffineTerminator(*elseRegion, parser->getBuilder(), result->location);
+    AffineIfOp::ensureTerminator(*elseRegion, parser->getBuilder(),
+                                 result->location);
   }
 
   // Parse the optional attribute list.
index b8ad79a..63e0da0 100644 (file)
@@ -49,28 +49,11 @@ LoopOpsDialect::LoopOpsDialect(MLIRContext *context)
 // ForOp
 //===----------------------------------------------------------------------===//
 
-// Check that if a "block" is not empty, it has a `TerminatorOp` terminator.
-static LogicalResult checkHasTerminator(OpState &op, Block &block) {
-  if (block.empty() || isa<TerminatorOp>(block.back()))
-    return success();
-
-  return op.emitOpError("expects regions to end with '" +
-                        TerminatorOp::getOperationName() + "'")
-             .attachNote()
-         << "in custom textual format, the absence of terminator implies '"
-         << TerminatorOp::getOperationName() << "'";
-}
-
-void mlir::loop::ensureLoopTerminator(Region &region, Builder &builder,
-                                      Location loc) {
-  impl::ensureRegionTerminator<TerminatorOp>(region, builder, loc);
-}
-
 void ForOp::build(Builder *builder, OperationState *result, Value *lb,
                   Value *ub, Value *step) {
   result->addOperands({lb, ub, step});
   Region *bodyRegion = result->addRegion();
-  ensureLoopTerminator(*bodyRegion, *builder, result->location);
+  ForOp::ensureTerminator(*bodyRegion, *builder, result->location);
   bodyRegion->front().addArgument(builder->getIndexType());
 }
 
@@ -86,8 +69,6 @@ LogicalResult verify(ForOp op) {
       !body->getArgument(0)->getType().isIndex())
     return op.emitOpError("expected body to have a single index argument for "
                           "the induction variable");
-  if (failed(checkHasTerminator(op, *body)))
-    return failure();
   return success();
 }
 
@@ -123,7 +104,7 @@ static ParseResult parseForOp(OpAsmParser *parser, OperationState *result) {
   if (parser->parseRegion(*body, inductionVariable, indexType))
     return failure();
 
-  ensureLoopTerminator(*body, builder, result->location);
+  ForOp::ensureTerminator(*body, builder, result->location);
 
   // Parse the optional attribute list.
   if (parser->parseOptionalAttributeDict(result->attributes))
@@ -150,9 +131,9 @@ void IfOp::build(Builder *builder, OperationState *result, Value *cond,
   result->addOperands(cond);
   Region *thenRegion = result->addRegion();
   Region *elseRegion = result->addRegion();
-  ensureLoopTerminator(*thenRegion, *builder, result->location);
+  IfOp::ensureTerminator(*thenRegion, *builder, result->location);
   if (withElseRegion)
-    ensureLoopTerminator(*elseRegion, *builder, result->location);
+    IfOp::ensureTerminator(*elseRegion, *builder, result->location);
 }
 
 static LogicalResult verify(IfOp op) {
@@ -161,13 +142,6 @@ static LogicalResult verify(IfOp op) {
     if (region.empty())
       continue;
 
-    // TODO(riverriddle) We currently do not allow multiple blocks in child
-    // regions.
-    if (std::next(region.begin()) != region.end())
-      return op.emitOpError("expected one block per 'then' or 'else' regions");
-    if (failed(checkHasTerminator(op, region.front())))
-      return failure();
-
     for (auto &b : region)
       if (b.getNumArguments() != 0)
         return op.emitOpError(
@@ -192,13 +166,13 @@ static ParseResult parseIfOp(OpAsmParser *parser, OperationState *result) {
   // Parse the 'then' region.
   if (parser->parseRegion(*thenRegion, {}, {}))
     return failure();
-  ensureLoopTerminator(*thenRegion, parser->getBuilder(), result->location);
+  IfOp::ensureTerminator(*thenRegion, parser->getBuilder(), result->location);
 
   // If we find an 'else' keyword then parse the 'else' region.
   if (!parser->parseOptionalKeyword("else")) {
     if (parser->parseRegion(*elseRegion, {}, {}))
       return failure();
-    ensureLoopTerminator(*elseRegion, parser->getBuilder(), result->location);
+    IfOp::ensureTerminator(*elseRegion, parser->getBuilder(), result->location);
   }
 
   // Parse the optional attribute list.
index 2991d12..36e0ebb 100644 (file)
@@ -729,7 +729,10 @@ func @std_for_step(%arg0: f32, %arg1: index) {
 func @std_for_step_nonnegative(%arg0: index) {
   // expected-error@+2 {{constant step operand must be nonnegative}}
   %c0 = constant 0 : index
-  "loop.for"(%arg0, %arg0, %c0) ({^bb0:}) : (index, index, index) -> ()
+  "loop.for"(%arg0, %arg0, %c0) ({
+    ^bb0(%arg1: index):
+      "loop.terminator"() : () -> ()
+  }) : (index, index, index) -> ()
   return
 }
 
@@ -747,7 +750,7 @@ func @std_for_one_region(%arg0: index) {
 // -----
 
 func @std_for_single_block(%arg0: index) {
-  // expected-error@+1 {{region #0 ('region') failed to verify constraint: region with 1 blocks}}
+  // expected-error@+1 {{expects region #0 to have 0 or 1 blocks}}
   "loop.for"(%arg0, %arg0, %arg0) (
     {
     ^bb1:
@@ -791,7 +794,7 @@ func @std_if_more_than_2_regions(%arg0: i1) {
 // -----
 
 func @std_if_not_one_block_per_region(%arg0: i1) {
-  // expected-error@+1 {{region #0 ('thenRegion') failed to verify constraint: region with 1 blocks}}
+  // expected-error@+1 {{expects region #0 to have 0 or 1 blocks}}
   "loop.if"(%arg0) ({
     ^bb0:
       "loop.terminator"() : () -> ()