[mlir] Add traits for verifying the number of successors and providing relevant acces...
authorRiver Riddle <riddleriver@gmail.com>
Thu, 5 Mar 2020 20:39:46 +0000 (12:39 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 5 Mar 2020 20:49:59 +0000 (12:49 -0800)
This allows for simplifying OpDefGen, as well providing specializing accessors for the different successor counts. This mirrors the existing traits for operands and results.

Differential Revision: https://reviews.llvm.org/D75313

mlir/include/mlir/IR/OpDefinition.h
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Dialect/SPIRV/control-flow-ops.mlir
mlir/test/mlir-tblgen/op-decl.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index efbcf0a..1a81321 100644 (file)
@@ -381,6 +381,10 @@ LogicalResult verifyResultsAreBoolLike(Operation *op);
 LogicalResult verifyResultsAreFloatLike(Operation *op);
 LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
 LogicalResult verifyIsTerminator(Operation *op);
+LogicalResult verifyZeroSuccessor(Operation *op);
+LogicalResult verifyOneSuccessor(Operation *op);
+LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
+LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
 LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
 LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
 } // namespace impl
@@ -410,6 +414,9 @@ protected:
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Operand Traits
+
 namespace detail {
 /// Utility trait base that provides accessors for derived traits that have
 /// multiple operands.
@@ -522,6 +529,9 @@ template <typename ConcreteType>
 class VariadicOperands
     : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
 
+//===----------------------------------------------------------------------===//
+// Result Traits
+
 /// This class provides return value APIs for ops that are known to have
 /// zero results.
 template <typename ConcreteType>
@@ -644,6 +654,123 @@ template <typename ConcreteType>
 class VariadicResults
     : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
 
+//===----------------------------------------------------------------------===//
+// Terminator Traits
+
+/// This class provides the API for ops that are known to be terminators.
+template <typename ConcreteType>
+class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
+public:
+  static AbstractOperation::OperationProperties getTraitProperties() {
+    return static_cast<AbstractOperation::OperationProperties>(
+        OperationProperty::Terminator);
+  }
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyIsTerminator(op);
+  }
+
+  unsigned getNumSuccessorOperands(unsigned index) {
+    return this->getOperation()->getNumSuccessorOperands(index);
+  }
+};
+
+/// This class provides verification for ops that are known to have zero
+/// successors.
+template <typename ConcreteType>
+class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyZeroSuccessor(op);
+  }
+};
+
+namespace detail {
+/// Utility trait base that provides accessors for derived traits that have
+/// multiple successors.
+template <typename ConcreteType, template <typename> class TraitType>
+struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
+  using succ_iterator = Operation::succ_iterator;
+  using succ_range = SuccessorRange;
+
+  /// Return the number of successors.
+  unsigned getNumSuccessors() {
+    return this->getOperation()->getNumSuccessors();
+  }
+
+  /// Return the successor at `index`.
+  Block *getSuccessor(unsigned i) {
+    return this->getOperation()->getSuccessor(i);
+  }
+
+  /// Set the successor at `index`.
+  void setSuccessor(Block *block, unsigned i) {
+    return this->getOperation()->setSuccessor(block, i);
+  }
+
+  /// Successor iterator access.
+  succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
+  succ_iterator succ_end() { return this->getOperation()->succ_end(); }
+  succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
+};
+} // end namespace detail
+
+/// This class provides APIs for ops that are known to have a single successor.
+template <typename ConcreteType>
+class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
+public:
+  Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
+  void setSuccessor(Block *succ) {
+    this->getOperation()->setSuccessor(succ, 0);
+  }
+
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyOneSuccessor(op);
+  }
+};
+
+/// This class provides the API for ops that are known to have a specified
+/// number of successors.
+template <unsigned N>
+class NSuccessors {
+public:
+  static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2");
+
+  template <typename ConcreteType>
+  class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
+                                                      NSuccessors<N>::Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      return impl::verifyNSuccessors(op, N);
+    }
+  };
+};
+
+/// This class provides APIs for ops that are known to have at least a specified
+/// number of successors.
+template <unsigned N>
+class AtLeastNSuccessors {
+public:
+  template <typename ConcreteType>
+  class Impl
+      : public detail::MultiSuccessorTraitBase<ConcreteType,
+                                               AtLeastNSuccessors<N>::Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      return impl::verifyAtLeastNSuccessors(op, N);
+    }
+  };
+};
+
+/// This class provides the API for ops which have an unknown number of
+/// successors.
+template <typename ConcreteType>
+class VariadicSuccessors
+    : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
+};
+
+//===----------------------------------------------------------------------===//
+// Misc Traits
+
 /// This class provides verification for ops that are known to have the same
 /// operand shape: all operands are scalars, vectors/tensors of the same
 /// shape.
@@ -789,41 +916,6 @@ public:
   }
 };
 
-/// This class provides the API for ops that are known to be terminators.
-template <typename ConcreteType>
-class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
-public:
-  static AbstractOperation::OperationProperties getTraitProperties() {
-    return static_cast<AbstractOperation::OperationProperties>(
-        OperationProperty::Terminator);
-  }
-  static LogicalResult verifyTrait(Operation *op) {
-    return impl::verifyIsTerminator(op);
-  }
-
-  unsigned getNumSuccessors() {
-    return this->getOperation()->getNumSuccessors();
-  }
-  unsigned getNumSuccessorOperands(unsigned index) {
-    return this->getOperation()->getNumSuccessorOperands(index);
-  }
-
-  Block *getSuccessor(unsigned index) {
-    return this->getOperation()->getSuccessor(index);
-  }
-
-  void setSuccessor(Block *block, unsigned index) {
-    return this->getOperation()->setSuccessor(block, index);
-  }
-
-  void addSuccessorOperand(unsigned index, Value value) {
-    return this->getOperation()->addSuccessorOperand(index, value);
-  }
-  void addSuccessorOperands(unsigned index, ArrayRef<Value> values) {
-    return this->getOperation()->addSuccessorOperand(index, values);
-  }
-};
-
 /// This class provides the API for ops that are known to be isolated from
 /// above.
 template <typename ConcreteType>
index 2bf1969..6a63867 100644 (file)
@@ -1894,7 +1894,7 @@ static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
     return false;
 
   auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
-  return branchOp && branchOp.getSuccessor(0) == &dstBlock;
+  return branchOp && branchOp.getSuccessor() == &dstBlock;
 }
 
 static LogicalResult verify(spirv::LoopOp loopOp) {
index 2131497..1059e66 100644 (file)
@@ -477,9 +477,9 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
 };
 } // end anonymous namespace.
 
-Block *BranchOp::getDest() { return getSuccessor(0); }
+Block *BranchOp::getDest() { return getSuccessor(); }
 
-void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); }
+void BranchOp::setDest(Block *block) { return setSuccessor(block); }
 
 void BranchOp::eraseOperand(unsigned index) {
   getOperation()->eraseSuccessorOperand(0, index);
index 49185eb..bfd4b40 100644 (file)
@@ -942,6 +942,14 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
   return success();
 }
 
+LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
+  Block *block = op->getBlock();
+  // Verify that the operation is at the end of the respective parent block.
+  if (!block || &block->back() != op)
+    return op->emitOpError("must be the last operation in the parent block");
+  return success();
+}
+
 static LogicalResult verifySuccessor(Operation *op, unsigned succNo) {
   Operation::operand_range operands = op->getSuccessorOperands(succNo);
   unsigned operandCount = op->getNumSuccessorOperands(succNo);
@@ -976,18 +984,40 @@ static LogicalResult verifyTerminatorSuccessors(Operation *op) {
   return success();
 }
 
-LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
-  Block *block = op->getBlock();
-  // Verify that the operation is at the end of the respective parent block.
-  if (!block || &block->back() != op)
-    return op->emitOpError("must be the last operation in the parent block");
-
-  // Verify the state of the successor blocks.
-  if (op->getNumSuccessors() != 0 && failed(verifyTerminatorSuccessors(op)))
-    return failure();
+LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) {
+  if (op->getNumSuccessors() != 0) {
+    return op->emitOpError("requires 0 successors but found ")
+           << op->getNumSuccessors();
+  }
   return success();
 }
 
+LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) {
+  if (op->getNumSuccessors() != 1) {
+    return op->emitOpError("requires 1 successor but found ")
+           << op->getNumSuccessors();
+  }
+  return verifyTerminatorSuccessors(op);
+}
+LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op,
+                                               unsigned numSuccessors) {
+  if (op->getNumSuccessors() != numSuccessors) {
+    return op->emitOpError("requires ")
+           << numSuccessors << " successors but found "
+           << op->getNumSuccessors();
+  }
+  return verifyTerminatorSuccessors(op);
+}
+LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op,
+                                                      unsigned numSuccessors) {
+  if (op->getNumSuccessors() < numSuccessors) {
+    return op->emitOpError("requires at least ")
+           << numSuccessors << " successors but found "
+           << op->getNumSuccessors();
+  }
+  return verifyTerminatorSuccessors(op);
+}
+
 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
   for (auto resultType : op->getResultTypes()) {
     auto elementType = getTensorOrVectorElementType(resultType);
index d45f05b..19bf61d 100644 (file)
@@ -356,7 +356,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
   // Emit branches.  We need to look up the remapped blocks and ignore the block
   // arguments that were transformed into PHI nodes.
   if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
-    builder.CreateBr(blockMapping[brOp.getSuccessor(0)]);
+    builder.CreateBr(blockMapping[brOp.getSuccessor()]);
     return success();
   }
   if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
index cc7e09f..a6dcac1 100644 (file)
@@ -24,7 +24,7 @@ func @branch_argument() -> () {
 // -----
 
 func @missing_accessor() -> () {
-  // expected-error @+1 {{has incorrect number of successors: expected 1 but found 0}}
+  // expected-error @+1 {{requires 1 successor but found 0}}
   spv.Branch
 }
 
@@ -32,7 +32,7 @@ func @missing_accessor() -> () {
 
 func @wrong_accessor_count() -> () {
   %true = spv.constant true
-  // expected-error @+1 {{incorrect number of successors: expected 1 but found 2}}
+  // expected-error @+1 {{requires 1 successor but found 2}}
   "spv.Branch"()[^one, ^two] : () -> ()
 ^one:
   spv.Return
@@ -116,7 +116,7 @@ func @wrong_condition_type() -> () {
 
 func @wrong_accessor_count() -> () {
   %true = spv.constant true
-  // expected-error @+1 {{incorrect number of successors: expected 2 but found 1}}
+  // expected-error @+1 {{requires 2 successors but found 1}}
   "spv.BranchConditional"(%true)[^one] : (i1) -> ()
 ^one:
   spv.Return
index 61f0c56..f07f995 100644 (file)
@@ -54,7 +54,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> {
 // CHECK:   ArrayRef<Value> tblgen_operands;
 // CHECK: };
 
-// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl
+// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl
 // CHECK: public:
 // CHECK:   using Op::Op;
 // CHECK:   using OperandAdaptor = AOpOperandAdaptor;
index 8c6ba60..ebd82f9 100644 (file)
@@ -1390,26 +1390,8 @@ void OpEmitter::genRegionVerifier(OpMethodBody &body) {
 }
 
 void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
-  unsigned numSuccessors = op.getNumSuccessors();
-
-  const char *checkSuccessorSizeCode = R"(
-  if (this->getOperation()->getNumSuccessors() {0} {1}) {
-    return emitOpError("has incorrect number of successors: expected{2} {1}"
-                       " but found ")
-             << this->getOperation()->getNumSuccessors();
-  }
-  )";
-
-  // Verify this op has the correct number of successors.
-  unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
-  if (numVariadicSuccessors == 0) {
-    body << formatv(checkSuccessorSizeCode, "!=", numSuccessors, "");
-  } else if (numVariadicSuccessors != numSuccessors) {
-    body << formatv(checkSuccessorSizeCode, "<",
-                    numSuccessors - numVariadicSuccessors, " at least");
-  }
-
   // If we have no successors, there is nothing more to do.
+  unsigned numSuccessors = op.getNumSuccessors();
   if (numSuccessors == 0)
     return;
 
@@ -1441,31 +1423,44 @@ void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
   body << "  }\n";
 }
 
+/// Add a size count trait to the given operation class.
+static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
+                              int numNonVariadic, int numVariadic) {
+  if (numVariadic != 0) {
+    if (numNonVariadic == numVariadic)
+      opClass.addTrait("OpTrait::Variadic" + traitKind + "s");
+    else
+      opClass.addTrait("OpTrait::AtLeastN" + traitKind + "s<" +
+                       Twine(numNonVariadic - numVariadic) + ">::Impl");
+    return;
+  }
+  switch (numNonVariadic) {
+  case 0:
+    opClass.addTrait("OpTrait::Zero" + traitKind);
+    break;
+  case 1:
+    opClass.addTrait("OpTrait::One" + traitKind);
+    break;
+  default:
+    opClass.addTrait("OpTrait::N" + traitKind + "s<" + Twine(numNonVariadic) +
+                     ">::Impl");
+    break;
+  }
+}
+
 void OpEmitter::genTraits() {
   int numResults = op.getNumResults();
   int numVariadicResults = op.getNumVariadicResults();
 
   // Add return size trait.
-  if (numVariadicResults != 0) {
-    if (numResults == numVariadicResults)
-      opClass.addTrait("OpTrait::VariadicResults");
-    else
-      opClass.addTrait("OpTrait::AtLeastNResults<" +
-                       Twine(numResults - numVariadicResults) + ">::Impl");
-  } else {
-    switch (numResults) {
-    case 0:
-      opClass.addTrait("OpTrait::ZeroResult");
-      break;
-    case 1:
-      opClass.addTrait("OpTrait::OneResult");
-      break;
-    default:
-      opClass.addTrait("OpTrait::NResults<" + Twine(numResults) + ">::Impl");
-      break;
-    }
-  }
+  addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
+
+  // Add successor size trait.
+  unsigned numSuccessors = op.getNumSuccessors();
+  unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
+  addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
 
+  // Add the native and interface traits.
   for (const auto &trait : op.getTraits()) {
     if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
       opClass.addTrait(opTrait->getTrait());