From dd74e6b6f4fb7a4685086a4895c1934e043f875b Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Tue, 6 Dec 2022 12:55:43 -0800 Subject: [PATCH] [mlir] FunctionOpInterface: arg and result attrs dispatch to interface This patch removes the `arg_attrs` and `res_attrs` named attributes as a requirement for FunctionOpInterface and replaces them with interface methods for the getters, setters, and removers of the relevent attributes. This allows operations to use their own storage for the argument and result attributes. Depends on D139471 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D139472 --- mlir/examples/toy/Ch2/include/toy/Ops.td | 4 +- mlir/examples/toy/Ch2/mlir/Dialect.cpp | 5 +- mlir/examples/toy/Ch3/include/toy/Ops.td | 4 +- mlir/examples/toy/Ch3/mlir/Dialect.cpp | 5 +- mlir/examples/toy/Ch4/include/toy/Ops.td | 4 +- mlir/examples/toy/Ch4/mlir/Dialect.cpp | 5 +- mlir/examples/toy/Ch5/include/toy/Ops.td | 4 +- mlir/examples/toy/Ch5/mlir/Dialect.cpp | 5 +- mlir/examples/toy/Ch6/include/toy/Ops.td | 4 +- mlir/examples/toy/Ch6/mlir/Dialect.cpp | 5 +- mlir/examples/toy/Ch7/include/toy/Ops.td | 4 +- mlir/examples/toy/Ch7/mlir/Dialect.cpp | 5 +- mlir/include/mlir/Dialect/Async/IR/AsyncOps.td | 4 +- mlir/include/mlir/Dialect/Func/IR/FuncOps.td | 4 +- mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 4 +- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 4 +- .../mlir/Dialect/MLProgram/IR/MLProgramOps.td | 4 + .../mlir/Dialect/PDLInterp/IR/PDLInterpOps.td | 4 +- .../mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td | 2 + mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 2 + mlir/include/mlir/IR/FunctionImplementation.h | 16 +- mlir/include/mlir/IR/FunctionInterfaces.h | 91 +++------ mlir/include/mlir/IR/FunctionInterfaces.td | 66 ++++--- mlir/include/mlir/IR/OpBase.td | 2 + mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 29 ++- mlir/lib/Dialect/Async/IR/Async.cpp | 13 +- mlir/lib/Dialect/Func/IR/FuncOps.cpp | 13 +- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 8 +- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 13 +- mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp | 16 +- mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp | 8 +- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 8 +- mlir/lib/Dialect/Shape/IR/Shape.cpp | 13 +- mlir/lib/IR/FunctionImplementation.cpp | 56 +++--- mlir/lib/IR/FunctionInterfaces.cpp | 213 ++++++++++++++------- mlir/test/IR/invalid-func-op.mlir | 13 +- 36 files changed, 380 insertions(+), 280 deletions(-) diff --git a/mlir/examples/toy/Ch2/include/toy/Ops.td b/mlir/examples/toy/Ch2/include/toy/Ops.td index 380536b..4e2fb9e 100644 --- a/mlir/examples/toy/Ch2/include/toy/Ops.td +++ b/mlir/examples/toy/Ch2/include/toy/Ops.td @@ -134,7 +134,9 @@ def FuncOp : Toy_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp index ac12c5c..201f9c7 100644 --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -218,8 +218,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch3/include/toy/Ops.td b/mlir/examples/toy/Ch3/include/toy/Ops.td index e526fe5..1a4e6a1 100644 --- a/mlir/examples/toy/Ch3/include/toy/Ops.td +++ b/mlir/examples/toy/Ch3/include/toy/Ops.td @@ -133,7 +133,9 @@ def FuncOp : Toy_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp index 75cb57e..4bd1055 100644 --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -205,8 +205,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td index 4956b0e..cbece47 100644 --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index 2d5a369..3a02ea3 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -294,8 +294,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td index f4e7b08..70e482d 100644 --- a/mlir/examples/toy/Ch5/include/toy/Ops.td +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index 280bf31..49ce3d9 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -294,8 +294,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td index ea9323e..cf2bc3f 100644 --- a/mlir/examples/toy/Ch6/include/toy/Ops.td +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp index 280bf31..49ce3d9 100644 --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -294,8 +294,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td index 45ecdd3..08671a7 100644 --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -186,7 +186,9 @@ def FuncOp : Toy_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index b0d2130..cb65a95 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -321,8 +321,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td index 30895e5..14146cd 100644 --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -140,7 +140,9 @@ def Async_FuncOp : Async_Op<"func", let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, - OptionalAttr:$sym_visibility); + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td index f1b7cfd..4922689 100644 --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -251,7 +251,9 @@ def FuncOp : Func_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, - OptionalAttr:$sym_visibility); + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let regions = (region AnyRegion:$body); let builders = [OpBuilder<(ins diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 0642b18..f9fff78 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -242,7 +242,9 @@ def GPU_GPUFuncOp : GPU_Op<"func", [ attribution. }]; - let arguments = (ins TypeAttrOf:$function_type); + let arguments = (ins TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let regions = (region AnyRegion:$body); let skipDefaultBuilders = 1; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index c36e390..4d2b2f9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1308,7 +1308,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ DefaultValuedAttr:$CConv, OptionalAttr:$personality, OptionalAttr:$garbageCollector, - OptionalAttr:$passthrough + OptionalAttr:$passthrough, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td index 422680a..db6c773 100644 --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td @@ -52,6 +52,8 @@ def MLProgram_FuncOp : MLProgram_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); @@ -401,6 +403,8 @@ def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td index 42a48cd..6ecbed2 100644 --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -652,7 +652,9 @@ def PDLInterp_FuncOp : PDLInterp_Op<"func", [ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region MinSizedRegion<1>:$body); diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td index 147705e..8339afc 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -291,6 +291,8 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [ let arguments = (ins TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, StrAttr:$sym_name, SPIRV_FunctionControlAttr:$function_control ); diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index c3697f0..97d1f0c 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -1107,6 +1107,8 @@ def Shape_FuncOp : Shape_Op<"func", let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h index f4c0cc0..eb79790 100644 --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -39,10 +39,12 @@ private: /// with special names given by getResultAttrName, getArgumentAttrName. void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef argAttrs, - ArrayRef resultAttrs); + ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName); void addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef argAttrs, - ArrayRef resultAttrs); + ArrayRef args, + ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName); /// Callback type for `parseFunctionOp`, the callback should produce the /// type that will be associated with a function-like operation from lists of @@ -77,15 +79,17 @@ Type getFunctionType(Builder &builder, ArrayRef argAttrs, /// type, report the error or delegate the reporting to the op's verifier. ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, - FuncTypeBuilder funcTypeBuilder); + FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName); /// Printer implementation for function-like operations. void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, - StringRef typeAttrName); + StringRef typeAttrName, StringAttr argAttrsName, + StringAttr resAttrsName); /// Prints the signature of the function-like operation `op`. Assumes `op` has /// is a FunctionOpInterface and has passed verification. -void printFunctionSignature(OpAsmPrinter &p, Operation *op, +void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef argTypes, bool isVariadic, ArrayRef resultTypes); diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h index bc2ec47..3beb3db 100644 --- a/mlir/include/mlir/IR/FunctionInterfaces.h +++ b/mlir/include/mlir/IR/FunctionInterfaces.h @@ -26,48 +26,30 @@ class FunctionOpInterface; namespace function_interface_impl { -/// Return the name of the attribute used for function argument attributes. -inline StringRef getArgDictAttrName() { return "arg_attrs"; } - -/// Return the name of the attribute used for function argument attributes. -inline StringRef getResultDictAttrName() { return "res_attrs"; } - /// Returns the dictionary attribute corresponding to the argument at 'index'. /// If there are no argument attributes at 'index', a null attribute is /// returned. -DictionaryAttr getArgAttrDict(Operation *op, unsigned index); +DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index); /// Returns the dictionary attribute corresponding to the result at 'index'. /// If there are no result attributes at 'index', a null attribute is /// returned. -DictionaryAttr getResultAttrDict(Operation *op, unsigned index); +DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index); -namespace detail { -/// Update the given index into an argument or result attribute dictionary. -void setArgResAttrDict(Operation *op, StringRef attrName, - unsigned numTotalIndices, unsigned index, - DictionaryAttr attrs); -} // namespace detail +/// Return all of the attributes for the argument at 'index'. +ArrayRef getArgAttrs(FunctionOpInterface op, unsigned index); + +/// Return all of the attributes for the result at 'index'. +ArrayRef getResultAttrs(FunctionOpInterface op, unsigned index); /// Set all of the argument or result attribute dictionaries for a function. The /// size of `attrs` is expected to match the number of arguments/results of the /// given `op`. -void setAllArgAttrDicts(Operation *op, ArrayRef attrs); -void setAllArgAttrDicts(Operation *op, ArrayRef attrs); -void setAllResultAttrDicts(Operation *op, ArrayRef attrs); -void setAllResultAttrDicts(Operation *op, ArrayRef attrs); - -/// Return all of the attributes for the argument at 'index'. -inline ArrayRef getArgAttrs(Operation *op, unsigned index) { - auto argDict = getArgAttrDict(op, index); - return argDict ? argDict.getValue() : std::nullopt; -} - -/// Return all of the attributes for the result at 'index'. -inline ArrayRef getResultAttrs(Operation *op, unsigned index) { - auto resultDict = getResultAttrDict(op, index); - return resultDict ? resultDict.getValue() : std::nullopt; -} +void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef attrs); +void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef attrs); +void setAllResultAttrDicts(FunctionOpInterface op, + ArrayRef attrs); +void setAllResultAttrDicts(FunctionOpInterface op, ArrayRef attrs); /// Insert the specified arguments and update the function type attribute. void insertFunctionArguments(FunctionOpInterface op, @@ -110,20 +92,10 @@ TypeRange filterTypesOut(TypeRange types, const BitVector &indices, //===----------------------------------------------------------------------===// /// Set the attributes held by the argument at 'index'. -template -void setArgAttrs(ConcreteType op, unsigned index, - ArrayRef attributes) { - assert(index < op.getNumArguments() && "invalid argument number"); - return detail::setArgResAttrDict( - op, getArgDictAttrName(), op.getNumArguments(), index, - DictionaryAttr::get(op->getContext(), attributes)); -} -template -void setArgAttrs(ConcreteType op, unsigned index, DictionaryAttr attributes) { - return detail::setArgResAttrDict( - op, getArgDictAttrName(), op.getNumArguments(), index, - attributes ? attributes : DictionaryAttr::get(op->getContext())); -} +void setArgAttrs(FunctionOpInterface op, unsigned index, + ArrayRef attributes); +void setArgAttrs(FunctionOpInterface op, unsigned index, + DictionaryAttr attributes); /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. @@ -157,23 +129,10 @@ Attribute removeArgAttr(ConcreteType op, unsigned index, StringAttr name) { //===----------------------------------------------------------------------===// /// Set the attributes held by the result at 'index'. -template -void setResultAttrs(ConcreteType op, unsigned index, - ArrayRef attributes) { - assert(index < op.getNumResults() && "invalid result number"); - return detail::setArgResAttrDict( - op, getResultDictAttrName(), op.getNumResults(), index, - DictionaryAttr::get(op->getContext(), attributes)); -} - -template -void setResultAttrs(ConcreteType op, unsigned index, - DictionaryAttr attributes) { - assert(index < op.getNumResults() && "invalid result number"); - return detail::setArgResAttrDict( - op, getResultDictAttrName(), op.getNumResults(), index, - attributes ? attributes : DictionaryAttr::get(op->getContext())); -} +void setResultAttrs(FunctionOpInterface op, unsigned index, + ArrayRef attributes); +void setResultAttrs(FunctionOpInterface op, unsigned index, + DictionaryAttr attributes); /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. @@ -213,9 +172,8 @@ LogicalResult verifyTrait(ConcreteOp op) { unsigned numArgs = op.getNumArguments(); if (allArgAttrs.size() != numArgs) { return op.emitOpError() - << "expects argument attribute array `" << getArgDictAttrName() - << "` to have the same number of elements as the number of " - "function arguments, got " + << "expects argument attribute array to have the same number of " + "elements as the number of function arguments, got " << allArgAttrs.size() << ", but expected " << numArgs; } for (unsigned i = 0; i != numArgs; ++i) { @@ -245,9 +203,8 @@ LogicalResult verifyTrait(ConcreteOp op) { unsigned numResults = op.getNumResults(); if (allResultAttrs.size() != numResults) { return op.emitOpError() - << "expects result attribute array `" << getResultDictAttrName() - << "` to have the same number of elements as the number of " - "function results, got " + << "expects result attribute array to have the same number of " + "elements as the number of function results, got " << allResultAttrs.size() << ", but expected " << numResults; } for (unsigned i = 0; i != numResults; ++i) { diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td index e86057a..0e8a3ad 100644 --- a/mlir/include/mlir/IR/FunctionInterfaces.td +++ b/mlir/include/mlir/IR/FunctionInterfaces.td @@ -59,6 +59,42 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { result attributes. }], "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>, + + InterfaceMethod<[{ + Get the array of argument attribute dictionaries. The method should return + an array attribute containing only dictionary attributes equal in number + to the number of function arguments. Alternatively, the method can return + null to indicate that the function has no argument attributes. + }], + "::mlir::ArrayAttr", "getArgAttrsAttr">, + InterfaceMethod<[{ + Get the array of result attribute dictionaries. The method should return + an array attribute containing only dictionary attributes equal in number + to the number of function results. Alternatively, the method can return + null to indicate that the function has no result attributes. + }], + "::mlir::ArrayAttr", "getResAttrsAttr">, + InterfaceMethod<[{ + Set the array of argument attribute dictionaries. + }], + "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, + InterfaceMethod<[{ + Set the array of result attribute dictionaries. + }], + "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, + InterfaceMethod<[{ + Remove the array of argument attribute dictionaries. This is the same as + setting all argument attributes to an empty dictionary. The method should + return the removed attribute. + }], + "::mlir::Attribute", "removeArgAttrsAttr">, + InterfaceMethod<[{ + Remove the array of result attribute dictionaries. This is the same as + setting all result attributes to an empty dictionary. The method should + return the removed attribute. + }], + "::mlir::Attribute", "removeResAttrsAttr">, + InterfaceMethod<[{ Returns the function argument types based exclusively on the type (to allow for this method may be called on function @@ -250,20 +286,6 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { function_interface_impl::setFunctionType(this->getOperation(), newType); } - // FIXME: These functions should be removed in favor of just forwarding to - // the derived operation, which should already have these defined - // (via ODS). - - /// Returns the name of the attribute used for function argument attributes. - static StringRef getArgDictAttrName() { - return function_interface_impl::getArgDictAttrName(); - } - - /// Returns the name of the attribute used for function argument attributes. - static StringRef getResultDictAttrName() { - return function_interface_impl::getResultDictAttrName(); - } - //===------------------------------------------------------------------===// // Argument and Result Handling //===------------------------------------------------------------------===// @@ -405,10 +427,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { /// Return an ArrayAttr containing all argument attribute dictionaries of /// this function, or nullptr if no arguments have attributes. - ArrayAttr getAllArgAttrs() { - return this->getOperation()->template getAttrOfType( - getArgDictAttrName()); - } + ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); } + /// Return all argument attributes of this function. void getAllArgAttrs(SmallVectorImpl &result) { if (ArrayAttr argAttrs = getAllArgAttrs()) { @@ -460,7 +480,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { } void setAllArgAttrs(ArrayAttr attributes) { assert(attributes.size() == $_op.getNumArguments()); - this->getOperation()->setAttr(getArgDictAttrName(), attributes); + $_op.setArgAttrsAttr(attributes); } /// If the an attribute exists with the specified name, change it to the new @@ -496,10 +516,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { /// Return an ArrayAttr containing all result attribute dictionaries of this /// function, or nullptr if no result have attributes. - ArrayAttr getAllResultAttrs() { - return this->getOperation()->template getAttrOfType( - getResultDictAttrName()); - } + ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); } + /// Return all result attributes of this function. void getAllResultAttrs(SmallVectorImpl &result) { if (ArrayAttr argAttrs = getAllResultAttrs()) { @@ -553,7 +571,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { } void setAllResultAttrs(ArrayAttr attributes) { assert(attributes.size() == $_op.getNumResults()); - this->getOperation()->setAttr(getResultDictAttrName(), attributes); + $_op.setResAttrsAttr(attributes); } /// If the an attribute exists with the specified name, change it to the new diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 2affd9ae..400f671 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1524,6 +1524,8 @@ def TypeArrayAttr : TypedArrayAttrBase { } def IndexListArrayAttr : TypedArrayAttrBase; +def DictArrayAttr : + TypedArrayAttrBase; // Attributes containing symbol references. def SymbolRefAttr : Attr()">, diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 9f522aaa..0cd024e 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -66,8 +66,8 @@ static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs, attr.getName() == func.getFunctionTypeAttrName() || attr.getName() == "func.varargs" || (filterArgAndResAttrs && - (attr.getName() == FunctionOpInterface::getArgDictAttrName() || - attr.getName() == FunctionOpInterface::getResultDictAttrName()))) + (attr.getName() == func.getArgAttrsAttrName() || + attr.getName() == func.getResAttrsAttrName()))) continue; result.push_back(attr); } @@ -90,18 +90,19 @@ static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) { static void prependResAttrsToArgAttrs(OpBuilder &builder, SmallVectorImpl &attributes, - size_t numArguments) { + func::FuncOp func) { + size_t numArguments = func.getNumArguments(); auto allAttrs = SmallVector( numArguments + 1, DictionaryAttr::get(builder.getContext())); NamedAttribute *argAttrs = nullptr; for (auto *it = attributes.begin(); it != attributes.end();) { - if (it->getName() == FunctionOpInterface::getArgDictAttrName()) { + if (it->getName() == func.getArgAttrsAttrName()) { auto arrayAttrs = it->getValue().cast(); assert(arrayAttrs.size() == numArguments && "Number of arg attrs and args should match"); std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1); argAttrs = it; - } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) { + } else if (it->getName() == func.getResAttrsAttrName()) { auto arrayAttrs = it->getValue().cast(); assert(!arrayAttrs.empty() && "expected array to be non-empty"); allAttrs[0] = (arrayAttrs.size() == 1) @@ -113,9 +114,8 @@ prependResAttrsToArgAttrs(OpBuilder &builder, it++; } - auto newArgAttrs = - builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), - builder.getArrayAttr(allAttrs)); + auto newArgAttrs = builder.getNamedAttr(func.getArgAttrsAttrName(), + builder.getArrayAttr(allAttrs)); if (!argAttrs) { attributes.emplace_back(newArgAttrs); return; @@ -141,7 +141,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, auto [wrapperFuncType, resultIsNowArg] = typeConverter.convertFunctionTypeCWrapper(type); if (resultIsNowArg) - prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments()); + prependResAttrsToArgAttrs(rewriter, attributes, funcOp); auto wrapperFuncOp = rewriter.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, @@ -205,7 +205,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes); if (resultIsNowArg) - prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments()); + prependResAttrsToArgAttrs(builder, attributes, funcOp); // Create the auxiliary function. auto wrapperFunc = builder.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), @@ -309,8 +309,8 @@ protected: ? resAttrDicts : rewriter.getArrayAttr( {wrapAsStructAttrs(rewriter, resAttrDicts)}); - attributes.push_back(rewriter.getNamedAttr( - FunctionOpInterface::getResultDictAttrName(), newResAttrDicts)); + attributes.push_back( + rewriter.getNamedAttr(funcOp.getResAttrsAttrName(), newResAttrDicts)); } if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { SmallVector newArgAttrs( @@ -353,9 +353,8 @@ protected: newArgAttrs[mapping->inputNo + j] = DictionaryAttr::get(rewriter.getContext(), convertedAttrs); } - attributes.push_back( - rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), - rewriter.getArrayAttr(newArgAttrs))); + attributes.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(newArgAttrs))); } for (const auto &pair : llvm::enumerate(attributes)) { if (pair.value().getName() == "llvm.linkage") { diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index 064bf52..54acc37 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -340,8 +340,9 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { @@ -352,12 +353,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType); + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, - getFunctionTypeAttrName()); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Check that the result type of async.func is not void and must be diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp index fc9bd11..7bb3663c 100644 --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -251,8 +251,9 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { @@ -263,12 +264,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType); + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, - getFunctionTypeAttrName()); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Clone the internal blocks from this function into dest and all attributes diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 80db646..9ea1b11 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -934,8 +934,9 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(type)); - function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs, - resultAttrs); + function_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); // Parse workgroup memory attributions. if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), @@ -996,7 +997,8 @@ void GPUFuncOp::print(OpAsmPrinter &p) { function_interface_impl::printFunctionAttributes( p, *this, {getNumWorkgroupAttributionsAttrName(), - GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName()}); + GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()}); p << ' '; p.printRegion(getBody(), /*printEntryBlockArgs=*/false); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 6b428a1..cff547e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2006,8 +2006,9 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, assert(type.cast().getNumParams() == argAttrs.size() && "expected as many argument attribute lists as arguments"); - function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, result, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } // Builds an LLVM function type from the given lists of input and output types. @@ -2095,8 +2096,9 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); - function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, - entryArgs, resultAttrs); + function_interface_impl::addArgAndResultAttrs( + parser.getBuilder(), result, entryArgs, resultAttrs, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); auto *body = result.addRegion(); OptionalParseResult parseResult = @@ -2131,7 +2133,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) { isVarArg(), resTypes); function_interface_impl::printFunctionAttributes( p, *this, - {getFunctionTypeAttrName(), getLinkageAttrName(), getCConvAttrName()}); + {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), + getLinkageAttrName(), getCConvAttrName()}); // Print the body if this is not an external function. Region &body = getBody(); diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp index 27c6130..31ed5ad 100644 --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -153,12 +153,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType); + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, - getFunctionTypeAttrName()); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// @@ -316,12 +318,14 @@ ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType); + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void SubgraphOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, - getFunctionTypeAttrName()); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp index 28fc4db..2cc282d 100644 --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -221,12 +221,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType); + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, - getFunctionTypeAttrName()); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 3ce3913..3341b5e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2396,8 +2396,9 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) { // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); - function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs, - resultAttrs); + function_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); // Parse the optional function body. auto *body = result.addRegion(); @@ -2419,7 +2420,8 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) { function_interface_impl::printFunctionAttributes( printer, *this, {spirv::attributeName(), - getFunctionTypeAttrName(), getFunctionControlAttrName()}); + getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), + getFunctionControlAttrName()}); // Print the body if this is not an external function. Region &body = this->getBody(); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 30c5f56..28ac98e 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1300,8 +1300,9 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { @@ -1312,12 +1313,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType); + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, - getFunctionTypeAttrName()); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp index af692be..5ca6777 100644 --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -113,7 +113,7 @@ parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, return parser.parseRParen(); } -ParseResult mlir::function_interface_impl::parseFunctionSignature( +ParseResult function_interface_impl::parseFunctionSignature( OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &arguments, bool &isVariadic, SmallVectorImpl &resultTypes, @@ -125,9 +125,10 @@ ParseResult mlir::function_interface_impl::parseFunctionSignature( return success(); } -void mlir::function_interface_impl::addArgAndResultAttrs( +void function_interface_impl::addArgAndResultAttrs( Builder &builder, OperationState &result, ArrayRef argAttrs, - ArrayRef resultAttrs) { + ArrayRef resultAttrs, StringAttr argAttrsName, + StringAttr resAttrsName) { auto nonEmptyAttrsFn = [](DictionaryAttr attrs) { return attrs && !attrs.empty(); }; @@ -142,28 +143,28 @@ void mlir::function_interface_impl::addArgAndResultAttrs( // Add the attributes to the function arguments. if (llvm::any_of(argAttrs, nonEmptyAttrsFn)) - result.addAttribute(function_interface_impl::getArgDictAttrName(), - getArrayAttr(argAttrs)); + result.addAttribute(argAttrsName, getArrayAttr(argAttrs)); // Add the attributes to the function results. if (llvm::any_of(resultAttrs, nonEmptyAttrsFn)) - result.addAttribute(function_interface_impl::getResultDictAttrName(), - getArrayAttr(resultAttrs)); + result.addAttribute(resAttrsName, getArrayAttr(resultAttrs)); } -void mlir::function_interface_impl::addArgAndResultAttrs( +void function_interface_impl::addArgAndResultAttrs( Builder &builder, OperationState &result, - ArrayRef args, - ArrayRef resultAttrs) { + ArrayRef args, ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName) { SmallVector argAttrs; for (const auto &arg : args) argAttrs.push_back(arg.attrs); - addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); + addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName, + resAttrsName); } -ParseResult mlir::function_interface_impl::parseFunctionOp( +ParseResult function_interface_impl::parseFunctionOp( OpAsmParser &parser, OperationState &result, bool allowVariadic, - StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder) { + StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName) { SmallVector entryArgs; SmallVector resultAttrs; SmallVector resultTypes; @@ -220,7 +221,8 @@ ParseResult mlir::function_interface_impl::parseFunctionOp( // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); - addArgAndResultAttrs(builder, result, entryArgs, resultAttrs); + addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName, + resAttrsName); // Parse the optional function body. The printer will not print the body if // its empty, so disallow parsing of empty body in the parser. @@ -261,14 +263,14 @@ static void printFunctionResultList(OpAsmPrinter &p, ArrayRef types, os << ')'; } -void mlir::function_interface_impl::printFunctionSignature( - OpAsmPrinter &p, Operation *op, ArrayRef argTypes, bool isVariadic, - ArrayRef resultTypes) { +void function_interface_impl::printFunctionSignature( + OpAsmPrinter &p, FunctionOpInterface op, ArrayRef argTypes, + bool isVariadic, ArrayRef resultTypes) { Region &body = op->getRegion(0); bool isExternal = body.empty(); p << '('; - ArrayAttr argAttrs = op->getAttrOfType(getArgDictAttrName()); + ArrayAttr argAttrs = op.getArgAttrsAttr(); for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { if (i > 0) p << ", "; @@ -295,26 +297,23 @@ void mlir::function_interface_impl::printFunctionSignature( if (!resultTypes.empty()) { p.getStream() << " -> "; - auto resultAttrs = op->getAttrOfType(getResultDictAttrName()); + auto resultAttrs = op.getResAttrsAttr(); printFunctionResultList(p, resultTypes, resultAttrs); } } -void mlir::function_interface_impl::printFunctionAttributes( +void function_interface_impl::printFunctionAttributes( OpAsmPrinter &p, Operation *op, ArrayRef elided) { // Print out function attributes, if present. - SmallVector ignoredAttrs = {SymbolTable::getSymbolAttrName(), - getArgDictAttrName(), - getResultDictAttrName()}; + SmallVector ignoredAttrs = {SymbolTable::getSymbolAttrName()}; ignoredAttrs.append(elided.begin(), elided.end()); p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); } -void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p, - FunctionOpInterface op, - bool isVariadic, - StringRef typeAttrName) { +void function_interface_impl::printFunctionOp( + OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, + StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) { // Print the operation and the function name. auto funcName = op->getAttrOfType(SymbolTable::getSymbolAttrName()) @@ -329,7 +328,8 @@ void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p, ArrayRef argTypes = op.getArgumentTypes(); ArrayRef resultTypes = op.getResultTypes(); printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); - printFunctionAttributes(p, op, {visibilityAttrName, typeAttrName}); + printFunctionAttributes( + p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName}); // Print the body if this is not an external function. Region &body = op->getRegion(0); if (!body.empty()) { diff --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp index 9ba8303..347fb15 100644 --- a/mlir/lib/IR/FunctionInterfaces.cpp +++ b/mlir/lib/IR/FunctionInterfaces.cpp @@ -24,27 +24,104 @@ static bool isEmptyAttrDict(Attribute attr) { return attr.cast().empty(); } -DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op, - unsigned index) { - ArrayAttr attrs = op->getAttrOfType(getArgDictAttrName()); +DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op, + unsigned index) { + ArrayAttr attrs = op.getArgAttrsAttr(); DictionaryAttr argAttrs = attrs ? attrs[index].cast() : DictionaryAttr(); return argAttrs; } DictionaryAttr -mlir::function_interface_impl::getResultAttrDict(Operation *op, - unsigned index) { - ArrayAttr attrs = op->getAttrOfType(getResultDictAttrName()); +function_interface_impl::getResultAttrDict(FunctionOpInterface op, + unsigned index) { + ArrayAttr attrs = op.getResAttrsAttr(); DictionaryAttr resAttrs = attrs ? attrs[index].cast() : DictionaryAttr(); return resAttrs; } -void mlir::function_interface_impl::detail::setArgResAttrDict( - Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index, - DictionaryAttr attrs) { - ArrayAttr allAttrs = op->getAttrOfType(attrName); +ArrayRef +function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) { + auto argDict = getArgAttrDict(op, index); + return argDict ? argDict.getValue() : std::nullopt; +} + +ArrayRef +function_interface_impl::getResultAttrs(FunctionOpInterface op, + unsigned index) { + auto resultDict = getResultAttrDict(op, index); + return resultDict ? resultDict.getValue() : std::nullopt; +} + +/// Get either the argument or result attributes array. +template +static ArrayAttr getArgResAttrs(FunctionOpInterface op) { + if constexpr (isArg) + return op.getArgAttrsAttr(); + else + return op.getResAttrsAttr(); +} + +/// Set either the argument or result attributes array. +template +static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) { + if constexpr (isArg) + op.setArgAttrsAttr(attrs); + else + op.setResAttrsAttr(attrs); +} + +/// Erase either the argument or result attributes array. +template +static void removeArgResAttrs(FunctionOpInterface op) { + if constexpr (isArg) + op.removeArgAttrsAttr(); + else + op.removeResAttrsAttr(); +} + +/// Set all of the argument or result attribute dictionaries for a function. +template +static void setAllArgResAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + if (llvm::all_of(attrs, isEmptyAttrDict)) + removeArgResAttrs(op); + else + setArgResAttrs(op, ArrayAttr::get(op->getContext(), attrs)); +} + +void function_interface_impl::setAllArgAttrDicts( + FunctionOpInterface op, ArrayRef attrs) { + setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} + +void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); +} + +void function_interface_impl::setAllResultAttrDicts( + FunctionOpInterface op, ArrayRef attrs) { + setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} + +void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); +} + +/// Update the given index into an argument or result attribute dictionary. +template +static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices, + unsigned index, DictionaryAttr attrs) { + ArrayAttr allAttrs = getArgResAttrs(op); if (!allAttrs) { if (attrs.empty()) return; @@ -53,7 +130,7 @@ void mlir::function_interface_impl::detail::setArgResAttrDict( SmallVector newAttrs(numTotalIndices, DictionaryAttr::get(op->getContext())); newAttrs[index] = attrs; - op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); + setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); return; } // Check to see if the attribute is different from what we already have. @@ -65,53 +142,51 @@ void mlir::function_interface_impl::detail::setArgResAttrDict( ArrayRef rawAttrArray = allAttrs.getValue(); if (attrs.empty() && llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) && - llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) { - op->removeAttr(attrName); - return; - } + llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) + return removeArgResAttrs(op); // Otherwise, create a new attribute array with the updated dictionary. SmallVector newAttrs(rawAttrArray.begin(), rawAttrArray.end()); newAttrs[index] = attrs; - op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); + setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); } -/// Set all of the argument or result attribute dictionaries for a function. -static void setAllArgResAttrDicts(Operation *op, StringRef attrName, - ArrayRef attrs) { - if (llvm::all_of(attrs, isEmptyAttrDict)) - op->removeAttr(attrName); - else - op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs)); +void function_interface_impl::setArgAttrs(FunctionOpInterface op, + unsigned index, + ArrayRef attributes) { + assert(index < op.getNumArguments() && "invalid argument number"); + return setArgResAttrDict( + op, op.getNumArguments(), index, + DictionaryAttr::get(op->getContext(), attributes)); } -void mlir::function_interface_impl::setAllArgAttrDicts( - Operation *op, ArrayRef attrs) { - setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); -} -void mlir::function_interface_impl::setAllArgAttrDicts( - Operation *op, ArrayRef attrs) { - auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { - return !attr ? DictionaryAttr::get(op->getContext()) : attr; - }); - setAllArgResAttrDicts(op, getArgDictAttrName(), - llvm::to_vector<8>(wrappedAttrs)); +void function_interface_impl::setArgAttrs(FunctionOpInterface op, + unsigned index, + DictionaryAttr attributes) { + return setArgResAttrDict( + op, op.getNumArguments(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); } -void mlir::function_interface_impl::setAllResultAttrDicts( - Operation *op, ArrayRef attrs) { - setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +void function_interface_impl::setResultAttrs( + FunctionOpInterface op, unsigned index, + ArrayRef attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return setArgResAttrDict( + op, op.getNumResults(), index, + DictionaryAttr::get(op->getContext(), attributes)); } -void mlir::function_interface_impl::setAllResultAttrDicts( - Operation *op, ArrayRef attrs) { - auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { - return !attr ? DictionaryAttr::get(op->getContext()) : attr; - }); - setAllArgResAttrDicts(op, getResultDictAttrName(), - llvm::to_vector<8>(wrappedAttrs)); + +void function_interface_impl::setResultAttrs(FunctionOpInterface op, + unsigned index, + DictionaryAttr attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return setArgResAttrDict( + op, op.getNumResults(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); } -void mlir::function_interface_impl::insertFunctionArguments( +void function_interface_impl::insertFunctionArguments( FunctionOpInterface op, ArrayRef argIndices, TypeRange argTypes, ArrayRef argAttrs, ArrayRef argLocs, unsigned originalNumArgs, Type newType) { @@ -128,7 +203,7 @@ void mlir::function_interface_impl::insertFunctionArguments( Block &entry = op->getRegion(0).front(); // Update the argument attributes of the function. - auto oldArgAttrs = op->getAttrOfType(getArgDictAttrName()); + ArrayAttr oldArgAttrs = op.getArgAttrsAttr(); if (oldArgAttrs || !argAttrs.empty()) { SmallVector newArgAttrs; newArgAttrs.reserve(originalNumArgs + argIndices.size()); @@ -157,7 +232,7 @@ void mlir::function_interface_impl::insertFunctionArguments( entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]); } -void mlir::function_interface_impl::insertFunctionResults( +void function_interface_impl::insertFunctionResults( FunctionOpInterface op, ArrayRef resultIndices, TypeRange resultTypes, ArrayRef resultAttrs, unsigned originalNumResults, Type newType) { @@ -171,7 +246,7 @@ void mlir::function_interface_impl::insertFunctionResults( // - Result attrs. // Update the result attributes of the function. - auto oldResultAttrs = op->getAttrOfType(getResultDictAttrName()); + ArrayAttr oldResultAttrs = op.getResAttrsAttr(); if (oldResultAttrs || !resultAttrs.empty()) { SmallVector newResultAttrs; newResultAttrs.reserve(originalNumResults + resultIndices.size()); @@ -199,7 +274,7 @@ void mlir::function_interface_impl::insertFunctionResults( op.setFunctionTypeAttr(TypeAttr::get(newType)); } -void mlir::function_interface_impl::eraseFunctionArguments( +void function_interface_impl::eraseFunctionArguments( FunctionOpInterface op, const BitVector &argIndices, Type newType) { // There are 3 things that need to be updated: // - Function type. @@ -208,7 +283,7 @@ void mlir::function_interface_impl::eraseFunctionArguments( Block &entry = op->getRegion(0).front(); // Update the argument attributes of the function. - if (auto argAttrs = op->getAttrOfType(getArgDictAttrName())) { + if (ArrayAttr argAttrs = op.getArgAttrsAttr()) { SmallVector newArgAttrs; newArgAttrs.reserve(argAttrs.size()); for (unsigned i = 0, e = argIndices.size(); i < e; ++i) @@ -222,14 +297,14 @@ void mlir::function_interface_impl::eraseFunctionArguments( entry.eraseArguments(argIndices); } -void mlir::function_interface_impl::eraseFunctionResults( +void function_interface_impl::eraseFunctionResults( FunctionOpInterface op, const BitVector &resultIndices, Type newType) { // There are 2 things that need to be updated: // - Function type. // - Result attrs. // Update the result attributes of the function. - if (auto resAttrs = op->getAttrOfType(getResultDictAttrName())) { + if (ArrayAttr resAttrs = op.getResAttrsAttr()) { SmallVector newResultAttrs; newResultAttrs.reserve(resAttrs.size()); for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) @@ -242,7 +317,7 @@ void mlir::function_interface_impl::eraseFunctionResults( op.setFunctionTypeAttr(TypeAttr::get(newType)); } -TypeRange mlir::function_interface_impl::insertTypesInto( +TypeRange function_interface_impl::insertTypesInto( TypeRange oldTypes, ArrayRef indices, TypeRange newTypes, SmallVectorImpl &storage) { assert(indices.size() == newTypes.size() && @@ -261,7 +336,7 @@ TypeRange mlir::function_interface_impl::insertTypesInto( return storage; } -TypeRange mlir::function_interface_impl::filterTypesOut( +TypeRange function_interface_impl::filterTypesOut( TypeRange types, const BitVector &indices, SmallVectorImpl &storage) { if (indices.none()) return types; @@ -276,8 +351,8 @@ TypeRange mlir::function_interface_impl::filterTypesOut( // Function type signature. //===----------------------------------------------------------------------===// -void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op, - Type newType) { +void function_interface_impl::setFunctionType(FunctionOpInterface op, + Type newType) { unsigned oldNumArgs = op.getNumArguments(); unsigned oldNumResults = op.getNumResults(); op.setFunctionTypeAttr(TypeAttr::get(newType)); @@ -285,35 +360,31 @@ void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op, unsigned newNumResults = op.getNumResults(); // Functor used to update the argument and result attributes of the function. - auto updateAttrFn = [&](StringRef attrName, unsigned oldCount, - unsigned newCount, auto setAttrFn) { + auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) { + constexpr bool isArgVal = std::is_same_v; + if (oldCount == newCount) return; // The new type has no arguments/results, just drop the attribute. - if (newCount == 0) { - op->removeAttr(attrName); - return; - } - ArrayAttr attrs = op->getAttrOfType(attrName); + if (newCount == 0) + return removeArgResAttrs(op); + ArrayAttr attrs = getArgResAttrs(op); if (!attrs) return; // The new type has less arguments/results, take the first N attributes. if (newCount < oldCount) - return setAttrFn(op, attrs.getValue().take_front(newCount)); + return setAllArgResAttrDicts( + op, attrs.getValue().take_front(newCount)); // Otherwise, the new type has more arguments/results. Initialize the new // arguments/results with empty attributes. SmallVector newAttrs(attrs.begin(), attrs.end()); newAttrs.resize(newCount); - setAttrFn(op, newAttrs); + setAllArgResAttrDicts(op, newAttrs); }; // Update the argument and result attributes. - updateAttrFn( - getArgDictAttrName(), oldNumArgs, newNumArgs, - [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); }); - updateAttrFn( - getResultDictAttrName(), oldNumResults, newNumResults, - [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); }); + updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs); + updateAttrFn(std::false_type{}, oldNumResults, newNumResults); } diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir index a72abad..d995689 100644 --- a/mlir/test/IR/invalid-func-op.mlir +++ b/mlir/test/IR/invalid-func-op.mlir @@ -96,20 +96,11 @@ func.func private @invalid_symbol_type_attr() attributes { function_type = "x" } // ----- -// expected-error@+1 {{argument attribute array `arg_attrs` to have the same number of elements as the number of function arguments}} +// expected-error@+1 {{argument attribute array to have the same number of elements as the number of function arguments}} func.func private @invalid_arg_attrs() attributes { arg_attrs = [{}] } // ----- -// expected-error@+1 {{expects argument attribute dictionary to be a DictionaryAttr, but got `10 : i64`}} -func.func private @invalid_arg_attrs(i32) attributes { arg_attrs = [10] } -// ----- - -// expected-error@+1 {{result attribute array `res_attrs` to have the same number of elements as the number of function results}} +// expected-error@+1 {{result attribute array to have the same number of elements as the number of function results}} func.func private @invalid_res_attrs() attributes { res_attrs = [{}] } - -// ----- - -// expected-error@+1 {{expects result attribute dictionary to be a DictionaryAttr, but got `10 : i64`}} -func.func private @invalid_res_attrs() -> i32 attributes { res_attrs = [10] } -- 2.7.4