From b8dc04a00504fbefe020821757769c68cbcd49a2 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 16 Apr 2019 13:20:38 -0700 Subject: [PATCH] [TableGen] Fix builder for ops with one variadic input and SameValueType For ops with the SameValueType trait, we generate a builder without requiring result type; we get the result type from the operand. However, if the operand is variadic, we need to index into the first value in the pack. -- PiperOrigin-RevId: 243866647 --- mlir/test/mlir-tblgen/op-builder.td | 28 ++++++++++++++++++++++++++++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 8 ++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) create mode 100644 mlir/test/mlir-tblgen/op-builder.td diff --git a/mlir/test/mlir-tblgen/op-builder.td b/mlir/test/mlir-tblgen/op-builder.td new file mode 100644 index 0000000..7be9018 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-builder.td @@ -0,0 +1,28 @@ +// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + + +def NS_OpA : Op<"op_same_value_type", [SameValueType]> { + let arguments = (ins Tensor:$input); + let results = (outs Tensor:$result); +} + +// Test that with SameValueType trait we can generate a builder without +// requiring result type +// --- + +// CHECK-LABEL: OpA::build(Builder *, OperationState *tblgen_state, Value *input) +// CHECK: tblgen_state->addTypes({input->getType()}); + +def NS_OpB : Op<"op_same_value_type_variadic_input", [SameValueType]> { + let arguments = (ins Variadic:$input); + let results = (outs Tensor:$result); +} + +// Test that if the only operand is variadic, we acess the first value in the +// pack to set result type +// --- + +// CHECK-LABEL: OpB::build(Builder *, OperationState *tblgen_state, ArrayRef input) +// CHECK: tblgen_state->addTypes({input.front()->getType()}); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 398fe78..f6200f1 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -548,6 +548,8 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType, auto &method = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); + bool hasVariadicOperand = op.hasVariadicOperand(); + // Push all result types to the result if (numResults > 0) { if (!useOperandType && !useAttrType) { @@ -578,7 +580,10 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType, resultType = formatv("{0}.getType()", namedAttr.name); } } else { - resultType = formatv("{0}->getType()", getArgumentName(op, 0)).str(); + const char *index = + (numOperands == 1 && hasVariadicOperand) ? ".front()" : ""; + resultType = + formatv("{0}{1}->getType()", getArgumentName(op, 0), index).str(); } method.body() << " " << builderOpState << "->addTypes({" << resultType; for (unsigned i = 1; i != numResults; ++i) @@ -588,7 +593,6 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType, } // Push all operands to the result - bool hasVariadicOperand = op.hasVariadicOperand(); int numNonVariadicOperands = numOperands - static_cast(hasVariadicOperand); if (numNonVariadicOperands > 0) { -- 2.7.4