Match attributes in input pattern.
authorJacques Pienaar <jpienaar@google.com>
Mon, 7 Jan 2019 17:52:26 +0000 (09:52 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 22:00:55 +0000 (15:00 -0700)
Bind attributes similar to operands. Use to rewrite leakyreulo and const rewrite pattern. The attribute type/attributes are not currently checked so should only be used where the attributes match due to the construction of the op.

To support current attribute namespacing, convert __ in attribute name to "$" for matching purposes ('$' is not valid character in variable in TableGen).

Some simplification to make it simpler to specify indented ostream and avoid so many spaces. The goal is not to have perfectly formatted code generated but good enough so that its still easy to read for a user.

PiperOrigin-RevId: 228183639

mlir/include/mlir/TableGen/Operator.h
mlir/lib/TableGen/Operator.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp

index 523eb30e60a1bb6d06bcc1420ddb7bf0718286a1..977090b777a007a4cac022024d42529b2ff17663 100644 (file)
@@ -55,6 +55,10 @@ public:
 
   // Operations attribute accessors.
   struct Attribute {
+    std::string getName() const;
+    StringRef getReturnType() const;
+    StringRef getStorageType() const;
+
     llvm::StringInit *name;
     llvm::Record *record;
     bool isDerived;
index 085d1db2cc8d875e48a6093929907783043faec9..a126334207ff46c0d675a2ac26e742bc17f61602 100644 (file)
@@ -141,6 +141,23 @@ void Operator::populateOperandsAndAttributes() {
   }
 }
 
+std::string mlir::Operator::Attribute::getName() const {
+  std::string ret = name->getAsUnquotedString();
+  // TODO(jpienaar): Revise this post dialect prefixing attribute discussion.
+  auto split = StringRef(ret).split("__");
+  if (split.second.empty())
+    return ret;
+  return llvm::join_items("$", split.first, split.second);
+}
+
+StringRef mlir::Operator::Attribute::getReturnType() const {
+  return record->getValueAsString("returnType").trim();
+}
+
+StringRef mlir::Operator::Attribute::getStorageType() const {
+  return record->getValueAsString("storageType").trim();
+}
+
 bool mlir::Operator::Operand::hasMatcher() const {
   llvm::Init *matcher = defInit->getDef()->getValue("predicate")->getValue();
   return !isa<llvm::UnsetInit>(matcher);
index 99f69ccf16135e1a486518fa9cead21fa5b9efaf..965f08d53304f8ac73e3047ec50389e6c752b98d 100644 (file)
@@ -34,6 +34,9 @@ using namespace mlir;
 
 static const char *const generatedArgName = "_arg";
 
+// Helper macro that returns indented os.
+#define OUT(X) os.indent((X))
+
 // TODO(jpienaar): The builder body should probably be separate from the header.
 
 // Variation of method in FormatVariadic.h which takes a StringRef as input
@@ -164,8 +167,8 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) {
   os << "> {\npublic:\n";
 
   // Build operation name.
-  os << "  static StringRef getOperationName() { return \""
-     << emitter.op.getOperationName() << "\"; };\n";
+  OUT(2) << "static StringRef getOperationName() { return \""
+         << emitter.op.getOperationName() << "\"; };\n";
 
   emitter.emitNamedOperands();
   emitter.emitBuilder();
@@ -176,8 +179,8 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) {
   emitter.emitCanonicalizationPatterns();
   emitter.emitConstantFolder();
 
-  os << "private:\n  friend class ::mlir::OperationInst;\n";
-  os << "  explicit " << emitter.op.cppClassName()
+  os << "private:\n  friend class ::mlir::OperationInst;\n"
+     << "  explicit " << emitter.op.cppClassName()
      << "(const OperationInst* state) : Op(state) {}\n};\n";
   emitter.mapOverClassNamespaces(
       [&os](StringRef ns) { os << "} // end namespace " << ns << "\n"; });
@@ -190,22 +193,20 @@ void OpEmitter::emitAttrGetters() {
 
     // Emit the derived attribute body.
     if (attr.isDerived) {
-      os << "  " << def->getValueAsString("returnType").trim() << ' ' << name
-         << "() const {" << def->getValueAsString("body") << " }\n";
+      OUT(2) << attr.getReturnType() << ' ' << name << "() const {"
+             << def->getValueAsString("body") << " }\n";
       continue;
     }
 
     // Emit normal emitter.
-    os << "  " << def->getValueAsString("returnType").trim() << ' ' << name
-       << "() const {\n";
+    OUT(2) << attr.getReturnType() << ' ' << name << "() const {\n";
 
     // Return the queried attribute with the correct return type.
-    std::string attrVal =
-        formatv("this->getAttrOfType<{0}>(\"{1}\")",
-                def->getValueAsString("storageType").trim(), name);
-    os << "    return "
-       << formatv(def->getValueAsString("convertFromStorage"), attrVal)
-       << ";\n  }\n";
+    std::string attrVal = formatv("this->getAttrOfType<{0}>(\"{1}\")",
+                                  attr.getStorageType(), name);
+    OUT(4) << "return "
+           << formatv(def->getValueAsString("convertFromStorage"), attrVal)
+           << ";\n  }\n";
   }
 }
 
@@ -243,7 +244,7 @@ void OpEmitter::emitBuilder() {
   // 1. Stand-alone parameters
 
   std::vector<Record *> returnTypes = def.getValueAsListOfDefs("returnTypes");
-  os << "  static void build(Builder* builder, OperationState* result";
+  OUT(2) << "static void build(Builder* builder, OperationState* result";
 
   // Emit parameters for all return types
   for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
@@ -267,7 +268,7 @@ void OpEmitter::emitBuilder() {
 
   // Push all result types to the result
   if (!returnTypes.empty()) {
-    os << "    result->addTypes({returnType0";
+    OUT(4) << "result->addTypes({returnType0";
     for (unsigned i = 1, e = returnTypes.size(); i != e; ++i)
       os << ", returnType" << i;
     os << "});\n\n";
@@ -275,7 +276,7 @@ void OpEmitter::emitBuilder() {
 
   // Push all operands to the result
   if (op.getNumOperands() > 0) {
-    os << "    result->addOperands({" << getArgumentName(op, 0);
+    OUT(4) << "result->addOperands({" << getArgumentName(op, 0);
     for (int i = 1, e = op.getNumOperands(); i != e; ++i)
       os << ", " << getArgumentName(op, i);
     os << "});\n";
@@ -284,45 +285,45 @@ void OpEmitter::emitBuilder() {
   // Push all attributes to the result
   for (const auto &attr : op.getAttributes())
     if (!attr.isDerived)
-      os.indent(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
-                              getAttributeName(attr));
-  os << "  }\n";
+      OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
+                        getAttributeName(attr));
+  OUT(2) << "}\n";
 
   // 2. Aggregated parameters
 
   // Signature
-  os << "  static void build(Builder* builder, OperationState* result, "
-     << "ArrayRef<Type> resultTypes, ArrayRef<Value*> args, "
-        "ArrayRef<NamedAttribute> attributes) {\n";
+  OUT(2) << "static void build(Builder* builder, OperationState* result, "
+         << "ArrayRef<Type> resultTypes, ArrayRef<Value*> args, "
+            "ArrayRef<NamedAttribute> attributes) {\n";
 
   // Result types
-  os << "    assert(resultTypes.size() == " << returnTypes.size()
-     << "u && \"mismatched number of return types\");\n"
-     << "    result->addTypes(resultTypes);\n";
+  OUT(4) << "assert(resultTypes.size() == " << returnTypes.size()
+         << "u && \"mismatched number of return types\");\n"
+         << "    result->addTypes(resultTypes);\n";
 
   // Operands
-  os << "    assert(args.size() == " << op.getNumOperands()
-     << "u && \"mismatched number of parameters\");\n"
-     << "    result->addOperands(args);\n\n";
+  OUT(4) << "assert(args.size() == " << op.getNumOperands()
+         << "u && \"mismatched number of parameters\");\n"
+         << "    result->addOperands(args);\n\n";
 
   // Attributes
   if (op.getNumAttributes() > 0) {
-    os << "    assert(!attributes.size() && \"no attributes expected\");\n"
-       << "  }\n";
+    OUT(4) << "assert(!attributes.size() && \"no attributes expected\");\n"
+           << "  }\n";
   } else {
-    os << "    assert(attributes.size() >= " << op.getNumAttributes()
-       << "u && \"not enough attributes\");\n"
-       << "    for (const auto& pair : attributes)\n"
-       << "      result->addAttribute(pair.first, pair.second);\n"
-       << "  }\n";
+    OUT(4) << "assert(attributes.size() >= " << op.getNumAttributes()
+           << "u && \"not enough attributes\");\n"
+           << "    for (const auto& pair : attributes)\n"
+           << "      result->addAttribute(pair.first, pair.second);\n"
+           << "  }\n";
   }
 }
 
 void OpEmitter::emitCanonicalizationPatterns() {
   if (!def.getValueAsBit("hasCanonicalizationPatterns"))
     return;
-  os << "  static void getCanonicalizationPatterns("
-     << "OwningRewritePatternList &results, MLIRContext* context);\n";
+  OUT(2) << "static void getCanonicalizationPatterns("
+         << "OwningRewritePatternList &results, MLIRContext* context);\n";
 }
 
 void OpEmitter::emitConstantFolder() {
@@ -363,7 +364,7 @@ void OpEmitter::emitVerifier() {
   if (!hasCustomVerify && op.getNumArgs() == 0)
     return;
 
-  os << "  bool verify() const {\n";
+  OUT(2) << "bool verify() const {\n";
   // Verify the attributes have the correct type.
   for (const auto &attr : op.getAttributes()) {
     if (attr.isDerived)
@@ -371,17 +372,15 @@ void OpEmitter::emitVerifier() {
 
     auto name = getAttributeName(attr);
     if (!hasStringAttribute(*attr.record, "storageType")) {
-      os << "    if (!this->getAttr(\"" << name
-         << "\")) return emitOpError(\"requires attribute '" << name
-         << "'\");\n";
+      OUT(4) << "if (!this->getAttr(\"" << name
+             << "\")) return emitOpError(\"requires attribute '" << name
+             << "'\");\n";
       continue;
     }
 
-    os << "    if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
-       << attr.record->getValueAsString("storageType").trim()
-       << ">()) return emitOpError(\"requires "
-       << attr.record->getValueAsString("returnType").trim() << " attribute '"
-       << name << "'\");\n";
+    OUT(4) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
+           << attr.getStorageType() << ">()) return emitOpError(\"requires "
+           << attr.getReturnType() << " attribute '" << name << "'\");\n";
   }
 
   // TODO: Handle variadic.
@@ -392,17 +391,17 @@ void OpEmitter::emitVerifier() {
     if (operand.hasMatcher()) {
       auto pred =
           "if (!(" + operand.createTypeMatcherTemplate() + ")) return false;\n";
-      os.indent(4) << formatv(pred, "this->getInstruction()->getOperand(" +
-                                        Twine(opIndex) + ")->getType()");
+      OUT(4) << formatv(pred, "this->getInstruction()->getOperand(" +
+                                  Twine(opIndex) + ")->getType()");
     }
     ++opIndex;
   }
 
   if (hasCustomVerify)
-    os << "    " << codeInit->getValue() << "\n";
+    OUT(4) << codeInit->getValue() << "\n";
   else
-    os << "    return false;\n";
-  os << "  }\n";
+    OUT(4) << "return false;\n";
+  OUT(2) << "}\n";
 }
 
 void OpEmitter::emitTraits() {
index 6bc1366bd5845922aeaccdca1b7bb35c38656f63..c83480d113019b80816971dd7cb41980d3013af8 100644 (file)
@@ -150,6 +150,8 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
                         "' in pattern and op's definition");
   for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
     auto arg = tree->getArg(i);
+    auto opArg = op.getArg(i);
+
     if (auto argTree = dyn_cast<DagInit>(arg)) {
       os.indent(indent) << "{\n";
       os.indent(indent + 2) << formatv(
@@ -162,12 +164,11 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
 
     // Verify arguments.
     if (auto defInit = dyn_cast<DefInit>(arg)) {
-      auto opArg = op.getArg(i);
       // Verify operands.
       if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) {
         // Skip verification where not needed due to definition of op.
         if (operand->defInit == defInit)
-          goto SkipOperandVerification;
+          goto StateCapture;
 
         if (!defInit->getDef()->isSubClassOf("Type"))
           PrintFatalError(pattern->getLoc(),
@@ -185,15 +186,24 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
                        formatv("op{0}->getOperand({1})->getType()", depth, i))
             << ")) return matchFailure();\n";
       }
+
+      // TODO(jpienaar): Verify attributes.
+      if (auto *attr = opArg.dyn_cast<Operator::Attribute *>()) {
+      }
     }
-  SkipOperandVerification:
-    // TODO(jpienaar): Verify attributes.
 
+  StateCapture:
     auto name = tree->getArgNameStr(i);
     if (name.empty())
       continue;
-    os.indent(indent) << "state->" << name << " = op" << depth
-                      << "->getOperand(" << i << ");\n";
+    if (opArg.is<Operator::Operand *>())
+      os.indent(indent) << "state->" << name << " = op" << depth
+                        << "->getOperand(" << i << ");\n";
+    if (auto attr = opArg.dyn_cast<Operator::Attribute *>()) {
+      os.indent(indent) << "state->" << name << " = op" << depth
+                        << "->getAttrOfType<" << attr->getStorageType()
+                        << ">(\"" << attr->getName() << "\");\n";
+    }
   }
 }
 
@@ -291,13 +301,18 @@ void Pattern::emit(StringRef rewriteName) {
     (os << ",\n").indent(6);
 
     // The argument in the result DAG pattern.
-    auto name = resultOp.getArgName(i);
+    auto name = resultTree->getArgNameStr(i);
+    auto opName = resultOp.getArgName(i);
     auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
     auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
-    if (!value)
-      PrintFatalError(pattern->getLoc(),
-                      Twine("attribute '") + name +
-                          "' needs to be constant initialized");
+    if (!value) {
+      if (boundArguments.find(name) == boundArguments.end())
+        PrintFatalError(pattern->getLoc(),
+                        Twine("referencing unbound variable '") + name + "'");
+      os << "/*" << opName << "=*/"
+         << "s." << name;
+      continue;
+    }
 
     // TODO(jpienaar): Refactor out into map to avoid recomputing these.
     auto argument = resultOp.getArg(i);