From 13c6e419ca68cd4a5434f4349db5433395e6fbf0 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 25 Nov 2019 17:26:16 -0800 Subject: [PATCH] Add support for AttrSizedOperandSegments/AttrSizedResultSegments Certain operations can have multiple variadic operands and their size relationship is not always known statically. For such cases, we need a per-op-instance specification to divide the operands into logical groups or segments. This can be modeled by attributes. This CL introduces C++ trait AttrSizedOperandSegments for operands and AttrSizedResultSegments for results. The C++ trait just guarantees such size attribute has the correct type (1D vector) and values (non-negative), etc. It serves as the basis for ODS sugaring that with ODS argument declarations we can further verify the number of elements match the number of ODS-declared operands and we can generate handy getter methods. PiperOrigin-RevId: 282467075 --- mlir/include/mlir/IR/OpBase.td | 17 +++- mlir/include/mlir/IR/OpDefinition.h | 39 ++++++++ mlir/include/mlir/TableGen/Operator.h | 4 +- mlir/lib/IR/Operation.cpp | 41 ++++++++ mlir/lib/TableGen/Operator.cpp | 12 +-- mlir/test/IR/traits.mlir | 100 ++++++++++++++++++++ mlir/test/lib/TestDialect/TestOps.td | 24 +++++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 141 +++++++++++++++++++++++----- mlir/tools/mlir-tblgen/RewriterGen.cpp | 6 +- 9 files changed, 345 insertions(+), 39 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 314acf6..f81063f 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1307,8 +1307,7 @@ class NativeOpTrait : OpTrait { // the value in `prop` as the trait name and the value in `params` as // parameters to construct the native trait class name. class ParamNativeOpTrait - : NativeOpTrait::Impl"> { -} + : NativeOpTrait::Impl">; // GenInternalOpTrait is an op trait that does not have direct C++ mapping but // affects op definition generator internals, like how op builders and @@ -1351,7 +1350,7 @@ def Symbol : NativeOpTrait<"Symbol">; // Op defines a symbol table. def SymbolTable : NativeOpTrait<"SymbolTable">; // Op is a terminator. -def Terminator : NativeOpTrait<"IsTerminator">; +def Terminator : NativeOpTrait<"IsTerminator">; // Op's regions have a single block with the specified terminator. class SingleBlockImplicitTerminator @@ -1381,6 +1380,18 @@ def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">; // to have the same array size. def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">; +// Uses an attribute named `operand_segment_sizes` to specify how many actual +// operand each ODS-declared operand (variadic or not) corresponds to. +// This trait is used for ops that have multiple variadic operands but do +// not know statically their size relationship. The attribute must be a 1D +// vector that has the same number of elements as the number of ODS declared +// operands. That means even if some operands are non-variadic, the attribute +// still need to have an element for its size, which is always 1. +def AttrSizedOperandSegments : NativeOpTrait<"AttrSizedOperandSegments">; +// Similar to AttrSizedOperandSegments, but used for results. The attribute +// should be named as `result_segment_sizes`. +def AttrSizedResultSegments : NativeOpTrait<"AttrSizedResultSegments">; + //===----------------------------------------------------------------------===// // OpInterface definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index ebe373c..89ab014 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -386,6 +386,8 @@ LogicalResult verifyResultsAreBoolLike(Operation *op); LogicalResult verifyResultsAreFloatLike(Operation *op); LogicalResult verifyResultsAreIntegerLike(Operation *op); LogicalResult verifyIsTerminator(Operation *op); +LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); +LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); } // namespace impl /// Helper class for implementing traits. Clients are not expected to interact @@ -907,6 +909,43 @@ template struct HasParent { }; }; +/// A trait for operations that have an attribute specifying operand segments. +/// +/// Certain operations can have multiple variadic operands and their size +/// relationship is not always known statically. For such cases, we need +/// a per-op-instance specification to divide the operands into logical groups +/// or segments. This can be modeled by attributes. The attribute will be named +/// as `operand_segment_sizes`. +/// +/// This trait verifies the attribute for specifying operand segments has +/// the correct type (1D vector) and values (non-negative), etc. +template +class AttrSizedOperandSegments + : public TraitBase { +public: + static StringRef getOperandSegmentSizeAttr() { + return "operand_segment_sizes"; + } + + static LogicalResult verifyTrait(Operation *op) { + return ::mlir::OpTrait::impl::verifyOperandSizeAttr( + op, getOperandSegmentSizeAttr()); + } +}; + +/// Similar to AttrSizedOperandSegments but used for results. +template +class AttrSizedResultSegments + : public TraitBase { +public: + static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; } + + static LogicalResult verifyTrait(Operation *op) { + return ::mlir::OpTrait::impl::verifyResultSizeAttr( + op, getResultSegmentSizeAttr()); + } +}; + } // end namespace OpTrait //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 95df9cb..7b636dd 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -136,10 +136,10 @@ public: Argument getArg(int index) const; StringRef getArgName(int index) const; - // Returns true if this op has the given MLIR C++ `trait`. + // Returns the trait wrapper for the given MLIR C++ `trait`. // TODO: We should add a C++ wrapper class for TableGen OpTrait instead of // requiring the raw MLIR trait here. - bool hasTrait(llvm::StringRef trait) const; + const OpTrait *getTrait(llvm::StringRef trait) const; // Returns the number of regions. unsigned getNumRegions() const; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index f53f38d..973b833 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -957,6 +957,47 @@ LogicalResult OpTrait::impl::verifyResultsAreIntegerLike(Operation *op) { return success(); } +static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, + bool isOperand) { + auto sizeAttr = op->getAttrOfType(attrName); + if (!sizeAttr) + return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; + + auto sizeAttrType = sizeAttr.getType().dyn_cast(); + if (!sizeAttrType || sizeAttrType.getRank() != 1) + return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; + + if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) { + return !element.isNonNegative(); + })) + return op->emitOpError("'") + << attrName << "' attribute cannot have negative elements"; + + size_t totalCount = std::accumulate( + sizeAttr.begin(), sizeAttr.end(), 0, + [](unsigned all, APInt one) { return all + one.getZExtValue(); }); + + if (isOperand && totalCount != op->getNumOperands()) + return op->emitOpError("operand count (") + << op->getNumOperands() << ") does not match with the total size (" + << totalCount << ") specified in attribute '" << attrName << "'"; + else if (!isOperand && totalCount != op->getNumResults()) + return op->emitOpError("result count (") + << op->getNumResults() << ") does not match with the total size (" + << totalCount << ") specified in attribute '" << attrName << "'"; + return success(); +} + +LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, + StringRef attrName) { + return verifyValueSizeAttr(op, attrName, /*isOperand=*/true); +} + +LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, + StringRef attrName) { + return verifyValueSizeAttr(op, attrName, /*isOperand=*/false); +} + //===----------------------------------------------------------------------===// // BinaryOp implementation //===----------------------------------------------------------------------===// diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 927f275..4529208 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -145,20 +145,20 @@ StringRef tblgen::Operator::getArgName(int index) const { return argumentValues->getArgName(index)->getValue(); } -bool tblgen::Operator::hasTrait(StringRef trait) const { - for (auto t : getTraits()) { +const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const { + for (const auto &t : traits) { if (auto opTrait = dyn_cast(&t)) { if (opTrait->getTrait() == trait) - return true; + return opTrait; } else if (auto opTrait = dyn_cast(&t)) { if (opTrait->getTrait() == trait) - return true; + return opTrait; } else if (auto opTrait = dyn_cast(&t)) { if (opTrait->getTrait() == trait) - return true; + return opTrait; } } - return false; + return nullptr; } unsigned tblgen::Operator::getNumRegions() const { return regions.size(); } diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir index 69804f1..449a4b3 100644 --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -215,3 +215,103 @@ func @failedSingleBlockImplicitTerminator_missing_terminator() { }) : () -> () func @foo() { } + +// ----- + +func @failedMissingOperandSizeAttr(%arg: i32) { + // expected-error @+1 {{requires 1D vector attribute 'operand_segment_sizes'}} + "test.attr_sized_operands"(%arg, %arg, %arg, %arg) : (i32, i32, i32, i32) -> () +} + +// ----- + +func @failedOperandSizeAttrWrongType(%arg: i32) { + // expected-error @+1 {{requires 1D vector attribute 'operand_segment_sizes'}} + "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[1, 1, 1, 1]>: tensor<4xi32>} : (i32, i32, i32, i32) -> () +} + +// ----- + +func @failedOperandSizeAttrWrongRank(%arg: i32) { + // expected-error @+1 {{requires 1D vector attribute 'operand_segment_sizes'}} + "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[[1, 1], [1, 1]]>: vector<2x2xi32>} : (i32, i32, i32, i32) -> () +} + +// ----- + +func @failedOperandSizeAttrNegativeValue(%arg: i32) { + // expected-error @+1 {{'operand_segment_sizes' attribute cannot have negative elements}} + "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[1, 1, -1, 1]>: vector<4xi32>} : (i32, i32, i32, i32) -> () +} + +// ----- + +func @failedOperandSizeAttrWrongTotalSize(%arg: i32) { + // expected-error @+1 {{operand count (4) does not match with the total size (3) specified in attribute 'operand_segment_sizes'}} + "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[0, 1, 1, 1]>: vector<4xi32>} : (i32, i32, i32, i32) -> () +} + +// ----- + +func @failedOperandSizeAttrWrongCount(%arg: i32) { + // expected-error @+1 {{'operand_segment_sizes' attribute for specifiying operand segments must have 4 elements}} + "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[2, 1, 1]>: vector<3xi32>} : (i32, i32, i32, i32) -> () +} + +// ----- + +func @succeededOperandSizeAttr(%arg: i32) { + // CHECK: test.attr_sized_operands + "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[0, 2, 1, 1]>: vector<4xi32>} : (i32, i32, i32, i32) -> () + return +} + +// ----- + +func @failedMissingResultSizeAttr() { + // expected-error @+1 {{requires 1D vector attribute 'result_segment_sizes'}} + %0:4 = "test.attr_sized_results"() : () -> (i32, i32, i32, i32) +} + +// ----- + +func @failedResultSizeAttrWrongType() { + // expected-error @+1 {{requires 1D vector attribute 'result_segment_sizes'}} + %0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[1, 1, 1, 1]>: tensor<4xi32>} : () -> (i32, i32, i32, i32) +} + +// ----- + +func @failedResultSizeAttrWrongRank() { + // expected-error @+1 {{requires 1D vector attribute 'result_segment_sizes'}} + %0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[[1, 1], [1, 1]]>: vector<2x2xi32>} : () -> (i32, i32, i32, i32) +} + +// ----- + +func @failedResultSizeAttrNegativeValue() { + // expected-error @+1 {{'result_segment_sizes' attribute cannot have negative elements}} + %0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[1, 1, -1, 1]>: vector<4xi32>} : () -> (i32, i32, i32, i32) +} + +// ----- + +func @failedResultSizeAttrWrongTotalSize() { + // expected-error @+1 {{result count (4) does not match with the total size (3) specified in attribute 'result_segment_sizes'}} + %0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[0, 1, 1, 1]>: vector<4xi32>} : () -> (i32, i32, i32, i32) +} + +// ----- + +func @failedResultSizeAttrWrongCount() { + // expected-error @+1 {{'result_segment_sizes' attribute for specifiying result segments must have 4 elements}} + %0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[2, 1, 1]>: vector<3xi32>} : () -> (i32, i32, i32, i32) +} + +// ----- + +func @succeededResultSizeAttr() { + // CHECK: test.attr_sized_results + %0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[0, 2, 1, 1]>: vector<4xi32>} : () -> (i32, i32, i32, i32) + return +} diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index e8ca8b8..6bb0cbc 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -413,6 +413,30 @@ def TestBranchOp : TEST_Op<"br", [Terminator]> { let arguments = (ins Variadic:$operands); } +def AttrSizedOperandOp : TEST_Op<"attr_sized_operands", + [AttrSizedOperandSegments]> { + let arguments = (ins + Variadic:$a, + Variadic:$b, + I32:$c, + Variadic:$d, + I32ElementsAttr:$operand_segment_sizes + ); +} + +def AttrSizedResultOp : TEST_Op<"attr_sized_results", + [AttrSizedResultSegments]> { + let arguments = (ins + I32ElementsAttr:$result_segment_sizes + ); + let results = (outs + Variadic:$a, + Variadic:$b, + I32:$c, + Variadic:$d + ); +} + //===----------------------------------------------------------------------===// // Test Patterns //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 538aa6e..864f773 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -40,18 +40,18 @@ static const char *const tblgenNamePrefix = "tblgen_"; static const char *const generatedArgName = "tblgen_arg"; static const char *const builderOpState = "tblgen_state"; -// The logic to calculate the dynamic value range for an static operand/result +// The logic to calculate the actual value range for a declared operand/result // of an op with variadic operands/results. Note that this logic is not for // general use; it assumes all variadic operands/results must have the same // number of values. // -// {0}: The list of whether each static operand/result is variadic. +// {0}: The list of whether each declared operand/result is variadic. // {1}: The total number of non-variadic operands/results. // {2}: The total number of variadic operands/results. -// {3}: The total number of dynamic values. -// {4}: The begin iterator of the dynamic values. -// {5}: "operand" or "result" -const char *valueRangeCalcCode = R"( +// {3}: The total number of actual values. +// {4}: The begin iterator of the actual values. +// {5}: "operand" or "result". +const char *sameVariadicSizeValueRangeCalcCode = R"( bool isVariadic[] = {{{0}}; int prevVariadicCount = 0; for (unsigned i = 0; i < index; ++i) @@ -70,6 +70,22 @@ const char *valueRangeCalcCode = R"( return {{std::next({4}, offset), std::next({4}, offset + size)}; )"; +// The logic to calculate the actual value range for a declared operand/result +// of an op with variadic operands/results. Note that this logic is assumes +// the op has an attribute specifying the size of each operand/result segment +// (variadic or not). +// +// {0}: The name of the attribute specifying the segment sizes. +// {1}: The begin iterator of the actual values. +const char *attrSizedSegmentValueRangeCalcCode = R"( + auto sizeAttr = getAttrOfType("{0}"); + unsigned start = 0; + for (unsigned i = 0; i < index; ++i) + start += (*(sizeAttr.begin() + i)).getZExtValue(); + unsigned end = start + (*(sizeAttr.begin() + index)).getZExtValue(); + return {{std::next({1}, start), std::next({1}, end)}; +)"; + static const char *const opCommentHeader = R"( //===----------------------------------------------------------------------===// // {0} {1} @@ -239,6 +255,10 @@ class OpClass : public Class { public: explicit OpClass(StringRef name, StringRef extraClassDeclaration = ""); + // Sets whether this OpClass should generate the using directive for its + // associate operand adaptor class. + void setHasOperandAdaptorClass(bool has); + // Adds an op trait. void addTrait(Twine trait); @@ -249,6 +269,7 @@ public: private: StringRef extraClassDeclaration; SmallVector traits; + bool hasOperandAdaptor; }; } // end anonymous namespace @@ -401,7 +422,10 @@ void Class::writeDefTo(raw_ostream &os) const { } OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) - : Class(name), extraClassDeclaration(extraClassDeclaration) {} + : Class(name), extraClassDeclaration(extraClassDeclaration), + hasOperandAdaptor(true) {} + +void OpClass::setHasOperandAdaptorClass(bool has) { hasOperandAdaptor = has; } // Adds the given trait to this op. void OpClass::addTrait(Twine trait) { traits.push_back(trait.str()); } @@ -412,7 +436,8 @@ void OpClass::writeDeclTo(raw_ostream &os) const { os << ", " << trait; os << "> {\npublic:\n"; os << " using Op::Op;\n"; - os << " using OperandAdaptor = " << className << "OperandAdaptor;\n"; + if (hasOperandAdaptor) + os << " using OperandAdaptor = " << className << "OperandAdaptor;\n"; bool hasPrivateMethod = false; for (const auto &method : methods) { @@ -667,12 +692,27 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass, const int numVariadicOperands = op.getNumVariadicOperands(); const int numNormalOperands = numOperands - numVariadicOperands; - if (numVariadicOperands > 1 && - !op.hasTrait("OpTrait::SameVariadicOperandSize")) { + const auto *sameVariadicSize = + op.getTrait("OpTrait::SameVariadicOperandSize"); + const auto *attrSizedOperands = + op.getTrait("OpTrait::AttrSizedOperandSegments"); + + if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) { PrintFatalError(op.getLoc(), "op has multiple variadic operands but no " "specification over their sizes"); } + if (numVariadicOperands < 2 && attrSizedOperands) { + PrintFatalError(op.getLoc(), "op must have at least two variadic operands " + "to use 'AttrSizedOperandSegments' trait"); + } + + if (attrSizedOperands && sameVariadicSize) { + PrintFatalError(op.getLoc(), + "op cannot have both 'AttrSizedOperandSegments' and " + "'SameVariadicOperandSize' traits"); + } + // First emit a "sink" getter method upon which we layer all nicer named // getter methods. auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index"); @@ -681,6 +721,9 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass, // We still need to match the return type, which is a range. m.body() << " return {std::next(" << rangeBeginCall << ", index), std::next(" << rangeBeginCall << ", index + 1)};"; + } else if (attrSizedOperands) { + m.body() << formatv(attrSizedSegmentValueRangeCalcCode, + "operand_segment_sizes", rangeBeginCall); } else { // Because the op can have arbitrarily interleaved variadic and non-variadic // operands, we need to embed a list in the "sink" getter method for @@ -692,9 +735,9 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass, } std::string isVariadicList = llvm::join(isVariadic, ", "); - m.body() << formatv(valueRangeCalcCode, isVariadicList, numNormalOperands, - numVariadicOperands, rangeSizeCall, rangeBeginCall, - "operand"); + m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, + numNormalOperands, numVariadicOperands, rangeSizeCall, + rangeBeginCall, "operand"); } // Then we emit nicer named getter methods by redirecting to the "sink" getter @@ -716,6 +759,9 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass, } void OpEmitter::genNamedOperandGetters() { + if (op.getTrait("OpTrait::AttrSizedOperandSegments")) + opClass.setHasOperandAdaptorClass(false); + generateNamedOperandGetters( op, opClass, /*rangeType=*/"Operation::operand_range", /*rangeBeginCall=*/"getOperation()->operand_begin()", @@ -731,18 +777,36 @@ void OpEmitter::genNamedResultGetters() { // If we have more than one variadic results, we need more complicated logic // to calculate the value range for each result. - if (numVariadicResults > 1 && - !op.hasTrait("OpTrait::SameVariadicResultSize")) { + const auto *sameVariadicSize = op.getTrait("OpTrait::SameVariadicResultSize"); + const auto *attrSizedResults = + op.getTrait("OpTrait::AttrSizedResultSegments"); + + if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) { PrintFatalError(op.getLoc(), "op has multiple variadic results but no " "specification over their sizes"); } + if (numVariadicResults < 2 && attrSizedResults) { + PrintFatalError(op.getLoc(), "op must have at least two variadic results " + "to use 'AttrSizedResultSegments' trait"); + } + + if (attrSizedResults && sameVariadicSize) { + PrintFatalError(op.getLoc(), + "op cannot have both 'AttrSizedResultSegments' and " + "'SameVariadicResultSize' traits"); + } + auto &m = opClass.newMethod("Operation::result_range", "getODSResults", "unsigned index"); if (numVariadicResults == 0) { m.body() << " return {std::next(getOperation()->result_begin(), index), " "std::next(getOperation()->result_begin(), index + 1)};"; + } else if (attrSizedResults) { + m.body() << formatv(attrSizedSegmentValueRangeCalcCode, + "result_segment_sizes", + "getOperation()->result_begin()"); } else { llvm::SmallVector isVariadic; isVariadic.reserve(numResults); @@ -751,8 +815,9 @@ void OpEmitter::genNamedResultGetters() { } std::string isVariadicList = llvm::join(isVariadic, ", "); - m.body() << formatv(valueRangeCalcCode, isVariadicList, numNormalResults, - numVariadicResults, "getOperation()->getNumResults()", + m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, + numNormalResults, numVariadicResults, + "getOperation()->getNumResults()", "getOperation()->result_begin()", "result"); } @@ -952,11 +1017,11 @@ void OpEmitter::genBuilder() { // use the first operand or attribute's type as all result types // to facilitate different call patterns. if (op.getNumVariadicResults() == 0) { - if (op.hasTrait("OpTrait::SameOperandsAndResultType")) { + if (op.getTrait("OpTrait::SameOperandsAndResultType")) { genUseOperandAsResultTypeSeparateParamBuilder(); genUseOperandAsResultTypeCollectiveParamBuilder(); } - if (op.hasTrait("OpTrait::FirstAttrDerivedResultType")) + if (op.getTrait("OpTrait::FirstAttrDerivedResultType")) genUseAttrAsResultTypeBuilder(); } } @@ -1243,18 +1308,38 @@ void OpEmitter::genVerifier() { body << " }\n"; } - genOperandResultVerifier(body, op.getOperands(), "operand"); - genOperandResultVerifier(body, op.getResults(), "result"); + const char *code = R"( + auto sizeAttr = getAttrOfType("{0}"); + auto numElements = sizeAttr.getType().cast().getNumElements(); + if (numElements != {1}) {{ + return emitOpError("'{0}' attribute for specifiying {2} segments " + "must have {1} elements"); + } + )"; for (auto &trait : op.getTraits()) { - if (auto t = dyn_cast(&trait)) { + if (auto *t = dyn_cast(&trait)) { body << tgfmt(" if (!($0)) {\n " "return emitOpError(\"failed to verify that $1\");\n }\n", &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx), t->getDescription()); + } else if (auto *t = dyn_cast(&trait)) { + if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") { + body << formatv(code, "operand_segment_sizes", op.getNumOperands(), + "operand"); + } else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") { + body << formatv(code, "result_segment_sizes", op.getNumResults(), + "result"); + } } } + // These should happen after we verified the traits because + // getODSOperands()/getODSResults() may depend on traits (e.g., + // AttrSizedOperandSegments/AttrSizedResultSegments). + genOperandResultVerifier(body, op.getOperands(), "operand"); + genOperandResultVerifier(body, op.getResults(), "result"); + genRegionVerifier(body); if (hasCustomVerify) { @@ -1405,7 +1490,7 @@ void OpEmitter::genOpAsmInterface() { // TODO: We could also add a flag to allow operations to opt in to this // generation, even if they only have a single operation. int numResults = op.getNumResults(); - if (numResults <= 1 || op.hasTrait("OpAsmOpInterface::Trait")) + if (numResults <= 1 || op.getTrait("OpAsmOpInterface::Trait")) return; SmallVector resultNames(numResults); @@ -1484,13 +1569,19 @@ static void emitOpClasses(const std::vector &defs, raw_ostream &os, } for (auto *def : defs) { Operator op(*def); + const auto *attrSizedOperands = + op.getTrait("OpTrait::AttrSizedOperandSegments"); if (emitDecl) { os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); - OpOperandAdaptorEmitter::emitDecl(op, os); + // We cannot generate the operand adaptor class if operand getters depend + // on an attribute. + if (!attrSizedOperands) + OpOperandAdaptorEmitter::emitDecl(op, os); OpEmitter::emitDecl(op, os); } else { os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); - OpOperandAdaptorEmitter::emitDef(op, os); + if (!attrSizedOperands) + OpOperandAdaptorEmitter::emitDef(op, os); OpEmitter::emitDef(op, os); } } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index ac2976a..d2776e0 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -761,8 +761,8 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, // special cases listed below, DRR needs to supply types for all results // when building an op. bool isSameOperandsAndResultType = - resultOp.hasTrait("OpTrait::SameOperandsAndResultType"); - bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType"); + resultOp.getTrait("OpTrait::SameOperandsAndResultType"); + bool useFirstAttr = resultOp.getTrait("OpTrait::FirstAttrDerivedResultType"); if (isSameOperandsAndResultType || useFirstAttr) { // We know how to deduce the result type for ops with these traits and we've @@ -780,7 +780,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, } bool isBroadcastable = - resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult"); + resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult"); bool usePartialResults = valuePackName != resultValue; if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) { -- 2.7.4