Use aggregate-parameter builder for ops having autogen type-deduction builder
authorLei Zhang <antiagainst@google.com>
Fri, 15 Nov 2019 15:33:21 +0000 (07:33 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 15 Nov 2019 15:33:54 +0000 (07:33 -0800)
Thus far DRR always invokes the separate-parameter builder (i.e., requiring
a separate parameter for each result-type/operand/attribute) for creating
ops, no matter whether we can auto-generate a builder with type-deduction
ability or not.

This CL changes the path for ops that we can auto-generate type-deduction
builders, i.e., with SameOperandsAndResultType/FirstAttrDerivedResultType
traits. Now they are going through a aggregate-parameter builder (i.e.,
requiring one parameter for all result-types/operands/attributes).
attributes.)

It is expected this approach will be more friendly for future shape inference
function autogen and calling those autogen'd shape inference function without
excessive packing and repacking operand/attribute lists.
Also, it would enable better support for creating ops with optional attributes
because we are not required to provide an Attribute() as placeholder for
an optional attribute anymore.

PiperOrigin-RevId: 280654800

mlir/g3doc/DeclarativeRewrites.md
mlir/test/mlir-tblgen/op-attribute.td
mlir/test/mlir-tblgen/op-result.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp

index 94eb841..5c344a1 100644 (file)
@@ -224,10 +224,11 @@ In the above, we are using `BOp`'s result for building `COp`.
 #### Building operations
 
 Given that `COp` was specified with table-driven op definition, there will be
-several `build()` methods generated for it. One of them has a separate argument
-in the signature for each argument appearing in the op's `arguments` list:
-`void COp::build(..., Value *input, Attribute attr)`. The pattern in the above
-calls this `build()` method for constructing the `COp`.
+several `build()` methods generated for it. One of them has aggregated
+parameters for result types, operands, and attributes in the signature: `void
+COp::build(..., ArrayRef<Type> resultTypes, Array<Value *> operands,
+ArrayRef<NamedAttribute> attr)`. The pattern in the above calls this `build()`
+method for constructing the `COp`.
 
 In general, arguments in the the result pattern will be passed directly to the
 `build()` method to leverage the auto-generated `build()` method, list them in
@@ -246,16 +247,29 @@ that has result type deduction ability via `OpBuilder` in ODS. For example,
 in the following pattern
 
 ```tblgen
-def : Pat<(AOp $input, $attr), (COp (BOp) $attr)>;
+def : Pat<(AOp $input, $attr), (COp (AOp $input, $attr) $attr)>;
 ```
 
-`BOp` is generated via a nested result pattern; DRR won't be able to deduce the
-result type for it. A custom builder for `BOp` should be defined and it should
-deduce the result type by itself.
+`AOp` is generated via a nested result pattern; DRR won't be able to deduce the
+result type for it. A custom builder for `AOp` should be defined and it should
+deduce the result type by itself. The builder should have the a separate
+parameter for each operand and attribute and deduce the result type internally
+by itself. For example, for the above `AOp`, a possible builder is:
 
-Failing to define such a builder will result in an error at C++ compilation
-time saying the call to `BOp::build()` cannot be resolved because of the number
-of parameters mismatch.
+```c++
+
+void AOp::build(Builder *builder, OperationState &state,
+                Value *input, Attribute attr) {
+  state.addOperands({input});
+  state.addAttribute("a_attr", attr);
+  Type type = ...; // Deduce result type here
+  state.addTypes({type});
+}
+```
+
+Failing to define such a builder will result in an error at C++ compilation time
+saying the call to `AOp::build()` cannot be resolved because of the number of
+parameters mismatch.
 
 #### Generating DAG of operations
 
@@ -658,5 +672,11 @@ providing include paths via `-I`. For example,
 mlir-tblgen --gen-rewriters -I /path/to/mlir/include /path/to/input/td/file
 ```
 
+### Compilation error: no matching member function for call to 'build'
+
+This is because DRR is failing to call a `build()` mehtod with result type
+deduction ability. See [building operations](#building-operations) for more
+details.
+
 [TableGen]: https://llvm.org/docs/TableGen/index.html
 [OpBase]: https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/OpBase.td
index 7fe249b..61fe70f 100644 (file)
@@ -56,8 +56,7 @@ def AOp : NS_Op<"a_op", []> {
 
 // CHECK:      void AOp::build(
 // CHECK-SAME:   ArrayRef<NamedAttribute> attributes
-// CHECK:        for (const auto& pair : attributes)
-// CHECK-NEXT:     tblgen_state.addAttribute(pair.first, pair.second);
+// CHECK:      tblgen_state.addAttributes(attributes);
 
 // Test verify method
 // ---
index 9176dfc..9979480 100644 (file)
@@ -45,8 +45,8 @@ def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
 }
 
 // CHECK-LABEL: OpD definitions
-// CHECK: void OpD::build(Builder *, OperationState &tblgen_state, Value *x, TypeAttr attr, FloatAttr f32)
-// CHECK: tblgen_state.addTypes({attr.getValue()});
+// CHECK: void OpD::build(Builder *, OperationState &tblgen_state, ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes)
+// CHECK: tblgen_state.addTypes({attr.second.cast<TypeAttr>().getValue()});
 
 def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
   let arguments = (ins I32:$x, F32Attr:$attr);
@@ -54,8 +54,8 @@ def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
 }
 
 // CHECK-LABEL: OpE definitions
-// CHECK: void OpE::build(Builder *, OperationState &tblgen_state, Value *x, FloatAttr attr)
-// CHECK: tblgen_state.addTypes({attr.getType()});
+// CHECK: void OpE::build(Builder *, OperationState &tblgen_state, ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes)
+// CHECK: tblgen_state.addTypes({attr.second.getType()});
 
 def OpF : NS_Op<"one_variadic_result_op", []> {
   let results = (outs Variadic<I32>:$x);
index 17dd173..46803b5 100644 (file)
@@ -488,8 +488,13 @@ private:
 
   // Generates the build() method that takes each operand/attribute as a
   // stand-alone parameter. This build() method uses first operand's type
-  // as all result's types.
-  void genUseOperandAsResultTypeBuilder();
+  // as all results' types.
+  void genUseOperandAsResultTypeSeparateParamBuilder();
+
+  // Generates the build() method that takes all operands/attributes
+  // collectively as one parameter. This build() method uses first operand's
+  // type as all results' types.
+  void genUseOperandAsResultTypeCollectiveParamBuilder();
 
   // Generates the build() method that takes each operand/attribute as a
   // stand-alone parameter. This build() method uses first attribute's type
@@ -813,7 +818,40 @@ void OpEmitter::genCollectiveTypeParamBuilder() {
   m.body() << formatv("  {0}.addTypes(resultTypes);\n", builderOpState);
 }
 
-void OpEmitter::genUseOperandAsResultTypeBuilder() {
+void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
+  // If this op has a variadic result, we cannot generate this builder because
+  // we don't know how many results to create.
+  if (op.getNumVariadicResults() != 0)
+    return;
+
+  int numResults = op.getNumResults();
+
+  // Signature
+  std::string params =
+      std::string("Builder *, OperationState &") + builderOpState +
+      ", ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes";
+  auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
+  auto &body = m.body();
+
+  // Result types
+  SmallVector<std::string, 2> resultTypes(numResults, "operands[0]->getType()");
+  body << "  " << builderOpState << ".addTypes({"
+       << llvm::join(resultTypes, ", ") << "});\n\n";
+
+  // Operands
+  body << "  " << builderOpState << ".addOperands(operands);\n\n";
+
+  // Attributes
+  body << "  " << builderOpState << ".addAttributes(attributes);\n";
+
+  // Create the correct number of regions
+  if (int numRegions = op.getNumRegions()) {
+    for (int i = 0; i < numRegions; ++i)
+      m.body() << "  (void)" << builderOpState << ".addRegion();\n";
+  }
+}
+
+void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
   std::string paramList;
   llvm::SmallVector<std::string, 4> resultNames;
   buildParamList(paramList, resultNames, TypeParamKind::None);
@@ -836,29 +874,32 @@ void OpEmitter::genUseOperandAsResultTypeBuilder() {
 }
 
 void OpEmitter::genUseAttrAsResultTypeBuilder() {
-  std::string paramList;
-  llvm::SmallVector<std::string, 4> resultNames;
-  buildParamList(paramList, resultNames, TypeParamKind::None);
-
-  auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
-  genCodeForAddingArgAndRegionForBuilder(m.body());
-
-  auto numResults = op.getNumResults();
-  if (numResults == 0)
-    return;
+  std::string params =
+      std::string("Builder *, OperationState &") + builderOpState +
+      ", ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes";
+  auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
+  auto &body = m.body();
 
   // Push all result types to the operation state
   std::string resultType;
   const auto &namedAttr = op.getAttribute(0);
+
+  body << "  for (auto attr : attributes) {\n";
+  body << "    if (attr.first != \"" << namedAttr.name << "\") continue;\n";
   if (namedAttr.attr.isTypeAttr()) {
-    resultType = formatv("{0}.getValue()", namedAttr.name);
+    resultType = "attr.second.cast<TypeAttr>().getValue()";
   } else {
-    resultType = formatv("{0}.getType()", namedAttr.name);
+    resultType = "attr.second.getType()";
   }
-  m.body() << "  " << builderOpState << ".addTypes({" << resultType;
-  for (int i = 1; i != numResults; ++i)
-    m.body() << ", " << resultType;
-  m.body() << "});\n\n";
+  SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
+  body << "    " << builderOpState << ".addTypes({"
+       << llvm::join(resultTypes, ", ") << "});\n";
+  body << "  }\n";
+
+  // Operands
+  body << "  " << builderOpState << ".addOperands(operands);\n\n";
+  // Attributes
+  body << "  " << builderOpState << ".addAttributes(attributes);\n";
 }
 
 void OpEmitter::genBuilder() {
@@ -907,8 +948,10 @@ void OpEmitter::genBuilder() {
   //    use the first operand or attribute's type as all result types
   // to facilitate different call patterns.
   if (op.getNumVariadicResults() == 0) {
-    if (op.hasTrait("OpTrait::SameOperandsAndResultType"))
-      genUseOperandAsResultTypeBuilder();
+    if (op.hasTrait("OpTrait::SameOperandsAndResultType")) {
+      genUseOperandAsResultTypeSeparateParamBuilder();
+      genUseOperandAsResultTypeCollectiveParamBuilder();
+    }
     if (op.hasTrait("OpTrait::FirstAttrDerivedResultType"))
       genUseAttrAsResultTypeBuilder();
   }
@@ -946,9 +989,7 @@ void OpEmitter::genCollectiveParamBuilder() {
   body << "  " << builderOpState << ".addOperands(operands);\n\n";
 
   // Attributes
-  body << "  for (const auto& pair : attributes)\n"
-       << "    " << builderOpState
-       << ".addAttribute(pair.first, pair.second);\n";
+  body << "  " << builderOpState << ".addAttributes(attributes);\n";
 
   // Create the correct number of regions
   if (int numRegions = op.getNumRegions()) {
index 2f7e8e0..ac2976a 100644 (file)
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
 
-using namespace llvm;
 using namespace mlir;
 using namespace mlir::tblgen;
 
+using llvm::formatv;
+using llvm::Record;
+using llvm::RecordKeeper;
+
 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
 
 namespace llvm {
@@ -121,6 +124,22 @@ private:
   // result value name.
   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
 
+  using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
+
+  // Emits a local variable for each value and attribute to be used for creating
+  // an op.
+  void createSeparateLocalVarsForOpArgs(DagNode node,
+                                        ChildNodeIndexNameMap &childNodeNames);
+
+  // Emits the concrete arguments used to call a op's builder.
+  void supplyValuesForOpArgs(DagNode node,
+                             const ChildNodeIndexNameMap &childNodeNames);
+
+  // Emits the local variables for holding all values as a whole and all named
+  // attributes as a whole to be used for creating an op.
+  void createAggregateLocalVarsForOpArgs(
+      DagNode node, const ChildNodeIndexNameMap &childNodeNames);
+
   // Returns the C++ expression to construct a constant attribute of the given
   // `value` for the given attribute kind `attr`.
   std::string handleConstantAttr(Attribute attr, StringRef value);
@@ -529,7 +548,6 @@ void PatternEmitter::emitRewriteLogic() {
     PrintFatalError(loc, error);
   }
 
-  os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n";
   os.indent(4) << "auto loc = rewriter.getFusedLoc({";
   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
@@ -555,18 +573,18 @@ void PatternEmitter::emitRewriteLogic() {
     os.indent(4) << "rewriter.eraseOp(op0);\n";
   } else {
     // Process replacement result patterns.
-    os.indent(4) << "SmallVector<Value *, 4> tblgen_values;";
+    os.indent(4) << "SmallVector<Value *, 4> tblgen_repl_values;\n";
     for (int i = replStartIndex; i < numResultPatterns; ++i) {
       DagNode resultTree = pattern.getResultPattern(i);
       auto val = handleResultPattern(resultTree, offsets[i], 0);
       os.indent(4) << "\n";
       // Resolve each symbol for all range use so that we can loop over them.
       os << symbolInfoMap.getAllRangeUse(
-          val, "    for (auto *v : {0}) {{ tblgen_values.push_back(v); }",
+          val, "    for (auto *v : {0}) {{ tblgen_repl_values.push_back(v); }",
           "\n");
     }
     os.indent(4) << "\n";
-    os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n";
+    os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n";
   }
 
   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
@@ -705,7 +723,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
 
   // A map to collect all nested DAG child nodes' names, with operand index as
   // the key. This includes both bound and unbound child nodes.
-  llvm::DenseMap<unsigned, std::string> childNodeNames;
+  ChildNodeIndexNameMap childNodeNames;
 
   // First go through all the child nodes who are nested DAG constructs to
   // create ops for them and remember the symbol names for them, so that we can
@@ -739,6 +757,80 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
                           valuePackName);
   os.indent(4) << "{\n";
 
+  // Right now ODS don't have general type inference support. Except a few
+  // special cases listed below, DRR needs to supply types for all results
+  // when building an op.
+  bool isSameOperandsAndResultType =
+      resultOp.hasTrait("OpTrait::SameOperandsAndResultType");
+  bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType");
+
+  if (isSameOperandsAndResultType || useFirstAttr) {
+    // We know how to deduce the result type for ops with these traits and we've
+    // generated builders taking aggregrate parameters. Use those builders to
+    // create the ops.
+
+    // First prepare local variables for op arguments used in builder call.
+    createAggregateLocalVarsForOpArgs(tree, childNodeNames);
+    // Then create the op.
+    os.indent(6) << formatv(
+        "{0} = rewriter.create<{1}>(loc, tblgen_values, tblgen_attrs);\n",
+        valuePackName, resultOp.getQualCppClassName());
+    os.indent(4) << "}\n";
+    return resultValue;
+  }
+
+  bool isBroadcastable =
+      resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult");
+  bool usePartialResults = valuePackName != resultValue;
+
+  if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) {
+    // For these cases (broadcastable ops, op results used both as auxiliary
+    // values and replacement values, ops in nested patterns, auxiliary ops), we
+    // still need to supply the result types when building the op. But because
+    // we don't generate a builder automatically with ODS for them, it's the
+    // developer's responsiblity to make sure such a builder (with result type
+    // deduction ability) exists. We go through the separate-parameter builder
+    // here given that it's easier for developers to write compared to
+    // aggregate-parameter builders.
+    createSeparateLocalVarsForOpArgs(tree, childNodeNames);
+    os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
+                            resultOp.getQualCppClassName());
+    supplyValuesForOpArgs(tree, childNodeNames);
+    os << "\n      );\n";
+    os.indent(4) << "}\n";
+    return resultValue;
+  }
+
+  // If depth == 0 and resultIndex >= 0, it means we are replacing the values
+  // generated from the source pattern root op. Then we can use the source
+  // pattern's value types to determine the value type of the generated op
+  // here.
+
+  // First prepare local variables for op arguments used in builder call.
+  createAggregateLocalVarsForOpArgs(tree, childNodeNames);
+
+  // Then prepare the result types. We need to specify the types for all
+  // results.
+  os.indent(6) << formatv(
+      "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n");
+  int numResults = resultOp.getNumResults();
+  if (numResults != 0) {
+    for (int i = 0; i < numResults; ++i)
+      os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) {{"
+                              "tblgen_types.push_back(v->getType()); }\n",
+                              resultIndex + i);
+  }
+  os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc, tblgen_types, "
+                          "tblgen_values, tblgen_attrs);\n",
+                          valuePackName, resultOp.getQualCppClassName());
+  os.indent(4) << "}\n";
+  return resultValue;
+}
+
+void PatternEmitter::createSeparateLocalVarsForOpArgs(
+    DagNode node, ChildNodeIndexNameMap &childNodeNames) {
+  Operator &resultOp = node.getDialectOp(opMap);
+
   // Now prepare operands used for building this op:
   // * If the operand is non-variadic, we create a `Value*` local variable.
   // * If the operand is variadic, we create a `SmallVector<Value*>` local
@@ -752,15 +844,16 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
       // We do not need special handling for attributes.
       continue;
     }
+
     std::string varName;
     if (operand->isVariadic()) {
       varName = formatv("tblgen_values_{0}", valueIndex++);
       os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
       std::string range;
-      if (tree.isNestedDagArg(argIndex)) {
+      if (node.isNestedDagArg(argIndex)) {
         range = childNodeNames[argIndex];
       } else {
-        range = tree.getArgName(argIndex);
+        range = node.getArgName(argIndex);
       }
       // Resolve the symbol for all range use so that we have a uniform way of
       // capturing the values.
@@ -770,12 +863,12 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     } else {
       varName = formatv("tblgen_value_{0}", valueIndex++);
       os.indent(6) << formatv("Value *{0} = ", varName);
-      if (tree.isNestedDagArg(argIndex)) {
+      if (node.isNestedDagArg(argIndex)) {
         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
       } else {
-        DagLeaf leaf = tree.getArgAsLeaf(argIndex);
+        DagLeaf leaf = node.getArgAsLeaf(argIndex);
         auto symbol =
-            symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex));
+            symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
         if (leaf.isNativeCodeCall()) {
           os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
         } else {
@@ -788,74 +881,37 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     // Update to use the newly created local variable for building the op later.
     childNodeNames[argIndex] = varName;
   }
+}
 
-  // Then we create the builder call.
-
-  // Right now we don't have general type inference in MLIR. Except a few
-  // special cases listed below, we need to supply types for all results
-  // when building an op.
-  bool isSameOperandsAndResultType =
-      resultOp.hasTrait("OpTrait::SameOperandsAndResultType");
-  bool isBroadcastable =
-      resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult");
-  bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType");
-  bool usePartialResults = valuePackName != resultValue;
-
-  if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
-      usePartialResults || depth > 0 || resultIndex < 0) {
-    os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
-                            resultOp.getQualCppClassName());
-  } else {
-    // If depth == 0 and resultIndex >= 0, it means we are replacing the values
-    // generated from the source pattern root op. Then we can use the source
-    // pattern's value types to determine the value type of the generated op
-    // here.
-
-    // We need to specify the types for all results.
-    int numResults = resultOp.getNumResults();
-    if (numResults != 0) {
-      os.indent(6) << "tblgen_types.clear();\n";
-      for (int i = 0; i < numResults; ++i) {
-        os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) "
-                                "tblgen_types.push_back(v->getType());\n",
-                                resultIndex + i);
-      }
-    }
-
-    os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
-                            resultOp.getQualCppClassName());
-    if (numResults != 0)
-      os.indent(6) << ", tblgen_types";
-  }
-
-  // Add arguments for the builder call.
-  for (int argIndex = 0; argIndex != numOpArgs; ++argIndex) {
+void PatternEmitter::supplyValuesForOpArgs(
+    DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
+  Operator &resultOp = node.getDialectOp(opMap);
+  for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
+       argIndex != numOpArgs; ++argIndex) {
     // Start each argment on its own line.
     (os << ",\n").indent(8);
 
     Argument opArg = resultOp.getArg(argIndex);
     // Handle the case of operand first.
     if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
-      if (!operand->name.empty()) {
+      if (!operand->name.empty())
         os << "/*" << operand->name << "=*/";
-      }
-      os << childNodeNames[argIndex];
-      // TODO(jpienaar): verify types
+      os << childNodeNames.lookup(argIndex);
       continue;
     }
 
     // The argument in the op definition.
     auto opArgName = resultOp.getArgName(argIndex);
-    if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
+    if (auto subTree = node.getArgAsNestedDag(argIndex)) {
       if (!subTree.isNativeCodeCall())
         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
                              "for creating attribute");
       os << formatv("/*{0}=*/{1}", opArgName,
                     handleReplaceWithNativeCodeCall(subTree));
     } else {
-      auto leaf = tree.getArgAsLeaf(argIndex);
+      auto leaf = node.getArgAsLeaf(argIndex);
       // The argument in the result DAG pattern.
-      auto patArgName = tree.getArgName(argIndex);
+      auto patArgName = node.getArgName(argIndex);
       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
         if (!opArg.is<NamedAttribute *>())
@@ -868,10 +924,74 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
       os << handleOpArgument(leaf, patArgName);
     }
   }
-  os << "\n      );\n";
-  os.indent(4) << "}\n";
+}
 
-  return resultValue;
+void PatternEmitter::createAggregateLocalVarsForOpArgs(
+    DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
+  Operator &resultOp = node.getDialectOp(opMap);
+
+  os.indent(6) << formatv(
+      "SmallVector<Value *, 4> tblgen_values; (void)tblgen_values;\n");
+  os.indent(6) << formatv(
+      "SmallVector<NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs;\n");
+
+  for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
+    if (const auto *attr =
+            resultOp.getArg(argIndex).dyn_cast<NamedAttribute *>()) {
+      const char *addAttrCmd = "if ({1}) {{"
+                               "  tblgen_attrs.emplace_back(rewriter."
+                               "getIdentifier(\"{0}\"), {1}); }\n";
+      // The argument in the op definition.
+      auto opArgName = resultOp.getArgName(argIndex);
+      if (auto subTree = node.getArgAsNestedDag(argIndex)) {
+        if (!subTree.isNativeCodeCall())
+          PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
+                               "for creating attribute");
+        os.indent(6) << formatv(addAttrCmd, opArgName,
+                                handleReplaceWithNativeCodeCall(subTree));
+      } else {
+        auto leaf = node.getArgAsLeaf(argIndex);
+        // The argument in the result DAG pattern.
+        auto patArgName = node.getArgName(argIndex);
+        os.indent(6) << formatv(addAttrCmd, opArgName,
+                                handleOpArgument(leaf, patArgName));
+      }
+      continue;
+    }
+
+    const auto *operand =
+        resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
+    std::string varName;
+    if (operand->isVariadic()) {
+      std::string range;
+      if (node.isNestedDagArg(argIndex)) {
+        range = childNodeNames.lookup(argIndex);
+      } else {
+        range = node.getArgName(argIndex);
+      }
+      // Resolve the symbol for all range use so that we have a uniform way of
+      // capturing the values.
+      range = symbolInfoMap.getValueAndRangeUse(range);
+      os.indent(6) << formatv(
+          "for (auto *v : {0}) tblgen_values.push_back(v);\n", range);
+    } else {
+      os.indent(6) << formatv("tblgen_values.push_back(", varName);
+      if (node.isNestedDagArg(argIndex)) {
+        os << symbolInfoMap.getValueAndRangeUse(
+            childNodeNames.lookup(argIndex));
+      } else {
+        DagLeaf leaf = node.getArgAsLeaf(argIndex);
+        auto symbol =
+            symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
+        if (leaf.isNativeCodeCall()) {
+          os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
+        } else {
+          os << symbol;
+        }
+      }
+      os << ");\n";
+    }
+  }
 }
 
 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {