[TableGen] Make sure op in pattern has the same number of arguments as definition
authorLei Zhang <antiagainst@google.com>
Mon, 8 Apr 2019 22:14:59 +0000 (15:14 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 9 Apr 2019 02:17:56 +0000 (19:17 -0700)
    When an op in the source pattern specifies more arguments than its definition, we
    will have out-of-bound query for op arguments from the definition. That will cause
    crashes. This change fixes it.

--

PiperOrigin-RevId: 242548415

mlir/include/mlir/TableGen/Pattern.h
mlir/lib/TableGen/Pattern.cpp

index e7856e6..cacf400 100644 (file)
@@ -160,10 +160,6 @@ public:
   // Precondition: isNativeCodeBuilder.
   llvm::StringRef getNativeCodeBuilder() const;
 
-  // Collects all recursively bound arguments involved in the DAG tree rooted
-  // from this node.
-  void collectBoundArguments(Pattern *pattern) const;
-
   // Returns true if this DAG construct means to replace with an existing SSA
   // value.
   bool isReplaceWithValue() const;
@@ -235,6 +231,10 @@ public:
   int getBenefit() const;
 
 private:
+  // Recursively collects all bound arguments inside the DAG tree rooted
+  // at `tree`.
+  void collectBoundArguments(DagNode tree);
+
   // The TableGen definition of this pattern.
   const llvm::Record &def;
 
index bc95c2f..2200346 100644 (file)
@@ -22,6 +22,7 @@
 
 #include "mlir/TableGen/Pattern.h"
 #include "llvm/ADT/Twine.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
 
@@ -141,36 +142,6 @@ StringRef tblgen::DagNode::getArgName(unsigned index) const {
   return node->getArgNameStr(index);
 }
 
-static void collectBoundArguments(const llvm::DagInit *tree,
-                                  tblgen::Pattern *pattern) {
-  auto &op = pattern->getDialectOp(tblgen::DagNode(tree));
-  if (llvm::StringInit *si = tree->getName()) {
-    auto name = si->getAsUnquotedString();
-    if (!name.empty())
-      pattern->getSourcePatternBoundResults().insert(name);
-  }
-
-  // TODO(jpienaar): Expand to multiple matches.
-  for (unsigned i = 0, e = tree->getNumArgs(); i != e; ++i) {
-    auto *arg = tree->getArg(i);
-
-    if (auto *argTree = dyn_cast<llvm::DagInit>(arg)) {
-      collectBoundArguments(argTree, pattern);
-      continue;
-    }
-
-    StringRef name = tree->getArgNameStr(i);
-    if (name.empty())
-      continue;
-
-    pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i));
-  }
-}
-
-void tblgen::DagNode::collectBoundArguments(tblgen::Pattern *pattern) const {
-  ::collectBoundArguments(node, pattern);
-}
-
 bool tblgen::DagNode::isReplaceWithValue() const {
   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
   return dagOpDef->getName() == "replaceWithValue";
@@ -194,7 +165,7 @@ llvm::StringRef tblgen::DagNode::getNativeCodeBuilder() const {
 
 tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
     : def(*def), recordOpMap(mapper) {
-  getSourcePattern().collectBoundArguments(this);
+  collectBoundArguments(getSourcePattern());
 }
 
 tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
@@ -276,3 +247,34 @@ int tblgen::Pattern::getBenefit() const {
   }
   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
 }
+
+void tblgen::Pattern::collectBoundArguments(DagNode tree) {
+  auto &op = getDialectOp(tree);
+  auto numOpArgs = op.getNumArgs();
+  auto numTreeArgs = tree.getNumArgs();
+
+  if (numOpArgs != numTreeArgs) {
+    PrintFatalError(def.getLoc(),
+                    formatv("op '{0}' argument number mismatch: "
+                            "{1} in pattern vs. {2} in definition",
+                            op.getOperationName(), numTreeArgs, numOpArgs));
+  }
+
+  // The name attached to the DAG node's operator is for representing the
+  // results generated from this op. It should be remembered as bound results.
+  auto treeName = tree.getOpName();
+  if (!treeName.empty())
+    boundResults.insert(treeName);
+
+  // TODO(jpienaar): Expand to multiple matches.
+  for (unsigned i = 0; i != numTreeArgs; ++i) {
+    if (auto treeArg = tree.getArgAsNestedDag(i)) {
+      // This DAG node argument is a DAG node itself. Go inside recursively.
+      collectBoundArguments(treeArg);
+    } else {
+      auto treeArgName = tree.getArgName(i);
+      if (!treeArgName.empty())
+        boundArguments.try_emplace(treeArgName, op.getArg(i));
+    }
+  }
+}