[DRR] Allow interleaved operands and attributes
authorLei Zhang <antiagainst@google.com>
Tue, 22 Oct 2019 03:47:49 +0000 (20:47 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 22 Oct 2019 03:48:17 +0000 (20:48 -0700)
Previously DRR assumes attributes to appear after operands. This was the
previous requirements on ODS, but that has changed some time ago. Fix
DRR to also support interleaved operands and attributes.

PiperOrigin-RevId: 275983485

mlir/include/mlir/TableGen/Operator.h
mlir/lib/TableGen/Operator.cpp
mlir/test/lib/TestDialect/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir
mlir/tools/mlir-tblgen/RewriterGen.cpp

index 8b304d5..1d82976 100644 (file)
@@ -124,6 +124,14 @@ public:
   // Returns the total number of arguments.
   int getNumArgs() const { return arguments.size(); }
 
+  using arg_iterator = const Argument *;
+  using arg_range = llvm::iterator_range<arg_iterator>;
+
+  // Op argument (attribute or operand) iterators.
+  arg_iterator arg_begin() const;
+  arg_iterator arg_end() const;
+  arg_range getArgs() const;
+
   // Op argument (attribute or operand) accessors.
   Argument getArg(int index) const;
   StringRef getArgName(int index) const;
index 60fecf7..7d926d9 100644 (file)
@@ -126,6 +126,18 @@ unsigned tblgen::Operator::getNumVariadicOperands() const {
       [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
 }
 
+tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const {
+  return arguments.begin();
+}
+
+tblgen::Operator::arg_iterator tblgen::Operator::arg_end() const {
+  return arguments.end();
+}
+
+tblgen::Operator::arg_range tblgen::Operator::getArgs() const {
+  return {arg_begin(), arg_end()};
+}
+
 StringRef tblgen::Operator::getArgName(int index) const {
   DagInit *argumentValues = def.getValueAsDag("arguments");
   return argumentValues->getArgName(index)->getValue();
index 1157eb8..b8443ac 100644 (file)
@@ -426,6 +426,28 @@ def OpJ : TEST_Op<"op_j">, Arguments<(ins)>, Results<(outs I32)>;
 def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>;
 def : Pat<(OpJ), (OpK)>;
 
+def OpInterleavedOperandAttribute1 : TEST_Op<"interleaved_operand_attr1"> {
+  let arguments = (ins
+    I32:$input1,
+    I64Attr:$attr1,
+    I32:$input2,
+    I64Attr:$attr2
+  );
+}
+
+def OpInterleavedOperandAttribute2 : TEST_Op<"interleaved_operand_attr2"> {
+  let arguments = (ins
+    I32:$input1,
+    I64Attr:$attr1,
+    I32:$input2,
+    I64Attr:$attr2
+  );
+}
+
+// Test that we can capture and reference interleaved operands and attributes.
+def : Pat<(OpInterleavedOperandAttribute1 $input1, $attr1, $input2, $attr2),
+          (OpInterleavedOperandAttribute2 $input1, $attr1, $input2, $attr2)>;
+
 // Test NativeCodeCall.
 def OpNativeCodeCall1 : TEST_Op<"native_code_call1"> {
   let arguments = (ins
index 7ebd56c..df21c77 100644 (file)
@@ -24,6 +24,14 @@ func @verifyZeroArg() -> i32 {
   return %0 : i32
 }
 
+// CHECK-LABEL: verifyInterleavedOperandAttribute
+// CHECK-SAME:    %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func @verifyInterleavedOperandAttribute(%arg0: i32, %arg1: i32) {
+  // CHECK: "test.interleaved_operand_attr2"(%[[ARG0]], %[[ARG1]]) {attr1 = 15 : i64, attr2 = 42 : i64}
+  "test.interleaved_operand_attr1"(%arg0, %arg1) {attr1 = 15, attr2 = 42} : (i32, i32) -> ()
+  return
+}
+
 // CHECK-LABEL: verifyBenefit
 func @verifyBenefit(%arg0 : i32) -> i32 {
   %0 = "test.op_d"(%arg0) : (i32) -> i32
index 684e619..2f7e8e0 100644 (file)
@@ -81,13 +81,13 @@ private:
   // `tree`.
   void emitOpMatch(DagNode tree, int depth);
 
-  // Emits C++ statements for matching the `index`-th argument of the given DAG
-  // `tree` as an operand.
-  void emitOperandMatch(DagNode tree, int index, int depth, int indent);
+  // Emits C++ statements for matching the `argIndex`-th argument of the given
+  // DAG `tree` as an operand.
+  void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
 
-  // Emits C++ statements for matching the `index`-th argument of the given DAG
-  // `tree` as an attribute.
-  void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
+  // Emits C++ statements for matching the `argIndex`-th argument of the given
+  // DAG `tree` as an attribute.
+  void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
 
   //===--------------------------------------------------------------------===//
   // Rewrite utilities
@@ -260,11 +260,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
                           << '\n');
 }
 
-void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
+void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
                                       int indent) {
   Operator &op = tree.getDialectOp(opMap);
-  auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
-  auto matcher = tree.getArgAsLeaf(index);
+  auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
+  auto matcher = tree.getArgAsLeaf(argIndex);
 
   // If a constraint is specified, we need to generate C++ statements to
   // check the constraint.
@@ -272,7 +272,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
     if (!matcher.isOperandMatcher()) {
       PrintFatalError(
           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
-                       op.getOperationName(), index + 1));
+                       op.getOperationName(), argIndex + 1));
     }
 
     // Only need to verify if the matcher's type is different from the one
@@ -281,12 +281,12 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
       if (operand->isVariadic()) {
         auto error = formatv(
             "further constrain op {0}'s variadic operand #{1} unsupported now",
-            op.getOperationName(), index);
+            op.getOperationName(), argIndex);
         PrintFatalError(loc, error);
       }
       auto self =
           formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
-                  depth, index);
+                  depth, argIndex);
       os.indent(indent) << "if (!("
                         << tgfmt(matcher.getConditionTemplate(),
                                  &fmtCtx.withSelf(self))
@@ -295,17 +295,23 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
   }
 
   // Capture the value
-  auto name = tree.getArgName(index);
+  auto name = tree.getArgName(argIndex);
   if (!name.empty()) {
+    // We need to subtract the number of attributes before this operand to get
+    // the index in the operand list.
+    auto numPrevAttrs = std::count_if(
+        op.arg_begin(), op.arg_begin() + argIndex,
+        [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
+
     os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
-                                 name, depth, index);
+                                 name, depth, argIndex - numPrevAttrs);
   }
 }
 
-void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
+void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
                                         int indent) {
   Operator &op = tree.getDialectOp(opMap);
-  auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
+  auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
   const auto &attr = namedAttr->attr;
 
   os.indent(indent) << "{\n";
@@ -328,12 +334,12 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
     os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
   }
 
-  auto matcher = tree.getArgAsLeaf(index);
+  auto matcher = tree.getArgAsLeaf(argIndex);
   if (!matcher.isUnspecified()) {
     if (!matcher.isAttrMatcher()) {
       PrintFatalError(
           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
-                       op.getOperationName(), index + 1));
+                       op.getOperationName(), argIndex + 1));
     }
 
     // If a constraint is specified, we need to generate C++ statements to
@@ -345,7 +351,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
   }
 
   // Capture the value
-  auto name = tree.getArgName(index);
+  auto name = tree.getArgName(argIndex);
   if (!name.empty()) {
     os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
   }
@@ -683,6 +689,10 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
 
 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
                                              int depth) {
+  LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
+  LLVM_DEBUG(tree.print(llvm::dbgs()));
+  LLVM_DEBUG(llvm::dbgs() << '\n');
+
   Operator &resultOp = tree.getDialectOp(opMap);
   auto numOpArgs = resultOp.getNumArgs();
 
@@ -734,12 +744,16 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
   // * If the operand is variadic, we create a `SmallVector<Value*>` local
   //   variable.
 
-  int argIndex = 0;   // The current index to this op's ODS argument
   int valueIndex = 0; // An index for uniquing local variable names.
-  for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
-    const auto &operand = resultOp.getOperand(argIndex);
+  for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
+    const auto *operand =
+        resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
+    if (!operand) {
+      // We do not need special handling for attributes.
+      continue;
+    }
     std::string varName;
-    if (operand.isVariadic()) {
+    if (operand->isVariadic()) {
       varName = formatv("tblgen_values_{0}", valueIndex++);
       os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
       std::string range;
@@ -814,22 +828,22 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
       os.indent(6) << ", tblgen_types";
   }
 
-  // Add operands for the builder all.
-  for (int i = 0; i < argIndex; ++i) {
-    const auto &operand = resultOp.getOperand(i);
-    // Start each operand on its own line.
+  // Add arguments for the builder call.
+  for (int argIndex = 0; argIndex != numOpArgs; ++argIndex) {
+    // Start each argment on its own line.
     (os << ",\n").indent(8);
-    if (!operand.name.empty()) {
-      os << "/*" << operand.name << "=*/";
+
+    Argument opArg = resultOp.getArg(argIndex);
+    // Handle the case of operand first.
+    if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
+      if (!operand->name.empty()) {
+        os << "/*" << operand->name << "=*/";
+      }
+      os << childNodeNames[argIndex];
+      // TODO(jpienaar): verify types
+      continue;
     }
-    os << childNodeNames[i];
-    // TODO(jpienaar): verify types
-  }
 
-  // Add attributes for the builder call.
-  for (; argIndex != numOpArgs; ++argIndex) {
-    // Start each attribute on its own line.
-    (os << ",\n").indent(8);
     // The argument in the op definition.
     auto opArgName = resultOp.getArgName(argIndex);
     if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
@@ -844,8 +858,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
       auto patArgName = tree.getArgName(argIndex);
       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
-        auto argument = resultOp.getArg(argIndex);
-        if (!argument.is<NamedAttribute *>())
+        if (!opArg.is<NamedAttribute *>())
           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
         if (!patArgName.empty())
           os << "/*" << patArgName << "=*/";