[mlir] Add method to populate default attributes
authorJacques Pienaar <jpienaar@google.com>
Fri, 8 Jul 2022 18:31:12 +0000 (11:31 -0700)
committerJacques Pienaar <jpienaar@google.com>
Fri, 8 Jul 2022 18:31:13 +0000 (11:31 -0700)
Previously default attributes were only usable by way of the ODS generated
accessors, but this was undesirable as
1. The ODS getters could construct Attribute each get request;
2. For non-C++ uses this would require either duplicating some of tee default
   attribute generating or generating additional bindings to generate methods;
3. Accessing op.getAttr("foo") and op.getFoo() would return different results;
Generate method to populate default attributes that can be used to address
these.

This merely adds this facility but does not employ by default on any path.

Differential Revision: https://reviews.llvm.org/D128962

mlir/include/mlir/IR/ExtensibleDialect.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/ExtensibleDialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/unittests/IR/OperationSupportTest.cpp

index ee83ef5576df9829bf52360ae5ccbc624cbc061f..65f9f190f57ec20cf10c7f834252da6c9f98e1ed 100644 (file)
@@ -431,6 +431,7 @@ private:
   OperationName::PrintAssemblyFn printFn;
   OperationName::FoldHookFn foldHookFn;
   OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
+  OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn;
 
   friend ExtensibleDialect;
 };
index c98993a2cb93bd3daf804481a1701978f101c2fa..81b0603fa899570c636300eacfc6d97957db83d8 100644 (file)
@@ -182,6 +182,10 @@ public:
   static void getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {}
 
+  /// This hook populates any unset default attrs.
+  static void populateDefaultAttrs(const RegisteredOperationName &,
+                                   NamedAttrList &) {}
+
 protected:
   /// If the concrete type didn't implement a custom verifier hook, just fall
   /// back to this one which accepts everything.
@@ -1869,6 +1873,10 @@ private:
     OpState::printOpName(op, p, defaultDialect);
     return cast<ConcreteType>(op).print(p);
   }
+  /// Implementation of `PopulateDefaultAttrsFn` OperationName hook.
+  static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() {
+    return ConcreteType::populateDefaultAttrs;
+  }
   /// Implementation of `VerifyInvariantsFn` OperationName hook.
   static LogicalResult verifyInvariants(Operation *op) {
     static_assert(hasNoDataMembers(),
index d6a231b1941e7509fad9da5717605f7dec96d769..70509bdb2c5153c384d0ba1911d14871890045f5 100644 (file)
@@ -467,6 +467,15 @@ public:
     setAttrs(attrs.getDictionary(getContext()));
   }
 
+  /// Sets default attributes on unset attributes.
+  void populateDefaultAttrs() {
+    if (auto registered = getRegisteredInfo()) {
+      NamedAttrList attrs(getAttrDictionary());
+      registered->populateDefaultAttrs(attrs);
+      setAttrs(attrs.getDictionary(getContext()));
+    }
+  }
+
   //===--------------------------------------------------------------------===//
   // Blocks
   //===--------------------------------------------------------------------===//
index de09ca58a4092e7a5e1f9fd8e8c32213e885d59b..2c480d6ca52dae7614ba0d30a81edda18421f801 100644 (file)
@@ -36,6 +36,7 @@ class Dialect;
 class DictionaryAttr;
 class ElementsAttr;
 class MutableOperandRangeRange;
+class NamedAttrList;
 class Operation;
 struct OperationState;
 class OpAsmParser;
@@ -69,6 +70,10 @@ public:
   using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
   using ParseAssemblyFn =
       llvm::unique_function<ParseResult(OpAsmParser &, OperationState &) const>;
+  // Note: RegisteredOperationName is passed as reference here as the derived
+  // class is defined below.
+  using PopulateDefaultAttrsFn = llvm::unique_function<void(
+      const RegisteredOperationName &, NamedAttrList &) const>;
   using PrintAssemblyFn =
       llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
   using VerifyInvariantsFn =
@@ -112,6 +117,7 @@ protected:
     GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
     HasTraitFn hasTraitFn;
     ParseAssemblyFn parseAssemblyFn;
+    PopulateDefaultAttrsFn populateDefaultAttrsFn;
     PrintAssemblyFn printAssemblyFn;
     VerifyInvariantsFn verifyInvariantsFn;
     VerifyRegionInvariantsFn verifyRegionInvariantsFn;
@@ -254,7 +260,8 @@ public:
            T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
            T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(),
            T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(),
-           T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames());
+           T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(),
+           T::getPopulateDefaultAttrsFn());
   }
   /// The use of this method is in general discouraged in favor of
   /// 'insert<CustomOp>(dialect)'.
@@ -266,7 +273,8 @@ public:
          FoldHookFn &&foldHook,
          GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
          detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
-         ArrayRef<StringRef> attrNames);
+         ArrayRef<StringRef> attrNames,
+         PopulateDefaultAttrsFn &&populateDefaultAttrs);
 
   /// Return the dialect this operation is registered to.
   Dialect &getDialect() const { return *impl->dialect; }
@@ -364,6 +372,10 @@ public:
     return impl->attributeNames;
   }
 
+  /// This hook implements the method to populate defaults attributes that are
+  /// unset.
+  void populateDefaultAttrs(NamedAttrList &attrs) const;
+
   /// Represent the operation name as an opaque pointer. (Used to support
   /// PointerLikeTypeTraits).
   static RegisteredOperationName getFromOpaquePointer(const void *pointer) {
index 3e96b83031d2941e93957975a43f1f1b965907e0..0dcc971ca2e5a0a36d42f1341e83578943813698 100644 (file)
@@ -447,7 +447,8 @@ void ExtensibleDialect::registerDynamicOp(
       std::move(op->printFn), std::move(op->verifyFn),
       std::move(op->verifyRegionFn), std::move(op->foldHookFn),
       std::move(op->getCanonicalizationPatternsFn),
-      detail::InterfaceMap::get<>(), std::move(hasTraitFn), {});
+      detail::InterfaceMap::get<>(), std::move(hasTraitFn), {},
+      std::move(op->getPopulateDefaultAttrsFn));
 }
 
 bool ExtensibleDialect::classof(const Dialect *dialect) {
index 2a84362635ac61455043f5b76f62aca134f94666..273faa89b826cce083d7aaf940931a44022f4de7 100644 (file)
@@ -707,6 +707,10 @@ RegisteredOperationName::parseAssembly(OpAsmParser &parser,
   return impl->parseAssemblyFn(parser, result);
 }
 
+void RegisteredOperationName::populateDefaultAttrs(NamedAttrList &attrs) const {
+  impl->populateDefaultAttrsFn(*this, attrs);
+}
+
 void RegisteredOperationName::insert(
     StringRef name, Dialect &dialect, TypeID typeID,
     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
@@ -714,7 +718,8 @@ void RegisteredOperationName::insert(
     VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
-    ArrayRef<StringRef> attrNames) {
+    ArrayRef<StringRef> attrNames,
+    PopulateDefaultAttrsFn &&populateDefaultAttrs) {
   MLIRContext *ctx = dialect.getContext();
   auto &ctxImpl = ctx->getImpl();
   assert(ctxImpl.multiThreadedExecutionContext == 0 &&
@@ -769,6 +774,7 @@ void RegisteredOperationName::insert(
   impl.verifyInvariantsFn = std::move(verifyInvariants);
   impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
   impl.attributeNames = cachedAttrNames;
+  impl.populateDefaultAttrsFn = std::move(populateDefaultAttrs);
 }
 
 //===----------------------------------------------------------------------===//
index dd3487bb2b0cd0e846e713535cf6090197ac3f0a..3330fdf3c28a4e4bebadcb849cf51b2d7da3e59b 100644 (file)
@@ -430,6 +430,9 @@ private:
   // Generates getters for named successors.
   void genNamedSuccessorGetters();
 
+  // Generates the method to populate default attributes.
+  void genPopulateDefaultAttributes();
+
   // Generates builder methods for the operation.
   void genBuilder();
 
@@ -823,6 +826,7 @@ OpEmitter::OpEmitter(const Operator &op,
   genAttrSetters();
   genOptionalAttrRemovers();
   genBuilder();
+  genPopulateDefaultAttributes();
   genParser();
   genPrinter();
   genVerifier();
@@ -1587,6 +1591,45 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
        << llvm::join(resultTypes, ", ") << "});\n\n";
 }
 
+void OpEmitter::genPopulateDefaultAttributes() {
+  // All done if no attributes have default values.
+  if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) {
+        return !named.attr.hasDefaultValue();
+      }))
+    return;
+
+  SmallVector<MethodParameter> paramList;
+  paramList.emplace_back("const ::mlir::RegisteredOperationName &", "opName");
+  paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
+  auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList);
+  ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
+  auto &body = m->body();
+  body.indent();
+
+  // Set default attributes that are unset.
+  body << "auto attrNames = opName.getAttributeNames();\n";
+  body << "::mlir::Builder " << odsBuilder
+       << "(attrNames.front().getContext());\n";
+  StringMap<int> attrIndex;
+  for (const auto &it : llvm::enumerate(emitHelper.getAttrMetadata())) {
+    attrIndex[it.value().first] = it.index();
+  }
+  for (const NamedAttribute &namedAttr : op.getAttributes()) {
+    auto &attr = namedAttr.attr;
+    if (!attr.hasDefaultValue())
+      continue;
+    auto index = attrIndex[namedAttr.name];
+    body << "if (!attributes.get(attrNames[" << index << "])) {\n";
+    FmtContext fctx;
+    fctx.withBuilder(odsBuilder);
+    std::string defaultValue = std::string(
+        tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+    body.indent() << formatv(" attributes.append(attrNames[{0}], {1});\n",
+                             index, defaultValue);
+    body.unindent() << "}\n";
+  }
+}
+
 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
   SmallVector<MethodParameter> paramList;
   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
@@ -1869,7 +1912,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
   auto numResults = op.getNumResults();
   resultTypeNames.reserve(numResults);
 
-  paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
+  paramList.emplace_back("::mlir::OpBuilder &", odsBuilder);
   paramList.emplace_back("::mlir::OperationState &", builderOpState);
 
   switch (typeParamKind) {
@@ -2879,7 +2922,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
       body << "  if (!attr)\n    attr = " << defaultValue << ";\n";
     }
-    body << "  return attr;\n";
+    body << "return attr;\n";
   };
 
   {
index 2511a5d3b6bfd7e017c566346d94b73148154fde..b8cbc6d1e6c0c674d85568750dfa7f58dc4ca3eb 100644 (file)
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/OperationSupport.h"
+#include "../../test/lib/Dialect/Test/TestDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/ADT/BitVector.h"
@@ -271,4 +272,22 @@ TEST(NamedAttrListTest, TestAppendAssign) {
   attrs.assign({});
   ASSERT_TRUE(attrs.empty());
 }
+
+TEST(OperandStorageTest, PopulateDefaultAttrs) {
+  MLIRContext context;
+  context.getOrLoadDialect<test::TestDialect>();
+  Builder builder(&context);
+
+  OpBuilder b(&context);
+  auto req1 = b.getI32IntegerAttr(10);
+  auto req2 = b.getI32IntegerAttr(60);
+  Operation *op = b.create<test::OpAttrMatch1>(b.getUnknownLoc(), req1, nullptr,
+                                               nullptr, req2);
+  EXPECT_EQ(op->getAttr("default_valued_attr"), nullptr);
+  op->populateDefaultAttrs();
+  auto opt = op->getAttr("default_valued_attr");
+  EXPECT_NE(opt, nullptr) << *op;
+
+  op->destroy();
+}
 } // namespace