[mlir] Plumb through default attribute populate for extensible dialect.
authorJacques Pienaar <jpienaar@google.com>
Wed, 13 Jul 2022 16:05:04 +0000 (09:05 -0700)
committerJacques Pienaar <jpienaar@google.com>
Wed, 13 Jul 2022 16:05:04 +0000 (09:05 -0700)
mlir/include/mlir/IR/ExtensibleDialect.h
mlir/lib/IR/ExtensibleDialect.cpp

index 65f9f19..84dc567 100644 (file)
@@ -362,7 +362,8 @@ public:
       OperationName::PrintAssemblyFn &&printFn,
       OperationName::FoldHookFn &&foldHookFn,
       OperationName::GetCanonicalizationPatternsFn
-          &&getCanonicalizationPatternsFn);
+          &&getCanonicalizationPatternsFn,
+      OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);
 
   /// Returns the op typeID.
   TypeID getTypeID() { return typeID; }
@@ -405,15 +406,23 @@ public:
     getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
   }
 
+  /// Set the hook populating default attributes.
+  void setPopulateDefaultAttrsFn(
+      OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrs) {
+    populateDefaultAttrsFn = std::move(populateDefaultAttrs);
+  }
+
 private:
-  DynamicOpDefinition(StringRef name, ExtensibleDialect *dialect,
-                      OperationName::VerifyInvariantsFn &&verifyFn,
-                      OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
-                      OperationName::ParseAssemblyFn &&parseFn,
-                      OperationName::PrintAssemblyFn &&printFn,
-                      OperationName::FoldHookFn &&foldHookFn,
-                      OperationName::GetCanonicalizationPatternsFn
-                          &&getCanonicalizationPatternsFn);
+  DynamicOpDefinition(
+      StringRef name, ExtensibleDialect *dialect,
+      OperationName::VerifyInvariantsFn &&verifyFn,
+      OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
+      OperationName::ParseAssemblyFn &&parseFn,
+      OperationName::PrintAssemblyFn &&printFn,
+      OperationName::FoldHookFn &&foldHookFn,
+      OperationName::GetCanonicalizationPatternsFn
+          &&getCanonicalizationPatternsFn,
+      OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);
 
   /// Unique identifier for this operation.
   TypeID typeID;
@@ -431,7 +440,7 @@ private:
   OperationName::PrintAssemblyFn printFn;
   OperationName::FoldHookFn foldHookFn;
   OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
-  OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn;
+  OperationName::PopulateDefaultAttrsFn populateDefaultAttrsFn;
 
   friend ExtensibleDialect;
 };
index 0dcc971..148ba9f 100644 (file)
@@ -295,13 +295,15 @@ DynamicOpDefinition::DynamicOpDefinition(
     OperationName::PrintAssemblyFn &&printFn,
     OperationName::FoldHookFn &&foldHookFn,
     OperationName::GetCanonicalizationPatternsFn
-        &&getCanonicalizationPatternsFn)
+        &&getCanonicalizationPatternsFn,
+    OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn)
     : typeID(dialect->allocateTypeID()),
       name((dialect->getNamespace() + "." + name).str()), dialect(dialect),
       verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
       parseFn(std::move(parseFn)), printFn(std::move(printFn)),
       foldHookFn(std::move(foldHookFn)),
-      getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)) {}
+      getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)),
+      populateDefaultAttrsFn(std::move(populateDefaultAttrsFn)) {}
 
 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
     StringRef name, ExtensibleDialect *dialect,
@@ -336,25 +338,31 @@ std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
   auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) {
   };
 
+  auto populateDefaultAttrsFn = [](const RegisteredOperationName &,
+                                   NamedAttrList &) {};
+
   return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
                                   std::move(verifyRegionFn), std::move(parseFn),
                                   std::move(printFn), std::move(foldHookFn),
-                                  std::move(getCanonicalizationPatternsFn));
+                                  std::move(getCanonicalizationPatternsFn),
+                                  std::move(populateDefaultAttrsFn));
 }
 
-std::unique_ptr<DynamicOpDefinition>
-DynamicOpDefinition::get(StringRef name, ExtensibleDialect *dialect,
-                         OperationName::VerifyInvariantsFn &&verifyFn,
-                         OperationName::VerifyInvariantsFn &&verifyRegionFn,
-                         OperationName::ParseAssemblyFn &&parseFn,
-                         OperationName::PrintAssemblyFn &&printFn,
-                         OperationName::FoldHookFn &&foldHookFn,
-                         OperationName::GetCanonicalizationPatternsFn
-                             &&getCanonicalizationPatternsFn) {
+std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
+    StringRef name, ExtensibleDialect *dialect,
+    OperationName::VerifyInvariantsFn &&verifyFn,
+    OperationName::VerifyInvariantsFn &&verifyRegionFn,
+    OperationName::ParseAssemblyFn &&parseFn,
+    OperationName::PrintAssemblyFn &&printFn,
+    OperationName::FoldHookFn &&foldHookFn,
+    OperationName::GetCanonicalizationPatternsFn
+        &&getCanonicalizationPatternsFn,
+    OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn) {
   return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition(
       name, dialect, std::move(verifyFn), std::move(verifyRegionFn),
       std::move(parseFn), std::move(printFn), std::move(foldHookFn),
-      std::move(getCanonicalizationPatternsFn)));
+      std::move(getCanonicalizationPatternsFn),
+      std::move(populateDefaultAttrsFn)));
 }
 
 //===----------------------------------------------------------------------===//
@@ -448,7 +456,7 @@ void ExtensibleDialect::registerDynamicOp(
       std::move(op->verifyRegionFn), std::move(op->foldHookFn),
       std::move(op->getCanonicalizationPatternsFn),
       detail::InterfaceMap::get<>(), std::move(hasTraitFn), {},
-      std::move(op->getPopulateDefaultAttrsFn));
+      std::move(op->populateDefaultAttrsFn));
 }
 
 bool ExtensibleDialect::classof(const Dialect *dialect) {