[mlir] FunctionOpInterface: turn required attributes into interface methods (Reland)
authorJeff Niu <jeff@modular.com>
Tue, 6 Dec 2022 19:28:47 +0000 (11:28 -0800)
committerJeff Niu <jeff@modular.com>
Sat, 10 Dec 2022 23:17:09 +0000 (15:17 -0800)
Reland D139447, D139471 With flang actually working

- FunctionOpInterface: make get/setFunctionType interface methods

This patch removes the concept of a `function_type`-named type attribute
as a requirement for implementors of FunctionOpInterface. Instead, this
type should be provided through two interface methods, `getFunctionType`
and `setFunctionTypeAttr` (*Attr because functions may use different
concrete function types), which should be automatically implemented by
ODS for ops that define a `$function_type` attribute.

This also allows FunctionOpInterface to materialize function types if
they don't carry them in an attribute, for example.

Importantly, all the function "helper" still accept an attribute name to
use in parsing and printing functions, for example.

- FunctionOpInterface: arg and result attrs dispatch to interface

This patch removes the `arg_attrs` and `res_attrs` named attributes as a
requirement for FunctionOpInterface and replaces them with interface
methods for the getters, setters, and removers of the relevent
attributes. This allows operations to use their own storage for the
argument and result attributes.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D139736

40 files changed:
flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
mlir/examples/toy/Ch2/include/toy/Ops.td
mlir/examples/toy/Ch2/mlir/Dialect.cpp
mlir/examples/toy/Ch3/include/toy/Ops.td
mlir/examples/toy/Ch3/mlir/Dialect.cpp
mlir/examples/toy/Ch4/include/toy/Ops.td
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch5/include/toy/Ops.td
mlir/examples/toy/Ch5/mlir/Dialect.cpp
mlir/examples/toy/Ch6/include/toy/Ops.td
mlir/examples/toy/Ch6/mlir/Dialect.cpp
mlir/examples/toy/Ch7/include/toy/Ops.td
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/include/mlir/Dialect/Func/IR/FuncOps.td
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/include/mlir/IR/FunctionImplementation.h
mlir/include/mlir/IR/FunctionInterfaces.h
mlir/include/mlir/IR/FunctionInterfaces.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Func/IR/FuncOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/IR/FunctionImplementation.cpp
mlir/lib/IR/FunctionInterfaces.cpp
mlir/test/IR/invalid-func-op.mlir

index 1ad2526..87206c1 100644 (file)
@@ -501,7 +501,8 @@ public:
     // correctly.
     for (auto e : llvm::enumerate(funcTy.getInputs())) {
       unsigned index = e.index();
-      llvm::ArrayRef<mlir::NamedAttribute> attrs = func.getArgAttrs(index);
+      llvm::ArrayRef<mlir::NamedAttribute> attrs =
+          mlir::function_interface_impl::getArgAttrs(func, index);
       for (mlir::NamedAttribute attr : attrs) {
         savedAttrs.push_back({index, attr});
       }
index 380536b..4e2fb9e 100644 (file)
@@ -134,7 +134,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 
index dbc1efb..a6ccbbf 100644 (file)
@@ -211,14 +211,17 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//
index e526fe5..1a4e6a1 100644 (file)
@@ -133,7 +133,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 
index 50e2dfc..913979a 100644 (file)
@@ -198,14 +198,17 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//
index 4956b0e..cbece47 100644 (file)
@@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 
index 0a6195b..5db2f95 100644 (file)
@@ -287,14 +287,17 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Returns the region on the function operation that is callable.
index f4e7b08..70e482d 100644 (file)
@@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 
index f236a1f..c2015ee 100644 (file)
@@ -287,14 +287,17 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Returns the region on the function operation that is callable.
index ea9323e..cf2bc3f 100644 (file)
@@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 
index f236a1f..c2015ee 100644 (file)
@@ -287,14 +287,17 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Returns the region on the function operation that is callable.
index 45ecdd3..08671a7 100644 (file)
@@ -186,7 +186,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 
index cc66a5d..ffcdd7a 100644 (file)
@@ -314,14 +314,17 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Returns the region on the function operation that is callable.
index 30895e5..14146cd 100644 (file)
@@ -140,7 +140,9 @@ def Async_FuncOp : Async_Op<"func",
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
-                       OptionalAttr<StrAttr>:$sym_visibility);
+                       OptionalAttr<StrAttr>:$sym_visibility,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs);
 
   let regions = (region AnyRegion:$body);
 
index f1b7cfd..4922689 100644 (file)
@@ -251,7 +251,9 @@ def FuncOp : Func_Op<"func", [
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
-                       OptionalAttr<StrAttr>:$sym_visibility);
+                       OptionalAttr<StrAttr>:$sym_visibility,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs);
   let regions = (region AnyRegion:$body);
 
   let builders = [OpBuilder<(ins
index 0642b18..f9fff78 100644 (file)
@@ -242,7 +242,9 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
     attribution.
   }];
 
-  let arguments = (ins TypeAttrOf<FunctionType>:$function_type);
+  let arguments = (ins TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs);
   let regions = (region AnyRegion:$body);
 
   let skipDefaultBuilders = 1;
index 19e589c..afc07e2 100644 (file)
@@ -1311,7 +1311,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     DefaultValuedAttr<CConv, "CConv::C">:$CConv,
     OptionalAttr<FlatSymbolRefAttr>:$personality,
     OptionalAttr<StrAttr>:$garbageCollector,
-    OptionalAttr<ArrayAttr>:$passthrough
+    OptionalAttr<ArrayAttr>:$passthrough,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
 
   let regions = (region AnyRegion:$body);
index 422680a..db6c773 100644 (file)
@@ -52,6 +52,8 @@ def MLProgram_FuncOp : MLProgram_Op<"func", [
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs,
                        OptionalAttr<StrAttr>:$sym_visibility);
   let regions = (region AnyRegion:$body);
 
@@ -401,6 +403,8 @@ def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs,
                        OptionalAttr<StrAttr>:$sym_visibility);
   let regions = (region AnyRegion:$body);
 
index 42a48cd..6ecbed2 100644 (file)
@@ -652,7 +652,9 @@ def PDLInterp_FuncOp : PDLInterp_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region MinSizedRegion<1>:$body);
 
index 147705e..8339afc 100644 (file)
@@ -291,6 +291,8 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
 
   let arguments = (ins
     TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs,
     StrAttr:$sym_name,
     SPIRV_FunctionControlAttr:$function_control
   );
index c3697f0..97d1f0c 100644 (file)
@@ -1107,6 +1107,8 @@ def Shape_FuncOp : Shape_Op<"func",
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs,
                        OptionalAttr<StrAttr>:$sym_visibility);
   let regions = (region AnyRegion:$body);
 
index 5265f78..eb79790 100644 (file)
@@ -39,10 +39,12 @@ private:
 /// with special names given by getResultAttrName, getArgumentAttrName.
 void addArgAndResultAttrs(Builder &builder, OperationState &result,
                           ArrayRef<DictionaryAttr> argAttrs,
-                          ArrayRef<DictionaryAttr> resultAttrs);
+                          ArrayRef<DictionaryAttr> resultAttrs,
+                          StringAttr argAttrsName, StringAttr resAttrsName);
 void addArgAndResultAttrs(Builder &builder, OperationState &result,
-                          ArrayRef<OpAsmParser::Argument> argAttrs,
-                          ArrayRef<DictionaryAttr> resultAttrs);
+                          ArrayRef<OpAsmParser::Argument> args,
+                          ArrayRef<DictionaryAttr> resultAttrs,
+                          StringAttr argAttrsName, StringAttr resAttrsName);
 
 /// Callback type for `parseFunctionOp`, the callback should produce the
 /// type that will be associated with a function-like operation from lists of
@@ -69,21 +71,25 @@ Type getFunctionType(Builder &builder, ArrayRef<OpAsmParser::Argument> argAttrs,
 
 /// Parser implementation for function-like operations.  Uses
 /// `funcTypeBuilder` to construct the custom function type given lists of
-/// input and output types.  If `allowVariadic` is set, the parser will accept
+/// input and output types. The parser sets the `typeAttrName` attribute to the
+/// resulting function type. If `allowVariadic` is set, the parser will accept
 /// trailing ellipsis in the function signature and indicate to the builder
 /// whether the function is variadic.  If the builder returns a null type,
 /// `result` will not contain the `type` attribute.  The caller can then add a
 /// type, report the error or delegate the reporting to the op's verifier.
 ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
-                            bool allowVariadic,
-                            FuncTypeBuilder funcTypeBuilder);
+                            bool allowVariadic, StringAttr typeAttrName,
+                            FuncTypeBuilder funcTypeBuilder,
+                            StringAttr argAttrsName, StringAttr resAttrsName);
 
 /// Printer implementation for function-like operations.
-void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic);
+void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
+                     StringRef typeAttrName, StringAttr argAttrsName,
+                     StringAttr resAttrsName);
 
 /// Prints the signature of the function-like operation `op`. Assumes `op` has
 /// is a FunctionOpInterface and has passed verification.
-void printFunctionSignature(OpAsmPrinter &p, Operation *op,
+void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
                             ArrayRef<Type> argTypes, bool isVariadic,
                             ArrayRef<Type> resultTypes);
 
@@ -92,8 +98,7 @@ void printFunctionSignature(OpAsmPrinter &p, Operation *op,
 /// function-like operation internally are not printed. Nothing is printed
 /// if all attributes are elided. Assumes `op` is a FunctionOpInterface and
 /// has passed verification.
-void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs,
-                             unsigned numResults,
+void printFunctionAttributes(OpAsmPrinter &p, Operation *op,
                              ArrayRef<StringRef> elided = {});
 
 } // namespace function_interface_impl
index 23fd884..3beb3db 100644 (file)
 #include "llvm/ADT/SmallString.h"
 
 namespace mlir {
+class FunctionOpInterface;
 
 namespace function_interface_impl {
 
-/// Return the name of the attribute used for function types.
-inline StringRef getTypeAttrName() { return "function_type"; }
-
-/// Return the name of the attribute used for function argument attributes.
-inline StringRef getArgDictAttrName() { return "arg_attrs"; }
-
-/// Return the name of the attribute used for function argument attributes.
-inline StringRef getResultDictAttrName() { return "res_attrs"; }
-
 /// Returns the dictionary attribute corresponding to the argument at 'index'.
 /// If there are no argument attributes at 'index', a null attribute is
 /// returned.
-DictionaryAttr getArgAttrDict(Operation *op, unsigned index);
+DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index);
 
 /// Returns the dictionary attribute corresponding to the result at 'index'.
 /// If there are no result attributes at 'index', a null attribute is
 /// returned.
-DictionaryAttr getResultAttrDict(Operation *op, unsigned index);
+DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index);
 
-namespace detail {
-/// Update the given index into an argument or result attribute dictionary.
-void setArgResAttrDict(Operation *op, StringRef attrName,
-                       unsigned numTotalIndices, unsigned index,
-                       DictionaryAttr attrs);
-} // namespace detail
+/// Return all of the attributes for the argument at 'index'.
+ArrayRef<NamedAttribute> getArgAttrs(FunctionOpInterface op, unsigned index);
+
+/// Return all of the attributes for the result at 'index'.
+ArrayRef<NamedAttribute> getResultAttrs(FunctionOpInterface op, unsigned index);
 
 /// Set all of the argument or result attribute dictionaries for a function. The
 /// size of `attrs` is expected to match the number of arguments/results of the
 /// given `op`.
-void setAllArgAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
-void setAllArgAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
-void setAllResultAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
-void setAllResultAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
-
-/// Return all of the attributes for the argument at 'index'.
-inline ArrayRef<NamedAttribute> getArgAttrs(Operation *op, unsigned index) {
-  auto argDict = getArgAttrDict(op, index);
-  return argDict ? argDict.getValue() : std::nullopt;
-}
-
-/// Return all of the attributes for the result at 'index'.
-inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
-  auto resultDict = getResultAttrDict(op, index);
-  return resultDict ? resultDict.getValue() : std::nullopt;
-}
+void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs);
+void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs);
+void setAllResultAttrDicts(FunctionOpInterface op,
+                           ArrayRef<DictionaryAttr> attrs);
+void setAllResultAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs);
 
 /// Insert the specified arguments and update the function type attribute.
-void insertFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
-                             TypeRange argTypes,
+void insertFunctionArguments(FunctionOpInterface op,
+                             ArrayRef<unsigned> argIndices, TypeRange argTypes,
                              ArrayRef<DictionaryAttr> argAttrs,
                              ArrayRef<Location> argLocs,
                              unsigned originalNumArgs, Type newType);
 
 /// Insert the specified results and update the function type attribute.
-void insertFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
+void insertFunctionResults(FunctionOpInterface op,
+                           ArrayRef<unsigned> resultIndices,
                            TypeRange resultTypes,
                            ArrayRef<DictionaryAttr> resultAttrs,
                            unsigned originalNumResults, Type newType);
 
 /// Erase the specified arguments and update the function type attribute.
-void eraseFunctionArguments(Operation *op, const BitVector &argIndices,
+void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices,
                             Type newType);
 
 /// Erase the specified results and update the function type attribute.
-void eraseFunctionResults(Operation *op, const BitVector &resultIndices,
-                          Type newType);
+void eraseFunctionResults(FunctionOpInterface op,
+                          const BitVector &resultIndices, Type newType);
 
 /// Set a FunctionOpInterface operation's type signature.
-void setFunctionType(Operation *op, Type newType);
+void setFunctionType(FunctionOpInterface op, Type newType);
 
 /// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any
 /// types are inserted, `storage` is used to hold the new type list. The new
@@ -111,20 +92,10 @@ TypeRange filterTypesOut(TypeRange types, const BitVector &indices,
 //===----------------------------------------------------------------------===//
 
 /// Set the attributes held by the argument at 'index'.
-template <typename ConcreteType>
-void setArgAttrs(ConcreteType op, unsigned index,
-                 ArrayRef<NamedAttribute> attributes) {
-  assert(index < op.getNumArguments() && "invalid argument number");
-  return detail::setArgResAttrDict(
-      op, getArgDictAttrName(), op.getNumArguments(), index,
-      DictionaryAttr::get(op->getContext(), attributes));
-}
-template <typename ConcreteType>
-void setArgAttrs(ConcreteType op, unsigned index, DictionaryAttr attributes) {
-  return detail::setArgResAttrDict(
-      op, getArgDictAttrName(), op.getNumArguments(), index,
-      attributes ? attributes : DictionaryAttr::get(op->getContext()));
-}
+void setArgAttrs(FunctionOpInterface op, unsigned index,
+                 ArrayRef<NamedAttribute> attributes);
+void setArgAttrs(FunctionOpInterface op, unsigned index,
+                 DictionaryAttr attributes);
 
 /// If the an attribute exists with the specified name, change it to the new
 /// value. Otherwise, add a new attribute with the specified name/value.
@@ -158,23 +129,10 @@ Attribute removeArgAttr(ConcreteType op, unsigned index, StringAttr name) {
 //===----------------------------------------------------------------------===//
 
 /// Set the attributes held by the result at 'index'.
-template <typename ConcreteType>
-void setResultAttrs(ConcreteType op, unsigned index,
-                    ArrayRef<NamedAttribute> attributes) {
-  assert(index < op.getNumResults() && "invalid result number");
-  return detail::setArgResAttrDict(
-      op, getResultDictAttrName(), op.getNumResults(), index,
-      DictionaryAttr::get(op->getContext(), attributes));
-}
-
-template <typename ConcreteType>
-void setResultAttrs(ConcreteType op, unsigned index,
-                    DictionaryAttr attributes) {
-  assert(index < op.getNumResults() && "invalid result number");
-  return detail::setArgResAttrDict(
-      op, getResultDictAttrName(), op.getNumResults(), index,
-      attributes ? attributes : DictionaryAttr::get(op->getContext()));
-}
+void setResultAttrs(FunctionOpInterface op, unsigned index,
+                    ArrayRef<NamedAttribute> attributes);
+void setResultAttrs(FunctionOpInterface op, unsigned index,
+                    DictionaryAttr attributes);
 
 /// If the an attribute exists with the specified name, change it to the new
 /// value. Otherwise, add a new attribute with the specified name/value.
@@ -207,10 +165,6 @@ Attribute removeResultAttr(ConcreteType op, unsigned index, StringAttr name) {
 /// method on FunctionOpInterface::Trait.
 template <typename ConcreteOp>
 LogicalResult verifyTrait(ConcreteOp op) {
-  if (!op.getFunctionTypeAttr())
-    return op.emitOpError("requires a type attribute '")
-           << function_interface_impl::getTypeAttrName() << '\'';
-
   if (failed(op.verifyType()))
     return failure();
 
@@ -218,9 +172,8 @@ LogicalResult verifyTrait(ConcreteOp op) {
     unsigned numArgs = op.getNumArguments();
     if (allArgAttrs.size() != numArgs) {
       return op.emitOpError()
-             << "expects argument attribute array `" << getArgDictAttrName()
-             << "` to have the same number of elements as the number of "
-                "function arguments, got "
+             << "expects argument attribute array to have the same number of "
+                "elements as the number of function arguments, got "
              << allArgAttrs.size() << ", but expected " << numArgs;
     }
     for (unsigned i = 0; i != numArgs; ++i) {
@@ -250,9 +203,8 @@ LogicalResult verifyTrait(ConcreteOp op) {
     unsigned numResults = op.getNumResults();
     if (allResultAttrs.size() != numResults) {
       return op.emitOpError()
-             << "expects result attribute array `" << getResultDictAttrName()
-             << "` to have the same number of elements as the number of "
-                "function results, got "
+             << "expects result attribute array to have the same number of "
+                "elements as the number of function results, got "
              << allResultAttrs.size() << ", but expected " << numResults;
     }
     for (unsigned i = 0; i != numResults; ++i) {
index c56129e..0e8a3ad 100644 (file)
@@ -50,6 +50,52 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
   }];
   let methods = [
     InterfaceMethod<[{
+      Returns the type of the function.
+    }],
+    "::mlir::Type", "getFunctionType">,
+    InterfaceMethod<[{
+      Set the type of the function. This method should perform an unsafe
+      modification to the function type; it should not update argument or
+      result attributes.
+    }],
+    "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>,
+
+    InterfaceMethod<[{
+      Get the array of argument attribute dictionaries. The method should return
+      an array attribute containing only dictionary attributes equal in number
+      to the number of function arguments. Alternatively, the method can return
+      null to indicate that the function has no argument attributes.
+    }],
+    "::mlir::ArrayAttr", "getArgAttrsAttr">,
+    InterfaceMethod<[{
+      Get the array of result attribute dictionaries. The method should return
+      an array attribute containing only dictionary attributes equal in number
+      to the number of function results. Alternatively, the method can return
+      null to indicate that the function has no result attributes.
+    }],
+    "::mlir::ArrayAttr", "getResAttrsAttr">,
+    InterfaceMethod<[{
+      Set the array of argument attribute dictionaries.
+    }],
+    "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>,
+    InterfaceMethod<[{
+      Set the array of result attribute dictionaries.
+    }],
+    "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>,
+    InterfaceMethod<[{
+      Remove the array of argument attribute dictionaries. This is the same as
+      setting all argument attributes to an empty dictionary. The method should
+      return the removed attribute.
+    }],
+    "::mlir::Attribute", "removeArgAttrsAttr">,
+    InterfaceMethod<[{
+      Remove the array of result attribute dictionaries. This is the same as
+      setting all result attributes to an empty dictionary. The method should
+      return the removed attribute.
+    }],
+    "::mlir::Attribute", "removeResAttrsAttr">,
+
+    InterfaceMethod<[{
       Returns the function argument types based exclusively on
       the type (to allow for this method may be called on function
       declarations).
@@ -139,7 +185,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
         ArrayRef<NamedAttribute> attrs, TypeRange inputTypes) {
       state.addAttribute(SymbolTable::getSymbolAttrName(),
                         builder.getStringAttr(name));
-      state.addAttribute(function_interface_impl::getTypeAttrName(),
+      state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name),
                         TypeAttr::get(type));
       state.attributes.append(attrs.begin(), attrs.end());
 
@@ -240,34 +286,6 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
       function_interface_impl::setFunctionType(this->getOperation(), newType);
     }
 
-    // FIXME: These functions should be removed in favor of just forwarding to
-    // the derived operation, which should already have these defined
-    // (via ODS).
-
-    /// Returns the name of the attribute used for function types.
-    static StringRef getTypeAttrName() {
-      return function_interface_impl::getTypeAttrName();
-    }
-
-    /// Returns the name of the attribute used for function argument attributes.
-    static StringRef getArgDictAttrName() {
-      return function_interface_impl::getArgDictAttrName();
-    }
-
-    /// Returns the name of the attribute used for function argument attributes.
-    static StringRef getResultDictAttrName() {
-      return function_interface_impl::getResultDictAttrName();
-    }
-
-    /// Return the attribute containing the type of this function.
-    TypeAttr getFunctionTypeAttr() {
-      return this->getOperation()->template getAttrOfType<TypeAttr>(
-          getTypeAttrName());
-    }
-
-    /// Return the type of this function.
-    Type getFunctionType() { return getFunctionTypeAttr().getValue(); }
-
     //===------------------------------------------------------------------===//
     // Argument and Result Handling
     //===------------------------------------------------------------------===//
@@ -409,10 +427,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
 
     /// Return an ArrayAttr containing all argument attribute dictionaries of
     /// this function, or nullptr if no arguments have attributes.
-    ArrayAttr getAllArgAttrs() {
-      return this->getOperation()->template getAttrOfType<ArrayAttr>(
-          getArgDictAttrName());
-    }
+    ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); }
+
     /// Return all argument attributes of this function.
     void getAllArgAttrs(SmallVectorImpl<DictionaryAttr> &result) {
       if (ArrayAttr argAttrs = getAllArgAttrs()) {
@@ -464,7 +480,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
     }
     void setAllArgAttrs(ArrayAttr attributes) {
       assert(attributes.size() == $_op.getNumArguments());
-      this->getOperation()->setAttr(getArgDictAttrName(), attributes);
+      $_op.setArgAttrsAttr(attributes);
     }
 
     /// If the an attribute exists with the specified name, change it to the new
@@ -500,10 +516,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
 
     /// Return an ArrayAttr containing all result attribute dictionaries of this
     /// function, or nullptr if no result have attributes.
-    ArrayAttr getAllResultAttrs() {
-      return this->getOperation()->template getAttrOfType<ArrayAttr>(
-          getResultDictAttrName());
-    }
+    ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); }
+
     /// Return all result attributes of this function.
     void getAllResultAttrs(SmallVectorImpl<DictionaryAttr> &result) {
       if (ArrayAttr argAttrs = getAllResultAttrs()) {
@@ -557,7 +571,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
     }
     void setAllResultAttrs(ArrayAttr attributes) {
       assert(attributes.size() == $_op.getNumResults());
-      this->getOperation()->setAttr(getResultDictAttrName(), attributes);
+      $_op.setResAttrsAttr(attributes);
     }
 
     /// If the an attribute exists with the specified name, change it to the new
index 2affd9a..400f671 100644 (file)
@@ -1524,6 +1524,8 @@ def TypeArrayAttr : TypedArrayAttrBase<TypeAttr, "type array attribute"> {
 }
 def IndexListArrayAttr :
   TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;
+def DictArrayAttr :
+  TypedArrayAttrBase<DictionaryAttr, "Array of dictionary attributes">;
 
 // Attributes containing symbol references.
 def SymbolRefAttr : Attr<CPred<"$_self.isa<::mlir::SymbolRefAttr>()">,
index d0e82de..0cd024e 100644 (file)
@@ -59,16 +59,15 @@ using namespace mlir;
 /// Only retain those attributes that are not constructed by
 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
 /// attributes.
-static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
-                                 bool filterArgAndResAttrs,
+static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs,
                                  SmallVectorImpl<NamedAttribute> &result) {
-  for (const auto &attr : attrs) {
+  for (const NamedAttribute &attr : func->getAttrs()) {
     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
-        attr.getName() == FunctionOpInterface::getTypeAttrName() ||
+        attr.getName() == func.getFunctionTypeAttrName() ||
         attr.getName() == "func.varargs" ||
         (filterArgAndResAttrs &&
-         (attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
-          attr.getName() == FunctionOpInterface::getResultDictAttrName())))
+         (attr.getName() == func.getArgAttrsAttrName() ||
+          attr.getName() == func.getResAttrsAttrName())))
       continue;
     result.push_back(attr);
   }
@@ -91,18 +90,19 @@ static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
 static void
 prependResAttrsToArgAttrs(OpBuilder &builder,
                           SmallVectorImpl<NamedAttribute> &attributes,
-                          size_t numArguments) {
+                          func::FuncOp func) {
+  size_t numArguments = func.getNumArguments();
   auto allAttrs = SmallVector<Attribute>(
       numArguments + 1, DictionaryAttr::get(builder.getContext()));
   NamedAttribute *argAttrs = nullptr;
   for (auto *it = attributes.begin(); it != attributes.end();) {
-    if (it->getName() == FunctionOpInterface::getArgDictAttrName()) {
+    if (it->getName() == func.getArgAttrsAttrName()) {
       auto arrayAttrs = it->getValue().cast<ArrayAttr>();
       assert(arrayAttrs.size() == numArguments &&
              "Number of arg attrs and args should match");
       std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1);
       argAttrs = it;
-    } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) {
+    } else if (it->getName() == func.getResAttrsAttrName()) {
       auto arrayAttrs = it->getValue().cast<ArrayAttr>();
       assert(!arrayAttrs.empty() && "expected array to be non-empty");
       allAttrs[0] = (arrayAttrs.size() == 1)
@@ -114,9 +114,8 @@ prependResAttrsToArgAttrs(OpBuilder &builder,
     it++;
   }
 
-  auto newArgAttrs =
-      builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
-                           builder.getArrayAttr(allAttrs));
+  auto newArgAttrs = builder.getNamedAttr(func.getArgAttrsAttrName(),
+                                          builder.getArrayAttr(allAttrs));
   if (!argAttrs) {
     attributes.emplace_back(newArgAttrs);
     return;
@@ -138,12 +137,11 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
                                    LLVM::LLVMFuncOp newFuncOp) {
   auto type = funcOp.getFunctionType();
   SmallVector<NamedAttribute, 4> attributes;
-  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
-                       attributes);
+  filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
   auto [wrapperFuncType, resultIsNowArg] =
       typeConverter.convertFunctionTypeCWrapper(type);
   if (resultIsNowArg)
-    prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments());
+    prependResAttrsToArgAttrs(rewriter, attributes, funcOp);
   auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
       wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false,
@@ -204,11 +202,10 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   assert(wrapperType && "unexpected type conversion failure");
 
   SmallVector<NamedAttribute, 4> attributes;
-  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
-                       attributes);
+  filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
 
   if (resultIsNowArg)
-    prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
+    prependResAttrsToArgAttrs(builder, attributes, funcOp);
   // Create the auxiliary function.
   auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
@@ -304,8 +301,7 @@ protected:
     // Propagate argument/result attributes to all converted arguments/result
     // obtained after converting a given original argument/result.
     SmallVector<NamedAttribute, 4> attributes;
-    filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
-                         attributes);
+    filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/true, attributes);
     if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
       assert(!resAttrDicts.empty() && "expected array to be non-empty");
       auto newResAttrDicts =
@@ -313,8 +309,8 @@ protected:
               ? resAttrDicts
               : rewriter.getArrayAttr(
                     {wrapAsStructAttrs(rewriter, resAttrDicts)});
-      attributes.push_back(rewriter.getNamedAttr(
-          FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
+      attributes.push_back(
+          rewriter.getNamedAttr(funcOp.getResAttrsAttrName(), newResAttrDicts));
     }
     if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
       SmallVector<Attribute, 4> newArgAttrs(
@@ -357,9 +353,8 @@ protected:
           newArgAttrs[mapping->inputNo + j] =
               DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
       }
-      attributes.push_back(
-          rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
-                                rewriter.getArrayAttr(newArgAttrs)));
+      attributes.push_back(rewriter.getNamedAttr(
+          funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(newArgAttrs)));
     }
     for (const auto &pair : llvm::enumerate(attributes)) {
       if (pair.value().getName() == "llvm.linkage") {
index 85001d5..48effe2 100644 (file)
@@ -60,7 +60,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   SmallVector<NamedAttribute, 4> attributes;
   for (const auto &attr : gpuFuncOp->getAttrs()) {
     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
-        attr.getName() == FunctionOpInterface::getTypeAttrName() ||
+        attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
         attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
       continue;
     attributes.push_back(attr);
index 119b1d3..2a83895 100644 (file)
@@ -226,7 +226,7 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
                                std::nullopt));
   for (const auto &namedAttr : funcOp->getAttrs()) {
-    if (namedAttr.getName() == FunctionOpInterface::getTypeAttrName() ||
+    if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
         namedAttr.getName() == SymbolTable::getSymbolAttrName())
       continue;
     newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
index e0772b4..54acc37 100644 (file)
@@ -332,8 +332,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
                    ArrayRef<DictionaryAttr> argAttrs) {
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
-  state.addAttribute(FunctionOpInterface::getTypeAttrName(),
-                     TypeAttr::get(type));
+  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
 
   state.attributes.append(attrs.begin(), attrs.end());
   state.addRegion();
@@ -341,8 +340,9 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
-                                                /*resultAttrs=*/std::nullopt);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
+      getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
 
 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -352,11 +352,15 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Check that the result type of async.func is not void and must be
index 961cf2e..7bb3663 100644 (file)
@@ -244,16 +244,16 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
                    ArrayRef<DictionaryAttr> argAttrs) {
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
-  state.addAttribute(FunctionOpInterface::getTypeAttrName(),
-                     TypeAttr::get(type));
+  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
   state.attributes.append(attrs.begin(), attrs.end());
   state.addRegion();
 
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
-                                                /*resultAttrs=*/std::nullopt);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
+      getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
 
 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -263,11 +263,15 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Clone the internal blocks from this function into dest and all attributes
index 7f73a65..9ea1b11 100644 (file)
@@ -859,7 +859,8 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
                       ArrayRef<NamedAttribute> attrs) {
   result.addAttribute(SymbolTable::getSymbolAttrName(),
                       builder.getStringAttr(name));
-  result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+  result.addAttribute(getFunctionTypeAttrName(result.name),
+                      TypeAttr::get(type));
   result.addAttribute(getNumWorkgroupAttributionsAttrName(),
                       builder.getI64IntegerAttr(workgroupAttributions.size()));
   result.addAttributes(attrs);
@@ -930,10 +931,12 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   for (auto &arg : entryArgs)
     argTypes.push_back(arg.type);
   auto type = builder.getFunctionType(argTypes, resultTypes);
-  result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type));
+  result.addAttribute(getFunctionTypeAttrName(result.name),
+                      TypeAttr::get(type));
 
-  function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
-                                                resultAttrs);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+      getResAttrsAttrName(result.name));
 
   // Parse workgroup memory attributions.
   if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
@@ -992,19 +995,15 @@ void GPUFuncOp::print(OpAsmPrinter &p) {
     p << ' ' << getKernelKeyword();
 
   function_interface_impl::printFunctionAttributes(
-      p, *this, type.getNumInputs(), type.getNumResults(),
+      p, *this,
       {getNumWorkgroupAttributionsAttrName(),
-       GPUDialect::getKernelFuncAttrName()});
+       GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
+       getArgAttrsAttrName(), getResAttrsAttrName()});
   p << ' ';
   p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
 }
 
 LogicalResult GPUFuncOp::verifyType() {
-  Type type = getFunctionTypeAttr().getValue();
-  if (!type.isa<FunctionType>())
-    return emitOpError("requires '" + getTypeAttrName() +
-                       "' attribute of function type");
-
   if (isKernel() && getFunctionType().getNumResults() != 0)
     return emitOpError() << "expected void return type for kernel function";
 
index 1087bcf..a4860e3 100644 (file)
@@ -2006,8 +2006,9 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
 
   assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
          "expected as many argument attribute lists as arguments");
-  function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs,
-                                                /*resultAttrs=*/std::nullopt);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 // Builds an LLVM function type from the given lists of input and output types.
@@ -2090,13 +2091,14 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
                             function_interface_impl::VariadicFlag(isVariadic));
   if (!type)
     return failure();
-  result.addAttribute(FunctionOpInterface::getTypeAttrName(),
+  result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(type));
 
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
     return failure();
-  function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result,
-                                                entryArgs, resultAttrs);
+  function_interface_impl::addArgAndResultAttrs(
+      parser.getBuilder(), result, entryArgs, resultAttrs,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 
   auto *body = result.addRegion();
   OptionalParseResult parseResult =
@@ -2130,8 +2132,9 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
   function_interface_impl::printFunctionSignature(p, *this, argTypes,
                                                   isVarArg(), resTypes);
   function_interface_impl::printFunctionAttributes(
-      p, *this, argTypes.size(), resTypes.size(),
-      {getLinkageAttrName(), getCConvAttrName()});
+      p, *this,
+      {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
+       getLinkageAttrName(), getCConvAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = getBody();
index 2f1e4b9..31ed5ad 100644 (file)
@@ -152,11 +152,15 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//
@@ -313,11 +317,15 @@ ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void SubgraphOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//
index e8a61ef..2cc282d 100644 (file)
@@ -220,11 +220,15 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//
index 52ad8ad..3341b5e 100644 (file)
@@ -2382,7 +2382,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
   for (auto &arg : entryArgs)
     argTypes.push_back(arg.type);
   auto fnType = builder.getFunctionType(argTypes, resultTypes);
-  result.addAttribute(FunctionOpInterface::getTypeAttrName(),
+  result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(fnType));
 
   // Parse the optional function control keyword.
@@ -2396,8 +2396,9 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   // Add the attributes to the function arguments.
   assert(resultAttrs.size() == resultTypes.size());
-  function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
-                                                resultAttrs);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+      getResAttrsAttrName(result.name));
 
   // Parse the optional function body.
   auto *body = result.addRegion();
@@ -2417,8 +2418,10 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
   printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
           << "\"";
   function_interface_impl::printFunctionAttributes(
-      printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
-      {spirv::attributeName<spirv::FunctionControl>()});
+      printer, *this,
+      {spirv::attributeName<spirv::FunctionControl>(),
+       getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
+       getFunctionControlAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = this->getBody();
@@ -2430,10 +2433,6 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
 }
 
 LogicalResult spirv::FuncOp::verifyType() {
-  auto type = getFunctionTypeAttr().getValue();
-  if (!type.isa<FunctionType>())
-    return emitOpError("requires '" + getTypeAttrName() +
-                       "' attribute of function type");
   if (getFunctionType().getNumResults() > 1)
     return emitOpError("cannot have more than one result");
   return success();
@@ -2473,7 +2472,7 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
                           ArrayRef<NamedAttribute> attrs) {
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
-  state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
   state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
                      builder.getAttr<spirv::FunctionControlAttr>(control));
   state.attributes.append(attrs.begin(), attrs.end());
index 2772c01..62e3a3d 100644 (file)
@@ -531,7 +531,7 @@ FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
 
   // Copy over all attributes other than the function name and type.
   for (const auto &namedAttr : funcOp->getAttrs()) {
-    if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() &&
+    if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
         namedAttr.getName() != SymbolTable::getSymbolAttrName())
       newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
   }
index 8c89ec8..28ac98e 100644 (file)
@@ -1300,8 +1300,9 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
-                                                /*resultAttrs=*/std::nullopt);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
+      getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
 
 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -1311,11 +1312,15 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//
index 9481e4a..5ca6777 100644 (file)
@@ -113,7 +113,7 @@ parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
   return parser.parseRParen();
 }
 
-ParseResult mlir::function_interface_impl::parseFunctionSignature(
+ParseResult function_interface_impl::parseFunctionSignature(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
     SmallVectorImpl<Type> &resultTypes,
@@ -125,9 +125,10 @@ ParseResult mlir::function_interface_impl::parseFunctionSignature(
   return success();
 }
 
-void mlir::function_interface_impl::addArgAndResultAttrs(
+void function_interface_impl::addArgAndResultAttrs(
     Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
-    ArrayRef<DictionaryAttr> resultAttrs) {
+    ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
+    StringAttr resAttrsName) {
   auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
     return attrs && !attrs.empty();
   };
@@ -142,28 +143,28 @@ void mlir::function_interface_impl::addArgAndResultAttrs(
 
   // Add the attributes to the function arguments.
   if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
-    result.addAttribute(function_interface_impl::getArgDictAttrName(),
-                        getArrayAttr(argAttrs));
+    result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
 
   // Add the attributes to the function results.
   if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
-    result.addAttribute(function_interface_impl::getResultDictAttrName(),
-                        getArrayAttr(resultAttrs));
+    result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
 }
 
-void mlir::function_interface_impl::addArgAndResultAttrs(
+void function_interface_impl::addArgAndResultAttrs(
     Builder &builder, OperationState &result,
-    ArrayRef<OpAsmParser::Argument> args,
-    ArrayRef<DictionaryAttr> resultAttrs) {
+    ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
+    StringAttr argAttrsName, StringAttr resAttrsName) {
   SmallVector<DictionaryAttr> argAttrs;
   for (const auto &arg : args)
     argAttrs.push_back(arg.attrs);
-  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
+  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
+                       resAttrsName);
 }
 
-ParseResult mlir::function_interface_impl::parseFunctionOp(
+ParseResult function_interface_impl::parseFunctionOp(
     OpAsmParser &parser, OperationState &result, bool allowVariadic,
-    FuncTypeBuilder funcTypeBuilder) {
+    StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
+    StringAttr argAttrsName, StringAttr resAttrsName) {
   SmallVector<OpAsmParser::Argument> entryArgs;
   SmallVector<DictionaryAttr> resultAttrs;
   SmallVector<Type> resultTypes;
@@ -197,7 +198,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
            << "failed to construct function type"
            << (errorMessage.empty() ? "" : ": ") << errorMessage;
   }
-  result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
 
   // If function attributes are present, parse them.
   NamedAttrList parsedAttributes;
@@ -209,7 +210,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
   // dictionary.
   for (StringRef disallowed :
        {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
-        getTypeAttrName()}) {
+        typeAttrName.getValue()}) {
     if (parsedAttributes.get(disallowed))
       return parser.emitError(attributeDictLocation, "'")
              << disallowed
@@ -220,7 +221,8 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
 
   // Add the attributes to the function arguments.
   assert(resultAttrs.size() == resultTypes.size());
-  addArgAndResultAttrs(builder, result, entryArgs, resultAttrs);
+  addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName,
+                       resAttrsName);
 
   // Parse the optional function body. The printer will not print the body if
   // its empty, so disallow parsing of empty body in the parser.
@@ -261,14 +263,14 @@ static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
     os << ')';
 }
 
-void mlir::function_interface_impl::printFunctionSignature(
-    OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
-    ArrayRef<Type> resultTypes) {
+void function_interface_impl::printFunctionSignature(
+    OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
+    bool isVariadic, ArrayRef<Type> resultTypes) {
   Region &body = op->getRegion(0);
   bool isExternal = body.empty();
 
   p << '(';
-  ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+  ArrayAttr argAttrs = op.getArgAttrsAttr();
   for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
     if (i > 0)
       p << ", ";
@@ -295,26 +297,23 @@ void mlir::function_interface_impl::printFunctionSignature(
 
   if (!resultTypes.empty()) {
     p.getStream() << " -> ";
-    auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+    auto resultAttrs = op.getResAttrsAttr();
     printFunctionResultList(p, resultTypes, resultAttrs);
   }
 }
 
-void mlir::function_interface_impl::printFunctionAttributes(
-    OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults,
-    ArrayRef<StringRef> elided) {
+void function_interface_impl::printFunctionAttributes(
+    OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
   // Print out function attributes, if present.
-  SmallVector<StringRef, 2> ignoredAttrs = {
-      ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(),
-      getArgDictAttrName(), getResultDictAttrName()};
+  SmallVector<StringRef, 8> ignoredAttrs = {SymbolTable::getSymbolAttrName()};
   ignoredAttrs.append(elided.begin(), elided.end());
 
   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
 }
 
-void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
-                                                    FunctionOpInterface op,
-                                                    bool isVariadic) {
+void function_interface_impl::printFunctionOp(
+    OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
+    StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) {
   // Print the operation and the function name.
   auto funcName =
       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
@@ -329,8 +328,8 @@ void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
   ArrayRef<Type> argTypes = op.getArgumentTypes();
   ArrayRef<Type> resultTypes = op.getResultTypes();
   printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
-  printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(),
-                          {visibilityAttrName});
+  printFunctionAttributes(
+      p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName});
   // Print the body if this is not an external function.
   Region &body = op->getRegion(0);
   if (!body.empty()) {
index 3331aef..4a50a49 100644 (file)
@@ -24,27 +24,104 @@ static bool isEmptyAttrDict(Attribute attr) {
   return attr.cast<DictionaryAttr>().empty();
 }
 
-DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op,
-                                                             unsigned index) {
-  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op,
+                                                       unsigned index) {
+  ArrayAttr attrs = op.getArgAttrsAttr();
   DictionaryAttr argAttrs =
       attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
   return argAttrs;
 }
 
 DictionaryAttr
-mlir::function_interface_impl::getResultAttrDict(Operation *op,
-                                                 unsigned index) {
-  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+function_interface_impl::getResultAttrDict(FunctionOpInterface op,
+                                           unsigned index) {
+  ArrayAttr attrs = op.getResAttrsAttr();
   DictionaryAttr resAttrs =
       attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
   return resAttrs;
 }
 
-void mlir::function_interface_impl::detail::setArgResAttrDict(
-    Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
-    DictionaryAttr attrs) {
-  ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
+ArrayRef<NamedAttribute>
+function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) {
+  auto argDict = getArgAttrDict(op, index);
+  return argDict ? argDict.getValue() : std::nullopt;
+}
+
+ArrayRef<NamedAttribute>
+function_interface_impl::getResultAttrs(FunctionOpInterface op,
+                                        unsigned index) {
+  auto resultDict = getResultAttrDict(op, index);
+  return resultDict ? resultDict.getValue() : std::nullopt;
+}
+
+/// Get either the argument or result attributes array.
+template <bool isArg>
+static ArrayAttr getArgResAttrs(FunctionOpInterface op) {
+  if constexpr (isArg)
+    return op.getArgAttrsAttr();
+  else
+    return op.getResAttrsAttr();
+}
+
+/// Set either the argument or result attributes array.
+template <bool isArg>
+static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) {
+  if constexpr (isArg)
+    op.setArgAttrsAttr(attrs);
+  else
+    op.setResAttrsAttr(attrs);
+}
+
+/// Erase either the argument or result attributes array.
+template <bool isArg>
+static void removeArgResAttrs(FunctionOpInterface op) {
+  if constexpr (isArg)
+    op.removeArgAttrsAttr();
+  else
+    op.removeResAttrsAttr();
+}
+
+/// Set all of the argument or result attribute dictionaries for a function.
+template <bool isArg>
+static void setAllArgResAttrDicts(FunctionOpInterface op,
+                                  ArrayRef<Attribute> attrs) {
+  if (llvm::all_of(attrs, isEmptyAttrDict))
+    removeArgResAttrs<isArg>(op);
+  else
+    setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs));
+}
+
+void function_interface_impl::setAllArgAttrDicts(
+    FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
+  setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+}
+
+void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op,
+                                                 ArrayRef<Attribute> attrs) {
+  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
+    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
+  });
+  setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(wrappedAttrs));
+}
+
+void function_interface_impl::setAllResultAttrDicts(
+    FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
+  setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+}
+
+void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op,
+                                                    ArrayRef<Attribute> attrs) {
+  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
+    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
+  });
+  setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(wrappedAttrs));
+}
+
+/// Update the given index into an argument or result attribute dictionary.
+template <bool isArg>
+static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
+                              unsigned index, DictionaryAttr attrs) {
+  ArrayAttr allAttrs = getArgResAttrs<isArg>(op);
   if (!allAttrs) {
     if (attrs.empty())
       return;
@@ -53,7 +130,7 @@ void mlir::function_interface_impl::detail::setArgResAttrDict(
     SmallVector<Attribute, 8> newAttrs(numTotalIndices,
                                        DictionaryAttr::get(op->getContext()));
     newAttrs[index] = attrs;
-    op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
+    setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
     return;
   }
   // Check to see if the attribute is different from what we already have.
@@ -65,54 +142,52 @@ void mlir::function_interface_impl::detail::setArgResAttrDict(
   ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
   if (attrs.empty() &&
       llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
-      llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
-    op->removeAttr(attrName);
-    return;
-  }
+      llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict))
+    return removeArgResAttrs<isArg>(op);
 
   // Otherwise, create a new attribute array with the updated dictionary.
   SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
   newAttrs[index] = attrs;
-  op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
+  setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
 }
 
-/// Set all of the argument or result attribute dictionaries for a function.
-static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
-                                  ArrayRef<Attribute> attrs) {
-  if (llvm::all_of(attrs, isEmptyAttrDict))
-    op->removeAttr(attrName);
-  else
-    op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
+void function_interface_impl::setArgAttrs(FunctionOpInterface op,
+                                          unsigned index,
+                                          ArrayRef<NamedAttribute> attributes) {
+  assert(index < op.getNumArguments() && "invalid argument number");
+  return setArgResAttrDict</*isArg=*/true>(
+      op, op.getNumArguments(), index,
+      DictionaryAttr::get(op->getContext(), attributes));
 }
 
-void mlir::function_interface_impl::setAllArgAttrDicts(
-    Operation *op, ArrayRef<DictionaryAttr> attrs) {
-  setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
-}
-void mlir::function_interface_impl::setAllArgAttrDicts(
-    Operation *op, ArrayRef<Attribute> attrs) {
-  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
-    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
-  });
-  setAllArgResAttrDicts(op, getArgDictAttrName(),
-                        llvm::to_vector<8>(wrappedAttrs));
+void function_interface_impl::setArgAttrs(FunctionOpInterface op,
+                                          unsigned index,
+                                          DictionaryAttr attributes) {
+  return setArgResAttrDict</*isArg=*/true>(
+      op, op.getNumArguments(), index,
+      attributes ? attributes : DictionaryAttr::get(op->getContext()));
 }
 
-void mlir::function_interface_impl::setAllResultAttrDicts(
-    Operation *op, ArrayRef<DictionaryAttr> attrs) {
-  setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+void function_interface_impl::setResultAttrs(
+    FunctionOpInterface op, unsigned index,
+    ArrayRef<NamedAttribute> attributes) {
+  assert(index < op.getNumResults() && "invalid result number");
+  return setArgResAttrDict</*isArg=*/false>(
+      op, op.getNumResults(), index,
+      DictionaryAttr::get(op->getContext(), attributes));
 }
-void mlir::function_interface_impl::setAllResultAttrDicts(
-    Operation *op, ArrayRef<Attribute> attrs) {
-  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
-    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
-  });
-  setAllArgResAttrDicts(op, getResultDictAttrName(),
-                        llvm::to_vector<8>(wrappedAttrs));
+
+void function_interface_impl::setResultAttrs(FunctionOpInterface op,
+                                             unsigned index,
+                                             DictionaryAttr attributes) {
+  assert(index < op.getNumResults() && "invalid result number");
+  return setArgResAttrDict</*isArg=*/false>(
+      op, op.getNumResults(), index,
+      attributes ? attributes : DictionaryAttr::get(op->getContext()));
 }
 
-void mlir::function_interface_impl::insertFunctionArguments(
-    Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
+void function_interface_impl::insertFunctionArguments(
+    FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
     ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
     unsigned originalNumArgs, Type newType) {
   assert(argIndices.size() == argTypes.size());
@@ -128,7 +203,7 @@ void mlir::function_interface_impl::insertFunctionArguments(
   Block &entry = op->getRegion(0).front();
 
   // Update the argument attributes of the function.
-  auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+  ArrayAttr oldArgAttrs = op.getArgAttrsAttr();
   if (oldArgAttrs || !argAttrs.empty()) {
     SmallVector<DictionaryAttr, 4> newArgAttrs;
     newArgAttrs.reserve(originalNumArgs + argIndices.size());
@@ -152,15 +227,15 @@ void mlir::function_interface_impl::insertFunctionArguments(
   }
 
   // Update the function type and any entry block arguments.
-  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+  op.setFunctionTypeAttr(TypeAttr::get(newType));
   for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
     entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
 }
 
-void mlir::function_interface_impl::insertFunctionResults(
-    Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
-    ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
-    Type newType) {
+void function_interface_impl::insertFunctionResults(
+    FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
+    TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
+    unsigned originalNumResults, Type newType) {
   assert(resultIndices.size() == resultTypes.size());
   assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
   if (resultIndices.empty())
@@ -171,7 +246,7 @@ void mlir::function_interface_impl::insertFunctionResults(
   // - Result attrs.
 
   // Update the result attributes of the function.
-  auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+  ArrayAttr oldResultAttrs = op.getResAttrsAttr();
   if (oldResultAttrs || !resultAttrs.empty()) {
     SmallVector<DictionaryAttr, 4> newResultAttrs;
     newResultAttrs.reserve(originalNumResults + resultIndices.size());
@@ -196,11 +271,11 @@ void mlir::function_interface_impl::insertFunctionResults(
   }
 
   // Update the function type.
-  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+  op.setFunctionTypeAttr(TypeAttr::get(newType));
 }
 
-void mlir::function_interface_impl::eraseFunctionArguments(
-    Operation *op, const BitVector &argIndices, Type newType) {
+void function_interface_impl::eraseFunctionArguments(
+    FunctionOpInterface op, const BitVector &argIndices, Type newType) {
   // There are 3 things that need to be updated:
   // - Function type.
   // - Arg attrs.
@@ -208,7 +283,7 @@ void mlir::function_interface_impl::eraseFunctionArguments(
   Block &entry = op->getRegion(0).front();
 
   // Update the argument attributes of the function.
-  if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
+  if (ArrayAttr argAttrs = op.getArgAttrsAttr()) {
     SmallVector<DictionaryAttr, 4> newArgAttrs;
     newArgAttrs.reserve(argAttrs.size());
     for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
@@ -218,18 +293,18 @@ void mlir::function_interface_impl::eraseFunctionArguments(
   }
 
   // Update the function type and any entry block arguments.
-  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+  op.setFunctionTypeAttr(TypeAttr::get(newType));
   entry.eraseArguments(argIndices);
 }
 
-void mlir::function_interface_impl::eraseFunctionResults(
-    Operation *op, const BitVector &resultIndices, Type newType) {
+void function_interface_impl::eraseFunctionResults(
+    FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
   // There are 2 things that need to be updated:
   // - Function type.
   // - Result attrs.
 
   // Update the result attributes of the function.
-  if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
+  if (ArrayAttr resAttrs = op.getResAttrsAttr()) {
     SmallVector<DictionaryAttr, 4> newResultAttrs;
     newResultAttrs.reserve(resAttrs.size());
     for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
@@ -239,10 +314,10 @@ void mlir::function_interface_impl::eraseFunctionResults(
   }
 
   // Update the function type.
-  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+  op.setFunctionTypeAttr(TypeAttr::get(newType));
 }
 
-TypeRange mlir::function_interface_impl::insertTypesInto(
+TypeRange function_interface_impl::insertTypesInto(
     TypeRange oldTypes, ArrayRef<unsigned> indices, TypeRange newTypes,
     SmallVectorImpl<Type> &storage) {
   assert(indices.size() == newTypes.size() &&
@@ -261,7 +336,7 @@ TypeRange mlir::function_interface_impl::insertTypesInto(
   return storage;
 }
 
-TypeRange mlir::function_interface_impl::filterTypesOut(
+TypeRange function_interface_impl::filterTypesOut(
     TypeRange types, const BitVector &indices, SmallVectorImpl<Type> &storage) {
   if (indices.none())
     return types;
@@ -276,45 +351,41 @@ TypeRange mlir::function_interface_impl::filterTypesOut(
 // Function type signature.
 //===----------------------------------------------------------------------===//
 
-void mlir::function_interface_impl::setFunctionType(Operation *op,
-                                                    Type newType) {
-  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
-  unsigned oldNumArgs = funcOp.getNumArguments();
-  unsigned oldNumResults = funcOp.getNumResults();
-  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
-  unsigned newNumArgs = funcOp.getNumArguments();
-  unsigned newNumResults = funcOp.getNumResults();
+void function_interface_impl::setFunctionType(FunctionOpInterface op,
+                                              Type newType) {
+  unsigned oldNumArgs = op.getNumArguments();
+  unsigned oldNumResults = op.getNumResults();
+  op.setFunctionTypeAttr(TypeAttr::get(newType));
+  unsigned newNumArgs = op.getNumArguments();
+  unsigned newNumResults = op.getNumResults();
 
   // Functor used to update the argument and result attributes of the function.
-  auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
-                          unsigned newCount, auto setAttrFn) {
+  auto emptyDict = DictionaryAttr::get(op.getContext());
+  auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) {
+    constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>;
+
     if (oldCount == newCount)
       return;
     // The new type has no arguments/results, just drop the attribute.
-    if (newCount == 0) {
-      op->removeAttr(attrName);
-      return;
-    }
-    ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
+    if (newCount == 0)
+      return removeArgResAttrs<isArgVal>(op);
+    ArrayAttr attrs = getArgResAttrs<isArgVal>(op);
     if (!attrs)
       return;
 
     // The new type has less arguments/results, take the first N attributes.
     if (newCount < oldCount)
-      return setAttrFn(op, attrs.getValue().take_front(newCount));
+      return setAllArgResAttrDicts<isArgVal>(
+          op, attrs.getValue().take_front(newCount));
 
     // Otherwise, the new type has more arguments/results. Initialize the new
-    // arguments/results with empty attributes.
+    // arguments/results with empty dictionary attributes.
     SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
-    newAttrs.resize(newCount);
-    setAttrFn(op, newAttrs);
+    newAttrs.resize(newCount, emptyDict);
+    setAllArgResAttrDicts<isArgVal>(op, newAttrs);
   };
 
   // Update the argument and result attributes.
-  updateAttrFn(
-      getArgDictAttrName(), oldNumArgs, newNumArgs,
-      [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); });
-  updateAttrFn(
-      getResultDictAttrName(), oldNumResults, newNumResults,
-      [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
+  updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs);
+  updateAttrFn(std::false_type{}, oldNumResults, newNumResults);
 }
index a72abad..d995689 100644 (file)
@@ -96,20 +96,11 @@ func.func private @invalid_symbol_type_attr() attributes { function_type = "x" }
 
 // -----
 
-// expected-error@+1 {{argument attribute array `arg_attrs` to have the same number of elements as the number of function arguments}}
+// expected-error@+1 {{argument attribute array to have the same number of elements as the number of function arguments}}
 func.func private @invalid_arg_attrs() attributes { arg_attrs = [{}] }
 
 // -----
 
-// expected-error@+1 {{expects argument attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
-func.func private @invalid_arg_attrs(i32) attributes { arg_attrs = [10] }
 
-// -----
-
-// expected-error@+1 {{result attribute array `res_attrs` to have the same number of elements as the number of function results}}
+// expected-error@+1 {{result attribute array to have the same number of elements as the number of function results}}
 func.func private @invalid_res_attrs() attributes { res_attrs = [{}] }
-
-// -----
-
-// expected-error@+1 {{expects result attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
-func.func private @invalid_res_attrs() -> i32 attributes { res_attrs = [10] }