From: Jeff Niu Date: Tue, 6 Dec 2022 19:28:47 +0000 (-0800) Subject: [mlir] FunctionOpInterface: turn required attributes into interface methods (Reland) X-Git-Tag: upstream/17.0.6~24358 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=53406427cdf4290986d1a48ea0d582ad195bff15;p=platform%2Fupstream%2Fllvm.git [mlir] FunctionOpInterface: turn required attributes into interface methods (Reland) Reland D139447, D139471 With flang actually working - FunctionOpInterface: make get/setFunctionType interface methods This patch removes the concept of a `function_type`-named type attribute as a requirement for implementors of FunctionOpInterface. Instead, this type should be provided through two interface methods, `getFunctionType` and `setFunctionTypeAttr` (*Attr because functions may use different concrete function types), which should be automatically implemented by ODS for ops that define a `$function_type` attribute. This also allows FunctionOpInterface to materialize function types if they don't carry them in an attribute, for example. Importantly, all the function "helper" still accept an attribute name to use in parsing and printing functions, for example. - FunctionOpInterface: arg and result attrs dispatch to interface This patch removes the `arg_attrs` and `res_attrs` named attributes as a requirement for FunctionOpInterface and replaces them with interface methods for the getters, setters, and removers of the relevent attributes. This allows operations to use their own storage for the argument and result attributes. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D139736 --- diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index 1ad2526..87206c1 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -501,7 +501,8 @@ public: // correctly. for (auto e : llvm::enumerate(funcTy.getInputs())) { unsigned index = e.index(); - llvm::ArrayRef attrs = func.getArgAttrs(index); + llvm::ArrayRef attrs = + mlir::function_interface_impl::getArgAttrs(func, index); for (mlir::NamedAttribute attr : attrs) { savedAttrs.push_back({index, attr}); } diff --git a/mlir/examples/toy/Ch2/include/toy/Ops.td b/mlir/examples/toy/Ch2/include/toy/Ops.td index 380536b..4e2fb9e 100644 --- a/mlir/examples/toy/Ch2/include/toy/Ops.td +++ b/mlir/examples/toy/Ch2/include/toy/Ops.td @@ -134,7 +134,9 @@ def FuncOp : Toy_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp index dbc1efb..a6ccbbf 100644 --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -211,14 +211,17 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch3/include/toy/Ops.td b/mlir/examples/toy/Ch3/include/toy/Ops.td index e526fe5..1a4e6a1 100644 --- a/mlir/examples/toy/Ch3/include/toy/Ops.td +++ b/mlir/examples/toy/Ch3/include/toy/Ops.td @@ -133,7 +133,9 @@ def FuncOp : Toy_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp index 50e2dfc..913979a 100644 --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -198,14 +198,17 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td index 4956b0e..cbece47 100644 --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index 0a6195b..5db2f95 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -287,14 +287,17 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td index f4e7b08..70e482d 100644 --- a/mlir/examples/toy/Ch5/include/toy/Ops.td +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index f236a1f..c2015ee 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -287,14 +287,17 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td index ea9323e..cf2bc3f 100644 --- a/mlir/examples/toy/Ch6/include/toy/Ops.td +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp index f236a1f..c2015ee 100644 --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -287,14 +287,17 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td index 45ecdd3..08671a7 100644 --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -186,7 +186,9 @@ def FuncOp : Toy_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index cc66a5d..ffcdd7a 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -314,14 +314,17 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td index 30895e5..14146cd 100644 --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -140,7 +140,9 @@ def Async_FuncOp : Async_Op<"func", let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, - OptionalAttr:$sym_visibility); + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td index f1b7cfd..4922689 100644 --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -251,7 +251,9 @@ def FuncOp : Func_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, - OptionalAttr:$sym_visibility); + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let regions = (region AnyRegion:$body); let builders = [OpBuilder<(ins diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 0642b18..f9fff78 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -242,7 +242,9 @@ def GPU_GPUFuncOp : GPU_Op<"func", [ attribution. }]; - let arguments = (ins TypeAttrOf:$function_type); + let arguments = (ins TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let regions = (region AnyRegion:$body); let skipDefaultBuilders = 1; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 19e589c..afc07e2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1311,7 +1311,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ DefaultValuedAttr:$CConv, OptionalAttr:$personality, OptionalAttr:$garbageCollector, - OptionalAttr:$passthrough + OptionalAttr:$passthrough, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td index 422680a..db6c773 100644 --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td @@ -52,6 +52,8 @@ def MLProgram_FuncOp : MLProgram_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); @@ -401,6 +403,8 @@ def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td index 42a48cd..6ecbed2 100644 --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -652,7 +652,9 @@ def PDLInterp_FuncOp : PDLInterp_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region MinSizedRegion<1>:$body); diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td index 147705e..8339afc 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -291,6 +291,8 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [ let arguments = (ins TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, StrAttr:$sym_name, SPIRV_FunctionControlAttr:$function_control ); diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index c3697f0..97d1f0c 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -1107,6 +1107,8 @@ def Shape_FuncOp : Shape_Op<"func", let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h index 5265f78..eb79790 100644 --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -39,10 +39,12 @@ private: /// with special names given by getResultAttrName, getArgumentAttrName. void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef argAttrs, - ArrayRef resultAttrs); + ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName); void addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef argAttrs, - ArrayRef resultAttrs); + ArrayRef args, + ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName); /// Callback type for `parseFunctionOp`, the callback should produce the /// type that will be associated with a function-like operation from lists of @@ -69,21 +71,25 @@ Type getFunctionType(Builder &builder, ArrayRef argAttrs, /// Parser implementation for function-like operations. Uses /// `funcTypeBuilder` to construct the custom function type given lists of -/// input and output types. If `allowVariadic` is set, the parser will accept +/// input and output types. The parser sets the `typeAttrName` attribute to the +/// resulting function type. If `allowVariadic` is set, the parser will accept /// trailing ellipsis in the function signature and indicate to the builder /// whether the function is variadic. If the builder returns a null type, /// `result` will not contain the `type` attribute. The caller can then add a /// type, report the error or delegate the reporting to the op's verifier. ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, - bool allowVariadic, - FuncTypeBuilder funcTypeBuilder); + bool allowVariadic, StringAttr typeAttrName, + FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName); /// Printer implementation for function-like operations. -void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic); +void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, + StringRef typeAttrName, StringAttr argAttrsName, + StringAttr resAttrsName); /// Prints the signature of the function-like operation `op`. Assumes `op` has /// is a FunctionOpInterface and has passed verification. -void printFunctionSignature(OpAsmPrinter &p, Operation *op, +void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef argTypes, bool isVariadic, ArrayRef resultTypes); @@ -92,8 +98,7 @@ void printFunctionSignature(OpAsmPrinter &p, Operation *op, /// function-like operation internally are not printed. Nothing is printed /// if all attributes are elided. Assumes `op` is a FunctionOpInterface and /// has passed verification. -void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, - unsigned numResults, +void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef elided = {}); } // namespace function_interface_impl diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h index 23fd884..3beb3db 100644 --- a/mlir/include/mlir/IR/FunctionInterfaces.h +++ b/mlir/include/mlir/IR/FunctionInterfaces.h @@ -22,78 +22,59 @@ #include "llvm/ADT/SmallString.h" namespace mlir { +class FunctionOpInterface; namespace function_interface_impl { -/// Return the name of the attribute used for function types. -inline StringRef getTypeAttrName() { return "function_type"; } - -/// Return the name of the attribute used for function argument attributes. -inline StringRef getArgDictAttrName() { return "arg_attrs"; } - -/// Return the name of the attribute used for function argument attributes. -inline StringRef getResultDictAttrName() { return "res_attrs"; } - /// Returns the dictionary attribute corresponding to the argument at 'index'. /// If there are no argument attributes at 'index', a null attribute is /// returned. -DictionaryAttr getArgAttrDict(Operation *op, unsigned index); +DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index); /// Returns the dictionary attribute corresponding to the result at 'index'. /// If there are no result attributes at 'index', a null attribute is /// returned. -DictionaryAttr getResultAttrDict(Operation *op, unsigned index); +DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index); -namespace detail { -/// Update the given index into an argument or result attribute dictionary. -void setArgResAttrDict(Operation *op, StringRef attrName, - unsigned numTotalIndices, unsigned index, - DictionaryAttr attrs); -} // namespace detail +/// Return all of the attributes for the argument at 'index'. +ArrayRef getArgAttrs(FunctionOpInterface op, unsigned index); + +/// Return all of the attributes for the result at 'index'. +ArrayRef getResultAttrs(FunctionOpInterface op, unsigned index); /// Set all of the argument or result attribute dictionaries for a function. The /// size of `attrs` is expected to match the number of arguments/results of the /// given `op`. -void setAllArgAttrDicts(Operation *op, ArrayRef attrs); -void setAllArgAttrDicts(Operation *op, ArrayRef attrs); -void setAllResultAttrDicts(Operation *op, ArrayRef attrs); -void setAllResultAttrDicts(Operation *op, ArrayRef attrs); - -/// Return all of the attributes for the argument at 'index'. -inline ArrayRef getArgAttrs(Operation *op, unsigned index) { - auto argDict = getArgAttrDict(op, index); - return argDict ? argDict.getValue() : std::nullopt; -} - -/// Return all of the attributes for the result at 'index'. -inline ArrayRef getResultAttrs(Operation *op, unsigned index) { - auto resultDict = getResultAttrDict(op, index); - return resultDict ? resultDict.getValue() : std::nullopt; -} +void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef attrs); +void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef attrs); +void setAllResultAttrDicts(FunctionOpInterface op, + ArrayRef attrs); +void setAllResultAttrDicts(FunctionOpInterface op, ArrayRef attrs); /// Insert the specified arguments and update the function type attribute. -void insertFunctionArguments(Operation *op, ArrayRef argIndices, - TypeRange argTypes, +void insertFunctionArguments(FunctionOpInterface op, + ArrayRef argIndices, TypeRange argTypes, ArrayRef argAttrs, ArrayRef argLocs, unsigned originalNumArgs, Type newType); /// Insert the specified results and update the function type attribute. -void insertFunctionResults(Operation *op, ArrayRef resultIndices, +void insertFunctionResults(FunctionOpInterface op, + ArrayRef resultIndices, TypeRange resultTypes, ArrayRef resultAttrs, unsigned originalNumResults, Type newType); /// Erase the specified arguments and update the function type attribute. -void eraseFunctionArguments(Operation *op, const BitVector &argIndices, +void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices, Type newType); /// Erase the specified results and update the function type attribute. -void eraseFunctionResults(Operation *op, const BitVector &resultIndices, - Type newType); +void eraseFunctionResults(FunctionOpInterface op, + const BitVector &resultIndices, Type newType); /// Set a FunctionOpInterface operation's type signature. -void setFunctionType(Operation *op, Type newType); +void setFunctionType(FunctionOpInterface op, Type newType); /// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any /// types are inserted, `storage` is used to hold the new type list. The new @@ -111,20 +92,10 @@ TypeRange filterTypesOut(TypeRange types, const BitVector &indices, //===----------------------------------------------------------------------===// /// Set the attributes held by the argument at 'index'. -template -void setArgAttrs(ConcreteType op, unsigned index, - ArrayRef attributes) { - assert(index < op.getNumArguments() && "invalid argument number"); - return detail::setArgResAttrDict( - op, getArgDictAttrName(), op.getNumArguments(), index, - DictionaryAttr::get(op->getContext(), attributes)); -} -template -void setArgAttrs(ConcreteType op, unsigned index, DictionaryAttr attributes) { - return detail::setArgResAttrDict( - op, getArgDictAttrName(), op.getNumArguments(), index, - attributes ? attributes : DictionaryAttr::get(op->getContext())); -} +void setArgAttrs(FunctionOpInterface op, unsigned index, + ArrayRef attributes); +void setArgAttrs(FunctionOpInterface op, unsigned index, + DictionaryAttr attributes); /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. @@ -158,23 +129,10 @@ Attribute removeArgAttr(ConcreteType op, unsigned index, StringAttr name) { //===----------------------------------------------------------------------===// /// Set the attributes held by the result at 'index'. -template -void setResultAttrs(ConcreteType op, unsigned index, - ArrayRef attributes) { - assert(index < op.getNumResults() && "invalid result number"); - return detail::setArgResAttrDict( - op, getResultDictAttrName(), op.getNumResults(), index, - DictionaryAttr::get(op->getContext(), attributes)); -} - -template -void setResultAttrs(ConcreteType op, unsigned index, - DictionaryAttr attributes) { - assert(index < op.getNumResults() && "invalid result number"); - return detail::setArgResAttrDict( - op, getResultDictAttrName(), op.getNumResults(), index, - attributes ? attributes : DictionaryAttr::get(op->getContext())); -} +void setResultAttrs(FunctionOpInterface op, unsigned index, + ArrayRef attributes); +void setResultAttrs(FunctionOpInterface op, unsigned index, + DictionaryAttr attributes); /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. @@ -207,10 +165,6 @@ Attribute removeResultAttr(ConcreteType op, unsigned index, StringAttr name) { /// method on FunctionOpInterface::Trait. template LogicalResult verifyTrait(ConcreteOp op) { - if (!op.getFunctionTypeAttr()) - return op.emitOpError("requires a type attribute '") - << function_interface_impl::getTypeAttrName() << '\''; - if (failed(op.verifyType())) return failure(); @@ -218,9 +172,8 @@ LogicalResult verifyTrait(ConcreteOp op) { unsigned numArgs = op.getNumArguments(); if (allArgAttrs.size() != numArgs) { return op.emitOpError() - << "expects argument attribute array `" << getArgDictAttrName() - << "` to have the same number of elements as the number of " - "function arguments, got " + << "expects argument attribute array to have the same number of " + "elements as the number of function arguments, got " << allArgAttrs.size() << ", but expected " << numArgs; } for (unsigned i = 0; i != numArgs; ++i) { @@ -250,9 +203,8 @@ LogicalResult verifyTrait(ConcreteOp op) { unsigned numResults = op.getNumResults(); if (allResultAttrs.size() != numResults) { return op.emitOpError() - << "expects result attribute array `" << getResultDictAttrName() - << "` to have the same number of elements as the number of " - "function results, got " + << "expects result attribute array to have the same number of " + "elements as the number of function results, got " << allResultAttrs.size() << ", but expected " << numResults; } for (unsigned i = 0; i != numResults; ++i) { diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td index c56129e..0e8a3ad 100644 --- a/mlir/include/mlir/IR/FunctionInterfaces.td +++ b/mlir/include/mlir/IR/FunctionInterfaces.td @@ -50,6 +50,52 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { }]; let methods = [ InterfaceMethod<[{ + Returns the type of the function. + }], + "::mlir::Type", "getFunctionType">, + InterfaceMethod<[{ + Set the type of the function. This method should perform an unsafe + modification to the function type; it should not update argument or + result attributes. + }], + "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>, + + InterfaceMethod<[{ + Get the array of argument attribute dictionaries. The method should return + an array attribute containing only dictionary attributes equal in number + to the number of function arguments. Alternatively, the method can return + null to indicate that the function has no argument attributes. + }], + "::mlir::ArrayAttr", "getArgAttrsAttr">, + InterfaceMethod<[{ + Get the array of result attribute dictionaries. The method should return + an array attribute containing only dictionary attributes equal in number + to the number of function results. Alternatively, the method can return + null to indicate that the function has no result attributes. + }], + "::mlir::ArrayAttr", "getResAttrsAttr">, + InterfaceMethod<[{ + Set the array of argument attribute dictionaries. + }], + "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, + InterfaceMethod<[{ + Set the array of result attribute dictionaries. + }], + "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, + InterfaceMethod<[{ + Remove the array of argument attribute dictionaries. This is the same as + setting all argument attributes to an empty dictionary. The method should + return the removed attribute. + }], + "::mlir::Attribute", "removeArgAttrsAttr">, + InterfaceMethod<[{ + Remove the array of result attribute dictionaries. This is the same as + setting all result attributes to an empty dictionary. The method should + return the removed attribute. + }], + "::mlir::Attribute", "removeResAttrsAttr">, + + InterfaceMethod<[{ Returns the function argument types based exclusively on the type (to allow for this method may be called on function declarations). @@ -139,7 +185,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { ArrayRef attrs, TypeRange inputTypes) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(function_interface_impl::getTypeAttrName(), + state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); @@ -240,34 +286,6 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { function_interface_impl::setFunctionType(this->getOperation(), newType); } - // FIXME: These functions should be removed in favor of just forwarding to - // the derived operation, which should already have these defined - // (via ODS). - - /// Returns the name of the attribute used for function types. - static StringRef getTypeAttrName() { - return function_interface_impl::getTypeAttrName(); - } - - /// Returns the name of the attribute used for function argument attributes. - static StringRef getArgDictAttrName() { - return function_interface_impl::getArgDictAttrName(); - } - - /// Returns the name of the attribute used for function argument attributes. - static StringRef getResultDictAttrName() { - return function_interface_impl::getResultDictAttrName(); - } - - /// Return the attribute containing the type of this function. - TypeAttr getFunctionTypeAttr() { - return this->getOperation()->template getAttrOfType( - getTypeAttrName()); - } - - /// Return the type of this function. - Type getFunctionType() { return getFunctionTypeAttr().getValue(); } - //===------------------------------------------------------------------===// // Argument and Result Handling //===------------------------------------------------------------------===// @@ -409,10 +427,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { /// Return an ArrayAttr containing all argument attribute dictionaries of /// this function, or nullptr if no arguments have attributes. - ArrayAttr getAllArgAttrs() { - return this->getOperation()->template getAttrOfType( - getArgDictAttrName()); - } + ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); } + /// Return all argument attributes of this function. void getAllArgAttrs(SmallVectorImpl &result) { if (ArrayAttr argAttrs = getAllArgAttrs()) { @@ -464,7 +480,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { } void setAllArgAttrs(ArrayAttr attributes) { assert(attributes.size() == $_op.getNumArguments()); - this->getOperation()->setAttr(getArgDictAttrName(), attributes); + $_op.setArgAttrsAttr(attributes); } /// If the an attribute exists with the specified name, change it to the new @@ -500,10 +516,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { /// Return an ArrayAttr containing all result attribute dictionaries of this /// function, or nullptr if no result have attributes. - ArrayAttr getAllResultAttrs() { - return this->getOperation()->template getAttrOfType( - getResultDictAttrName()); - } + ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); } + /// Return all result attributes of this function. void getAllResultAttrs(SmallVectorImpl &result) { if (ArrayAttr argAttrs = getAllResultAttrs()) { @@ -557,7 +571,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { } void setAllResultAttrs(ArrayAttr attributes) { assert(attributes.size() == $_op.getNumResults()); - this->getOperation()->setAttr(getResultDictAttrName(), attributes); + $_op.setResAttrsAttr(attributes); } /// If the an attribute exists with the specified name, change it to the new diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 2affd9ae..400f671 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1524,6 +1524,8 @@ def TypeArrayAttr : TypedArrayAttrBase { } def IndexListArrayAttr : TypedArrayAttrBase; +def DictArrayAttr : + TypedArrayAttrBase; // Attributes containing symbol references. def SymbolRefAttr : Attr()">, diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index d0e82de..0cd024e 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -59,16 +59,15 @@ using namespace mlir; /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. -static void filterFuncAttributes(ArrayRef attrs, - bool filterArgAndResAttrs, +static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs, SmallVectorImpl &result) { - for (const auto &attr : attrs) { + for (const NamedAttribute &attr : func->getAttrs()) { if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == FunctionOpInterface::getTypeAttrName() || + attr.getName() == func.getFunctionTypeAttrName() || attr.getName() == "func.varargs" || (filterArgAndResAttrs && - (attr.getName() == FunctionOpInterface::getArgDictAttrName() || - attr.getName() == FunctionOpInterface::getResultDictAttrName()))) + (attr.getName() == func.getArgAttrsAttrName() || + attr.getName() == func.getResAttrsAttrName()))) continue; result.push_back(attr); } @@ -91,18 +90,19 @@ static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) { static void prependResAttrsToArgAttrs(OpBuilder &builder, SmallVectorImpl &attributes, - size_t numArguments) { + func::FuncOp func) { + size_t numArguments = func.getNumArguments(); auto allAttrs = SmallVector( numArguments + 1, DictionaryAttr::get(builder.getContext())); NamedAttribute *argAttrs = nullptr; for (auto *it = attributes.begin(); it != attributes.end();) { - if (it->getName() == FunctionOpInterface::getArgDictAttrName()) { + if (it->getName() == func.getArgAttrsAttrName()) { auto arrayAttrs = it->getValue().cast(); assert(arrayAttrs.size() == numArguments && "Number of arg attrs and args should match"); std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1); argAttrs = it; - } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) { + } else if (it->getName() == func.getResAttrsAttrName()) { auto arrayAttrs = it->getValue().cast(); assert(!arrayAttrs.empty() && "expected array to be non-empty"); allAttrs[0] = (arrayAttrs.size() == 1) @@ -114,9 +114,8 @@ prependResAttrsToArgAttrs(OpBuilder &builder, it++; } - auto newArgAttrs = - builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), - builder.getArrayAttr(allAttrs)); + auto newArgAttrs = builder.getNamedAttr(func.getArgAttrsAttrName(), + builder.getArrayAttr(allAttrs)); if (!argAttrs) { attributes.emplace_back(newArgAttrs); return; @@ -138,12 +137,11 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, LLVM::LLVMFuncOp newFuncOp) { auto type = funcOp.getFunctionType(); SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, - attributes); + filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes); auto [wrapperFuncType, resultIsNowArg] = typeConverter.convertFunctionTypeCWrapper(type); if (resultIsNowArg) - prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments()); + prependResAttrsToArgAttrs(rewriter, attributes, funcOp); auto wrapperFuncOp = rewriter.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, @@ -204,11 +202,10 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, assert(wrapperType && "unexpected type conversion failure"); SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, - attributes); + filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes); if (resultIsNowArg) - prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments()); + prependResAttrsToArgAttrs(builder, attributes, funcOp); // Create the auxiliary function. auto wrapperFunc = builder.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), @@ -304,8 +301,7 @@ protected: // Propagate argument/result attributes to all converted arguments/result // obtained after converting a given original argument/result. SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true, - attributes); + filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/true, attributes); if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { assert(!resAttrDicts.empty() && "expected array to be non-empty"); auto newResAttrDicts = @@ -313,8 +309,8 @@ protected: ? resAttrDicts : rewriter.getArrayAttr( {wrapAsStructAttrs(rewriter, resAttrDicts)}); - attributes.push_back(rewriter.getNamedAttr( - FunctionOpInterface::getResultDictAttrName(), newResAttrDicts)); + attributes.push_back( + rewriter.getNamedAttr(funcOp.getResAttrsAttrName(), newResAttrDicts)); } if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { SmallVector newArgAttrs( @@ -357,9 +353,8 @@ protected: newArgAttrs[mapping->inputNo + j] = DictionaryAttr::get(rewriter.getContext(), convertedAttrs); } - attributes.push_back( - rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), - rewriter.getArrayAttr(newArgAttrs))); + attributes.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(newArgAttrs))); } for (const auto &pair : llvm::enumerate(attributes)) { if (pair.value().getName() == "llvm.linkage") { diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 85001d5..48effe2 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -60,7 +60,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, SmallVector attributes; for (const auto &attr : gpuFuncOp->getAttrs()) { if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == FunctionOpInterface::getTypeAttrName() || + attr.getName() == gpuFuncOp.getFunctionTypeAttrName() || attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) continue; attributes.push_back(attr); diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 119b1d3..2a83895 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -226,7 +226,7 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter, rewriter.getFunctionType(signatureConverter.getConvertedTypes(), std::nullopt)); for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.getName() == FunctionOpInterface::getTypeAttrName() || + if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() || namedAttr.getName() == SymbolTable::getSymbolAttrName()) continue; newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index e0772b4..54acc37 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -332,8 +332,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, ArrayRef argAttrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(FunctionOpInterface::getTypeAttrName(), - TypeAttr::get(type)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); @@ -341,8 +340,9 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { @@ -352,11 +352,15 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Check that the result type of async.func is not void and must be diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp index 961cf2e..7bb3663c 100644 --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -244,16 +244,16 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, ArrayRef argAttrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(FunctionOpInterface::getTypeAttrName(), - TypeAttr::get(type)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { @@ -263,11 +263,15 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Clone the internal blocks from this function into dest and all attributes diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 7f73a65..9ea1b11 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -859,7 +859,8 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result, ArrayRef attrs) { result.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(type)); result.addAttribute(getNumWorkgroupAttributionsAttrName(), builder.getI64IntegerAttr(workgroupAttributions.size())); result.addAttributes(attrs); @@ -930,10 +931,12 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { for (auto &arg : entryArgs) argTypes.push_back(arg.type); auto type = builder.getFunctionType(argTypes, resultTypes); - result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(type)); - function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs, - resultAttrs); + function_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); // Parse workgroup memory attributions. if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), @@ -992,19 +995,15 @@ void GPUFuncOp::print(OpAsmPrinter &p) { p << ' ' << getKernelKeyword(); function_interface_impl::printFunctionAttributes( - p, *this, type.getNumInputs(), type.getNumResults(), + p, *this, {getNumWorkgroupAttributionsAttrName(), - GPUDialect::getKernelFuncAttrName()}); + GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()}); p << ' '; p.printRegion(getBody(), /*printEntryBlockArgs=*/false); } LogicalResult GPUFuncOp::verifyType() { - Type type = getFunctionTypeAttr().getValue(); - if (!type.isa()) - return emitOpError("requires '" + getTypeAttrName() + - "' attribute of function type"); - if (isKernel() && getFunctionType().getNumResults() != 0) return emitOpError() << "expected void return type for kernel function"; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 1087bcf..a4860e3 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2006,8 +2006,9 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, assert(type.cast().getNumParams() == argAttrs.size() && "expected as many argument attribute lists as arguments"); - function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, result, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } // Builds an LLVM function type from the given lists of input and output types. @@ -2090,13 +2091,14 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { function_interface_impl::VariadicFlag(isVariadic)); if (!type) return failure(); - result.addAttribute(FunctionOpInterface::getTypeAttrName(), + result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(type)); if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); - function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, - entryArgs, resultAttrs); + function_interface_impl::addArgAndResultAttrs( + parser.getBuilder(), result, entryArgs, resultAttrs, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); auto *body = result.addRegion(); OptionalParseResult parseResult = @@ -2130,8 +2132,9 @@ void LLVMFuncOp::print(OpAsmPrinter &p) { function_interface_impl::printFunctionSignature(p, *this, argTypes, isVarArg(), resTypes); function_interface_impl::printFunctionAttributes( - p, *this, argTypes.size(), resTypes.size(), - {getLinkageAttrName(), getCConvAttrName()}); + p, *this, + {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), + getLinkageAttrName(), getCConvAttrName()}); // Print the body if this is not an external function. Region &body = getBody(); diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp index 2f1e4b9..31ed5ad 100644 --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -152,11 +152,15 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// @@ -313,11 +317,15 @@ ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void SubgraphOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp index e8a61ef..2cc282d 100644 --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -220,11 +220,15 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 52ad8ad..3341b5e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2382,7 +2382,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) { for (auto &arg : entryArgs) argTypes.push_back(arg.type); auto fnType = builder.getFunctionType(argTypes, resultTypes); - result.addAttribute(FunctionOpInterface::getTypeAttrName(), + result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(fnType)); // Parse the optional function control keyword. @@ -2396,8 +2396,9 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) { // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); - function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs, - resultAttrs); + function_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); // Parse the optional function body. auto *body = result.addRegion(); @@ -2417,8 +2418,10 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) { printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl()) << "\""; function_interface_impl::printFunctionAttributes( - printer, *this, fnType.getNumInputs(), fnType.getNumResults(), - {spirv::attributeName()}); + printer, *this, + {spirv::attributeName(), + getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), + getFunctionControlAttrName()}); // Print the body if this is not an external function. Region &body = this->getBody(); @@ -2430,10 +2433,6 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) { } LogicalResult spirv::FuncOp::verifyType() { - auto type = getFunctionTypeAttr().getValue(); - if (!type.isa()) - return emitOpError("requires '" + getTypeAttrName() + - "' attribute of function type"); if (getFunctionType().getNumResults() > 1) return emitOpError("cannot have more than one result"); return success(); @@ -2473,7 +2472,7 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state, ArrayRef attrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.addAttribute(spirv::attributeName(), builder.getAttr(control)); state.attributes.append(attrs.begin(), attrs.end()); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 2772c01..62e3a3d 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -531,7 +531,7 @@ FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() && + if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && namedAttr.getName() != SymbolTable::getSymbolAttrName()) newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 8c89ec8..28ac98e 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1300,8 +1300,9 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { @@ -1311,11 +1312,15 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp index 9481e4a..5ca6777 100644 --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -113,7 +113,7 @@ parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, return parser.parseRParen(); } -ParseResult mlir::function_interface_impl::parseFunctionSignature( +ParseResult function_interface_impl::parseFunctionSignature( OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &arguments, bool &isVariadic, SmallVectorImpl &resultTypes, @@ -125,9 +125,10 @@ ParseResult mlir::function_interface_impl::parseFunctionSignature( return success(); } -void mlir::function_interface_impl::addArgAndResultAttrs( +void function_interface_impl::addArgAndResultAttrs( Builder &builder, OperationState &result, ArrayRef argAttrs, - ArrayRef resultAttrs) { + ArrayRef resultAttrs, StringAttr argAttrsName, + StringAttr resAttrsName) { auto nonEmptyAttrsFn = [](DictionaryAttr attrs) { return attrs && !attrs.empty(); }; @@ -142,28 +143,28 @@ void mlir::function_interface_impl::addArgAndResultAttrs( // Add the attributes to the function arguments. if (llvm::any_of(argAttrs, nonEmptyAttrsFn)) - result.addAttribute(function_interface_impl::getArgDictAttrName(), - getArrayAttr(argAttrs)); + result.addAttribute(argAttrsName, getArrayAttr(argAttrs)); // Add the attributes to the function results. if (llvm::any_of(resultAttrs, nonEmptyAttrsFn)) - result.addAttribute(function_interface_impl::getResultDictAttrName(), - getArrayAttr(resultAttrs)); + result.addAttribute(resAttrsName, getArrayAttr(resultAttrs)); } -void mlir::function_interface_impl::addArgAndResultAttrs( +void function_interface_impl::addArgAndResultAttrs( Builder &builder, OperationState &result, - ArrayRef args, - ArrayRef resultAttrs) { + ArrayRef args, ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName) { SmallVector argAttrs; for (const auto &arg : args) argAttrs.push_back(arg.attrs); - addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); + addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName, + resAttrsName); } -ParseResult mlir::function_interface_impl::parseFunctionOp( +ParseResult function_interface_impl::parseFunctionOp( OpAsmParser &parser, OperationState &result, bool allowVariadic, - FuncTypeBuilder funcTypeBuilder) { + StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName) { SmallVector entryArgs; SmallVector resultAttrs; SmallVector resultTypes; @@ -197,7 +198,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp( << "failed to construct function type" << (errorMessage.empty() ? "" : ": ") << errorMessage; } - result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(typeAttrName, TypeAttr::get(type)); // If function attributes are present, parse them. NamedAttrList parsedAttributes; @@ -209,7 +210,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp( // dictionary. for (StringRef disallowed : {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), - getTypeAttrName()}) { + typeAttrName.getValue()}) { if (parsedAttributes.get(disallowed)) return parser.emitError(attributeDictLocation, "'") << disallowed @@ -220,7 +221,8 @@ ParseResult mlir::function_interface_impl::parseFunctionOp( // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); - addArgAndResultAttrs(builder, result, entryArgs, resultAttrs); + addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName, + resAttrsName); // Parse the optional function body. The printer will not print the body if // its empty, so disallow parsing of empty body in the parser. @@ -261,14 +263,14 @@ static void printFunctionResultList(OpAsmPrinter &p, ArrayRef types, os << ')'; } -void mlir::function_interface_impl::printFunctionSignature( - OpAsmPrinter &p, Operation *op, ArrayRef argTypes, bool isVariadic, - ArrayRef resultTypes) { +void function_interface_impl::printFunctionSignature( + OpAsmPrinter &p, FunctionOpInterface op, ArrayRef argTypes, + bool isVariadic, ArrayRef resultTypes) { Region &body = op->getRegion(0); bool isExternal = body.empty(); p << '('; - ArrayAttr argAttrs = op->getAttrOfType(getArgDictAttrName()); + ArrayAttr argAttrs = op.getArgAttrsAttr(); for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { if (i > 0) p << ", "; @@ -295,26 +297,23 @@ void mlir::function_interface_impl::printFunctionSignature( if (!resultTypes.empty()) { p.getStream() << " -> "; - auto resultAttrs = op->getAttrOfType(getResultDictAttrName()); + auto resultAttrs = op.getResAttrsAttr(); printFunctionResultList(p, resultTypes, resultAttrs); } } -void mlir::function_interface_impl::printFunctionAttributes( - OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, - ArrayRef elided) { +void function_interface_impl::printFunctionAttributes( + OpAsmPrinter &p, Operation *op, ArrayRef elided) { // Print out function attributes, if present. - SmallVector ignoredAttrs = { - ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(), - getArgDictAttrName(), getResultDictAttrName()}; + SmallVector ignoredAttrs = {SymbolTable::getSymbolAttrName()}; ignoredAttrs.append(elided.begin(), elided.end()); p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); } -void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p, - FunctionOpInterface op, - bool isVariadic) { +void function_interface_impl::printFunctionOp( + OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, + StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) { // Print the operation and the function name. auto funcName = op->getAttrOfType(SymbolTable::getSymbolAttrName()) @@ -329,8 +328,8 @@ void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p, ArrayRef argTypes = op.getArgumentTypes(); ArrayRef resultTypes = op.getResultTypes(); printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); - printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(), - {visibilityAttrName}); + printFunctionAttributes( + p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName}); // Print the body if this is not an external function. Region &body = op->getRegion(0); if (!body.empty()) { diff --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp index 3331aef..4a50a49 100644 --- a/mlir/lib/IR/FunctionInterfaces.cpp +++ b/mlir/lib/IR/FunctionInterfaces.cpp @@ -24,27 +24,104 @@ static bool isEmptyAttrDict(Attribute attr) { return attr.cast().empty(); } -DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op, - unsigned index) { - ArrayAttr attrs = op->getAttrOfType(getArgDictAttrName()); +DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op, + unsigned index) { + ArrayAttr attrs = op.getArgAttrsAttr(); DictionaryAttr argAttrs = attrs ? attrs[index].cast() : DictionaryAttr(); return argAttrs; } DictionaryAttr -mlir::function_interface_impl::getResultAttrDict(Operation *op, - unsigned index) { - ArrayAttr attrs = op->getAttrOfType(getResultDictAttrName()); +function_interface_impl::getResultAttrDict(FunctionOpInterface op, + unsigned index) { + ArrayAttr attrs = op.getResAttrsAttr(); DictionaryAttr resAttrs = attrs ? attrs[index].cast() : DictionaryAttr(); return resAttrs; } -void mlir::function_interface_impl::detail::setArgResAttrDict( - Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index, - DictionaryAttr attrs) { - ArrayAttr allAttrs = op->getAttrOfType(attrName); +ArrayRef +function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) { + auto argDict = getArgAttrDict(op, index); + return argDict ? argDict.getValue() : std::nullopt; +} + +ArrayRef +function_interface_impl::getResultAttrs(FunctionOpInterface op, + unsigned index) { + auto resultDict = getResultAttrDict(op, index); + return resultDict ? resultDict.getValue() : std::nullopt; +} + +/// Get either the argument or result attributes array. +template +static ArrayAttr getArgResAttrs(FunctionOpInterface op) { + if constexpr (isArg) + return op.getArgAttrsAttr(); + else + return op.getResAttrsAttr(); +} + +/// Set either the argument or result attributes array. +template +static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) { + if constexpr (isArg) + op.setArgAttrsAttr(attrs); + else + op.setResAttrsAttr(attrs); +} + +/// Erase either the argument or result attributes array. +template +static void removeArgResAttrs(FunctionOpInterface op) { + if constexpr (isArg) + op.removeArgAttrsAttr(); + else + op.removeResAttrsAttr(); +} + +/// Set all of the argument or result attribute dictionaries for a function. +template +static void setAllArgResAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + if (llvm::all_of(attrs, isEmptyAttrDict)) + removeArgResAttrs(op); + else + setArgResAttrs(op, ArrayAttr::get(op->getContext(), attrs)); +} + +void function_interface_impl::setAllArgAttrDicts( + FunctionOpInterface op, ArrayRef attrs) { + setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} + +void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); +} + +void function_interface_impl::setAllResultAttrDicts( + FunctionOpInterface op, ArrayRef attrs) { + setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} + +void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); +} + +/// Update the given index into an argument or result attribute dictionary. +template +static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices, + unsigned index, DictionaryAttr attrs) { + ArrayAttr allAttrs = getArgResAttrs(op); if (!allAttrs) { if (attrs.empty()) return; @@ -53,7 +130,7 @@ void mlir::function_interface_impl::detail::setArgResAttrDict( SmallVector newAttrs(numTotalIndices, DictionaryAttr::get(op->getContext())); newAttrs[index] = attrs; - op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); + setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); return; } // Check to see if the attribute is different from what we already have. @@ -65,54 +142,52 @@ void mlir::function_interface_impl::detail::setArgResAttrDict( ArrayRef rawAttrArray = allAttrs.getValue(); if (attrs.empty() && llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) && - llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) { - op->removeAttr(attrName); - return; - } + llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) + return removeArgResAttrs(op); // Otherwise, create a new attribute array with the updated dictionary. SmallVector newAttrs(rawAttrArray.begin(), rawAttrArray.end()); newAttrs[index] = attrs; - op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); + setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); } -/// Set all of the argument or result attribute dictionaries for a function. -static void setAllArgResAttrDicts(Operation *op, StringRef attrName, - ArrayRef attrs) { - if (llvm::all_of(attrs, isEmptyAttrDict)) - op->removeAttr(attrName); - else - op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs)); +void function_interface_impl::setArgAttrs(FunctionOpInterface op, + unsigned index, + ArrayRef attributes) { + assert(index < op.getNumArguments() && "invalid argument number"); + return setArgResAttrDict( + op, op.getNumArguments(), index, + DictionaryAttr::get(op->getContext(), attributes)); } -void mlir::function_interface_impl::setAllArgAttrDicts( - Operation *op, ArrayRef attrs) { - setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); -} -void mlir::function_interface_impl::setAllArgAttrDicts( - Operation *op, ArrayRef attrs) { - auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { - return !attr ? DictionaryAttr::get(op->getContext()) : attr; - }); - setAllArgResAttrDicts(op, getArgDictAttrName(), - llvm::to_vector<8>(wrappedAttrs)); +void function_interface_impl::setArgAttrs(FunctionOpInterface op, + unsigned index, + DictionaryAttr attributes) { + return setArgResAttrDict( + op, op.getNumArguments(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); } -void mlir::function_interface_impl::setAllResultAttrDicts( - Operation *op, ArrayRef attrs) { - setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +void function_interface_impl::setResultAttrs( + FunctionOpInterface op, unsigned index, + ArrayRef attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return setArgResAttrDict( + op, op.getNumResults(), index, + DictionaryAttr::get(op->getContext(), attributes)); } -void mlir::function_interface_impl::setAllResultAttrDicts( - Operation *op, ArrayRef attrs) { - auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { - return !attr ? DictionaryAttr::get(op->getContext()) : attr; - }); - setAllArgResAttrDicts(op, getResultDictAttrName(), - llvm::to_vector<8>(wrappedAttrs)); + +void function_interface_impl::setResultAttrs(FunctionOpInterface op, + unsigned index, + DictionaryAttr attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return setArgResAttrDict( + op, op.getNumResults(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); } -void mlir::function_interface_impl::insertFunctionArguments( - Operation *op, ArrayRef argIndices, TypeRange argTypes, +void function_interface_impl::insertFunctionArguments( + FunctionOpInterface op, ArrayRef argIndices, TypeRange argTypes, ArrayRef argAttrs, ArrayRef argLocs, unsigned originalNumArgs, Type newType) { assert(argIndices.size() == argTypes.size()); @@ -128,7 +203,7 @@ void mlir::function_interface_impl::insertFunctionArguments( Block &entry = op->getRegion(0).front(); // Update the argument attributes of the function. - auto oldArgAttrs = op->getAttrOfType(getArgDictAttrName()); + ArrayAttr oldArgAttrs = op.getArgAttrsAttr(); if (oldArgAttrs || !argAttrs.empty()) { SmallVector newArgAttrs; newArgAttrs.reserve(originalNumArgs + argIndices.size()); @@ -152,15 +227,15 @@ void mlir::function_interface_impl::insertFunctionArguments( } // Update the function type and any entry block arguments. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); for (unsigned i = 0, e = argIndices.size(); i < e; ++i) entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]); } -void mlir::function_interface_impl::insertFunctionResults( - Operation *op, ArrayRef resultIndices, TypeRange resultTypes, - ArrayRef resultAttrs, unsigned originalNumResults, - Type newType) { +void function_interface_impl::insertFunctionResults( + FunctionOpInterface op, ArrayRef resultIndices, + TypeRange resultTypes, ArrayRef resultAttrs, + unsigned originalNumResults, Type newType) { assert(resultIndices.size() == resultTypes.size()); assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty()); if (resultIndices.empty()) @@ -171,7 +246,7 @@ void mlir::function_interface_impl::insertFunctionResults( // - Result attrs. // Update the result attributes of the function. - auto oldResultAttrs = op->getAttrOfType(getResultDictAttrName()); + ArrayAttr oldResultAttrs = op.getResAttrsAttr(); if (oldResultAttrs || !resultAttrs.empty()) { SmallVector newResultAttrs; newResultAttrs.reserve(originalNumResults + resultIndices.size()); @@ -196,11 +271,11 @@ void mlir::function_interface_impl::insertFunctionResults( } // Update the function type. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); } -void mlir::function_interface_impl::eraseFunctionArguments( - Operation *op, const BitVector &argIndices, Type newType) { +void function_interface_impl::eraseFunctionArguments( + FunctionOpInterface op, const BitVector &argIndices, Type newType) { // There are 3 things that need to be updated: // - Function type. // - Arg attrs. @@ -208,7 +283,7 @@ void mlir::function_interface_impl::eraseFunctionArguments( Block &entry = op->getRegion(0).front(); // Update the argument attributes of the function. - if (auto argAttrs = op->getAttrOfType(getArgDictAttrName())) { + if (ArrayAttr argAttrs = op.getArgAttrsAttr()) { SmallVector newArgAttrs; newArgAttrs.reserve(argAttrs.size()); for (unsigned i = 0, e = argIndices.size(); i < e; ++i) @@ -218,18 +293,18 @@ void mlir::function_interface_impl::eraseFunctionArguments( } // Update the function type and any entry block arguments. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); entry.eraseArguments(argIndices); } -void mlir::function_interface_impl::eraseFunctionResults( - Operation *op, const BitVector &resultIndices, Type newType) { +void function_interface_impl::eraseFunctionResults( + FunctionOpInterface op, const BitVector &resultIndices, Type newType) { // There are 2 things that need to be updated: // - Function type. // - Result attrs. // Update the result attributes of the function. - if (auto resAttrs = op->getAttrOfType(getResultDictAttrName())) { + if (ArrayAttr resAttrs = op.getResAttrsAttr()) { SmallVector newResultAttrs; newResultAttrs.reserve(resAttrs.size()); for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) @@ -239,10 +314,10 @@ void mlir::function_interface_impl::eraseFunctionResults( } // Update the function type. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); } -TypeRange mlir::function_interface_impl::insertTypesInto( +TypeRange function_interface_impl::insertTypesInto( TypeRange oldTypes, ArrayRef indices, TypeRange newTypes, SmallVectorImpl &storage) { assert(indices.size() == newTypes.size() && @@ -261,7 +336,7 @@ TypeRange mlir::function_interface_impl::insertTypesInto( return storage; } -TypeRange mlir::function_interface_impl::filterTypesOut( +TypeRange function_interface_impl::filterTypesOut( TypeRange types, const BitVector &indices, SmallVectorImpl &storage) { if (indices.none()) return types; @@ -276,45 +351,41 @@ TypeRange mlir::function_interface_impl::filterTypesOut( // Function type signature. //===----------------------------------------------------------------------===// -void mlir::function_interface_impl::setFunctionType(Operation *op, - Type newType) { - FunctionOpInterface funcOp = cast(op); - unsigned oldNumArgs = funcOp.getNumArguments(); - unsigned oldNumResults = funcOp.getNumResults(); - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); - unsigned newNumArgs = funcOp.getNumArguments(); - unsigned newNumResults = funcOp.getNumResults(); +void function_interface_impl::setFunctionType(FunctionOpInterface op, + Type newType) { + unsigned oldNumArgs = op.getNumArguments(); + unsigned oldNumResults = op.getNumResults(); + op.setFunctionTypeAttr(TypeAttr::get(newType)); + unsigned newNumArgs = op.getNumArguments(); + unsigned newNumResults = op.getNumResults(); // Functor used to update the argument and result attributes of the function. - auto updateAttrFn = [&](StringRef attrName, unsigned oldCount, - unsigned newCount, auto setAttrFn) { + auto emptyDict = DictionaryAttr::get(op.getContext()); + auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) { + constexpr bool isArgVal = std::is_same_v; + if (oldCount == newCount) return; // The new type has no arguments/results, just drop the attribute. - if (newCount == 0) { - op->removeAttr(attrName); - return; - } - ArrayAttr attrs = op->getAttrOfType(attrName); + if (newCount == 0) + return removeArgResAttrs(op); + ArrayAttr attrs = getArgResAttrs(op); if (!attrs) return; // The new type has less arguments/results, take the first N attributes. if (newCount < oldCount) - return setAttrFn(op, attrs.getValue().take_front(newCount)); + return setAllArgResAttrDicts( + op, attrs.getValue().take_front(newCount)); // Otherwise, the new type has more arguments/results. Initialize the new - // arguments/results with empty attributes. + // arguments/results with empty dictionary attributes. SmallVector newAttrs(attrs.begin(), attrs.end()); - newAttrs.resize(newCount); - setAttrFn(op, newAttrs); + newAttrs.resize(newCount, emptyDict); + setAllArgResAttrDicts(op, newAttrs); }; // Update the argument and result attributes. - updateAttrFn( - getArgDictAttrName(), oldNumArgs, newNumArgs, - [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); }); - updateAttrFn( - getResultDictAttrName(), oldNumResults, newNumResults, - [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); }); + updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs); + updateAttrFn(std::false_type{}, oldNumResults, newNumResults); } diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir index a72abad..d995689 100644 --- a/mlir/test/IR/invalid-func-op.mlir +++ b/mlir/test/IR/invalid-func-op.mlir @@ -96,20 +96,11 @@ func.func private @invalid_symbol_type_attr() attributes { function_type = "x" } // ----- -// expected-error@+1 {{argument attribute array `arg_attrs` to have the same number of elements as the number of function arguments}} +// expected-error@+1 {{argument attribute array to have the same number of elements as the number of function arguments}} func.func private @invalid_arg_attrs() attributes { arg_attrs = [{}] } // ----- -// expected-error@+1 {{expects argument attribute dictionary to be a DictionaryAttr, but got `10 : i64`}} -func.func private @invalid_arg_attrs(i32) attributes { arg_attrs = [10] } -// ----- - -// expected-error@+1 {{result attribute array `res_attrs` to have the same number of elements as the number of function results}} +// expected-error@+1 {{result attribute array to have the same number of elements as the number of function results}} func.func private @invalid_res_attrs() attributes { res_attrs = [{}] } - -// ----- - -// expected-error@+1 {{expects result attribute dictionary to be a DictionaryAttr, but got `10 : i64`}} -func.func private @invalid_res_attrs() -> i32 attributes { res_attrs = [10] }