[mlir][AttrTypeGen] Add support for specifying a "accessor" type of a parameter
authorRiver Riddle <riddleriver@gmail.com>
Wed, 25 Aug 2021 09:26:56 +0000 (09:26 +0000)
committerRiver Riddle <riddleriver@gmail.com>
Wed, 25 Aug 2021 09:27:36 +0000 (09:27 +0000)
This allows for using a different type when accessing a parameter than the
one used for storage. This allows for returning parameters by reference,
enables using more optimized/convient reference results, and more.

Differential Revision: https://reviews.llvm.org/D108593

mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/AttrOrTypeDef.h
mlir/lib/TableGen/AttrOrTypeDef.cpp
mlir/test/mlir-tblgen/attrdefs.td
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

index 59088ae..e1b7e28 100644 (file)
@@ -2836,20 +2836,24 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
 
 // 'Parameters' should be subclasses of this or simple strings (which is a
 // shorthand for AttrOrTypeParameter<"C++Type">).
-class AttrOrTypeParameter<string type, string desc> {
+class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
   // Custom memory allocation code for storage constructor.
   code allocator = ?;
   // Custom comparator used to compare two instances for equality.
   code comparator = ?;
   // The C++ type of this parameter.
   string cppType = type;
+  // The C++ type of the accessor for this parameter.
+  string cppAccessorType = !if(!empty(accessorType), type, accessorType);
   // One-line human-readable description of the argument.
   string summary = desc;
   // The format string for the asm syntax (documentation only).
   string syntax = ?;
 }
-class AttrParameter<string type, string desc> : AttrOrTypeParameter<type, desc>;
-class TypeParameter<string type, string desc> : AttrOrTypeParameter<type, desc>;
+class AttrParameter<string type, string desc, string accessorType = "">
+ : AttrOrTypeParameter<type, desc, accessorType>;
+class TypeParameter<string type, string desc, string accessorType = "">
+ : AttrOrTypeParameter<type, desc, accessorType>;
 
 // For StringRefs, which require allocation.
 class StringRefParameter<string desc = ""> :
index ab07f43..2029c0e 100644 (file)
@@ -196,6 +196,9 @@ public:
   // Get the C++ type of this parameter.
   StringRef getCppType() const;
 
+  // Get the C++ accessor type of this parameter.
+  StringRef getCppAccessorType() const;
+
   // Get a description of this parameter for documentation purposes.
   Optional<StringRef> getSummary() const;
 
index c439b3c..2a0ad96 100644 (file)
@@ -210,6 +210,15 @@ StringRef AttrOrTypeParameter::getCppType() const {
       "which inherit from AttrOrTypeParameter\n");
 }
 
+StringRef AttrOrTypeParameter::getCppAccessorType() const {
+  if (auto *param = dyn_cast<llvm::DefInit>(def->getArg(index))) {
+    if (Optional<StringRef> type =
+            param->getDef()->getValueAsOptionalString("cppAccessorType"))
+      return *type;
+  }
+  return getCppType();
+}
+
 Optional<StringRef> AttrOrTypeParameter::getSummary() const {
   auto *parameterType = def->getArg(index);
   if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
index c42da94..7954fe3 100644 (file)
@@ -135,3 +135,14 @@ def E_AttrWithTypeBuilder : TestAttr<"AttrWithTypeBuilder"> {
 // DEF-LABEL: struct AttrWithTypeBuilderAttrStorage
 // DEF: AttrWithTypeBuilderAttrStorage (::mlir::IntegerAttr attr)
 // DEF-NEXT: : ::mlir::AttributeStorage(attr.getType()), attr(attr)
+
+def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> {
+  let parameters = (ins AttrParameter<"std::string", "", "StringRef">:$param);
+}
+
+// DECL-LABEL: class ParamWithAccessorTypeAttr
+// DECL: StringRef getParam()
+// DEF: ParamWithAccessorTypeAttrStorage
+// DEF-NEXT: ParamWithAccessorTypeAttrStorage (std::string param)
+// DEF: StringRef ParamWithAccessorTypeAttr::getParam()
+
index 5b1b803..d254d78 100644 (file)
@@ -413,7 +413,8 @@ void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
     for (AttrOrTypeParameter &parameter : parameters) {
       SmallString<16> name = parameter.getName();
       name[0] = llvm::toUpper(name[0]);
-      os << formatv("    {0} get{1}() const;\n", parameter.getCppType(), name);
+      os << formatv("    {0} get{1}() const;\n", parameter.getCppAccessorType(),
+                    name);
     }
   }
 
@@ -859,7 +860,7 @@ void DefGenerator::emitDefDef(const AttrOrTypeDef &def) {
         SmallString<16> name = param.getName();
         name[0] = llvm::toUpper(name[0]);
         os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n",
-                      param.getCppType(), name, paramStorageName,
+                      param.getCppAccessorType(), name, paramStorageName,
                       def.getCppClassName());
       }
     }