From: River Riddle Date: Thu, 5 Mar 2020 20:39:46 +0000 (-0800) Subject: [mlir] Add traits for verifying the number of successors and providing relevant acces... X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c0fd5e657e5d38a480d65b4e8f6f7a835afd6c76;p=platform%2Fupstream%2Fllvm.git [mlir] Add traits for verifying the number of successors and providing relevant accessors. This allows for simplifying OpDefGen, as well providing specializing accessors for the different successor counts. This mirrors the existing traits for operands and results. Differential Revision: https://reviews.llvm.org/D75313 --- diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index efbcf0a..1a81321 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -381,6 +381,10 @@ LogicalResult verifyResultsAreBoolLike(Operation *op); LogicalResult verifyResultsAreFloatLike(Operation *op); LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op); LogicalResult verifyIsTerminator(Operation *op); +LogicalResult verifyZeroSuccessor(Operation *op); +LogicalResult verifyOneSuccessor(Operation *op); +LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors); +LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors); LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); } // namespace impl @@ -410,6 +414,9 @@ protected: } }; +//===----------------------------------------------------------------------===// +// Operand Traits + namespace detail { /// Utility trait base that provides accessors for derived traits that have /// multiple operands. @@ -522,6 +529,9 @@ template class VariadicOperands : public detail::MultiOperandTraitBase {}; +//===----------------------------------------------------------------------===// +// Result Traits + /// This class provides return value APIs for ops that are known to have /// zero results. template @@ -644,6 +654,123 @@ template class VariadicResults : public detail::MultiResultTraitBase {}; +//===----------------------------------------------------------------------===// +// Terminator Traits + +/// This class provides the API for ops that are known to be terminators. +template +class IsTerminator : public TraitBase { +public: + static AbstractOperation::OperationProperties getTraitProperties() { + return static_cast( + OperationProperty::Terminator); + } + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyIsTerminator(op); + } + + unsigned getNumSuccessorOperands(unsigned index) { + return this->getOperation()->getNumSuccessorOperands(index); + } +}; + +/// This class provides verification for ops that are known to have zero +/// successors. +template +class ZeroSuccessor : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyZeroSuccessor(op); + } +}; + +namespace detail { +/// Utility trait base that provides accessors for derived traits that have +/// multiple successors. +template class TraitType> +struct MultiSuccessorTraitBase : public TraitBase { + using succ_iterator = Operation::succ_iterator; + using succ_range = SuccessorRange; + + /// Return the number of successors. + unsigned getNumSuccessors() { + return this->getOperation()->getNumSuccessors(); + } + + /// Return the successor at `index`. + Block *getSuccessor(unsigned i) { + return this->getOperation()->getSuccessor(i); + } + + /// Set the successor at `index`. + void setSuccessor(Block *block, unsigned i) { + return this->getOperation()->setSuccessor(block, i); + } + + /// Successor iterator access. + succ_iterator succ_begin() { return this->getOperation()->succ_begin(); } + succ_iterator succ_end() { return this->getOperation()->succ_end(); } + succ_range getSuccessors() { return this->getOperation()->getSuccessors(); } +}; +} // end namespace detail + +/// This class provides APIs for ops that are known to have a single successor. +template +class OneSuccessor : public TraitBase { +public: + Block *getSuccessor() { return this->getOperation()->getSuccessor(0); } + void setSuccessor(Block *succ) { + this->getOperation()->setSuccessor(succ, 0); + } + + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyOneSuccessor(op); + } +}; + +/// This class provides the API for ops that are known to have a specified +/// number of successors. +template +class NSuccessors { +public: + static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2"); + + template + class Impl : public detail::MultiSuccessorTraitBase::Impl> { + public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyNSuccessors(op, N); + } + }; +}; + +/// This class provides APIs for ops that are known to have at least a specified +/// number of successors. +template +class AtLeastNSuccessors { +public: + template + class Impl + : public detail::MultiSuccessorTraitBase::Impl> { + public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyAtLeastNSuccessors(op, N); + } + }; +}; + +/// This class provides the API for ops which have an unknown number of +/// successors. +template +class VariadicSuccessors + : public detail::MultiSuccessorTraitBase { +}; + +//===----------------------------------------------------------------------===// +// Misc Traits + /// This class provides verification for ops that are known to have the same /// operand shape: all operands are scalars, vectors/tensors of the same /// shape. @@ -789,41 +916,6 @@ public: } }; -/// This class provides the API for ops that are known to be terminators. -template -class IsTerminator : public TraitBase { -public: - static AbstractOperation::OperationProperties getTraitProperties() { - return static_cast( - OperationProperty::Terminator); - } - static LogicalResult verifyTrait(Operation *op) { - return impl::verifyIsTerminator(op); - } - - unsigned getNumSuccessors() { - return this->getOperation()->getNumSuccessors(); - } - unsigned getNumSuccessorOperands(unsigned index) { - return this->getOperation()->getNumSuccessorOperands(index); - } - - Block *getSuccessor(unsigned index) { - return this->getOperation()->getSuccessor(index); - } - - void setSuccessor(Block *block, unsigned index) { - return this->getOperation()->setSuccessor(block, index); - } - - void addSuccessorOperand(unsigned index, Value value) { - return this->getOperation()->addSuccessorOperand(index, value); - } - void addSuccessorOperands(unsigned index, ArrayRef values) { - return this->getOperation()->addSuccessorOperand(index, values); - } -}; - /// This class provides the API for ops that are known to be isolated from /// above. template diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 2bf1969..6a63867 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1894,7 +1894,7 @@ static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) { return false; auto branchOp = dyn_cast(srcBlock.back()); - return branchOp && branchOp.getSuccessor(0) == &dstBlock; + return branchOp && branchOp.getSuccessor() == &dstBlock; } static LogicalResult verify(spirv::LoopOp loopOp) { diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 2131497..1059e66 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -477,9 +477,9 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern { }; } // end anonymous namespace. -Block *BranchOp::getDest() { return getSuccessor(0); } +Block *BranchOp::getDest() { return getSuccessor(); } -void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); } +void BranchOp::setDest(Block *block) { return setSuccessor(block); } void BranchOp::eraseOperand(unsigned index) { getOperation()->eraseSuccessorOperand(0, index); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 49185eb..bfd4b40 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -942,6 +942,14 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { return success(); } +LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { + Block *block = op->getBlock(); + // Verify that the operation is at the end of the respective parent block. + if (!block || &block->back() != op) + return op->emitOpError("must be the last operation in the parent block"); + return success(); +} + static LogicalResult verifySuccessor(Operation *op, unsigned succNo) { Operation::operand_range operands = op->getSuccessorOperands(succNo); unsigned operandCount = op->getNumSuccessorOperands(succNo); @@ -976,18 +984,40 @@ static LogicalResult verifyTerminatorSuccessors(Operation *op) { return success(); } -LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { - Block *block = op->getBlock(); - // Verify that the operation is at the end of the respective parent block. - if (!block || &block->back() != op) - return op->emitOpError("must be the last operation in the parent block"); - - // Verify the state of the successor blocks. - if (op->getNumSuccessors() != 0 && failed(verifyTerminatorSuccessors(op))) - return failure(); +LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { + if (op->getNumSuccessors() != 0) { + return op->emitOpError("requires 0 successors but found ") + << op->getNumSuccessors(); + } return success(); } +LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { + if (op->getNumSuccessors() != 1) { + return op->emitOpError("requires 1 successor but found ") + << op->getNumSuccessors(); + } + return verifyTerminatorSuccessors(op); +} +LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, + unsigned numSuccessors) { + if (op->getNumSuccessors() != numSuccessors) { + return op->emitOpError("requires ") + << numSuccessors << " successors but found " + << op->getNumSuccessors(); + } + return verifyTerminatorSuccessors(op); +} +LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, + unsigned numSuccessors) { + if (op->getNumSuccessors() < numSuccessors) { + return op->emitOpError("requires at least ") + << numSuccessors << " successors but found " + << op->getNumSuccessors(); + } + return verifyTerminatorSuccessors(op); +} + LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { for (auto resultType : op->getResultTypes()) { auto elementType = getTensorOrVectorElementType(resultType); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index d45f05b..19bf61d 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -356,7 +356,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst, // Emit branches. We need to look up the remapped blocks and ignore the block // arguments that were transformed into PHI nodes. if (auto brOp = dyn_cast(opInst)) { - builder.CreateBr(blockMapping[brOp.getSuccessor(0)]); + builder.CreateBr(blockMapping[brOp.getSuccessor()]); return success(); } if (auto condbrOp = dyn_cast(opInst)) { diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir index cc7e09f..a6dcac1 100644 --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -24,7 +24,7 @@ func @branch_argument() -> () { // ----- func @missing_accessor() -> () { - // expected-error @+1 {{has incorrect number of successors: expected 1 but found 0}} + // expected-error @+1 {{requires 1 successor but found 0}} spv.Branch } @@ -32,7 +32,7 @@ func @missing_accessor() -> () { func @wrong_accessor_count() -> () { %true = spv.constant true - // expected-error @+1 {{incorrect number of successors: expected 1 but found 2}} + // expected-error @+1 {{requires 1 successor but found 2}} "spv.Branch"()[^one, ^two] : () -> () ^one: spv.Return @@ -116,7 +116,7 @@ func @wrong_condition_type() -> () { func @wrong_accessor_count() -> () { %true = spv.constant true - // expected-error @+1 {{incorrect number of successors: expected 2 but found 1}} + // expected-error @+1 {{requires 2 successors but found 1}} "spv.BranchConditional"(%true)[^one] : (i1) -> () ^one: spv.Return diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td index 61f0c56..f07f995 100644 --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -54,7 +54,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> { // CHECK: ArrayRef tblgen_operands; // CHECK: }; -// CHECK: class AOp : public Op::Impl, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl +// CHECK: class AOp : public Op::Impl, OpTrait::ZeroSuccessor, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl // CHECK: public: // CHECK: using Op::Op; // CHECK: using OperandAdaptor = AOpOperandAdaptor; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 8c6ba60..ebd82f9 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1390,26 +1390,8 @@ void OpEmitter::genRegionVerifier(OpMethodBody &body) { } void OpEmitter::genSuccessorVerifier(OpMethodBody &body) { - unsigned numSuccessors = op.getNumSuccessors(); - - const char *checkSuccessorSizeCode = R"( - if (this->getOperation()->getNumSuccessors() {0} {1}) { - return emitOpError("has incorrect number of successors: expected{2} {1}" - " but found ") - << this->getOperation()->getNumSuccessors(); - } - )"; - - // Verify this op has the correct number of successors. - unsigned numVariadicSuccessors = op.getNumVariadicSuccessors(); - if (numVariadicSuccessors == 0) { - body << formatv(checkSuccessorSizeCode, "!=", numSuccessors, ""); - } else if (numVariadicSuccessors != numSuccessors) { - body << formatv(checkSuccessorSizeCode, "<", - numSuccessors - numVariadicSuccessors, " at least"); - } - // If we have no successors, there is nothing more to do. + unsigned numSuccessors = op.getNumSuccessors(); if (numSuccessors == 0) return; @@ -1441,31 +1423,44 @@ void OpEmitter::genSuccessorVerifier(OpMethodBody &body) { body << " }\n"; } +/// Add a size count trait to the given operation class. +static void addSizeCountTrait(OpClass &opClass, StringRef traitKind, + int numNonVariadic, int numVariadic) { + if (numVariadic != 0) { + if (numNonVariadic == numVariadic) + opClass.addTrait("OpTrait::Variadic" + traitKind + "s"); + else + opClass.addTrait("OpTrait::AtLeastN" + traitKind + "s<" + + Twine(numNonVariadic - numVariadic) + ">::Impl"); + return; + } + switch (numNonVariadic) { + case 0: + opClass.addTrait("OpTrait::Zero" + traitKind); + break; + case 1: + opClass.addTrait("OpTrait::One" + traitKind); + break; + default: + opClass.addTrait("OpTrait::N" + traitKind + "s<" + Twine(numNonVariadic) + + ">::Impl"); + break; + } +} + void OpEmitter::genTraits() { int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariadicResults(); // Add return size trait. - if (numVariadicResults != 0) { - if (numResults == numVariadicResults) - opClass.addTrait("OpTrait::VariadicResults"); - else - opClass.addTrait("OpTrait::AtLeastNResults<" + - Twine(numResults - numVariadicResults) + ">::Impl"); - } else { - switch (numResults) { - case 0: - opClass.addTrait("OpTrait::ZeroResult"); - break; - case 1: - opClass.addTrait("OpTrait::OneResult"); - break; - default: - opClass.addTrait("OpTrait::NResults<" + Twine(numResults) + ">::Impl"); - break; - } - } + addSizeCountTrait(opClass, "Result", numResults, numVariadicResults); + + // Add successor size trait. + unsigned numSuccessors = op.getNumSuccessors(); + unsigned numVariadicSuccessors = op.getNumVariadicSuccessors(); + addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors); + // Add the native and interface traits. for (const auto &trait : op.getTraits()) { if (auto opTrait = dyn_cast(&trait)) opClass.addTrait(opTrait->getTrait());