From: Lei Zhang Date: Thu, 25 Apr 2019 21:45:37 +0000 (-0700) Subject: [TableGen] Support multiple variadic operands/results X-Git-Tag: llvmorg-11-init~1466^2~1885 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6749c21d6ec263891e9bb55a7203f2d2a8bb4d5f;p=platform%2Fupstream%2Fllvm.git [TableGen] Support multiple variadic operands/results Certain ops can have multiple variadic operands/results, e.g., `tf.DynamicStitch`. Even if an op has only one variadic operand/result, it is not necessarily the very last one, e.g., `tf.RaggedGather`. This CL enhances TableGen subsystem to be able to represent such cases. In order to deduce the operand/result value range for each variadic operand, currently we only support variadic operands/results all of the same size. So two new traits, `SameVariadicOperandSize` and `SameVariadicResultSize` are introduced. -- PiperOrigin-RevId: 245310628 --- diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 0e4ab44..300b86e 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -759,6 +759,17 @@ def Terminator : NativeOpTrait<"IsTerminator">; def FirstAttrDerivedResultType : GenInternalOpTrait<"FirstAttrDerivedResultType">; +// All variadic operands of the op have the same number of values. +// A variadic operand contains an array of values whose array size is only +// known at runtime. This trait requires all variadic operands of an op +// to have the same array size. +def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">; +// All variadic results of the op have the same number of values. +// A variadic result contains an array of values whose array size is only +// known at runtime. This trait requires all variadic results of an op +// to have the same array size. +def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">; + //===----------------------------------------------------------------------===// // Op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Argument.h b/mlir/include/mlir/TableGen/Argument.h index a37bc700..8390939 100644 --- a/mlir/include/mlir/TableGen/Argument.h +++ b/mlir/include/mlir/TableGen/Argument.h @@ -48,10 +48,12 @@ struct NamedAttribute { Attribute attr; }; -// A struct wrapping an op operand/result and its name together +// A struct wrapping an op operand/result's constraint and its name together struct NamedTypeConstraint { - // Returns true if this operand has constraint that need to be satisfied. + // Returns true if this operand/result has constraint to be satisfied. bool hasPredicate() const; + // Returns true if this operand/result is variadic. + bool isVariadic() const; llvm::StringRef name; TypeConstraint constraint; diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 7233626..e96b29c 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -82,8 +82,8 @@ public: // Returns the `index`-th result's name. StringRef getResultName(int index) const; - // Returns true if this operation has a variadic result. - bool hasVariadicResult() const; + // Returns the number of variadic results in this operation. + unsigned getNumVariadicResults() const; // Op attribute interators. using attribute_iterator = const NamedAttribute *; @@ -112,8 +112,8 @@ public: return operands[index]; } - // Returns true if this operation has a variadic operand. - bool hasVariadicOperand() const; + // Returns the number of variadic operands in this operation. + unsigned getNumVariadicOperands() const; // Returns the total number of arguments. int getNumArgs() const { return arguments.size(); } diff --git a/mlir/lib/TableGen/Argument.cpp b/mlir/lib/TableGen/Argument.cpp index f7af63f..7432e0f 100644 --- a/mlir/lib/TableGen/Argument.cpp +++ b/mlir/lib/TableGen/Argument.cpp @@ -23,3 +23,7 @@ using namespace mlir; bool tblgen::NamedTypeConstraint::hasPredicate() const { return !constraint.getPredicate().isNull(); } + +bool tblgen::NamedTypeConstraint::isVariadic() const { + return constraint.isVariadic(); +} diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 6f0f3ea..3854728 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -82,8 +82,10 @@ StringRef tblgen::Operator::getResultName(int index) const { return results->getArgNameStr(index); } -bool tblgen::Operator::hasVariadicResult() const { - return !results.empty() && results.back().constraint.isVariadic(); +unsigned tblgen::Operator::getNumVariadicResults() const { + return std::count_if( + results.begin(), results.end(), + [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); }); } int tblgen::Operator::getNumNativeAttributes() const { @@ -98,8 +100,10 @@ const tblgen::NamedAttribute &tblgen::Operator::getAttribute(int index) const { return attributes[index]; } -bool tblgen::Operator::hasVariadicOperand() const { - return !operands.empty() && operands.back().constraint.isVariadic(); +unsigned tblgen::Operator::getNumVariadicOperands() const { + return std::count_if( + operands.begin(), operands.end(), + [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); }); } StringRef tblgen::Operator::getArgName(int index) const { @@ -222,13 +226,6 @@ void tblgen::Operator::populateOpStructure() { } } - // Verify that only the last operand can be variadic. - for (int i = 0, e = operands.size() - 1; i < e; ++i) { - if (operands[i].constraint.isVariadic()) - PrintFatalError(def.getLoc(), - "only the last operand allowed to be variadic"); - } - auto *resultsDag = def.getValueAsDag("results"); auto *outsOp = dyn_cast(resultsDag->getOperator()); if (!outsOp || outsOp->getDef()->getName() != "outs") { @@ -246,13 +243,6 @@ void tblgen::Operator::populateOpStructure() { results.push_back({name, TypeConstraint(resultDef)}); } - // Verify that only the last result can be variadic. - for (int i = 0, e = results.size() - 1; i < e; ++i) { - if (results[i].constraint.isVariadic()) - PrintFatalError(def.getLoc(), - "only the last result allowed to be variadic"); - } - auto traitListInit = def.getValueAsListInit("traits"); if (!traitListInit) return; diff --git a/mlir/test/mlir-tblgen/op-builder.td b/mlir/test/mlir-tblgen/op-builder.td deleted file mode 100644 index 7be9018..0000000 --- a/mlir/test/mlir-tblgen/op-builder.td +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s - -include "mlir/IR/OpBase.td" - - -def NS_OpA : Op<"op_same_value_type", [SameValueType]> { - let arguments = (ins Tensor:$input); - let results = (outs Tensor:$result); -} - -// Test that with SameValueType trait we can generate a builder without -// requiring result type -// --- - -// CHECK-LABEL: OpA::build(Builder *, OperationState *tblgen_state, Value *input) -// CHECK: tblgen_state->addTypes({input->getType()}); - -def NS_OpB : Op<"op_same_value_type_variadic_input", [SameValueType]> { - let arguments = (ins Variadic:$input); - let results = (outs Tensor:$result); -} - -// Test that if the only operand is variadic, we acess the first value in the -// pack to set result type -// --- - -// CHECK-LABEL: OpB::build(Builder *, OperationState *tblgen_state, ArrayRef input) -// CHECK: tblgen_state->addTypes({input.front()->getType()}); diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td index 132d382..936fb7c 100644 --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -2,7 +2,7 @@ include "mlir/IR/OpBase.td" -def OpA : Op<"one_operand_op", []> { +def OpA : Op<"one_normal_operand_op", []> { let arguments = (ins I32:$input); } @@ -10,7 +10,7 @@ def OpA : Op<"one_operand_op", []> { // CHECK: void OpA::build // CHECK-SAME: Value *input -// CHECK: tblgen_state->addOperands({input}); +// CHECK: tblgen_state->operands.push_back(input); // CHECK: void OpA::build // CHECK-SAME: ArrayRef operands @@ -21,11 +21,72 @@ def OpA : Op<"one_operand_op", []> { // CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32)))) // CHECK-NEXT: return emitOpError("operand #0 must be 32-bit integer"); -def OpB : Op<"variadic_operand_op", []> { +def OpB : Op<"one_variadic_operand_op", []> { let arguments = (ins Variadic:$input); } // CHECK-LABEL: OpB::build -// CHECK-SAME: ArrayRef input -// CHECK-NOT: assert -// CHECK: tblgen_state->addOperands(input); +// CHECK-SAME: ArrayRef input +// CHECK-NOT: assert +// CHECK: tblgen_state->addOperands(input); + +def OpC : Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> { + let arguments = (ins Variadic:$input1, Variadic:$input2); +} + +// CHECK-LABEL: Operation::operand_range OpC::input1() +// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 0) / 2; +// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 0; +// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)}; + +// CHECK-LABEL: Operation::operand_range OpC::input2() +// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 0) / 2; +// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 1; +// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)}; + +// CHECK-LABEL: OpC::build +// CHECK-NEXT: tblgen_state->addOperands(input1); +// CHECK-NEXT: tblgen_state->addOperands(input2); + +def OpD : Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> { + let arguments = (ins Variadic:$input1, Tensor:$input2, Variadic:$input3); +} + +// CHECK-LABEL: Operation::operand_range OpD::input1() +// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2; +// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 0; +// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)}; + +// CHECK-LABEL: Value *OpD::input2() +// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2; +// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 1; +// CHECK-NEXT: return this->getOperand(offset); + +// CHECK-LABEL: Operation::operand_range OpD::input3() +// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2; +// CHECK-NEXT: unsigned offset = 1 + variadicOperandSize * 1; +// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)}; + +// CHECK-LABEL: OpD::build +// CHECK-NEXT: tblgen_state->addOperands(input1); +// CHECK-NEXT: tblgen_state->operands.push_back(input2); +// CHECK-NEXT: tblgen_state->addOperands(input3); + +def OpE : Op<"one_variadic_among_multi_normal_inputs_op", []> { + let arguments = (ins Tensor:$input1, Tensor:$input2, Variadic:$input3, Tensor:$input4, Tensor:$input5); +} + +// CHECK-LABEL: Value *OpE::input1() +// CHECK-NEXT: return this->getOperation()->getOperand(0); + +// CHECK-LABEL: Value *OpE::input2() +// CHECK-NEXT: return this->getOperation()->getOperand(1); + +// CHECK-LABEL: Operation::operand_range OpE::input3() +// CHECK-NEXT: return {std::next(operand_begin(), 2), std::next(operand_begin(), 2 + this->getNumOperands() - 4)}; + +// CHECK-LABEL: Value *OpE::input4() +// CHECK-NEXT: return this->getOperation()->getOperand(this->getNumOperands() - 2); + +// CHECK-LABEL: Value *OpE::input5() +// CHECK-NEXT: return this->getOperation()->getOperand(this->getNumOperands() - 1); diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index 06bc3b4..714c62c 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -2,82 +2,160 @@ include "mlir/IR/OpBase.td" -def OneResultOp : Op<"one_result_op", []> { +def OpA : Op<"one_normal_result_op", []> { let results = (outs I32:$result); } -// CHECK-LABEL: OneResultOp definitions +// CHECK-LABEL: Value *OpA::result() +// CHECK-NEXT: return this->getOperation()->getResult(0) -// CHECK: void OneResultOp::build +// CHECK-LABEL: void OpA::build // CHECK: ArrayRef resultTypes, ArrayRef operands // CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types"); // CHECK-NEXT: tblgen_state->addTypes(resultTypes); -// CHECK: LogicalResult OneResultOp::verify() { +// CHECK-LABEL: LogicalResult OpA::verify() // CHECK: if (!((this->getOperation()->getResult(0)->getType().isInteger(32)))) // CHECK-NEXT: return emitOpError("result #0 must be 32-bit integer"); - -def SameTypeOp : Op<"same_type_op", [SameValueType]> { +def OpB : Op<"same_input_output_type_op", [SameValueType]> { let arguments = (ins I32:$x); let results = (outs I32:$y); } -// CHECK-LABEL: SameTypeOp definitions -// CHECK: void SameTypeOp::build(Builder *, OperationState *tblgen_state, Type y, Value *x) -// CHECK: tblgen_state->addTypes({y}); -// CHECK: void SameTypeOp::build(Builder *, OperationState *tblgen_state, Value *x) +// CHECK-LABEL: OpB definitions +// CHECK: void OpB::build(Builder *, OperationState *tblgen_state, Type y, Value *x) +// CHECK: tblgen_state->types.push_back(y); +// CHECK: void OpB::build(Builder *, OperationState *tblgen_state, Value *x) // CHECK: tblgen_state->addTypes({x->getType()}); -def ThreeResultOp : Op<"three_result_op", []> { +def OpC : Op<"three_normal_result_op", []> { let results = (outs I32:$x, /*unnamed*/I32, I32:$z); } -// CHECK-LABEL: ThreeResultOp definitions -// CHECK: void ThreeResultOp::build(Builder *, OperationState *tblgen_state, Type x, Type resultType1, Type z) -// CHECK: tblgen_state->addTypes({x, resultType1, z}); +// CHECK-LABEL: OpC definitions +// CHECK: void OpC::build(Builder *, OperationState *tblgen_state, Type x, Type resultType1, Type z) +// CHECK-NEXT: tblgen_state->types.push_back(x) +// CHECK-NEXT: tblgen_state->types.push_back(resultType1) +// CHECK-NEXT: tblgen_state->types.push_back(z) def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">; -def TypeAttrResultTypeOp : Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> { +def OpD : Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> { let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32); let results = (outs Tensor:$y); } -// CHECK-LABEL: TypeAttrResultTypeOp definitions -// CHECK: void TypeAttrResultTypeOp::build(Builder *, OperationState *tblgen_state, Value *x, TypeAttr attr, FloatAttr f32) +// CHECK-LABEL: OpD definitions +// CHECK: void OpD::build(Builder *, OperationState *tblgen_state, Value *x, TypeAttr attr, FloatAttr f32) // CHECK: tblgen_state->addTypes({attr.getValue()}); -def ValueAttrResultTypeOp : Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> { +def OpE : Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> { let arguments = (ins I32:$x, F32Attr:$attr); let results = (outs Tensor:$y); } -// CHECK-LABEL: ValueAttrResultTypeOp definitions -// CHECK: void ValueAttrResultTypeOp::build(Builder *, OperationState *tblgen_state, Value *x, FloatAttr attr) +// CHECK-LABEL: OpE definitions +// CHECK: void OpE::build(Builder *, OperationState *tblgen_state, Value *x, FloatAttr attr) // CHECK: tblgen_state->addTypes({attr.getType()}); -def VariadicResultAloneOp : Op<"variadic_alone_op", []> { +def OpF : Op<"one_variadic_result_op", []> { let results = (outs Variadic:$x); } -// CHECK-LABEL: VariadicResultAloneOp definitions +// CHECK-LABEL: Operation::result_range OpF::x() +// CHECK-NEXT: return {std::next(result_begin(), 0), std::next(result_begin(), 0 + this->getNumResults() - 0)}; + +// CHECK-LABEL: void OpF::build +// CHECK-SAME: ArrayRef x +// CHECK-NOT: assert +// CHECK: tblgen_state->addTypes(x); -// CHECK-LABEL: void VariadicResultAloneOp::build -// CHECK-SAME: ArrayRef x -// CHECK-NOT: assert -// CHECK: tblgen_state->addTypes(x); +def OpG : Op<"one_normal_and_one_variadic_result_op", []> { -def VariadicResultOp : Op<"variadic_op", []> { let results = (outs I32:$x, Variadic:$y); } -// CHECK-LABEL: VariadicResultOp definitions +// CHECK-LABEL: OpG definitions -// CHECK: void VariadicResultOp::build(Builder *, OperationState *tblgen_state, Type x, ArrayRef y) -// CHECK: tblgen_state->addTypes({x}); -// CHECK: tblgen_state->addTypes(y); +// CHECK: void OpG::build(Builder *, OperationState *tblgen_state, Type x, ArrayRef y) +// CHECK-NEXT: tblgen_state->types.push_back(x); +// CHECK-NEXT: tblgen_state->addTypes(y); -// CHECK: void VariadicResultOp::build +// CHECK: void OpG::build // CHECK: ArrayRef resultTypes // CHECK: assert(resultTypes.size() >= 1u && "mismatched number of return types"); // CHECK-NEXT: tblgen_state->addTypes(resultTypes); + + +def OpH : Op<"all_variadic_results_op", [SameVariadicResultSize]> { + let results = (outs Variadic:$output1, Variadic:$output2); +} + +// CHECK-LABEL: Operation::result_range OpH::output1() +// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 0) / 2; +// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 0; +// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)}; + +// CHECK-LABEL: Operation::result_range OpH::output2() +// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 0) / 2; +// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 1; +// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)}; + + +// CHECK-LABEL: OpH::build +// CHECK-NEXT: tblgen_state->addTypes(output1); +// CHECK-NEXT: tblgen_state->addTypes(output2); + +def OpI : Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> { + let results = (outs Variadic:$output1, Tensor:$output2, Variadic:$output3); +} + +// CHECK-LABEL: Operation::result_range OpI::output1() +// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2; +// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 0; +// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)}; + +// CHECK-LABEL: Value *OpI::output2() +// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2; +// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 1; +// CHECK-NEXT: return this->getResult(offset); + +// CHECK-LABEL: Operation::result_range OpI::output3() +// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2; +// CHECK-NEXT: unsigned offset = 1 + variadicResultSize * 1; +// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)}; + +// CHECK-LABEL: OpI::build +// CHECK-NEXT: tblgen_state->addTypes(output1); +// CHECK-NEXT: tblgen_state->types.push_back(output2); +// CHECK-NEXT: tblgen_state->addTypes(output3); + +def OpJ : Op<"one_variadic_among_multi_normal_results_op", []> { + let results = (outs Tensor:$output1, Tensor:$output2, Variadic:$output3, Tensor:$output4, Tensor:$output5); +} + +// CHECK-LABEL: Value *OpJ::output1() +// CHECK-NEXT: return this->getOperation()->getResult(0); + +// CHECK-LABEL: Value *OpJ::output2() +// CHECK-NEXT: return this->getOperation()->getResult(1); + +// CHECK-LABEL: Operation::result_range OpJ::output3() +// CHECK-NEXT: return {std::next(result_begin(), 2), std::next(result_begin(), 2 + this->getNumResults() - 4)}; + +// CHECK-LABEL: Value *OpJ::output4() +// CHECK-NEXT: return this->getOperation()->getResult(this->getNumResults() - 2); + +// CHECK-LABEL: Value *OpJ::output5() +// CHECK-NEXT: return this->getOperation()->getResult(this->getNumResults() - 1); + +// Test that if the only operand is variadic, we acess the first value in the +// pack to set result type +// --- +def OpK : Op<"only_input_is_variadic_with_same_value_type_op", [SameValueType]> { + let arguments = (ins Variadic:$input); + let results = (outs Tensor:$result); +} + +// CHECK-LABEL: OpK::build(Builder *, OperationState *tblgen_state, ArrayRef input) +// CHECK: tblgen_state->addTypes({input.front()->getType()}); diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp index 89bbeeb..794001d 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -75,10 +75,14 @@ static StringLoc findNextVariable(StringRef str) { return {startPos, endPos - startPos}; } -// Check if `name` is the name of the variadic argument of `op`. The variadic -// argument can only appear at the last position in the list of arguments. -static bool isVariadicArgumentName(const tblgen::Operator &op, StringRef name) { - return op.hasVariadicOperand() && op.getArgName(op.getNumArgs() - 1) == name; +// Check if `name` is the name of the variadic operand of `op`. The variadic +// operand can only appear at the last position in the list of operands. +static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) { + unsigned numOperands = op.getNumOperands(); + if (numOperands == 0) + return false; + const auto &operand = op.getOperand(numOperands - 1); + return operand.isVariadic() && operand.name == name; } // Check if `result` is a known name of a result of `op`. @@ -127,9 +131,9 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) { // First, insert the non-matched part as is. bs << builderStrRef.substr(0, loc.pos); // Then, rewrite the name based on its kind. - bool isVariadicArg = isVariadicArgumentName(op, name); + bool isVariadicOperand = isVariadicOperandName(op, name); if (isOperandName(op, name)) { - auto result = isVariadicArg + auto result = isVariadicOperand ? formatv("lookupValues(op.{0}())", name) : formatv("valueMapping.lookup(op.{0}())", name); bs << result; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index dd68db3..e9be2b2 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -251,8 +251,9 @@ OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) { } void OpMethodBody::writeTo(raw_ostream &os) const { - os << body; - if (body.empty() || body.back() != '\n') + auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; }); + os << bodyRef; + if (bodyRef.empty() || bodyRef.back() != '\n') os << "\n"; } @@ -455,35 +456,153 @@ void OpEmitter::genAttrGetters() { } void OpEmitter::genNamedOperandGetters() { - for (int i = 0, e = op.getNumOperands(); i != e; ++i) { + const unsigned numOperands = op.getNumOperands(); + const unsigned numVariadicOperands = op.getNumVariadicOperands(); + const unsigned numNormalOperands = numOperands - numVariadicOperands; + + // Special case for ops without variadic operands: the i-th value is for the + // i-th operand defined in the op. + // Special case for ops with one variadic operand: the variadic operand can + // appear at any place, so the i-th value may not necessarily belong to the + // i-th operand definition. we need to calculate the index (range) for each + // operand. + if (numVariadicOperands <= 1) { + bool emittedVariadicOperand = false; + for (unsigned i = 0; i != numOperands; ++i) { + const auto &operand = op.getOperand(i); + if (operand.name.empty()) + continue; + + if (operand.isVariadic()) { + auto &m = opClass.newMethod("Operation::operand_range", operand.name); + m.body() << formatv( + " return {{std::next(operand_begin(), {0}), " + "std::next(operand_begin(), {0} + this->getNumOperands() - {1})};", + i, numNormalOperands); + emittedVariadicOperand = true; + } else { + auto &m = opClass.newMethod("Value *", operand.name); + m.body() << " return this->getOperation()->getOperand("; + if (emittedVariadicOperand) + m.body() << "this->getNumOperands() - " << numOperands - i; + else + m.body() << i; + m.body() << ");\n"; + } + } + return; + } + + // If we have more than one variadic operands, we need more complicated logic + // to calculate the value range for each operand. + + if (!op.hasTrait("SameVariadicOperandSize")) { + PrintFatalError(op.getLoc(), "op has multiple variadic operands but no " + "specification over their sizes"); + } + + unsigned emittedNormalOperands = 0; + unsigned emittedVariadicOperands = 0; + + for (unsigned i = 0; i != numOperands; ++i) { const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; - if (!operand.constraint.isVariadic()) { - auto &m = opClass.newMethod("Value *", operand.name); - m.body() << " return this->getOperation()->getOperand(" << i << ");\n"; - } else { - assert(i + 1 == e && "only the last operand can be variadic"); + const char *code = R"( + unsigned variadicOperandSize = (this->getNumOperands() - {0}) / {1}; + unsigned offset = {2} + variadicOperandSize * {3}; + return )"; + auto sizeAndOffset = + formatv(code, numNormalOperands, numVariadicOperands, + emittedNormalOperands, emittedVariadicOperands); - const char *const code = R"( - assert(getOperation()->getNumOperands() >= {0}); - return {std::next(operand_begin(), {0}), operand_end()}; - )"; + if (operand.isVariadic()) { auto &m = opClass.newMethod("Operation::operand_range", operand.name); - m.body() << formatv(code, i); + m.body() << sizeAndOffset + << "{std::next(operand_begin(), offset), " + "std::next(operand_begin(), offset + variadicOperandSize)};"; + ++emittedVariadicOperands; + } else { + auto &m = opClass.newMethod("Value *", operand.name); + m.body() << sizeAndOffset << "this->getOperand(offset);"; + ++emittedNormalOperands; } } } void OpEmitter::genNamedResultGetters() { - for (int i = 0, e = op.getNumResults(); i != e; ++i) { + const unsigned numResults = op.getNumResults(); + const unsigned numVariadicResults = op.getNumVariadicResults(); + const unsigned numNormalResults = numResults - numVariadicResults; + + // Special case for ops without variadic results: the i-th value is for the + // i-th result defined in the op. + // Special case for ops with one variadic result: the variadic result can + // appear at any place, so the i-th value may not necessarily belong to the + // i-th result definition. we need to calculate the index (range) for each + // result. + if (numVariadicResults <= 1) { + bool emittedVariadicResult = false; + for (unsigned i = 0; i != numResults; ++i) { + const auto &result = op.getResult(i); + if (result.name.empty()) + continue; + + if (result.isVariadic()) { + auto &m = opClass.newMethod("Operation::result_range", result.name); + m.body() << formatv( + " return {{std::next(result_begin(), {0}), " + "std::next(result_begin(), {0} + this->getNumResults() - {1})};", + i, numNormalResults); + emittedVariadicResult = true; + } else { + auto &m = opClass.newMethod("Value *", result.name); + m.body() << " return this->getOperation()->getResult("; + if (emittedVariadicResult) + m.body() << "this->getNumResults() - " << numResults - i; + else + m.body() << i; + m.body() << ");\n"; + } + } + return; + } + + // If we have more than one variadic results, we need more complicated logic + // to calculate the value range for each result. + + if (!op.hasTrait("SameVariadicResultSize")) { + PrintFatalError(op.getLoc(), "op has multiple variadic results but no " + "specification over their sizes"); + } + + unsigned emittedNormalResults = 0; + unsigned emittedVariadicResults = 0; + + for (unsigned i = 0; i != numResults; ++i) { const auto &result = op.getResult(i); - if (result.constraint.isVariadic() || result.name.empty()) + if (result.name.empty()) continue; - auto &m = opClass.newMethod("Value *", result.name); - m.body() << " return this->getOperation()->getResult(" << i << ");\n"; + const char *code = R"( + unsigned variadicResultSize = (this->getNumResults() - {0}) / {1}; + unsigned offset = {2} + variadicResultSize * {3}; + return )"; + auto sizeAndOffset = formatv(code, numNormalResults, numVariadicResults, + emittedNormalResults, emittedVariadicResults); + + if (result.isVariadic()) { + auto &m = opClass.newMethod("Operation::result_range", result.name); + m.body() << sizeAndOffset + << "{std::next(result_begin(), offset), " + "std::next(result_begin(), offset + variadicResultSize)};"; + ++emittedVariadicResults; + } else { + auto &m = opClass.newMethod("Value *", result.name); + m.body() << sizeAndOffset << "this->getResult(offset);"; + ++emittedNormalResults; + } } } @@ -505,12 +624,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType, // Emit parameters for all return types if (!useOperandType && !useAttrType) { for (unsigned i = 0; i != numResults; ++i) { - std::string resultName = op.getResultName(i); + const auto &result = op.getResult(i); + std::string resultName = result.name; if (resultName.empty()) resultName = formatv("resultType{0}", i); - bool isVariadic = op.getResultTypeConstraint(i).isVariadic(); - paramList.append(isVariadic ? ", ArrayRef " : ", Type "); + paramList.append(result.isVariadic() ? ", ArrayRef " : ", Type "); paramList.append(resultName); resultNames.emplace_back(std::move(resultName)); @@ -520,12 +639,13 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType, // Emit parameters for all arguments (operands and attributes). int numOperands = 0; int numAttrs = 0; + for (int i = 0, e = op.getNumArgs(); i < e; ++i) { auto argument = op.getArg(i); if (argument.is()) { - auto &operand = op.getOperand(numOperands); - paramList.append(operand.constraint.isVariadic() ? ", ArrayRef " - : ", Value *"); + const auto &operand = op.getOperand(numOperands); + paramList.append(operand.isVariadic() ? ", ArrayRef " + : ", Value *"); paramList.append(getArgumentName(op, numOperands)); ++numOperands; } else { @@ -542,33 +662,22 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType, } if (numOperands + numAttrs != op.getNumArgs()) - return PrintFatalError( - "op arguments must be either operands or attributes"); + PrintFatalError("op arguments must be either operands or attributes"); - auto &method = - opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); - - bool hasVariadicOperand = op.hasVariadicOperand(); + auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); // Push all result types to the result if (numResults > 0) { if (!useOperandType && !useAttrType) { - bool hasVariadicResult = op.hasVariadicResult(); - int numNonVariadicResults = - numResults - static_cast(hasVariadicResult); - - if (numNonVariadicResults > 0) { - method.body() << " " << builderOpState << "->addTypes({" - << resultNames.front(); - for (int i = 1; i < numNonVariadicResults; ++i) { - method.body() << ", " << resultNames[i]; + for (unsigned i = 0; i < numResults; ++i) { + const auto &result = op.getResult(i); + m.body() << " " << builderOpState; + if (result.isVariadic()) { + m.body() << "->addTypes("; + } else { + m.body() << "->types.push_back("; } - method.body() << "});\n"; - } - - if (hasVariadicResult) { - method.body() << " " << builderOpState << "->addTypes(" - << resultNames.back() << ");\n"; + m.body() << resultNames[i] << ");\n"; } } else { std::string resultType; @@ -580,32 +689,27 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType, resultType = formatv("{0}.getType()", namedAttr.name); } } else { - const char *index = - (numOperands == 1 && hasVariadicOperand) ? ".front()" : ""; + const char *index = op.getOperand(0).isVariadic() ? ".front()" : ""; resultType = formatv("{0}{1}->getType()", getArgumentName(op, 0), index).str(); } - method.body() << " " << builderOpState << "->addTypes({" << resultType; + m.body() << " " << builderOpState << "->addTypes({" << resultType; for (unsigned i = 1; i != numResults; ++i) - method.body() << ", " << resultType; - method.body() << "});\n\n"; + m.body() << ", " << resultType; + m.body() << "});\n\n"; } } // Push all operands to the result - int numNonVariadicOperands = - numOperands - static_cast(hasVariadicOperand); - if (numNonVariadicOperands > 0) { - method.body() << " " << builderOpState << "->addOperands({" - << getArgumentName(op, 0); - for (int i = 1; i < numNonVariadicOperands; ++i) { - method.body() << ", " << getArgumentName(op, i); + for (unsigned i = 0; i < numOperands; ++i) { + const auto &operand = op.getOperand(i); + m.body() << " " << builderOpState; + if (operand.isVariadic()) { + m.body() << "->addOperands("; + } else { + m.body() << "->operands.push_back("; } - method.body() << "});\n"; - } - if (hasVariadicOperand) { - method.body() << " " << builderOpState << "->addOperands(" - << getArgumentName(op, numOperands - 1) << ");\n"; + m.body() << getArgumentName(op, i) << ");\n"; } // Push all attributes to the result @@ -613,12 +717,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType, if (!namedAttr.attr.isDerivedAttr()) { bool emitNotNullCheck = namedAttr.attr.isOptional(); if (emitNotNullCheck) { - method.body() << formatv(" if ({0}) ", namedAttr.name) << "{\n"; + m.body() << formatv(" if ({0}) ", namedAttr.name) << "{\n"; } - method.body() << formatv(" {0}->addAttribute(\"{1}\", {1});\n", - builderOpState, namedAttr.name); + m.body() << formatv(" {0}->addAttribute(\"{1}\", {1});\n", + builderOpState, namedAttr.name); if (emitNotNullCheck) { - method.body() << " }\n"; + m.body() << " }\n"; } } } @@ -646,13 +750,13 @@ void OpEmitter::genBuilder() { } } - auto numResults = op.getNumResults(); - bool hasVariadicResult = op.hasVariadicResult(); - int numNonVariadicResults = numResults - int(hasVariadicResult); + unsigned numResults = op.getNumResults(); + unsigned numVariadicResults = op.getNumVariadicResults(); + unsigned numNonVariadicResults = numResults - numVariadicResults; - auto numOperands = op.getNumOperands(); - bool hasVariadicOperand = op.hasVariadicOperand(); - int numNonVariadicOperands = numOperands - int(hasVariadicOperand); + unsigned numOperands = op.getNumOperands(); + unsigned numVariadicOperands = op.getNumVariadicOperands(); + unsigned numNonVariadicOperands = numOperands - numVariadicOperands; // Generate default builders that requires all result type, operands, and // attributes as parameters. @@ -681,15 +785,16 @@ void OpEmitter::genBuilder() { auto &body = m.body(); // Result types - if (!(hasVariadicResult && numNonVariadicResults == 0)) + if (numVariadicResults == 0 || numNonVariadicResults != 0) body << " assert(resultTypes.size()" - << (hasVariadicResult ? " >= " : " == ") << numNonVariadicResults + << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults << "u && \"mismatched number of return types\");\n"; body << " " << builderOpState << "->addTypes(resultTypes);\n"; // Operands - if (!(hasVariadicOperand && numNonVariadicOperands == 0)) - body << " assert(operands.size()" << (hasVariadicOperand ? " >= " : " == ") + if (numVariadicOperands == 0 || numNonVariadicOperands != 0) + body << " assert(operands.size()" + << (numVariadicOperands != 0 ? " >= " : " == ") << numNonVariadicOperands << "u && \"mismatched number of parameters\");\n"; body << " " << builderOpState << "->addOperands(operands);\n\n"; @@ -703,7 +808,7 @@ void OpEmitter::genBuilder() { bool useOperandType = op.hasTrait("SameOperandsAndResultType"); bool useAttrType = op.hasTrait("FirstAttrDerivedResultType"); - if (!op.hasVariadicResult() && (useOperandType || useAttrType)) + if (numVariadicResults == 0 && (useOperandType || useAttrType)) genStandaloneParamBuilder(useOperandType, useAttrType); } @@ -824,7 +929,7 @@ void OpEmitter::genVerifier() { auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index, bool isOperand) -> void { // TODO: Handle variadic operand/result verification. - if (value.constraint.isVariadic()) + if (value.isVariadic()) return; // TODO: Commonality between matchers could be extracted to have a more @@ -869,12 +974,12 @@ void OpEmitter::genVerifier() { } void OpEmitter::genTraits() { - auto numResults = op.getNumResults(); - bool hasVariadicResult = op.hasVariadicResult(); + unsigned numResults = op.getNumResults(); + unsigned numVariadicResults = op.getNumVariadicResults(); // Add return size trait. - if (hasVariadicResult) { - if (numResults == 1) + if (numVariadicResults != 0) { + if (numResults == numVariadicResults) opClass.addTrait("VariadicResults"); else opClass.addTrait("AtLeastNResults<" + Twine(numResults - 1) + ">::Impl"); @@ -898,12 +1003,12 @@ void OpEmitter::genTraits() { } // Add variadic size trait and normal op traits. - auto numOperands = op.getNumOperands(); - bool hasVariadicOperand = op.hasVariadicOperand(); + unsigned numOperands = op.getNumOperands(); + unsigned numVariadicOperands = op.getNumVariadicOperands(); // Add operand size trait. - if (hasVariadicOperand) { - if (numOperands == 1) + if (numVariadicOperands != 0) { + if (numOperands == numVariadicOperands) opClass.addTrait("VariadicOperands"); else opClass.addTrait("AtLeastNOperands<" + Twine(numOperands - 1) + diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index dad3035..501b7a1 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -440,7 +440,7 @@ void PatternEmitter::emit(StringRef rewriteName) { const Operator &rootOp = pattern.getSourceRootOp(); auto rootName = rootOp.getOperationName(); - if (rootOp.hasVariadicResult()) + if (rootOp.getNumVariadicResults() != 0) PrintFatalError( loc, "replacing op with variadic results not supported right now");