[mlir] FunctionOpInterface: arg and result attrs dispatch to interface
authorJeff Niu <jeff@modular.com>
Tue, 6 Dec 2022 20:55:43 +0000 (12:55 -0800)
committerJeff Niu <jeff@modular.com>
Thu, 8 Dec 2022 19:32:38 +0000 (11:32 -0800)
This patch removes the `arg_attrs` and `res_attrs` named attributes as a
requirement for FunctionOpInterface and replaces them with interface
methods for the getters, setters, and removers of the relevent
attributes. This allows operations to use their own storage for the
argument and result attributes.

Depends on D139471

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D139472

36 files changed:
mlir/examples/toy/Ch2/include/toy/Ops.td
mlir/examples/toy/Ch2/mlir/Dialect.cpp
mlir/examples/toy/Ch3/include/toy/Ops.td
mlir/examples/toy/Ch3/mlir/Dialect.cpp
mlir/examples/toy/Ch4/include/toy/Ops.td
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch5/include/toy/Ops.td
mlir/examples/toy/Ch5/mlir/Dialect.cpp
mlir/examples/toy/Ch6/include/toy/Ops.td
mlir/examples/toy/Ch6/mlir/Dialect.cpp
mlir/examples/toy/Ch7/include/toy/Ops.td
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/include/mlir/Dialect/Func/IR/FuncOps.td
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/include/mlir/IR/FunctionImplementation.h
mlir/include/mlir/IR/FunctionInterfaces.h
mlir/include/mlir/IR/FunctionInterfaces.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Func/IR/FuncOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/IR/FunctionImplementation.cpp
mlir/lib/IR/FunctionInterfaces.cpp
mlir/test/IR/invalid-func-op.mlir

index 380536b..4e2fb9e 100644 (file)
@@ -134,7 +134,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 
index ac12c5c..201f9c7 100644 (file)
@@ -218,8 +218,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//
index e526fe5..1a4e6a1 100644 (file)
@@ -133,7 +133,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 
index 75cb57e..4bd1055 100644 (file)
@@ -205,8 +205,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//
index 4956b0e..cbece47 100644 (file)
@@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 
index 2d5a369..3a02ea3 100644 (file)
@@ -294,8 +294,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Returns the region on the function operation that is callable.
index f4e7b08..70e482d 100644 (file)
@@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 
index 280bf31..49ce3d9 100644 (file)
@@ -294,8 +294,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Returns the region on the function operation that is callable.
index ea9323e..cf2bc3f 100644 (file)
@@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 
index 280bf31..49ce3d9 100644 (file)
@@ -294,8 +294,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Returns the region on the function operation that is callable.
index 45ecdd3..08671a7 100644 (file)
@@ -186,7 +186,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 
index b0d2130..cb65a95 100644 (file)
@@ -321,8 +321,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Returns the region on the function operation that is callable.
index 30895e5..14146cd 100644 (file)
@@ -140,7 +140,9 @@ def Async_FuncOp : Async_Op<"func",
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
-                       OptionalAttr<StrAttr>:$sym_visibility);
+                       OptionalAttr<StrAttr>:$sym_visibility,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs);
 
   let regions = (region AnyRegion:$body);
 
index f1b7cfd..4922689 100644 (file)
@@ -251,7 +251,9 @@ def FuncOp : Func_Op<"func", [
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
-                       OptionalAttr<StrAttr>:$sym_visibility);
+                       OptionalAttr<StrAttr>:$sym_visibility,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs);
   let regions = (region AnyRegion:$body);
 
   let builders = [OpBuilder<(ins
index 0642b18..f9fff78 100644 (file)
@@ -242,7 +242,9 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
     attribution.
   }];
 
-  let arguments = (ins TypeAttrOf<FunctionType>:$function_type);
+  let arguments = (ins TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs);
   let regions = (region AnyRegion:$body);
 
   let skipDefaultBuilders = 1;
index c36e390..4d2b2f9 100644 (file)
@@ -1308,7 +1308,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     DefaultValuedAttr<CConv, "CConv::C">:$CConv,
     OptionalAttr<FlatSymbolRefAttr>:$personality,
     OptionalAttr<StrAttr>:$garbageCollector,
-    OptionalAttr<ArrayAttr>:$passthrough
+    OptionalAttr<ArrayAttr>:$passthrough,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
 
   let regions = (region AnyRegion:$body);
index 422680a..db6c773 100644 (file)
@@ -52,6 +52,8 @@ def MLProgram_FuncOp : MLProgram_Op<"func", [
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs,
                        OptionalAttr<StrAttr>:$sym_visibility);
   let regions = (region AnyRegion:$body);
 
@@ -401,6 +403,8 @@ def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs,
                        OptionalAttr<StrAttr>:$sym_visibility);
   let regions = (region AnyRegion:$body);
 
index 42a48cd..6ecbed2 100644 (file)
@@ -652,7 +652,9 @@ def PDLInterp_FuncOp : PDLInterp_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region MinSizedRegion<1>:$body);
 
index 147705e..8339afc 100644 (file)
@@ -291,6 +291,8 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
 
   let arguments = (ins
     TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs,
     StrAttr:$sym_name,
     SPIRV_FunctionControlAttr:$function_control
   );
index c3697f0..97d1f0c 100644 (file)
@@ -1107,6 +1107,8 @@ def Shape_FuncOp : Shape_Op<"func",
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs,
                        OptionalAttr<StrAttr>:$sym_visibility);
   let regions = (region AnyRegion:$body);
 
index f4c0cc0..eb79790 100644 (file)
@@ -39,10 +39,12 @@ private:
 /// with special names given by getResultAttrName, getArgumentAttrName.
 void addArgAndResultAttrs(Builder &builder, OperationState &result,
                           ArrayRef<DictionaryAttr> argAttrs,
-                          ArrayRef<DictionaryAttr> resultAttrs);
+                          ArrayRef<DictionaryAttr> resultAttrs,
+                          StringAttr argAttrsName, StringAttr resAttrsName);
 void addArgAndResultAttrs(Builder &builder, OperationState &result,
-                          ArrayRef<OpAsmParser::Argument> argAttrs,
-                          ArrayRef<DictionaryAttr> resultAttrs);
+                          ArrayRef<OpAsmParser::Argument> args,
+                          ArrayRef<DictionaryAttr> resultAttrs,
+                          StringAttr argAttrsName, StringAttr resAttrsName);
 
 /// Callback type for `parseFunctionOp`, the callback should produce the
 /// type that will be associated with a function-like operation from lists of
@@ -77,15 +79,17 @@ Type getFunctionType(Builder &builder, ArrayRef<OpAsmParser::Argument> argAttrs,
 /// type, report the error or delegate the reporting to the op's verifier.
 ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
                             bool allowVariadic, StringAttr typeAttrName,
-                            FuncTypeBuilder funcTypeBuilder);
+                            FuncTypeBuilder funcTypeBuilder,
+                            StringAttr argAttrsName, StringAttr resAttrsName);
 
 /// Printer implementation for function-like operations.
 void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
-                     StringRef typeAttrName);
+                     StringRef typeAttrName, StringAttr argAttrsName,
+                     StringAttr resAttrsName);
 
 /// Prints the signature of the function-like operation `op`. Assumes `op` has
 /// is a FunctionOpInterface and has passed verification.
-void printFunctionSignature(OpAsmPrinter &p, Operation *op,
+void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
                             ArrayRef<Type> argTypes, bool isVariadic,
                             ArrayRef<Type> resultTypes);
 
index bc2ec47..3beb3db 100644 (file)
@@ -26,48 +26,30 @@ class FunctionOpInterface;
 
 namespace function_interface_impl {
 
-/// Return the name of the attribute used for function argument attributes.
-inline StringRef getArgDictAttrName() { return "arg_attrs"; }
-
-/// Return the name of the attribute used for function argument attributes.
-inline StringRef getResultDictAttrName() { return "res_attrs"; }
-
 /// Returns the dictionary attribute corresponding to the argument at 'index'.
 /// If there are no argument attributes at 'index', a null attribute is
 /// returned.
-DictionaryAttr getArgAttrDict(Operation *op, unsigned index);
+DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index);
 
 /// Returns the dictionary attribute corresponding to the result at 'index'.
 /// If there are no result attributes at 'index', a null attribute is
 /// returned.
-DictionaryAttr getResultAttrDict(Operation *op, unsigned index);
+DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index);
 
-namespace detail {
-/// Update the given index into an argument or result attribute dictionary.
-void setArgResAttrDict(Operation *op, StringRef attrName,
-                       unsigned numTotalIndices, unsigned index,
-                       DictionaryAttr attrs);
-} // namespace detail
+/// Return all of the attributes for the argument at 'index'.
+ArrayRef<NamedAttribute> getArgAttrs(FunctionOpInterface op, unsigned index);
+
+/// Return all of the attributes for the result at 'index'.
+ArrayRef<NamedAttribute> getResultAttrs(FunctionOpInterface op, unsigned index);
 
 /// Set all of the argument or result attribute dictionaries for a function. The
 /// size of `attrs` is expected to match the number of arguments/results of the
 /// given `op`.
-void setAllArgAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
-void setAllArgAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
-void setAllResultAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
-void setAllResultAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
-
-/// Return all of the attributes for the argument at 'index'.
-inline ArrayRef<NamedAttribute> getArgAttrs(Operation *op, unsigned index) {
-  auto argDict = getArgAttrDict(op, index);
-  return argDict ? argDict.getValue() : std::nullopt;
-}
-
-/// Return all of the attributes for the result at 'index'.
-inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
-  auto resultDict = getResultAttrDict(op, index);
-  return resultDict ? resultDict.getValue() : std::nullopt;
-}
+void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs);
+void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs);
+void setAllResultAttrDicts(FunctionOpInterface op,
+                           ArrayRef<DictionaryAttr> attrs);
+void setAllResultAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs);
 
 /// Insert the specified arguments and update the function type attribute.
 void insertFunctionArguments(FunctionOpInterface op,
@@ -110,20 +92,10 @@ TypeRange filterTypesOut(TypeRange types, const BitVector &indices,
 //===----------------------------------------------------------------------===//
 
 /// Set the attributes held by the argument at 'index'.
-template <typename ConcreteType>
-void setArgAttrs(ConcreteType op, unsigned index,
-                 ArrayRef<NamedAttribute> attributes) {
-  assert(index < op.getNumArguments() && "invalid argument number");
-  return detail::setArgResAttrDict(
-      op, getArgDictAttrName(), op.getNumArguments(), index,
-      DictionaryAttr::get(op->getContext(), attributes));
-}
-template <typename ConcreteType>
-void setArgAttrs(ConcreteType op, unsigned index, DictionaryAttr attributes) {
-  return detail::setArgResAttrDict(
-      op, getArgDictAttrName(), op.getNumArguments(), index,
-      attributes ? attributes : DictionaryAttr::get(op->getContext()));
-}
+void setArgAttrs(FunctionOpInterface op, unsigned index,
+                 ArrayRef<NamedAttribute> attributes);
+void setArgAttrs(FunctionOpInterface op, unsigned index,
+                 DictionaryAttr attributes);
 
 /// If the an attribute exists with the specified name, change it to the new
 /// value. Otherwise, add a new attribute with the specified name/value.
@@ -157,23 +129,10 @@ Attribute removeArgAttr(ConcreteType op, unsigned index, StringAttr name) {
 //===----------------------------------------------------------------------===//
 
 /// Set the attributes held by the result at 'index'.
-template <typename ConcreteType>
-void setResultAttrs(ConcreteType op, unsigned index,
-                    ArrayRef<NamedAttribute> attributes) {
-  assert(index < op.getNumResults() && "invalid result number");
-  return detail::setArgResAttrDict(
-      op, getResultDictAttrName(), op.getNumResults(), index,
-      DictionaryAttr::get(op->getContext(), attributes));
-}
-
-template <typename ConcreteType>
-void setResultAttrs(ConcreteType op, unsigned index,
-                    DictionaryAttr attributes) {
-  assert(index < op.getNumResults() && "invalid result number");
-  return detail::setArgResAttrDict(
-      op, getResultDictAttrName(), op.getNumResults(), index,
-      attributes ? attributes : DictionaryAttr::get(op->getContext()));
-}
+void setResultAttrs(FunctionOpInterface op, unsigned index,
+                    ArrayRef<NamedAttribute> attributes);
+void setResultAttrs(FunctionOpInterface op, unsigned index,
+                    DictionaryAttr attributes);
 
 /// If the an attribute exists with the specified name, change it to the new
 /// value. Otherwise, add a new attribute with the specified name/value.
@@ -213,9 +172,8 @@ LogicalResult verifyTrait(ConcreteOp op) {
     unsigned numArgs = op.getNumArguments();
     if (allArgAttrs.size() != numArgs) {
       return op.emitOpError()
-             << "expects argument attribute array `" << getArgDictAttrName()
-             << "` to have the same number of elements as the number of "
-                "function arguments, got "
+             << "expects argument attribute array to have the same number of "
+                "elements as the number of function arguments, got "
              << allArgAttrs.size() << ", but expected " << numArgs;
     }
     for (unsigned i = 0; i != numArgs; ++i) {
@@ -245,9 +203,8 @@ LogicalResult verifyTrait(ConcreteOp op) {
     unsigned numResults = op.getNumResults();
     if (allResultAttrs.size() != numResults) {
       return op.emitOpError()
-             << "expects result attribute array `" << getResultDictAttrName()
-             << "` to have the same number of elements as the number of "
-                "function results, got "
+             << "expects result attribute array to have the same number of "
+                "elements as the number of function results, got "
              << allResultAttrs.size() << ", but expected " << numResults;
     }
     for (unsigned i = 0; i != numResults; ++i) {
index e86057a..0e8a3ad 100644 (file)
@@ -59,6 +59,42 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
       result attributes.
     }],
     "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>,
+
+    InterfaceMethod<[{
+      Get the array of argument attribute dictionaries. The method should return
+      an array attribute containing only dictionary attributes equal in number
+      to the number of function arguments. Alternatively, the method can return
+      null to indicate that the function has no argument attributes.
+    }],
+    "::mlir::ArrayAttr", "getArgAttrsAttr">,
+    InterfaceMethod<[{
+      Get the array of result attribute dictionaries. The method should return
+      an array attribute containing only dictionary attributes equal in number
+      to the number of function results. Alternatively, the method can return
+      null to indicate that the function has no result attributes.
+    }],
+    "::mlir::ArrayAttr", "getResAttrsAttr">,
+    InterfaceMethod<[{
+      Set the array of argument attribute dictionaries.
+    }],
+    "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>,
+    InterfaceMethod<[{
+      Set the array of result attribute dictionaries.
+    }],
+    "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>,
+    InterfaceMethod<[{
+      Remove the array of argument attribute dictionaries. This is the same as
+      setting all argument attributes to an empty dictionary. The method should
+      return the removed attribute.
+    }],
+    "::mlir::Attribute", "removeArgAttrsAttr">,
+    InterfaceMethod<[{
+      Remove the array of result attribute dictionaries. This is the same as
+      setting all result attributes to an empty dictionary. The method should
+      return the removed attribute.
+    }],
+    "::mlir::Attribute", "removeResAttrsAttr">,
+
     InterfaceMethod<[{
       Returns the function argument types based exclusively on
       the type (to allow for this method may be called on function
@@ -250,20 +286,6 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
       function_interface_impl::setFunctionType(this->getOperation(), newType);
     }
 
-    // FIXME: These functions should be removed in favor of just forwarding to
-    // the derived operation, which should already have these defined
-    // (via ODS).
-
-    /// Returns the name of the attribute used for function argument attributes.
-    static StringRef getArgDictAttrName() {
-      return function_interface_impl::getArgDictAttrName();
-    }
-
-    /// Returns the name of the attribute used for function argument attributes.
-    static StringRef getResultDictAttrName() {
-      return function_interface_impl::getResultDictAttrName();
-    }
-
     //===------------------------------------------------------------------===//
     // Argument and Result Handling
     //===------------------------------------------------------------------===//
@@ -405,10 +427,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
 
     /// Return an ArrayAttr containing all argument attribute dictionaries of
     /// this function, or nullptr if no arguments have attributes.
-    ArrayAttr getAllArgAttrs() {
-      return this->getOperation()->template getAttrOfType<ArrayAttr>(
-          getArgDictAttrName());
-    }
+    ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); }
+
     /// Return all argument attributes of this function.
     void getAllArgAttrs(SmallVectorImpl<DictionaryAttr> &result) {
       if (ArrayAttr argAttrs = getAllArgAttrs()) {
@@ -460,7 +480,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
     }
     void setAllArgAttrs(ArrayAttr attributes) {
       assert(attributes.size() == $_op.getNumArguments());
-      this->getOperation()->setAttr(getArgDictAttrName(), attributes);
+      $_op.setArgAttrsAttr(attributes);
     }
 
     /// If the an attribute exists with the specified name, change it to the new
@@ -496,10 +516,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
 
     /// Return an ArrayAttr containing all result attribute dictionaries of this
     /// function, or nullptr if no result have attributes.
-    ArrayAttr getAllResultAttrs() {
-      return this->getOperation()->template getAttrOfType<ArrayAttr>(
-          getResultDictAttrName());
-    }
+    ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); }
+
     /// Return all result attributes of this function.
     void getAllResultAttrs(SmallVectorImpl<DictionaryAttr> &result) {
       if (ArrayAttr argAttrs = getAllResultAttrs()) {
@@ -553,7 +571,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
     }
     void setAllResultAttrs(ArrayAttr attributes) {
       assert(attributes.size() == $_op.getNumResults());
-      this->getOperation()->setAttr(getResultDictAttrName(), attributes);
+      $_op.setResAttrsAttr(attributes);
     }
 
     /// If the an attribute exists with the specified name, change it to the new
index 2affd9a..400f671 100644 (file)
@@ -1524,6 +1524,8 @@ def TypeArrayAttr : TypedArrayAttrBase<TypeAttr, "type array attribute"> {
 }
 def IndexListArrayAttr :
   TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;
+def DictArrayAttr :
+  TypedArrayAttrBase<DictionaryAttr, "Array of dictionary attributes">;
 
 // Attributes containing symbol references.
 def SymbolRefAttr : Attr<CPred<"$_self.isa<::mlir::SymbolRefAttr>()">,
index 9f522aa..0cd024e 100644 (file)
@@ -66,8 +66,8 @@ static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs,
         attr.getName() == func.getFunctionTypeAttrName() ||
         attr.getName() == "func.varargs" ||
         (filterArgAndResAttrs &&
-         (attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
-          attr.getName() == FunctionOpInterface::getResultDictAttrName())))
+         (attr.getName() == func.getArgAttrsAttrName() ||
+          attr.getName() == func.getResAttrsAttrName())))
       continue;
     result.push_back(attr);
   }
@@ -90,18 +90,19 @@ static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
 static void
 prependResAttrsToArgAttrs(OpBuilder &builder,
                           SmallVectorImpl<NamedAttribute> &attributes,
-                          size_t numArguments) {
+                          func::FuncOp func) {
+  size_t numArguments = func.getNumArguments();
   auto allAttrs = SmallVector<Attribute>(
       numArguments + 1, DictionaryAttr::get(builder.getContext()));
   NamedAttribute *argAttrs = nullptr;
   for (auto *it = attributes.begin(); it != attributes.end();) {
-    if (it->getName() == FunctionOpInterface::getArgDictAttrName()) {
+    if (it->getName() == func.getArgAttrsAttrName()) {
       auto arrayAttrs = it->getValue().cast<ArrayAttr>();
       assert(arrayAttrs.size() == numArguments &&
              "Number of arg attrs and args should match");
       std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1);
       argAttrs = it;
-    } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) {
+    } else if (it->getName() == func.getResAttrsAttrName()) {
       auto arrayAttrs = it->getValue().cast<ArrayAttr>();
       assert(!arrayAttrs.empty() && "expected array to be non-empty");
       allAttrs[0] = (arrayAttrs.size() == 1)
@@ -113,9 +114,8 @@ prependResAttrsToArgAttrs(OpBuilder &builder,
     it++;
   }
 
-  auto newArgAttrs =
-      builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
-                           builder.getArrayAttr(allAttrs));
+  auto newArgAttrs = builder.getNamedAttr(func.getArgAttrsAttrName(),
+                                          builder.getArrayAttr(allAttrs));
   if (!argAttrs) {
     attributes.emplace_back(newArgAttrs);
     return;
@@ -141,7 +141,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
   auto [wrapperFuncType, resultIsNowArg] =
       typeConverter.convertFunctionTypeCWrapper(type);
   if (resultIsNowArg)
-    prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments());
+    prependResAttrsToArgAttrs(rewriter, attributes, funcOp);
   auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
       wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false,
@@ -205,7 +205,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
 
   if (resultIsNowArg)
-    prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
+    prependResAttrsToArgAttrs(builder, attributes, funcOp);
   // Create the auxiliary function.
   auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
@@ -309,8 +309,8 @@ protected:
               ? resAttrDicts
               : rewriter.getArrayAttr(
                     {wrapAsStructAttrs(rewriter, resAttrDicts)});
-      attributes.push_back(rewriter.getNamedAttr(
-          FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
+      attributes.push_back(
+          rewriter.getNamedAttr(funcOp.getResAttrsAttrName(), newResAttrDicts));
     }
     if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
       SmallVector<Attribute, 4> newArgAttrs(
@@ -353,9 +353,8 @@ protected:
           newArgAttrs[mapping->inputNo + j] =
               DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
       }
-      attributes.push_back(
-          rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
-                                rewriter.getArrayAttr(newArgAttrs)));
+      attributes.push_back(rewriter.getNamedAttr(
+          funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(newArgAttrs)));
     }
     for (const auto &pair : llvm::enumerate(attributes)) {
       if (pair.value().getName() == "llvm.linkage") {
index 064bf52..54acc37 100644 (file)
@@ -340,8 +340,9 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
-                                                /*resultAttrs=*/std::nullopt);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
+      getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
 
 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -352,12 +353,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   return function_interface_impl::parseFunctionOp(
       parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Check that the result type of async.func is not void and must be
index fc9bd11..7bb3663 100644 (file)
@@ -251,8 +251,9 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
-                                                /*resultAttrs=*/std::nullopt);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
+      getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
 
 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -263,12 +264,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   return function_interface_impl::parseFunctionOp(
       parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Clone the internal blocks from this function into dest and all attributes
index 80db646..9ea1b11 100644 (file)
@@ -934,8 +934,9 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(type));
 
-  function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
-                                                resultAttrs);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+      getResAttrsAttrName(result.name));
 
   // Parse workgroup memory attributions.
   if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
@@ -996,7 +997,8 @@ void GPUFuncOp::print(OpAsmPrinter &p) {
   function_interface_impl::printFunctionAttributes(
       p, *this,
       {getNumWorkgroupAttributionsAttrName(),
-       GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName()});
+       GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
+       getArgAttrsAttrName(), getResAttrsAttrName()});
   p << ' ';
   p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
 }
index 6b428a1..cff547e 100644 (file)
@@ -2006,8 +2006,9 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
 
   assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
          "expected as many argument attribute lists as arguments");
-  function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs,
-                                                /*resultAttrs=*/std::nullopt);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 // Builds an LLVM function type from the given lists of input and output types.
@@ -2095,8 +2096,9 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
     return failure();
-  function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result,
-                                                entryArgs, resultAttrs);
+  function_interface_impl::addArgAndResultAttrs(
+      parser.getBuilder(), result, entryArgs, resultAttrs,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 
   auto *body = result.addRegion();
   OptionalParseResult parseResult =
@@ -2131,7 +2133,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
                                                   isVarArg(), resTypes);
   function_interface_impl::printFunctionAttributes(
       p, *this,
-      {getFunctionTypeAttrName(), getLinkageAttrName(), getCConvAttrName()});
+      {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
+       getLinkageAttrName(), getCConvAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = getBody();
index 27c6130..31ed5ad 100644 (file)
@@ -153,12 +153,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   return function_interface_impl::parseFunctionOp(
       parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//
@@ -316,12 +318,14 @@ ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
 
   return function_interface_impl::parseFunctionOp(
       parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void SubgraphOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//
index 28fc4db..2cc282d 100644 (file)
@@ -221,12 +221,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   return function_interface_impl::parseFunctionOp(
       parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//
index 3ce3913..3341b5e 100644 (file)
@@ -2396,8 +2396,9 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   // Add the attributes to the function arguments.
   assert(resultAttrs.size() == resultTypes.size());
-  function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
-                                                resultAttrs);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+      getResAttrsAttrName(result.name));
 
   // Parse the optional function body.
   auto *body = result.addRegion();
@@ -2419,7 +2420,8 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
   function_interface_impl::printFunctionAttributes(
       printer, *this,
       {spirv::attributeName<spirv::FunctionControl>(),
-       getFunctionTypeAttrName(), getFunctionControlAttrName()});
+       getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
+       getFunctionControlAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = this->getBody();
index 30c5f56..28ac98e 100644 (file)
@@ -1300,8 +1300,9 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
-                                                /*resultAttrs=*/std::nullopt);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
+      getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
 
 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -1312,12 +1313,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   return function_interface_impl::parseFunctionOp(
       parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//
index af692be..5ca6777 100644 (file)
@@ -113,7 +113,7 @@ parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
   return parser.parseRParen();
 }
 
-ParseResult mlir::function_interface_impl::parseFunctionSignature(
+ParseResult function_interface_impl::parseFunctionSignature(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
     SmallVectorImpl<Type> &resultTypes,
@@ -125,9 +125,10 @@ ParseResult mlir::function_interface_impl::parseFunctionSignature(
   return success();
 }
 
-void mlir::function_interface_impl::addArgAndResultAttrs(
+void function_interface_impl::addArgAndResultAttrs(
     Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
-    ArrayRef<DictionaryAttr> resultAttrs) {
+    ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
+    StringAttr resAttrsName) {
   auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
     return attrs && !attrs.empty();
   };
@@ -142,28 +143,28 @@ void mlir::function_interface_impl::addArgAndResultAttrs(
 
   // Add the attributes to the function arguments.
   if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
-    result.addAttribute(function_interface_impl::getArgDictAttrName(),
-                        getArrayAttr(argAttrs));
+    result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
 
   // Add the attributes to the function results.
   if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
-    result.addAttribute(function_interface_impl::getResultDictAttrName(),
-                        getArrayAttr(resultAttrs));
+    result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
 }
 
-void mlir::function_interface_impl::addArgAndResultAttrs(
+void function_interface_impl::addArgAndResultAttrs(
     Builder &builder, OperationState &result,
-    ArrayRef<OpAsmParser::Argument> args,
-    ArrayRef<DictionaryAttr> resultAttrs) {
+    ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
+    StringAttr argAttrsName, StringAttr resAttrsName) {
   SmallVector<DictionaryAttr> argAttrs;
   for (const auto &arg : args)
     argAttrs.push_back(arg.attrs);
-  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
+  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
+                       resAttrsName);
 }
 
-ParseResult mlir::function_interface_impl::parseFunctionOp(
+ParseResult function_interface_impl::parseFunctionOp(
     OpAsmParser &parser, OperationState &result, bool allowVariadic,
-    StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder) {
+    StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
+    StringAttr argAttrsName, StringAttr resAttrsName) {
   SmallVector<OpAsmParser::Argument> entryArgs;
   SmallVector<DictionaryAttr> resultAttrs;
   SmallVector<Type> resultTypes;
@@ -220,7 +221,8 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
 
   // Add the attributes to the function arguments.
   assert(resultAttrs.size() == resultTypes.size());
-  addArgAndResultAttrs(builder, result, entryArgs, resultAttrs);
+  addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName,
+                       resAttrsName);
 
   // Parse the optional function body. The printer will not print the body if
   // its empty, so disallow parsing of empty body in the parser.
@@ -261,14 +263,14 @@ static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
     os << ')';
 }
 
-void mlir::function_interface_impl::printFunctionSignature(
-    OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
-    ArrayRef<Type> resultTypes) {
+void function_interface_impl::printFunctionSignature(
+    OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
+    bool isVariadic, ArrayRef<Type> resultTypes) {
   Region &body = op->getRegion(0);
   bool isExternal = body.empty();
 
   p << '(';
-  ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+  ArrayAttr argAttrs = op.getArgAttrsAttr();
   for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
     if (i > 0)
       p << ", ";
@@ -295,26 +297,23 @@ void mlir::function_interface_impl::printFunctionSignature(
 
   if (!resultTypes.empty()) {
     p.getStream() << " -> ";
-    auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+    auto resultAttrs = op.getResAttrsAttr();
     printFunctionResultList(p, resultTypes, resultAttrs);
   }
 }
 
-void mlir::function_interface_impl::printFunctionAttributes(
+void function_interface_impl::printFunctionAttributes(
     OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
   // Print out function attributes, if present.
-  SmallVector<StringRef, 2> ignoredAttrs = {SymbolTable::getSymbolAttrName(),
-                                            getArgDictAttrName(),
-                                            getResultDictAttrName()};
+  SmallVector<StringRef, 8> ignoredAttrs = {SymbolTable::getSymbolAttrName()};
   ignoredAttrs.append(elided.begin(), elided.end());
 
   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
 }
 
-void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
-                                                    FunctionOpInterface op,
-                                                    bool isVariadic,
-                                                    StringRef typeAttrName) {
+void function_interface_impl::printFunctionOp(
+    OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
+    StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) {
   // Print the operation and the function name.
   auto funcName =
       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
@@ -329,7 +328,8 @@ void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
   ArrayRef<Type> argTypes = op.getArgumentTypes();
   ArrayRef<Type> resultTypes = op.getResultTypes();
   printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
-  printFunctionAttributes(p, op, {visibilityAttrName, typeAttrName});
+  printFunctionAttributes(
+      p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName});
   // Print the body if this is not an external function.
   Region &body = op->getRegion(0);
   if (!body.empty()) {
index 9ba8303..347fb15 100644 (file)
@@ -24,27 +24,104 @@ static bool isEmptyAttrDict(Attribute attr) {
   return attr.cast<DictionaryAttr>().empty();
 }
 
-DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op,
-                                                             unsigned index) {
-  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op,
+                                                       unsigned index) {
+  ArrayAttr attrs = op.getArgAttrsAttr();
   DictionaryAttr argAttrs =
       attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
   return argAttrs;
 }
 
 DictionaryAttr
-mlir::function_interface_impl::getResultAttrDict(Operation *op,
-                                                 unsigned index) {
-  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+function_interface_impl::getResultAttrDict(FunctionOpInterface op,
+                                           unsigned index) {
+  ArrayAttr attrs = op.getResAttrsAttr();
   DictionaryAttr resAttrs =
       attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
   return resAttrs;
 }
 
-void mlir::function_interface_impl::detail::setArgResAttrDict(
-    Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
-    DictionaryAttr attrs) {
-  ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
+ArrayRef<NamedAttribute>
+function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) {
+  auto argDict = getArgAttrDict(op, index);
+  return argDict ? argDict.getValue() : std::nullopt;
+}
+
+ArrayRef<NamedAttribute>
+function_interface_impl::getResultAttrs(FunctionOpInterface op,
+                                        unsigned index) {
+  auto resultDict = getResultAttrDict(op, index);
+  return resultDict ? resultDict.getValue() : std::nullopt;
+}
+
+/// Get either the argument or result attributes array.
+template <bool isArg>
+static ArrayAttr getArgResAttrs(FunctionOpInterface op) {
+  if constexpr (isArg)
+    return op.getArgAttrsAttr();
+  else
+    return op.getResAttrsAttr();
+}
+
+/// Set either the argument or result attributes array.
+template <bool isArg>
+static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) {
+  if constexpr (isArg)
+    op.setArgAttrsAttr(attrs);
+  else
+    op.setResAttrsAttr(attrs);
+}
+
+/// Erase either the argument or result attributes array.
+template <bool isArg>
+static void removeArgResAttrs(FunctionOpInterface op) {
+  if constexpr (isArg)
+    op.removeArgAttrsAttr();
+  else
+    op.removeResAttrsAttr();
+}
+
+/// Set all of the argument or result attribute dictionaries for a function.
+template <bool isArg>
+static void setAllArgResAttrDicts(FunctionOpInterface op,
+                                  ArrayRef<Attribute> attrs) {
+  if (llvm::all_of(attrs, isEmptyAttrDict))
+    removeArgResAttrs<isArg>(op);
+  else
+    setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs));
+}
+
+void function_interface_impl::setAllArgAttrDicts(
+    FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
+  setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+}
+
+void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op,
+                                                 ArrayRef<Attribute> attrs) {
+  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
+    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
+  });
+  setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(wrappedAttrs));
+}
+
+void function_interface_impl::setAllResultAttrDicts(
+    FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
+  setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+}
+
+void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op,
+                                                    ArrayRef<Attribute> attrs) {
+  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
+    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
+  });
+  setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(wrappedAttrs));
+}
+
+/// Update the given index into an argument or result attribute dictionary.
+template <bool isArg>
+static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
+                              unsigned index, DictionaryAttr attrs) {
+  ArrayAttr allAttrs = getArgResAttrs<isArg>(op);
   if (!allAttrs) {
     if (attrs.empty())
       return;
@@ -53,7 +130,7 @@ void mlir::function_interface_impl::detail::setArgResAttrDict(
     SmallVector<Attribute, 8> newAttrs(numTotalIndices,
                                        DictionaryAttr::get(op->getContext()));
     newAttrs[index] = attrs;
-    op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
+    setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
     return;
   }
   // Check to see if the attribute is different from what we already have.
@@ -65,53 +142,51 @@ void mlir::function_interface_impl::detail::setArgResAttrDict(
   ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
   if (attrs.empty() &&
       llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
-      llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
-    op->removeAttr(attrName);
-    return;
-  }
+      llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict))
+    return removeArgResAttrs<isArg>(op);
 
   // Otherwise, create a new attribute array with the updated dictionary.
   SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
   newAttrs[index] = attrs;
-  op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
+  setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
 }
 
-/// Set all of the argument or result attribute dictionaries for a function.
-static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
-                                  ArrayRef<Attribute> attrs) {
-  if (llvm::all_of(attrs, isEmptyAttrDict))
-    op->removeAttr(attrName);
-  else
-    op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
+void function_interface_impl::setArgAttrs(FunctionOpInterface op,
+                                          unsigned index,
+                                          ArrayRef<NamedAttribute> attributes) {
+  assert(index < op.getNumArguments() && "invalid argument number");
+  return setArgResAttrDict</*isArg=*/true>(
+      op, op.getNumArguments(), index,
+      DictionaryAttr::get(op->getContext(), attributes));
 }
 
-void mlir::function_interface_impl::setAllArgAttrDicts(
-    Operation *op, ArrayRef<DictionaryAttr> attrs) {
-  setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
-}
-void mlir::function_interface_impl::setAllArgAttrDicts(
-    Operation *op, ArrayRef<Attribute> attrs) {
-  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
-    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
-  });
-  setAllArgResAttrDicts(op, getArgDictAttrName(),
-                        llvm::to_vector<8>(wrappedAttrs));
+void function_interface_impl::setArgAttrs(FunctionOpInterface op,
+                                          unsigned index,
+                                          DictionaryAttr attributes) {
+  return setArgResAttrDict</*isArg=*/true>(
+      op, op.getNumArguments(), index,
+      attributes ? attributes : DictionaryAttr::get(op->getContext()));
 }
 
-void mlir::function_interface_impl::setAllResultAttrDicts(
-    Operation *op, ArrayRef<DictionaryAttr> attrs) {
-  setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+void function_interface_impl::setResultAttrs(
+    FunctionOpInterface op, unsigned index,
+    ArrayRef<NamedAttribute> attributes) {
+  assert(index < op.getNumResults() && "invalid result number");
+  return setArgResAttrDict</*isArg=*/false>(
+      op, op.getNumResults(), index,
+      DictionaryAttr::get(op->getContext(), attributes));
 }
-void mlir::function_interface_impl::setAllResultAttrDicts(
-    Operation *op, ArrayRef<Attribute> attrs) {
-  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
-    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
-  });
-  setAllArgResAttrDicts(op, getResultDictAttrName(),
-                        llvm::to_vector<8>(wrappedAttrs));
+
+void function_interface_impl::setResultAttrs(FunctionOpInterface op,
+                                             unsigned index,
+                                             DictionaryAttr attributes) {
+  assert(index < op.getNumResults() && "invalid result number");
+  return setArgResAttrDict</*isArg=*/false>(
+      op, op.getNumResults(), index,
+      attributes ? attributes : DictionaryAttr::get(op->getContext()));
 }
 
-void mlir::function_interface_impl::insertFunctionArguments(
+void function_interface_impl::insertFunctionArguments(
     FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
     ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
     unsigned originalNumArgs, Type newType) {
@@ -128,7 +203,7 @@ void mlir::function_interface_impl::insertFunctionArguments(
   Block &entry = op->getRegion(0).front();
 
   // Update the argument attributes of the function.
-  auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+  ArrayAttr oldArgAttrs = op.getArgAttrsAttr();
   if (oldArgAttrs || !argAttrs.empty()) {
     SmallVector<DictionaryAttr, 4> newArgAttrs;
     newArgAttrs.reserve(originalNumArgs + argIndices.size());
@@ -157,7 +232,7 @@ void mlir::function_interface_impl::insertFunctionArguments(
     entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
 }
 
-void mlir::function_interface_impl::insertFunctionResults(
+void function_interface_impl::insertFunctionResults(
     FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
     TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
     unsigned originalNumResults, Type newType) {
@@ -171,7 +246,7 @@ void mlir::function_interface_impl::insertFunctionResults(
   // - Result attrs.
 
   // Update the result attributes of the function.
-  auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+  ArrayAttr oldResultAttrs = op.getResAttrsAttr();
   if (oldResultAttrs || !resultAttrs.empty()) {
     SmallVector<DictionaryAttr, 4> newResultAttrs;
     newResultAttrs.reserve(originalNumResults + resultIndices.size());
@@ -199,7 +274,7 @@ void mlir::function_interface_impl::insertFunctionResults(
   op.setFunctionTypeAttr(TypeAttr::get(newType));
 }
 
-void mlir::function_interface_impl::eraseFunctionArguments(
+void function_interface_impl::eraseFunctionArguments(
     FunctionOpInterface op, const BitVector &argIndices, Type newType) {
   // There are 3 things that need to be updated:
   // - Function type.
@@ -208,7 +283,7 @@ void mlir::function_interface_impl::eraseFunctionArguments(
   Block &entry = op->getRegion(0).front();
 
   // Update the argument attributes of the function.
-  if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
+  if (ArrayAttr argAttrs = op.getArgAttrsAttr()) {
     SmallVector<DictionaryAttr, 4> newArgAttrs;
     newArgAttrs.reserve(argAttrs.size());
     for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
@@ -222,14 +297,14 @@ void mlir::function_interface_impl::eraseFunctionArguments(
   entry.eraseArguments(argIndices);
 }
 
-void mlir::function_interface_impl::eraseFunctionResults(
+void function_interface_impl::eraseFunctionResults(
     FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
   // There are 2 things that need to be updated:
   // - Function type.
   // - Result attrs.
 
   // Update the result attributes of the function.
-  if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
+  if (ArrayAttr resAttrs = op.getResAttrsAttr()) {
     SmallVector<DictionaryAttr, 4> newResultAttrs;
     newResultAttrs.reserve(resAttrs.size());
     for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
@@ -242,7 +317,7 @@ void mlir::function_interface_impl::eraseFunctionResults(
   op.setFunctionTypeAttr(TypeAttr::get(newType));
 }
 
-TypeRange mlir::function_interface_impl::insertTypesInto(
+TypeRange function_interface_impl::insertTypesInto(
     TypeRange oldTypes, ArrayRef<unsigned> indices, TypeRange newTypes,
     SmallVectorImpl<Type> &storage) {
   assert(indices.size() == newTypes.size() &&
@@ -261,7 +336,7 @@ TypeRange mlir::function_interface_impl::insertTypesInto(
   return storage;
 }
 
-TypeRange mlir::function_interface_impl::filterTypesOut(
+TypeRange function_interface_impl::filterTypesOut(
     TypeRange types, const BitVector &indices, SmallVectorImpl<Type> &storage) {
   if (indices.none())
     return types;
@@ -276,8 +351,8 @@ TypeRange mlir::function_interface_impl::filterTypesOut(
 // Function type signature.
 //===----------------------------------------------------------------------===//
 
-void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op,
-                                                    Type newType) {
+void function_interface_impl::setFunctionType(FunctionOpInterface op,
+                                              Type newType) {
   unsigned oldNumArgs = op.getNumArguments();
   unsigned oldNumResults = op.getNumResults();
   op.setFunctionTypeAttr(TypeAttr::get(newType));
@@ -285,35 +360,31 @@ void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op,
   unsigned newNumResults = op.getNumResults();
 
   // Functor used to update the argument and result attributes of the function.
-  auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
-                          unsigned newCount, auto setAttrFn) {
+  auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) {
+    constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>;
+
     if (oldCount == newCount)
       return;
     // The new type has no arguments/results, just drop the attribute.
-    if (newCount == 0) {
-      op->removeAttr(attrName);
-      return;
-    }
-    ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
+    if (newCount == 0)
+      return removeArgResAttrs<isArgVal>(op);
+    ArrayAttr attrs = getArgResAttrs<isArgVal>(op);
     if (!attrs)
       return;
 
     // The new type has less arguments/results, take the first N attributes.
     if (newCount < oldCount)
-      return setAttrFn(op, attrs.getValue().take_front(newCount));
+      return setAllArgResAttrDicts<isArgVal>(
+          op, attrs.getValue().take_front(newCount));
 
     // Otherwise, the new type has more arguments/results. Initialize the new
     // arguments/results with empty attributes.
     SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
     newAttrs.resize(newCount);
-    setAttrFn(op, newAttrs);
+    setAllArgResAttrDicts<isArgVal>(op, newAttrs);
   };
 
   // Update the argument and result attributes.
-  updateAttrFn(
-      getArgDictAttrName(), oldNumArgs, newNumArgs,
-      [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); });
-  updateAttrFn(
-      getResultDictAttrName(), oldNumResults, newNumResults,
-      [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
+  updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs);
+  updateAttrFn(std::false_type{}, oldNumResults, newNumResults);
 }
index a72abad..d995689 100644 (file)
@@ -96,20 +96,11 @@ func.func private @invalid_symbol_type_attr() attributes { function_type = "x" }
 
 // -----
 
-// expected-error@+1 {{argument attribute array `arg_attrs` to have the same number of elements as the number of function arguments}}
+// expected-error@+1 {{argument attribute array to have the same number of elements as the number of function arguments}}
 func.func private @invalid_arg_attrs() attributes { arg_attrs = [{}] }
 
 // -----
 
-// expected-error@+1 {{expects argument attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
-func.func private @invalid_arg_attrs(i32) attributes { arg_attrs = [10] }
 
-// -----
-
-// expected-error@+1 {{result attribute array `res_attrs` to have the same number of elements as the number of function results}}
+// expected-error@+1 {{result attribute array to have the same number of elements as the number of function results}}
 func.func private @invalid_res_attrs() attributes { res_attrs = [{}] }
-
-// -----
-
-// expected-error@+1 {{expects result attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
-func.func private @invalid_res_attrs() -> i32 attributes { res_attrs = [10] }