[mlir][AttrDefGen] Add support for specifying the value type of an attribute
authorRiver Riddle <riddleriver@gmail.com>
Thu, 4 Mar 2021 20:37:22 +0000 (12:37 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 4 Mar 2021 21:04:05 +0000 (13:04 -0800)
The value type of the attribute can be specified by either overriding the typeBuilder field on the AttrDef, or by providing a parameter of type `AttributeSelfTypeParameter`. This removes the need to define custom storage class constructors for attributes that have a value type other than NoneType.

Differential Revision: https://reviews.llvm.org/D97590

mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/AttrOrTypeDef.h
mlir/lib/TableGen/AttrOrTypeDef.cpp
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/Dialect/Test/TestAttributes.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/attrdefs.td
mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

index db38107..c18034c 100644 (file)
@@ -2639,6 +2639,14 @@ class AttrDef<Dialect dialect, string name,
   // The name of the C++ Attribute class.
   string cppClassName = name # "Attr";
 
+  // A code block used to build the value 'Type' of an Attribute when
+  // initializing its storage instance. This field is optional, and if not
+  // present the attribute will have its value type set to `NoneType`. This code
+  // block may reference any of the attributes parameters via
+  // `$_<parameter-name`. If one of the parameters of the attribute is of type
+  // `AttributeSelfTypeParameter`, this field is ignored.
+  code typeBuilder = ?;
+
   // The predicate for when this def is used as a constraint.
   let predicate = CPred<"$_self.isa<" # dialect.cppNamespace #
                                  "::" # cppClassName # ">()">;
@@ -2704,4 +2712,10 @@ class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
   }];
 }
 
+// This is a special parameter used for AttrDefs that represents a `mlir::Type`
+// that is also used as the value `Type` of the attribute. Only one parameter
+// of the attribute may be of this type.
+class AttributeSelfTypeParameter<string desc> :
+    AttrOrTypeParameter<"::mlir::Type", desc> {}
+
 #endif // OP_BASE
index 8ce7520..84a5a04 100644 (file)
@@ -130,7 +130,10 @@ public:
   // Returns whether the AttrOrTypeDef is defined.
   operator bool() const { return def != nullptr; }
 
-private:
+  // Return the underlying def.
+  const llvm::Record *getDef() const { return def; }
+
+protected:
   const llvm::Record *def;
 
   // The builders of this type definition.
@@ -145,6 +148,12 @@ private:
 class AttrDef : public AttrOrTypeDef {
 public:
   using AttrOrTypeDef::AttrOrTypeDef;
+
+  // Returns the attributes value type builder code block, or None if it doesn't
+  // have one.
+  Optional<StringRef> getTypeBuilder() const;
+
+  static bool classof(const AttrOrTypeDef *def);
 };
 
 //===----------------------------------------------------------------------===//
@@ -183,6 +192,9 @@ public:
   // Get the assembly syntax documentation.
   StringRef getSyntax() const;
 
+  // Return the underlying def of this parameter.
+  const llvm::Init *getDef() const;
+
 private:
   /// The underlying tablegen parameter list this parameter is a part of.
   const llvm::DagInit *def;
@@ -190,6 +202,17 @@ private:
   unsigned index;
 };
 
+//===----------------------------------------------------------------------===//
+// AttributeSelfTypeParameter
+//===----------------------------------------------------------------------===//
+
+// A wrapper class for the AttributeSelfTypeParameter tblgen class. This
+// represents a parameter of mlir::Type that is the value type of an AttrDef.
+class AttributeSelfTypeParameter : public AttrOrTypeParameter {
+public:
+  static bool classof(const AttrOrTypeParameter *param);
+};
+
 } // end namespace tblgen
 } // end namespace mlir
 
index e82f0f0..037dc4d 100644 (file)
@@ -154,6 +154,18 @@ bool AttrOrTypeDef::operator<(const AttrOrTypeDef &other) const {
 }
 
 //===----------------------------------------------------------------------===//
+// AttrDef
+//===----------------------------------------------------------------------===//
+
+Optional<StringRef> AttrDef::getTypeBuilder() const {
+  return def->getValueAsOptionalString("typeBuilder");
+}
+
+bool AttrDef::classof(const AttrOrTypeDef *def) {
+  return def->getDef()->isSubClassOf("AttrDef");
+}
+
+//===----------------------------------------------------------------------===//
 // AttrOrTypeParameter
 //===----------------------------------------------------------------------===//
 
@@ -219,3 +231,18 @@ StringRef AttrOrTypeParameter::getSyntax() const {
   llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
                         "defs which inherit from AttrOrTypeParameter");
 }
+
+const llvm::Init *AttrOrTypeParameter::getDef() const {
+  return def->getArg(index);
+}
+
+//===----------------------------------------------------------------------===//
+// AttributeSelfTypeParameter
+//===----------------------------------------------------------------------===//
+
+bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
+  const llvm::Init *paramDef = param->getDef();
+  if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
+    return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
+  return false;
+}
index 8b3ebaa..33a94a8 100644 (file)
@@ -41,4 +41,17 @@ def CompoundAttrA : Test_Attr<"CompoundA"> {
   );
 }
 
+// An attribute testing AttributeSelfTypeParameter.
+def AttrWithSelfTypeParam : Test_Attr<"AttrWithSelfTypeParam"> {
+  let mnemonic = "attr_with_self_type_param";
+  let parameters = (ins AttributeSelfTypeParameter<"">:$type);
+}
+
+// An attribute testing AttributeSelfTypeParameter.
+def AttrWithTypeBuilder : Test_Attr<"AttrWithTypeBuilder"> {
+  let mnemonic = "attr_with_type_builder";
+  let parameters = (ins "::mlir::IntegerAttr":$attr);
+  let typeBuilder = "$_attr.getType()";
+}
+
 #endif // TEST_ATTRDEFS
index 39328b6..d13cd5a 100644 (file)
 using namespace mlir;
 using namespace mlir::test;
 
+//===----------------------------------------------------------------------===//
+// AttrWithSelfTypeParamAttr
+//===----------------------------------------------------------------------===//
+
+Attribute AttrWithSelfTypeParamAttr::parse(MLIRContext *context,
+                                           DialectAsmParser &parser,
+                                           Type type) {
+  Type selfType;
+  if (parser.parseType(selfType))
+    return Attribute();
+  return get(context, selfType);
+}
+
+void AttrWithSelfTypeParamAttr::print(DialectAsmPrinter &printer) const {
+  printer << "attr_with_self_type_param " << getType();
+}
+
+//===----------------------------------------------------------------------===//
+// AttrWithTypeBuilderAttr
+//===----------------------------------------------------------------------===//
+
+Attribute AttrWithTypeBuilderAttr::parse(MLIRContext *context,
+                                         DialectAsmParser &parser, Type type) {
+  IntegerAttr element;
+  if (parser.parseAttribute(element))
+    return Attribute();
+  return get(context, element);
+}
+
+void AttrWithTypeBuilderAttr::print(DialectAsmPrinter &printer) const {
+  printer << "attr_with_type_builder " << getAttr();
+}
+
+//===----------------------------------------------------------------------===//
+// CompoundAAttr
+//===----------------------------------------------------------------------===//
+
 Attribute CompoundAAttr::parse(MLIRContext *context, DialectAsmParser &parser,
                                Type type) {
   int widthOfSomething;
index e09e003..9d91dc4 100644 (file)
@@ -420,6 +420,14 @@ def OperandsHaveSameType :
   let arguments = (ins AnyType:$x, AnyType:$y);
 }
 
+def ResultHasSameTypeAsAttr :
+    TEST_Op<"result_has_same_type_as_attr",
+            [AllTypesMatch<["attr", "result"]>]> {
+  let arguments = (ins AnyAttr:$attr);
+  let results = (outs AnyType:$result);
+  let assemblyFormat = "$attr `->` type($result) attr-dict";
+}
+
 def OperandZeroAndResultHaveSameType :
     TEST_Op<"operand0_and_result_have_same_type",
             [AllTypesMatch<["x", "res"]>]> {
index 36ea2cb..802c2a1 100644 (file)
@@ -23,7 +23,7 @@ include "mlir/IR/OpBase.td"
 // DEF-NEXT: ::llvm::StringRef mnemonic, ::mlir::Type type) {
 // DEF: if (mnemonic == ::mlir::test::CompoundAAttr::getMnemonic()) return ::mlir::test::CompoundAAttr::parse(context, parser, type);
 // DEF-NEXT: if (mnemonic == ::mlir::test::IndexAttr::getMnemonic()) return ::mlir::test::IndexAttr::parse(context, parser, type);
-// DEF-NEXT: return ::mlir::Attribute();
+// DEF: return ::mlir::Attribute();
 
 def Test_Dialect: Dialect {
 // DECL-NOT: TestDialect
@@ -49,7 +49,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
       "::mlir::test::SimpleTypeA": $exampleTdType,
       "SomeCppStruct": $exampleCppType,
       ArrayRefParameter<"int", "Matrix dimensions">:$dims,
-      "::mlir::Type":$inner
+      AttributeSelfTypeParameter<"">:$inner
   );
 
   let genVerifyDecl = 1;
@@ -66,6 +66,20 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
 // DECL: int getWidthOfSomething() const;
 // DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
 // DECL: SomeCppStruct getExampleCppType() const;
+
+// Check that AttributeSelfTypeParameter is handled properly.
+// DEF-LABEL: struct CompoundAAttrStorage
+// DEF: CompoundAAttrStorage (
+// DEF-NEXT: : ::mlir::AttributeStorage(inner),
+
+// DEF: bool operator==(const KeyTy &key) const {
+// DEF-NEXT: return key == KeyTy(widthOfSomething, exampleTdType, exampleCppType, dims, getType());
+
+// DEF: static CompoundAAttrStorage *construct
+// DEF: return new (allocator.allocate<CompoundAAttrStorage>())
+// DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, exampleCppType, dims, inner);
+
+// DEF: ::mlir::Type CompoundAAttr::getInner() const { return getImpl()->getType(); }
 }
 
 def C_IndexAttr : TestAttr<"Index"> {
@@ -94,3 +108,14 @@ def D_SingleParameterAttr : TestAttr<"SingleParameter"> {
 // DECL-LABEL: class SingleParameterAttr
 // DECL-NEXT:                   detail::SingleParameterAttrStorage
 }
+
+// An attribute testing AttributeSelfTypeParameter.
+def E_AttrWithTypeBuilder : TestAttr<"AttrWithTypeBuilder"> {
+  let mnemonic = "attr_with_type_builder";
+  let parameters = (ins "::mlir::IntegerAttr":$attr);
+  let typeBuilder = "$_attr.getType()";
+}
+
+// DEF-LABEL: struct AttrWithTypeBuilderAttrStorage
+// DEF: AttrWithTypeBuilderAttrStorage (::mlir::IntegerAttr attr)
+// DEF-NEXT: : ::mlir::AttributeStorage(attr.getType()), attr(attr)
index 8c167ff..75e1681 100644 (file)
@@ -3,3 +3,9 @@
 // CHECK-LABEL: func private @compoundA()
 // CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]>
 func private @compoundA() attributes {foo = #test.cmpnd_a<1, !test.smpla, [5, 6]>}
+
+// CHECK: test.result_has_same_type_as_attr #test<"attr_with_self_type_param i32"> -> i32
+%a = test.result_has_same_type_as_attr #test<"attr_with_self_type_param i32"> -> i32
+
+// CHECK: test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16
+%b = test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16
index 3df9fd9..b7e5c7d 100644 (file)
@@ -485,31 +485,90 @@ static void emitStorageParameterAllocation(const AttrOrTypeDef &def,
   }
 }
 
+/// Builds a code block that initializes the attribute storage of 'def'.
+/// Attribute initialization is separated from Type initialization given that
+/// the Attribute also needs to initialize its self-type, which has multiple
+/// means of initialization.
+static std::string buildAttributeStorageParamInitializer(
+    const AttrOrTypeDef &def, ArrayRef<AttrOrTypeParameter> parameters) {
+  std::string paramInitializer;
+  llvm::raw_string_ostream paramOS(paramInitializer);
+  paramOS << "::mlir::AttributeStorage(";
+
+  // If this is an attribute, we need to check for value type initialization.
+  Optional<size_t> selfParamIndex;
+  for (auto it : llvm::enumerate(parameters)) {
+    const auto *selfParam = dyn_cast<AttributeSelfTypeParameter>(&it.value());
+    if (!selfParam)
+      continue;
+    if (selfParamIndex) {
+      llvm::PrintFatalError(def.getLoc(),
+                            "Only one attribute parameter can be marked as "
+                            "AttributeSelfTypeParameter");
+    }
+    paramOS << selfParam->getName();
+    selfParamIndex = it.index();
+  }
+
+  // If we didn't find a self param, but the def has a type builder we use that
+  // to construct the type.
+  if (!selfParamIndex) {
+    const AttrDef &attrDef = cast<AttrDef>(def);
+    if (Optional<StringRef> typeBuilder = attrDef.getTypeBuilder()) {
+      FmtContext fmtContext;
+      for (const AttrOrTypeParameter &param : parameters)
+        fmtContext.addSubst(("_" + param.getName()).str(), param.getName());
+      paramOS << tgfmt(*typeBuilder, &fmtContext);
+    }
+  }
+  paramOS << ")";
+
+  // Append the parameters to the initializer.
+  for (auto it : llvm::enumerate(parameters))
+    if (it.index() != selfParamIndex)
+      paramOS << llvm::formatv(", {0}({0})", it.value().getName());
+
+  return paramOS.str();
+}
+
 void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
-  SmallVector<AttrOrTypeParameter, 4> parameters;
-  def.getParameters(parameters);
+  SmallVector<AttrOrTypeParameter, 4> params;
+  def.getParameters(params);
 
-  // Collect the parameter names and types.
-  auto parameterNames =
-      map_range(parameters, [](AttrOrTypeParameter parameter) {
-        return parameter.getName();
-      });
+  // Collect the parameter types.
   auto parameterTypes =
-      map_range(parameters, [](AttrOrTypeParameter parameter) {
+      llvm::map_range(params, [](const AttrOrTypeParameter &parameter) {
         return parameter.getCppType();
       });
-  auto parameterList = join(parameterNames, ", ");
-  auto parameterTypeList = join(parameterTypes, ", ");
+  std::string parameterTypeList = llvm::join(parameterTypes, ", ");
+
+  // Collect the parameter initializer.
+  std::string paramInitializer;
+  if (isAttrGenerator) {
+    paramInitializer = buildAttributeStorageParamInitializer(def, params);
+
+  } else {
+    llvm::raw_string_ostream initOS(paramInitializer);
+    llvm::interleaveComma(params, initOS, [&](const AttrOrTypeParameter &it) {
+      initOS << llvm::formatv("{0}({0})", it.getName());
+    });
+  }
+
+  // Construct the parameter list that is used when a concrete instance of the
+  // storage exists.
+  auto nonStaticParameterNames = llvm::map_range(params, [](const auto &param) {
+    return isa<AttributeSelfTypeParameter>(param) ? "getType()"
+                                                  : param.getName();
+  });
 
   // 1) Emit most of the storage class up until the hashKey body.
   os << formatv(
       defStorageClassBeginStr, def.getStorageNamespace(),
       def.getStorageClassName(),
       ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
-                          parameters, /*prependComma=*/false),
-      ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNameInitializer,
-                          parameters, /*prependComma=*/false),
-      parameterList, parameterTypeList, valueType);
+                          params, /*prependComma=*/false),
+      paramInitializer, llvm::join(nonStaticParameterNames, ", "),
+      parameterTypeList, valueType);
 
   // 2) Emit the haskKey method.
   os << "  static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
@@ -517,7 +576,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
   // Extract each parameter from the key.
   os << "      return ::llvm::hash_combine(";
   llvm::interleaveComma(
-      llvm::seq<unsigned>(0, parameters.size()), os,
+      llvm::seq<unsigned>(0, params.size()), os,
       [&](unsigned it) { os << "std::get<" << it << ">(key)"; });
   os << ");\n    }\n";
 
@@ -535,9 +594,9 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
     // First, unbox the parameters.
     os << formatv(defStorageClassConstructorBeginStr, def.getStorageClassName(),
                   valueType);
-    for (unsigned i = 0, e = parameters.size(); i < e; ++i) {
+    for (unsigned i = 0, e = params.size(); i < e; ++i) {
       os << formatv("      auto {0} = std::get<{1}>(key);\n",
-                    parameters[i].getName(), i);
+                    params[i].getName(), i);
     }
 
     // Second, reassign the parameter variables with allocation code, if it's
@@ -545,14 +604,18 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
     emitStorageParameterAllocation(def, os);
 
     // Last, return an allocated copy.
+    auto parameterNames = llvm::map_range(
+        params, [](const auto &param) { return param.getName(); });
     os << formatv(defStorageClassConstructorEndStr, def.getStorageClassName(),
-                  parameterList);
+                  llvm::join(parameterNames, ", "));
   }
 
   // 4) Emit the parameters as storage class members.
-  for (auto parameter : parameters) {
-    os << "      " << parameter.getCppType() << " " << parameter.getName()
-       << ";\n";
+  for (const AttrOrTypeParameter &parameter : params) {
+    // Attribute value types are not stored as fields in the storage.
+    if (!isa<AttributeSelfTypeParameter>(parameter))
+      os << "      " << parameter.getCppType() << " " << parameter.getName()
+         << ";\n";
   }
   os << "  };\n";
 
@@ -708,10 +771,14 @@ void DefGenerator::emitDefDef(const AttrOrTypeDef &def) {
     // Otherwise, let the user define the exact accessor definition.
     if (def.genAccessors() && def.genStorageClass()) {
       for (const AttrOrTypeParameter &parameter : parameters) {
+        StringRef paramStorageName = isa<AttributeSelfTypeParameter>(parameter)
+                                         ? "getType()"
+                                         : parameter.getName();
+
         SmallString<16> name = parameter.getName();
         name[0] = llvm::toUpper(name[0]);
         os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n",
-                      parameter.getCppType(), name, parameter.getName(),
+                      parameter.getCppType(), name, paramStorageName,
                       def.getCppClassName());
       }
     }