[ODS] Fix operation argument population to avoid crash
authorLei Zhang <antiagainst@google.com>
Thu, 14 Nov 2019 19:02:52 +0000 (11:02 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 14 Nov 2019 19:03:29 +0000 (11:03 -0800)
The `Operator` class keeps an `arguments` field, which contains pointers
to `operands` and `attributes` elements. Thus it must be populated after
`operands` and `attributes` are finalized so to have stable pointers.
SmallVector may re-allocate when still having new elements added, which
will invalidate pointers.

PiperOrigin-RevId: 280466896

mlir/include/mlir/TableGen/Operator.h
mlir/lib/TableGen/Operator.cpp
mlir/lib/TableGen/Pattern.cpp
mlir/test/lib/TestDialect/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir

index 1d82976..95df9cb 100644 (file)
@@ -172,6 +172,10 @@ public:
   // Returns the dialect of the op.
   const Dialect &getDialect() const { return dialect; }
 
+  // Prints the contents in this operator to the given `os`. This is used for
+  // debugging purposes.
+  void print(llvm::raw_ostream &os) const;
+
 private:
   // Populates the vectors containing operands, attributes, results and traits.
   void populateOpStructure();
index 7d926d9..8afffd0 100644 (file)
@@ -27,6 +27,8 @@
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
 
+#define DEBUG_TYPE "mlir-tblgen-operator"
+
 using namespace mlir;
 
 using llvm::DagInit;
@@ -205,12 +207,11 @@ void tblgen::Operator::populateOpStructure() {
   auto derivedAttrClass = recordKeeper.getClass("DerivedAttr");
   numNativeAttributes = 0;
 
-  // The argument ordering is operands, native attributes, derived
-  // attributes.
   DagInit *argumentValues = def.getValueAsDag("arguments");
-  unsigned i = 0;
+  unsigned numArgs = argumentValues->getNumArgs();
+
   // Handle operands and native attributes.
-  for (unsigned e = argumentValues->getNumArgs(); i != e; ++i) {
+  for (unsigned i = 0; i != numArgs; ++i) {
     auto arg = argumentValues->getArg(i);
     auto givenName = argumentValues->getArgNameStr(i);
     auto argDefInit = dyn_cast<DefInit>(arg);
@@ -222,7 +223,6 @@ void tblgen::Operator::populateOpStructure() {
     if (argDef->isSubClassOf(typeConstraintClass)) {
       operands.push_back(
           NamedTypeConstraint{givenName, TypeConstraint(argDefInit)});
-      arguments.emplace_back(&operands.back());
     } else if (argDef->isSubClassOf(attrClass)) {
       if (givenName.empty())
         PrintFatalError(argDef->getLoc(), "attributes must be named");
@@ -230,7 +230,6 @@ void tblgen::Operator::populateOpStructure() {
         PrintFatalError(argDef->getLoc(),
                         "derived attributes not allowed in argument list");
       attributes.push_back({givenName, Attribute(argDef)});
-      arguments.emplace_back(&attributes.back());
       ++numNativeAttributes;
     } else {
       PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
@@ -258,6 +257,22 @@ void tblgen::Operator::populateOpStructure() {
     }
   }
 
+  // Populate `arguments`. This must happen after we've finalized `operands` and
+  // `attributes` because we will put their elements' pointers in `arguments`.
+  // SmallVector may perform re-allocation under the hood when adding new
+  // elements.
+  int operandIndex = 0, attrIndex = 0;
+  for (unsigned i = 0; i != numArgs; ++i) {
+    Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
+
+    if (argDef->isSubClassOf(typeConstraintClass)) {
+      arguments.emplace_back(&operands[operandIndex++]);
+    } else {
+      assert(argDef->isSubClassOf(attrClass));
+      arguments.emplace_back(&attributes[attrIndex++]);
+    }
+  }
+
   auto *resultsDag = def.getValueAsDag("results");
   auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
   if (!outsOp || outsOp->getDef()->getName() != "outs") {
@@ -298,6 +313,8 @@ void tblgen::Operator::populateOpStructure() {
     }
     regions.push_back({name, Region(regionInit->getDef())});
   }
+
+  LLVM_DEBUG(print(llvm::dbgs()));
 }
 
 ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }
@@ -317,3 +334,13 @@ bool tblgen::Operator::hasSummary() const {
 StringRef tblgen::Operator::getSummary() const {
   return def.getValueAsString("summary");
 }
+
+void tblgen::Operator::print(llvm::raw_ostream &os) const {
+  os << "op '" << getOperationName() << "'\n";
+  for (Argument arg : arguments) {
+    if (auto *attr = arg.dyn_cast<NamedAttribute *>())
+      os << "[attribute] " << attr->name << '\n';
+    else
+      os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
+  }
+}
index ddec0ba..d3c1ddd 100644 (file)
@@ -211,6 +211,7 @@ int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
 
 std::string
 tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
+  LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
   switch (kind) {
   case Kind::Attr: {
     auto type =
index 2972793..a0e1cd6 100644 (file)
@@ -468,6 +468,30 @@ def OpInterleavedOperandAttribute2 : TEST_Op<"interleaved_operand_attr2"> {
   );
 }
 
+def ManyArgsOp : TEST_Op<"many_arguments"> {
+  let arguments = (ins
+    I32:$input1, I32:$input2, I32:$input3, I32:$input4, I32:$input5,
+    I32:$input6, I32:$input7, I32:$input8, I32:$input9,
+    I64Attr:$attr1, I64Attr:$attr2, I64Attr:$attr3, I64Attr:$attr4,
+    I64Attr:$attr5, I64Attr:$attr6, I64Attr:$attr7, I64Attr:$attr8,
+    I64Attr:$attr9
+  );
+}
+
+// Test that DRR does not blow up when seeing lots of arguments.
+def : Pat<(ManyArgsOp
+            $input1, $input2, $input3, $input4, $input5,
+            $input6, $input7, $input8, $input9,
+            ConstantAttr<I64Attr, "42">,
+            $attr2, $attr3, $attr4, $attr5, $attr6,
+            $attr7, $attr8, $attr9),
+          (ManyArgsOp
+            $input1, $input2, $input3, $input4, $input5,
+            $input6, $input7, $input8, $input9,
+            ConstantAttr<I64Attr, "24">,
+            $attr2, $attr3, $attr4, $attr5, $attr6,
+            $attr7, $attr8, $attr9)>;
+
 // Test that we can capture and reference interleaved operands and attributes.
 def : Pat<(OpInterleavedOperandAttribute1 $input1, $attr1, $input2, $attr2),
           (OpInterleavedOperandAttribute2 $input1, $attr1, $input2, $attr2)>;
index df21c77..7586d84 100644 (file)
@@ -71,6 +71,18 @@ func @verifyAllAttrConstraintOf() -> (i32, i32, i32) {
   return %0, %1, %2: i32, i32, i32
 }
 
+// CHECK-LABEL: verifyManyArgs
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func @verifyManyArgs(%arg: i32) {
+  // CHECK: "test.many_arguments"(%[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]])
+  // CHECK-SAME: {attr1 = 24 : i64, attr2 = 42 : i64, attr3 = 42 : i64, attr4 = 42 : i64, attr5 = 42 : i64, attr6 = 42 : i64, attr7 = 42 : i64, attr8 = 42 : i64, attr9 = 42 : i64}
+  "test.many_arguments"(%arg, %arg, %arg, %arg, %arg, %arg, %arg, %arg, %arg) {
+    attr1 = 42, attr2 = 42, attr3 = 42, attr4 = 42, attr5 = 42,
+    attr6 = 42, attr7 = 42, attr8 = 42, attr9 = 42
+  } : (i32, i32, i32, i32, i32, i32, i32, i32, i32) -> ()
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // Test Symbol Binding
 //===----------------------------------------------------------------------===//