[TableGen] Support multiple variadic operands/results
authorLei Zhang <antiagainst@google.com>
Thu, 25 Apr 2019 21:45:37 +0000 (14:45 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:16:54 +0000 (08:16 -0700)
    Certain ops can have multiple variadic operands/results, e.g., `tf.DynamicStitch`.
    Even if an op has only one variadic operand/result, it is not necessarily the
    very last one, e.g., `tf.RaggedGather`. This CL enhances TableGen subsystem to be
    able to represent such cases.

    In order to deduce the operand/result value range for each variadic operand,
    currently we only support variadic operands/results all of the same size.
    So two new traits, `SameVariadicOperandSize` and `SameVariadicResultSize` are
    introduced.

--

PiperOrigin-RevId: 245310628

mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Argument.h
mlir/include/mlir/TableGen/Operator.h
mlir/lib/TableGen/Argument.cpp
mlir/lib/TableGen/Operator.cpp
mlir/test/mlir-tblgen/op-builder.td [deleted file]
mlir/test/mlir-tblgen/op-operand.td
mlir/test/mlir-tblgen/op-result.td
mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp

index 0e4ab44..300b86e 100644 (file)
@@ -759,6 +759,17 @@ def Terminator       : NativeOpTrait<"IsTerminator">;
 def FirstAttrDerivedResultType :
   GenInternalOpTrait<"FirstAttrDerivedResultType">;
 
+// All variadic operands of the op have the same number of values.
+// A variadic operand contains an array of values whose array size is only
+// known at runtime. This trait requires all variadic operands of an op
+// to have the same array size.
+def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">;
+// All variadic results of the op have the same number of values.
+// A variadic result contains an array of values whose array size is only
+// known at runtime. This trait requires all variadic results of an op
+// to have the same array size.
+def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
+
 //===----------------------------------------------------------------------===//
 // Op definitions
 //===----------------------------------------------------------------------===//
index a37bc70..8390939 100644 (file)
@@ -48,10 +48,12 @@ struct NamedAttribute {
   Attribute attr;
 };
 
-// A struct wrapping an op operand/result and its name together
+// A struct wrapping an op operand/result's constraint and its name together
 struct NamedTypeConstraint {
-  // Returns true if this operand has constraint that need to be satisfied.
+  // Returns true if this operand/result has constraint to be satisfied.
   bool hasPredicate() const;
+  // Returns true if this operand/result is variadic.
+  bool isVariadic() const;
 
   llvm::StringRef name;
   TypeConstraint constraint;
index 7233626..e96b29c 100644 (file)
@@ -82,8 +82,8 @@ public:
   // Returns the `index`-th result's name.
   StringRef getResultName(int index) const;
 
-  // Returns true if this operation has a variadic result.
-  bool hasVariadicResult() const;
+  // Returns the number of variadic results in this operation.
+  unsigned getNumVariadicResults() const;
 
   // Op attribute interators.
   using attribute_iterator = const NamedAttribute *;
@@ -112,8 +112,8 @@ public:
     return operands[index];
   }
 
-  // Returns true if this operation has a variadic operand.
-  bool hasVariadicOperand() const;
+  // Returns the number of variadic operands in this operation.
+  unsigned getNumVariadicOperands() const;
 
   // Returns the total number of arguments.
   int getNumArgs() const { return arguments.size(); }
index f7af63f..7432e0f 100644 (file)
@@ -23,3 +23,7 @@ using namespace mlir;
 bool tblgen::NamedTypeConstraint::hasPredicate() const {
   return !constraint.getPredicate().isNull();
 }
+
+bool tblgen::NamedTypeConstraint::isVariadic() const {
+  return constraint.isVariadic();
+}
index 6f0f3ea..3854728 100644 (file)
@@ -82,8 +82,10 @@ StringRef tblgen::Operator::getResultName(int index) const {
   return results->getArgNameStr(index);
 }
 
-bool tblgen::Operator::hasVariadicResult() const {
-  return !results.empty() && results.back().constraint.isVariadic();
+unsigned tblgen::Operator::getNumVariadicResults() const {
+  return std::count_if(
+      results.begin(), results.end(),
+      [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
 }
 
 int tblgen::Operator::getNumNativeAttributes() const {
@@ -98,8 +100,10 @@ const tblgen::NamedAttribute &tblgen::Operator::getAttribute(int index) const {
   return attributes[index];
 }
 
-bool tblgen::Operator::hasVariadicOperand() const {
-  return !operands.empty() && operands.back().constraint.isVariadic();
+unsigned tblgen::Operator::getNumVariadicOperands() const {
+  return std::count_if(
+      operands.begin(), operands.end(),
+      [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
 }
 
 StringRef tblgen::Operator::getArgName(int index) const {
@@ -222,13 +226,6 @@ void tblgen::Operator::populateOpStructure() {
     }
   }
 
-  // Verify that only the last operand can be variadic.
-  for (int i = 0, e = operands.size() - 1; i < e; ++i) {
-    if (operands[i].constraint.isVariadic())
-      PrintFatalError(def.getLoc(),
-                      "only the last operand allowed to be variadic");
-  }
-
   auto *resultsDag = def.getValueAsDag("results");
   auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
   if (!outsOp || outsOp->getDef()->getName() != "outs") {
@@ -246,13 +243,6 @@ void tblgen::Operator::populateOpStructure() {
     results.push_back({name, TypeConstraint(resultDef)});
   }
 
-  // Verify that only the last result can be variadic.
-  for (int i = 0, e = results.size() - 1; i < e; ++i) {
-    if (results[i].constraint.isVariadic())
-      PrintFatalError(def.getLoc(),
-                      "only the last result allowed to be variadic");
-  }
-
   auto traitListInit = def.getValueAsListInit("traits");
   if (!traitListInit)
     return;
diff --git a/mlir/test/mlir-tblgen/op-builder.td b/mlir/test/mlir-tblgen/op-builder.td
deleted file mode 100644 (file)
index 7be9018..0000000
+++ /dev/null
@@ -1,28 +0,0 @@
-// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
-
-include "mlir/IR/OpBase.td"
-
-
-def NS_OpA : Op<"op_same_value_type", [SameValueType]> {
-  let arguments = (ins Tensor:$input);
-  let results = (outs Tensor:$result);
-}
-
-// Test that with SameValueType trait we can generate a builder without
-// requiring result type
-// ---
-
-// CHECK-LABEL: OpA::build(Builder *, OperationState *tblgen_state, Value *input)
-// CHECK: tblgen_state->addTypes({input->getType()});
-
-def NS_OpB : Op<"op_same_value_type_variadic_input", [SameValueType]> {
-  let arguments = (ins Variadic<Tensor>:$input);
-  let results = (outs Tensor:$result);
-}
-
-// Test that if the only operand is variadic, we acess the first value in the
-// pack to set result type
-// ---
-
-// CHECK-LABEL: OpB::build(Builder *, OperationState *tblgen_state, ArrayRef<Value *> input)
-// CHECK: tblgen_state->addTypes({input.front()->getType()});
index 132d382..936fb7c 100644 (file)
@@ -2,7 +2,7 @@
 
 include "mlir/IR/OpBase.td"
 
-def OpA : Op<"one_operand_op", []> {
+def OpA : Op<"one_normal_operand_op", []> {
   let arguments = (ins I32:$input);
 }
 
@@ -10,7 +10,7 @@ def OpA : Op<"one_operand_op", []> {
 
 // CHECK:      void OpA::build
 // CHECK-SAME:   Value *input
-// CHECK:        tblgen_state->addOperands({input});
+// CHECK:        tblgen_state->operands.push_back(input);
 
 // CHECK:      void OpA::build
 // CHECK-SAME:   ArrayRef<Value *> operands
@@ -21,11 +21,72 @@ def OpA : Op<"one_operand_op", []> {
 // CHECK:        if (!((this->getOperation()->getOperand(0)->getType().isInteger(32))))
 // CHECK-NEXT:     return emitOpError("operand #0 must be 32-bit integer");
 
-def OpB : Op<"variadic_operand_op", []> {
+def OpB : Op<"one_variadic_operand_op", []> {
   let arguments = (ins Variadic<I32>:$input);
 }
 
 // CHECK-LABEL: OpB::build
-// CHECK-SAME:  ArrayRef<Value *> input
-// CHECK-NOT:   assert
-// CHECK:       tblgen_state->addOperands(input);
+// CHECK-SAME:    ArrayRef<Value *> input
+// CHECK-NOT:     assert
+// CHECK:         tblgen_state->addOperands(input);
+
+def OpC : Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> {
+  let arguments = (ins Variadic<Tensor>:$input1, Variadic<Tensor>:$input2);
+}
+
+// CHECK-LABEL: Operation::operand_range OpC::input1()
+// CHECK-NEXT:    unsigned variadicOperandSize = (this->getNumOperands() - 0) / 2;
+// CHECK-NEXT:    unsigned offset = 0 + variadicOperandSize * 0;
+// CHECK-NEXT:    return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
+
+// CHECK-LABEL: Operation::operand_range OpC::input2()
+// CHECK-NEXT:    unsigned variadicOperandSize = (this->getNumOperands() - 0) / 2;
+// CHECK-NEXT:    unsigned offset = 0 + variadicOperandSize * 1;
+// CHECK-NEXT:    return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
+
+// CHECK-LABEL: OpC::build
+// CHECK-NEXT: tblgen_state->addOperands(input1);
+// CHECK-NEXT: tblgen_state->addOperands(input2);
+
+def OpD : Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> {
+  let arguments = (ins Variadic<Tensor>:$input1, Tensor:$input2, Variadic<Tensor>:$input3);
+}
+
+// CHECK-LABEL: Operation::operand_range OpD::input1()
+// CHECK-NEXT:    unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2;
+// CHECK-NEXT:    unsigned offset = 0 + variadicOperandSize * 0;
+// CHECK-NEXT:    return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
+
+// CHECK-LABEL: Value *OpD::input2()
+// CHECK-NEXT:    unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2;
+// CHECK-NEXT:    unsigned offset = 0 + variadicOperandSize * 1;
+// CHECK-NEXT:    return this->getOperand(offset);
+
+// CHECK-LABEL: Operation::operand_range OpD::input3()
+// CHECK-NEXT:    unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2;
+// CHECK-NEXT:    unsigned offset = 1 + variadicOperandSize * 1;
+// CHECK-NEXT:    return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
+
+// CHECK-LABEL: OpD::build
+// CHECK-NEXT: tblgen_state->addOperands(input1);
+// CHECK-NEXT: tblgen_state->operands.push_back(input2);
+// CHECK-NEXT: tblgen_state->addOperands(input3);
+
+def OpE : Op<"one_variadic_among_multi_normal_inputs_op", []> {
+  let arguments = (ins Tensor:$input1, Tensor:$input2, Variadic<Tensor>:$input3, Tensor:$input4, Tensor:$input5);
+}
+
+// CHECK-LABEL: Value *OpE::input1()
+// CHECK-NEXT:    return this->getOperation()->getOperand(0);
+
+// CHECK-LABEL: Value *OpE::input2()
+// CHECK-NEXT:    return this->getOperation()->getOperand(1);
+
+// CHECK-LABEL: Operation::operand_range OpE::input3()
+// CHECK-NEXT:    return {std::next(operand_begin(), 2), std::next(operand_begin(), 2 + this->getNumOperands() - 4)};
+
+// CHECK-LABEL: Value *OpE::input4()
+// CHECK-NEXT:    return this->getOperation()->getOperand(this->getNumOperands() - 2);
+
+// CHECK-LABEL: Value *OpE::input5()
+// CHECK-NEXT:    return this->getOperation()->getOperand(this->getNumOperands() - 1);
index 06bc3b4..714c62c 100644 (file)
 
 include "mlir/IR/OpBase.td"
 
-def OneResultOp : Op<"one_result_op", []> {
+def OpA : Op<"one_normal_result_op", []> {
   let results = (outs I32:$result);
 }
 
-// CHECK-LABEL: OneResultOp definitions
+// CHECK-LABEL: Value *OpA::result()
+// CHECK-NEXT:    return this->getOperation()->getResult(0)
 
-// CHECK:       void OneResultOp::build
+// CHECK-LABEL: void OpA::build
 // CHECK:         ArrayRef<Type> resultTypes, ArrayRef<Value *> operands
 // CHECK:         assert(resultTypes.size() == 1u && "mismatched number of return types");
 // CHECK-NEXT:    tblgen_state->addTypes(resultTypes);
 
-// CHECK:       LogicalResult OneResultOp::verify() {
+// CHECK-LABEL: LogicalResult OpA::verify()
 // CHECK:         if (!((this->getOperation()->getResult(0)->getType().isInteger(32))))
 // CHECK-NEXT:      return emitOpError("result #0 must be 32-bit integer");
 
-
-def SameTypeOp : Op<"same_type_op", [SameValueType]> {
+def OpB : Op<"same_input_output_type_op", [SameValueType]> {
   let arguments = (ins I32:$x);
   let results = (outs I32:$y);
 }
 
-// CHECK-LABEL: SameTypeOp definitions
-// CHECK: void SameTypeOp::build(Builder *, OperationState *tblgen_state, Type y, Value *x)
-// CHECK:   tblgen_state->addTypes({y});
-// CHECK: void SameTypeOp::build(Builder *, OperationState *tblgen_state, Value *x)
+// CHECK-LABEL: OpB definitions
+// CHECK: void OpB::build(Builder *, OperationState *tblgen_state, Type y, Value *x)
+// CHECK:   tblgen_state->types.push_back(y);
+// CHECK: void OpB::build(Builder *, OperationState *tblgen_state, Value *x)
 // CHECK:   tblgen_state->addTypes({x->getType()});
 
-def ThreeResultOp : Op<"three_result_op", []> {
+def OpC : Op<"three_normal_result_op", []> {
   let results = (outs I32:$x, /*unnamed*/I32, I32:$z);
 }
 
-// CHECK-LABEL: ThreeResultOp definitions
-// CHECK: void ThreeResultOp::build(Builder *, OperationState *tblgen_state, Type x, Type resultType1, Type z)
-// CHECK:   tblgen_state->addTypes({x, resultType1, z});
+// CHECK-LABEL: OpC definitions
+// CHECK:       void OpC::build(Builder *, OperationState *tblgen_state, Type x, Type resultType1, Type z)
+// CHECK-NEXT:   tblgen_state->types.push_back(x)
+// CHECK-NEXT:   tblgen_state->types.push_back(resultType1)
+// CHECK-NEXT:   tblgen_state->types.push_back(z)
 
 def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">;
-def TypeAttrResultTypeOp : Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
+def OpD : Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
   let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32);
   let results = (outs Tensor:$y);
 }
 
-// CHECK-LABEL: TypeAttrResultTypeOp definitions
-// CHECK: void TypeAttrResultTypeOp::build(Builder *, OperationState *tblgen_state, Value *x, TypeAttr attr, FloatAttr f32)
+// CHECK-LABEL: OpD definitions
+// CHECK: void OpD::build(Builder *, OperationState *tblgen_state, Value *x, TypeAttr attr, FloatAttr f32)
 // CHECK: tblgen_state->addTypes({attr.getValue()});
 
-def ValueAttrResultTypeOp : Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
+def OpE : Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
   let arguments = (ins I32:$x, F32Attr:$attr);
   let results = (outs Tensor:$y);
 }
 
-// CHECK-LABEL: ValueAttrResultTypeOp definitions
-// CHECK: void ValueAttrResultTypeOp::build(Builder *, OperationState *tblgen_state, Value *x, FloatAttr attr)
+// CHECK-LABEL: OpE definitions
+// CHECK: void OpE::build(Builder *, OperationState *tblgen_state, Value *x, FloatAttr attr)
 // CHECK: tblgen_state->addTypes({attr.getType()});
 
-def VariadicResultAloneOp : Op<"variadic_alone_op", []> {
+def OpF : Op<"one_variadic_result_op", []> {
   let results = (outs Variadic<I32>:$x);
 }
 
-// CHECK-LABEL: VariadicResultAloneOp definitions
+// CHECK-LABEL: Operation::result_range OpF::x()
+// CHECK-NEXT:    return {std::next(result_begin(), 0), std::next(result_begin(), 0 + this->getNumResults() - 0)};
+
+// CHECK-LABEL: void OpF::build
+// CHECK-SAME:    ArrayRef<Type> x
+// CHECK-NOT:     assert
+// CHECK:         tblgen_state->addTypes(x);
 
-// CHECK-LABEL: void VariadicResultAloneOp::build
-// CHECK-SAME:  ArrayRef<Type> x
-// CHECK-NOT:   assert
-// CHECK:       tblgen_state->addTypes(x);
+def OpG : Op<"one_normal_and_one_variadic_result_op", []> {
 
-def VariadicResultOp : Op<"variadic_op", []> {
   let results = (outs I32:$x, Variadic<I32>:$y);
 }
 
-// CHECK-LABEL: VariadicResultOp definitions
+// CHECK-LABEL: OpG definitions
 
-// CHECK: void VariadicResultOp::build(Builder *, OperationState *tblgen_state, Type x, ArrayRef<Type> y)
-// CHECK:   tblgen_state->addTypes({x});
-// CHECK:   tblgen_state->addTypes(y);
+// CHECK:      void OpG::build(Builder *, OperationState *tblgen_state, Type x, ArrayRef<Type> y)
+// CHECK-NEXT:   tblgen_state->types.push_back(x);
+// CHECK-NEXT:   tblgen_state->addTypes(y);
 
-// CHECK:                      void VariadicResultOp::build
+// CHECK:                      void OpG::build
 // CHECK:         ArrayRef<Type> resultTypes
 // CHECK:         assert(resultTypes.size() >= 1u && "mismatched number of return types");
 // CHECK-NEXT:    tblgen_state->addTypes(resultTypes);
+
+
+def OpH : Op<"all_variadic_results_op", [SameVariadicResultSize]> {
+  let results = (outs Variadic<Tensor>:$output1, Variadic<Tensor>:$output2);
+}
+
+// CHECK-LABEL: Operation::result_range OpH::output1()
+// CHECK-NEXT:      unsigned variadicResultSize = (this->getNumResults() - 0) / 2;
+// CHECK-NEXT:      unsigned offset = 0 + variadicResultSize * 0;
+// CHECK-NEXT:      return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
+
+// CHECK-LABEL: Operation::result_range OpH::output2()
+// CHECK-NEXT:      unsigned variadicResultSize = (this->getNumResults() - 0) / 2;
+// CHECK-NEXT:      unsigned offset = 0 + variadicResultSize * 1;
+// CHECK-NEXT:      return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
+
+
+// CHECK-LABEL: OpH::build
+// CHECK-NEXT:    tblgen_state->addTypes(output1);
+// CHECK-NEXT:    tblgen_state->addTypes(output2);
+
+def OpI : Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> {
+  let results = (outs Variadic<Tensor>:$output1, Tensor:$output2, Variadic<Tensor>:$output3);
+}
+
+// CHECK-LABEL: Operation::result_range OpI::output1()
+// CHECK-NEXT:    unsigned variadicResultSize = (this->getNumResults() - 1) / 2;
+// CHECK-NEXT:    unsigned offset = 0 + variadicResultSize * 0;
+// CHECK-NEXT:    return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
+
+// CHECK-LABEL: Value *OpI::output2()
+// CHECK-NEXT:    unsigned variadicResultSize = (this->getNumResults() - 1) / 2;
+// CHECK-NEXT:    unsigned offset = 0 + variadicResultSize * 1;
+// CHECK-NEXT:    return this->getResult(offset);
+
+// CHECK-LABEL: Operation::result_range OpI::output3()
+// CHECK-NEXT:    unsigned variadicResultSize = (this->getNumResults() - 1) / 2;
+// CHECK-NEXT:    unsigned offset = 1 + variadicResultSize * 1;
+// CHECK-NEXT:    return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
+
+// CHECK-LABEL: OpI::build
+// CHECK-NEXT:    tblgen_state->addTypes(output1);
+// CHECK-NEXT:    tblgen_state->types.push_back(output2);
+// CHECK-NEXT:    tblgen_state->addTypes(output3);
+
+def OpJ : Op<"one_variadic_among_multi_normal_results_op", []> {
+  let results = (outs Tensor:$output1, Tensor:$output2, Variadic<Tensor>:$output3, Tensor:$output4, Tensor:$output5);
+}
+
+// CHECK-LABEL: Value *OpJ::output1()
+// CHECK-NEXT:    return this->getOperation()->getResult(0);
+
+// CHECK-LABEL: Value *OpJ::output2()
+// CHECK-NEXT:    return this->getOperation()->getResult(1);
+
+// CHECK-LABEL: Operation::result_range OpJ::output3()
+// CHECK-NEXT:    return {std::next(result_begin(), 2), std::next(result_begin(), 2 + this->getNumResults() - 4)};
+
+// CHECK-LABEL: Value *OpJ::output4()
+// CHECK-NEXT:    return this->getOperation()->getResult(this->getNumResults() - 2);
+
+// CHECK-LABEL: Value *OpJ::output5()
+// CHECK-NEXT:    return this->getOperation()->getResult(this->getNumResults() - 1);
+
+// Test that if the only operand is variadic, we acess the first value in the
+// pack to set result type
+// ---
+def OpK : Op<"only_input_is_variadic_with_same_value_type_op", [SameValueType]> {
+  let arguments = (ins Variadic<Tensor>:$input);
+  let results = (outs Tensor:$result);
+}
+
+// CHECK-LABEL: OpK::build(Builder *, OperationState *tblgen_state, ArrayRef<Value *> input)
+// CHECK: tblgen_state->addTypes({input.front()->getType()});
index 89bbeeb..794001d 100644 (file)
@@ -75,10 +75,14 @@ static StringLoc findNextVariable(StringRef str) {
   return {startPos, endPos - startPos};
 }
 
-// Check if `name` is the name of the variadic argument of `op`.  The variadic
-// argument can only appear at the last position in the list of arguments.
-static bool isVariadicArgumentName(const tblgen::Operator &op, StringRef name) {
-  return op.hasVariadicOperand() && op.getArgName(op.getNumArgs() - 1) == name;
+// Check if `name` is the name of the variadic operand of `op`.  The variadic
+// operand can only appear at the last position in the list of operands.
+static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
+  unsigned numOperands = op.getNumOperands();
+  if (numOperands == 0)
+    return false;
+  const auto &operand = op.getOperand(numOperands - 1);
+  return operand.isVariadic() && operand.name == name;
 }
 
 // Check if `result` is a known name of a result of `op`.
@@ -127,9 +131,9 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) {
     // First, insert the non-matched part as is.
     bs << builderStrRef.substr(0, loc.pos);
     // Then, rewrite the name based on its kind.
-    bool isVariadicArg = isVariadicArgumentName(op, name);
+    bool isVariadicOperand = isVariadicOperandName(op, name);
     if (isOperandName(op, name)) {
-      auto result = isVariadicArg
+      auto result = isVariadicOperand
                         ? formatv("lookupValues(op.{0}())", name)
                         : formatv("valueMapping.lookup(op.{0}())", name);
       bs << result;
index dd68db3..e9be2b2 100644 (file)
@@ -251,8 +251,9 @@ OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
 }
 
 void OpMethodBody::writeTo(raw_ostream &os) const {
-  os << body;
-  if (body.empty() || body.back() != '\n')
+  auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
+  os << bodyRef;
+  if (bodyRef.empty() || bodyRef.back() != '\n')
     os << "\n";
 }
 
@@ -455,35 +456,153 @@ void OpEmitter::genAttrGetters() {
 }
 
 void OpEmitter::genNamedOperandGetters() {
-  for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
+  const unsigned numOperands = op.getNumOperands();
+  const unsigned numVariadicOperands = op.getNumVariadicOperands();
+  const unsigned numNormalOperands = numOperands - numVariadicOperands;
+
+  // Special case for ops without variadic operands: the i-th value is for the
+  // i-th operand defined in the op.
+  // Special case for ops with one variadic operand: the variadic operand can
+  // appear at any place, so the i-th value may not necessarily belong to the
+  // i-th operand definition. we need to calculate the index (range) for each
+  // operand.
+  if (numVariadicOperands <= 1) {
+    bool emittedVariadicOperand = false;
+    for (unsigned i = 0; i != numOperands; ++i) {
+      const auto &operand = op.getOperand(i);
+      if (operand.name.empty())
+        continue;
+
+      if (operand.isVariadic()) {
+        auto &m = opClass.newMethod("Operation::operand_range", operand.name);
+        m.body() << formatv(
+            "  return {{std::next(operand_begin(), {0}), "
+            "std::next(operand_begin(), {0} + this->getNumOperands() - {1})};",
+            i, numNormalOperands);
+        emittedVariadicOperand = true;
+      } else {
+        auto &m = opClass.newMethod("Value *", operand.name);
+        m.body() << "  return this->getOperation()->getOperand(";
+        if (emittedVariadicOperand)
+          m.body() << "this->getNumOperands() - " << numOperands - i;
+        else
+          m.body() << i;
+        m.body() << ");\n";
+      }
+    }
+    return;
+  }
+
+  // If we have more than one variadic operands, we need more complicated logic
+  // to calculate the value range for each operand.
+
+  if (!op.hasTrait("SameVariadicOperandSize")) {
+    PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
+                                 "specification over their sizes");
+  }
+
+  unsigned emittedNormalOperands = 0;
+  unsigned emittedVariadicOperands = 0;
+
+  for (unsigned i = 0; i != numOperands; ++i) {
     const auto &operand = op.getOperand(i);
     if (operand.name.empty())
       continue;
 
-    if (!operand.constraint.isVariadic()) {
-      auto &m = opClass.newMethod("Value *", operand.name);
-      m.body() << "  return this->getOperation()->getOperand(" << i << ");\n";
-    } else {
-      assert(i + 1 == e && "only the last operand can be variadic");
+    const char *code = R"(
+  unsigned variadicOperandSize = (this->getNumOperands() - {0}) / {1};
+  unsigned offset = {2} + variadicOperandSize * {3};
+  return )";
+    auto sizeAndOffset =
+        formatv(code, numNormalOperands, numVariadicOperands,
+                emittedNormalOperands, emittedVariadicOperands);
 
-      const char *const code = R"(
-        assert(getOperation()->getNumOperands() >= {0});
-        return {std::next(operand_begin(), {0}), operand_end()};
-      )";
+    if (operand.isVariadic()) {
       auto &m = opClass.newMethod("Operation::operand_range", operand.name);
-      m.body() << formatv(code, i);
+      m.body() << sizeAndOffset
+               << "{std::next(operand_begin(), offset), "
+                  "std::next(operand_begin(), offset + variadicOperandSize)};";
+      ++emittedVariadicOperands;
+    } else {
+      auto &m = opClass.newMethod("Value *", operand.name);
+      m.body() << sizeAndOffset << "this->getOperand(offset);";
+      ++emittedNormalOperands;
     }
   }
 }
 
 void OpEmitter::genNamedResultGetters() {
-  for (int i = 0, e = op.getNumResults(); i != e; ++i) {
+  const unsigned numResults = op.getNumResults();
+  const unsigned numVariadicResults = op.getNumVariadicResults();
+  const unsigned numNormalResults = numResults - numVariadicResults;
+
+  // Special case for ops without variadic results: the i-th value is for the
+  // i-th result defined in the op.
+  // Special case for ops with one variadic result: the variadic result can
+  // appear at any place, so the i-th value may not necessarily belong to the
+  // i-th result definition. we need to calculate the index (range) for each
+  // result.
+  if (numVariadicResults <= 1) {
+    bool emittedVariadicResult = false;
+    for (unsigned i = 0; i != numResults; ++i) {
+      const auto &result = op.getResult(i);
+      if (result.name.empty())
+        continue;
+
+      if (result.isVariadic()) {
+        auto &m = opClass.newMethod("Operation::result_range", result.name);
+        m.body() << formatv(
+            "  return {{std::next(result_begin(), {0}), "
+            "std::next(result_begin(), {0} + this->getNumResults() - {1})};",
+            i, numNormalResults);
+        emittedVariadicResult = true;
+      } else {
+        auto &m = opClass.newMethod("Value *", result.name);
+        m.body() << "  return this->getOperation()->getResult(";
+        if (emittedVariadicResult)
+          m.body() << "this->getNumResults() - " << numResults - i;
+        else
+          m.body() << i;
+        m.body() << ");\n";
+      }
+    }
+    return;
+  }
+
+  // If we have more than one variadic results, we need more complicated logic
+  // to calculate the value range for each result.
+
+  if (!op.hasTrait("SameVariadicResultSize")) {
+    PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
+                                 "specification over their sizes");
+  }
+
+  unsigned emittedNormalResults = 0;
+  unsigned emittedVariadicResults = 0;
+
+  for (unsigned i = 0; i != numResults; ++i) {
     const auto &result = op.getResult(i);
-    if (result.constraint.isVariadic() || result.name.empty())
+    if (result.name.empty())
       continue;
 
-    auto &m = opClass.newMethod("Value *", result.name);
-    m.body() << "  return this->getOperation()->getResult(" << i << ");\n";
+    const char *code = R"(
+  unsigned variadicResultSize = (this->getNumResults() - {0}) / {1};
+  unsigned offset = {2} + variadicResultSize * {3};
+  return )";
+    auto sizeAndOffset = formatv(code, numNormalResults, numVariadicResults,
+                                 emittedNormalResults, emittedVariadicResults);
+
+    if (result.isVariadic()) {
+      auto &m = opClass.newMethod("Operation::result_range", result.name);
+      m.body() << sizeAndOffset
+               << "{std::next(result_begin(), offset), "
+                  "std::next(result_begin(), offset + variadicResultSize)};";
+      ++emittedVariadicResults;
+    } else {
+      auto &m = opClass.newMethod("Value *", result.name);
+      m.body() << sizeAndOffset << "this->getResult(offset);";
+      ++emittedNormalResults;
+    }
   }
 }
 
@@ -505,12 +624,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
   // Emit parameters for all return types
   if (!useOperandType && !useAttrType) {
     for (unsigned i = 0; i != numResults; ++i) {
-      std::string resultName = op.getResultName(i);
+      const auto &result = op.getResult(i);
+      std::string resultName = result.name;
       if (resultName.empty())
         resultName = formatv("resultType{0}", i);
 
-      bool isVariadic = op.getResultTypeConstraint(i).isVariadic();
-      paramList.append(isVariadic ? ", ArrayRef<Type> " : ", Type ");
+      paramList.append(result.isVariadic() ? ", ArrayRef<Type> " : ", Type ");
       paramList.append(resultName);
 
       resultNames.emplace_back(std::move(resultName));
@@ -520,12 +639,13 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
   // Emit parameters for all arguments (operands and attributes).
   int numOperands = 0;
   int numAttrs = 0;
+
   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
     auto argument = op.getArg(i);
     if (argument.is<tblgen::NamedTypeConstraint *>()) {
-      auto &operand = op.getOperand(numOperands);
-      paramList.append(operand.constraint.isVariadic() ? ", ArrayRef<Value *> "
-                                                       : ", Value *");
+      const auto &operand = op.getOperand(numOperands);
+      paramList.append(operand.isVariadic() ? ", ArrayRef<Value *> "
+                                            : ", Value *");
       paramList.append(getArgumentName(op, numOperands));
       ++numOperands;
     } else {
@@ -542,33 +662,22 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
   }
 
   if (numOperands + numAttrs != op.getNumArgs())
-    return PrintFatalError(
-        "op arguments must be either operands or attributes");
+    PrintFatalError("op arguments must be either operands or attributes");
 
-  auto &method =
-      opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
-
-  bool hasVariadicOperand = op.hasVariadicOperand();
+  auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
 
   // Push all result types to the result
   if (numResults > 0) {
     if (!useOperandType && !useAttrType) {
-      bool hasVariadicResult = op.hasVariadicResult();
-      int numNonVariadicResults =
-          numResults - static_cast<int>(hasVariadicResult);
-
-      if (numNonVariadicResults > 0) {
-        method.body() << "  " << builderOpState << "->addTypes({"
-                      << resultNames.front();
-        for (int i = 1; i < numNonVariadicResults; ++i) {
-          method.body() << ", " << resultNames[i];
+      for (unsigned i = 0; i < numResults; ++i) {
+        const auto &result = op.getResult(i);
+        m.body() << "  " << builderOpState;
+        if (result.isVariadic()) {
+          m.body() << "->addTypes(";
+        } else {
+          m.body() << "->types.push_back(";
         }
-        method.body() << "});\n";
-      }
-
-      if (hasVariadicResult) {
-        method.body() << "  " << builderOpState << "->addTypes("
-                      << resultNames.back() << ");\n";
+        m.body() << resultNames[i] << ");\n";
       }
     } else {
       std::string resultType;
@@ -580,32 +689,27 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
           resultType = formatv("{0}.getType()", namedAttr.name);
         }
       } else {
-        const char *index =
-            (numOperands == 1 && hasVariadicOperand) ? ".front()" : "";
+        const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
         resultType =
             formatv("{0}{1}->getType()", getArgumentName(op, 0), index).str();
       }
-      method.body() << "  " << builderOpState << "->addTypes({" << resultType;
+      m.body() << "  " << builderOpState << "->addTypes({" << resultType;
       for (unsigned i = 1; i != numResults; ++i)
-        method.body() << ", " << resultType;
-      method.body() << "});\n\n";
+        m.body() << ", " << resultType;
+      m.body() << "});\n\n";
     }
   }
 
   // Push all operands to the result
-  int numNonVariadicOperands =
-      numOperands - static_cast<int>(hasVariadicOperand);
-  if (numNonVariadicOperands > 0) {
-    method.body() << "  " << builderOpState << "->addOperands({"
-                  << getArgumentName(op, 0);
-    for (int i = 1; i < numNonVariadicOperands; ++i) {
-      method.body() << ", " << getArgumentName(op, i);
+  for (unsigned i = 0; i < numOperands; ++i) {
+    const auto &operand = op.getOperand(i);
+    m.body() << "  " << builderOpState;
+    if (operand.isVariadic()) {
+      m.body() << "->addOperands(";
+    } else {
+      m.body() << "->operands.push_back(";
     }
-    method.body() << "});\n";
-  }
-  if (hasVariadicOperand) {
-    method.body() << "  " << builderOpState << "->addOperands("
-                  << getArgumentName(op, numOperands - 1) << ");\n";
+    m.body() << getArgumentName(op, i) << ");\n";
   }
 
   // Push all attributes to the result
@@ -613,12 +717,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
     if (!namedAttr.attr.isDerivedAttr()) {
       bool emitNotNullCheck = namedAttr.attr.isOptional();
       if (emitNotNullCheck) {
-        method.body() << formatv("  if ({0}) ", namedAttr.name) << "{\n";
+        m.body() << formatv("  if ({0}) ", namedAttr.name) << "{\n";
       }
-      method.body() << formatv("  {0}->addAttribute(\"{1}\", {1});\n",
-                               builderOpState, namedAttr.name);
+      m.body() << formatv("  {0}->addAttribute(\"{1}\", {1});\n",
+                          builderOpState, namedAttr.name);
       if (emitNotNullCheck) {
-        method.body() << "  }\n";
+        m.body() << "  }\n";
       }
     }
   }
@@ -646,13 +750,13 @@ void OpEmitter::genBuilder() {
     }
   }
 
-  auto numResults = op.getNumResults();
-  bool hasVariadicResult = op.hasVariadicResult();
-  int numNonVariadicResults = numResults - int(hasVariadicResult);
+  unsigned numResults = op.getNumResults();
+  unsigned numVariadicResults = op.getNumVariadicResults();
+  unsigned numNonVariadicResults = numResults - numVariadicResults;
 
-  auto numOperands = op.getNumOperands();
-  bool hasVariadicOperand = op.hasVariadicOperand();
-  int numNonVariadicOperands = numOperands - int(hasVariadicOperand);
+  unsigned numOperands = op.getNumOperands();
+  unsigned numVariadicOperands = op.getNumVariadicOperands();
+  unsigned numNonVariadicOperands = numOperands - numVariadicOperands;
 
   // Generate default builders that requires all result type, operands, and
   // attributes as parameters.
@@ -681,15 +785,16 @@ void OpEmitter::genBuilder() {
   auto &body = m.body();
 
   // Result types
-  if (!(hasVariadicResult && numNonVariadicResults == 0))
+  if (numVariadicResults == 0 || numNonVariadicResults != 0)
     body << "  assert(resultTypes.size()"
-         << (hasVariadicResult ? " >= " : " == ") << numNonVariadicResults
+         << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
          << "u && \"mismatched number of return types\");\n";
   body << "  " << builderOpState << "->addTypes(resultTypes);\n";
 
   // Operands
-  if (!(hasVariadicOperand && numNonVariadicOperands == 0))
-    body << "  assert(operands.size()" << (hasVariadicOperand ? " >= " : " == ")
+  if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
+    body << "  assert(operands.size()"
+         << (numVariadicOperands != 0 ? " >= " : " == ")
          << numNonVariadicOperands
          << "u && \"mismatched number of parameters\");\n";
   body << "  " << builderOpState << "->addOperands(operands);\n\n";
@@ -703,7 +808,7 @@ void OpEmitter::genBuilder() {
 
   bool useOperandType = op.hasTrait("SameOperandsAndResultType");
   bool useAttrType = op.hasTrait("FirstAttrDerivedResultType");
-  if (!op.hasVariadicResult() && (useOperandType || useAttrType))
+  if (numVariadicResults == 0 && (useOperandType || useAttrType))
     genStandaloneParamBuilder(useOperandType, useAttrType);
 }
 
@@ -824,7 +929,7 @@ void OpEmitter::genVerifier() {
   auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index,
                          bool isOperand) -> void {
     // TODO: Handle variadic operand/result verification.
-    if (value.constraint.isVariadic())
+    if (value.isVariadic())
       return;
 
     // TODO: Commonality between matchers could be extracted to have a more
@@ -869,12 +974,12 @@ void OpEmitter::genVerifier() {
 }
 
 void OpEmitter::genTraits() {
-  auto numResults = op.getNumResults();
-  bool hasVariadicResult = op.hasVariadicResult();
+  unsigned numResults = op.getNumResults();
+  unsigned numVariadicResults = op.getNumVariadicResults();
 
   // Add return size trait.
-  if (hasVariadicResult) {
-    if (numResults == 1)
+  if (numVariadicResults != 0) {
+    if (numResults == numVariadicResults)
       opClass.addTrait("VariadicResults");
     else
       opClass.addTrait("AtLeastNResults<" + Twine(numResults - 1) + ">::Impl");
@@ -898,12 +1003,12 @@ void OpEmitter::genTraits() {
   }
 
   // Add variadic size trait and normal op traits.
-  auto numOperands = op.getNumOperands();
-  bool hasVariadicOperand = op.hasVariadicOperand();
+  unsigned numOperands = op.getNumOperands();
+  unsigned numVariadicOperands = op.getNumVariadicOperands();
 
   // Add operand size trait.
-  if (hasVariadicOperand) {
-    if (numOperands == 1)
+  if (numVariadicOperands != 0) {
+    if (numOperands == numVariadicOperands)
       opClass.addTrait("VariadicOperands");
     else
       opClass.addTrait("AtLeastNOperands<" + Twine(numOperands - 1) +
index dad3035..501b7a1 100644 (file)
@@ -440,7 +440,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
   const Operator &rootOp = pattern.getSourceRootOp();
   auto rootName = rootOp.getOperationName();
 
-  if (rootOp.hasVariadicResult())
+  if (rootOp.getNumVariadicResults() != 0)
     PrintFatalError(
         loc, "replacing op with variadic results not supported right now");