From 53b946aa636a31e9243b8c5bf1703a1f6eae798e Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 7 May 2021 19:30:25 -0700 Subject: [PATCH] [mlir] Refactor the representation of function-like argument/result attributes. The current design uses a unique entry for each argument/result attribute, with the name of the entry being something like "arg0". This provides for a somewhat sparse design, but ends up being much more expensive (from a runtime perspective) in-practice. The design requires building a string every time we lookup the dictionary for a specific arg/result, and also requires N attribute lookups when collecting all of the arg/result attribute dictionaries. This revision restructures the design to instead have an ArrayAttr that contains all of the attribute dictionaries for arguments and another for results. This design reduces the number of attribute name lookups to 1, and allows for O(1) lookup for individual element dictionaries. The major downside is that we can end up with larger memory usage, as the ArrayAttr contains an entry for each element even if that element has no attributes. If the memory usage becomes too problematic, we can experiment with a more sparse structure that still provides a lot of the wins in this revision. This dropped the compilation time of a somewhat large TensorFlow model from ~650 seconds to ~400 seconds. Differential Revision: https://reviews.llvm.org/D102035 --- mlir/include/mlir/Dialect/GPU/GPUOps.td | 10 - mlir/include/mlir/IR/BuiltinAttributes.td | 2 +- mlir/include/mlir/IR/FunctionImplementation.h | 7 +- mlir/include/mlir/IR/FunctionSupport.h | 317 +++++++++++---------- mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 2 +- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 2 +- .../Conversion/StandardToLLVM/StandardToLLVM.cpp | 30 +- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 36 +-- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 38 +-- mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp | 2 +- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 19 +- .../Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 +- mlir/lib/IR/BuiltinDialect.cpp | 61 ++-- mlir/lib/IR/FunctionImplementation.cpp | 143 +++++----- mlir/lib/IR/FunctionSupport.cpp | 216 ++++++++++---- mlir/lib/Transforms/Utils/DialectConversion.cpp | 8 +- mlir/test/Dialect/LLVMIR/func.mlir | 2 +- mlir/test/IR/invalid-func-op.mlir | 19 ++ mlir/test/IR/test-func-set-type.mlir | 2 - 19 files changed, 524 insertions(+), 394 deletions(-) diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td index 5bd6956..a29a22a 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -213,16 +213,6 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">, GPUDialect::getKernelFuncAttrName()) != nullptr; } - /// Change the type of this function in place. This is an extremely - /// dangerous operation and it is up to the caller to ensure that this is - /// legal for this function, and to restore invariants: - /// - the entry block args must be updated to match the function params. - /// - the argument/result attributes may need an update: if the new type - /// has less parameters we drop the extra attributes, if there are more - /// parameters they won't have any attributes. - // TODO: consider removing this function thanks to rewrite patterns. - void setType(FunctionType newType); - /// Returns the number of buffers located in the workgroup memory. unsigned getNumWorkgroupAttributions() { return (*this)->getAttrOfType( diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index c248ad5..05dbc6b 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -300,7 +300,7 @@ def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> { }]; let parameters = (ins ArrayRefParameter<"NamedAttribute", "">:$value); let builders = [ - AttrBuilder<(ins "ArrayRef":$value)> + AttrBuilder<(ins CArg<"ArrayRef", "llvm::None">:$value)> ]; let extraClassDeclaration = [{ using ValueType = ArrayRef; diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h index c19100c..cb7776f 100644 --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -20,7 +20,7 @@ namespace mlir { -namespace impl { +namespace function_like_impl { /// A named class for passing around the variadic flag. class VariadicFlag { @@ -38,6 +38,9 @@ private: /// Internally, argument and result attributes are stored as dict attributes /// with special names given by getResultAttrName, getArgumentAttrName. void addArgAndResultAttrs(Builder &builder, OperationState &result, + ArrayRef argAttrs, + ArrayRef resultAttrs); +void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef argAttrs, ArrayRef resultAttrs); @@ -103,7 +106,7 @@ void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, ArrayRef elided = {}); -} // namespace impl +} // namespace function_like_impl } // namespace mlir diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h index a3be1a7..21d6e37 100644 --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -20,45 +20,41 @@ namespace mlir { -namespace impl { +namespace function_like_impl { /// Return the name of the attribute used for function types. inline StringRef getTypeAttrName() { return "type"; } -/// Return the name of the attribute used for function arguments. -inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl &out) { - out.clear(); - return ("arg" + Twine(arg)).toStringRef(out); -} - -/// Returns true if the given name is a valid argument attribute name. -inline bool isArgAttrName(StringRef name) { - APInt unused; - return name.startswith("arg") && - !name.drop_front(3).getAsInteger(/*Radix=*/10, unused); -} +/// Return the name of the attribute used for function argument attributes. +inline StringRef getArgDictAttrName() { return "arg_attrs"; } -/// Return the name of the attribute used for function results. -inline StringRef getResultAttrName(unsigned arg, SmallVectorImpl &out) { - out.clear(); - return ("result" + Twine(arg)).toStringRef(out); -} +/// Return the name of the attribute used for function argument attributes. +inline StringRef getResultDictAttrName() { return "res_attrs"; } /// Returns the dictionary attribute corresponding to the argument at 'index'. /// If there are no argument attributes at 'index', a null attribute is /// returned. -inline DictionaryAttr getArgAttrDict(Operation *op, unsigned index) { - SmallString<8> nameOut; - return op->getAttrOfType(getArgAttrName(index, nameOut)); -} +DictionaryAttr getArgAttrDict(Operation *op, unsigned index); /// Returns the dictionary attribute corresponding to the result at 'index'. /// If there are no result attributes at 'index', a null attribute is /// returned. -inline DictionaryAttr getResultAttrDict(Operation *op, unsigned index) { - SmallString<8> nameOut; - return op->getAttrOfType(getResultAttrName(index, nameOut)); -} +DictionaryAttr getResultAttrDict(Operation *op, unsigned index); + +namespace detail { +/// Update the given index into an argument or result attribute dictionary. +void setArgResAttrDict(Operation *op, StringRef attrName, + unsigned numTotalIndices, unsigned index, + DictionaryAttr attrs); +} // namespace detail + +/// Set all of the argument or result attribute dictionaries for a function. The +/// size of `attrs` is expected to match the number of arguments/results of the +/// given `op`. +void setAllArgAttrDicts(Operation *op, ArrayRef attrs); +void setAllArgAttrDicts(Operation *op, ArrayRef attrs); +void setAllResultAttrDicts(Operation *op, ArrayRef attrs); +void setAllResultAttrDicts(Operation *op, ArrayRef attrs); /// Return all of the attributes for the argument at 'index'. inline ArrayRef getArgAttrs(Operation *op, unsigned index) { @@ -87,7 +83,7 @@ void setFunctionType(Operation *op, FunctionType newType); /// Get a FunctionLike operation's body. Region &getFunctionBody(Operation *op); -} // namespace impl +} // namespace function_like_impl namespace OpTrait { @@ -142,7 +138,7 @@ public: bool isExternal() { return empty(); } Region &getBody() { - return ::mlir::impl::getFunctionBody(this->getOperation()); + return function_like_impl::getFunctionBody(this->getOperation()); } /// Delete all blocks from this function. @@ -194,7 +190,9 @@ public: //===--------------------------------------------------------------------===// /// Return the name of the attribute used for function types. - static StringRef getTypeAttrName() { return ::mlir::impl::getTypeAttrName(); } + static StringRef getTypeAttrName() { + return function_like_impl::getTypeAttrName(); + } TypeAttr getTypeAttr() { return this->getOperation()->template getAttrOfType( @@ -207,7 +205,7 @@ public: /// hide this one if the concrete class does not use FunctionType for the /// function type under the hood. FunctionType getType() { - return ::mlir::impl::getFunctionType(this->getOperation()); + return function_like_impl::getFunctionType(this->getOperation()); } /// Return the type of this function without the specified arguments and @@ -277,8 +275,8 @@ public: void eraseArguments(ArrayRef argIndices) { unsigned originalNumArgs = getNumArguments(); Type newType = getTypeWithoutArgsAndResults(argIndices, {}); - ::mlir::impl::eraseFunctionArguments(this->getOperation(), argIndices, - originalNumArgs, newType); + function_like_impl::eraseFunctionArguments(this->getOperation(), argIndices, + originalNumArgs, newType); } /// Erase a single result at `resultIndex`. @@ -289,8 +287,8 @@ public: void eraseResults(ArrayRef resultIndices) { unsigned originalNumResults = getNumResults(); Type newType = getTypeWithoutArgsAndResults({}, resultIndices); - ::mlir::impl::eraseFunctionResults(this->getOperation(), resultIndices, - originalNumResults, newType); + function_like_impl::eraseFunctionResults( + this->getOperation(), resultIndices, originalNumResults, newType); } //===--------------------------------------------------------------------===// @@ -306,14 +304,23 @@ public: /// Return all of the attributes for the argument at 'index'. ArrayRef getArgAttrs(unsigned index) { - return ::mlir::impl::getArgAttrs(this->getOperation(), index); + return function_like_impl::getArgAttrs(this->getOperation(), index); } - /// Return all argument attributes of this function. If an argument does not - /// have any attributes, the corresponding entry in `result` is nullptr. + /// Return an ArrayAttr containing all argument attribute dictionaries of this + /// function, or nullptr if no arguments have attributes. + ArrayAttr getAllArgAttrs() { + return this->getOperation()->template getAttrOfType( + function_like_impl::getArgDictAttrName()); + } + /// Return all argument attributes of this function. void getAllArgAttrs(SmallVectorImpl &result) { - for (unsigned i = 0, e = getNumArguments(); i != e; ++i) - result.emplace_back(getArgAttrDict(i)); + if (ArrayAttr argAttrs = getAllArgAttrs()) { + auto argAttrRange = argAttrs.template getAsRange(); + result.append(argAttrRange.begin(), argAttrRange.end()); + } else { + result.resize(getNumArguments()); + } } /// Return the specified attribute, if present, for the argument at 'index', @@ -342,7 +349,19 @@ public: /// Set the attributes held by the argument at 'index'. `attributes` may be /// null, in which case any existing argument attributes are removed. void setArgAttrs(unsigned index, DictionaryAttr attributes); - void setAllArgAttrs(ArrayRef attributes); + void setAllArgAttrs(ArrayRef attributes) { + assert(attributes.size() == getNumArguments()); + function_like_impl::setAllArgAttrDicts(this->getOperation(), attributes); + } + void setAllArgAttrs(ArrayRef attributes) { + assert(attributes.size() == getNumArguments()); + function_like_impl::setAllArgAttrDicts(this->getOperation(), attributes); + } + void setAllArgAttrs(ArrayAttr attributes) { + assert(attributes.size() == getNumArguments()); + this->getOperation()->setAttr(function_like_impl::getArgDictAttrName(), + attributes); + } /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. @@ -370,14 +389,23 @@ public: /// Return all of the attributes for the result at 'index'. ArrayRef getResultAttrs(unsigned index) { - return ::mlir::impl::getResultAttrs(this->getOperation(), index); + return function_like_impl::getResultAttrs(this->getOperation(), index); } - /// Return all result attributes of this function. If a result does not have - /// any attributes, the corresponding entry in `result` is nullptr. + /// Return an ArrayAttr containing all result attribute dictionaries of this + /// function, or nullptr if no result have attributes. + ArrayAttr getAllResultAttrs() { + return this->getOperation()->template getAttrOfType( + function_like_impl::getResultDictAttrName()); + } + /// Return all result attributes of this function. void getAllResultAttrs(SmallVectorImpl &result) { - for (unsigned i = 0, e = getNumResults(); i != e; ++i) - result.emplace_back(getResultAttrDict(i)); + if (ArrayAttr argAttrs = getAllResultAttrs()) { + auto argAttrRange = argAttrs.template getAsRange(); + result.append(argAttrRange.begin(), argAttrRange.end()); + } else { + result.resize(getNumResults()); + } } /// Return the specified attribute, if present, for the result at 'index', @@ -402,10 +430,23 @@ public: /// Set the attributes held by the result at 'index'. void setResultAttrs(unsigned index, ArrayRef attributes); + /// Set the attributes held by the result at 'index'. `attributes` may be /// null, in which case any existing argument attributes are removed. void setResultAttrs(unsigned index, DictionaryAttr attributes); - void setAllResultAttrs(ArrayRef attributes); + void setAllResultAttrs(ArrayRef attributes) { + assert(attributes.size() == getNumResults()); + function_like_impl::setAllResultAttrDicts(this->getOperation(), attributes); + } + void setAllResultAttrs(ArrayRef attributes) { + assert(attributes.size() == getNumResults()); + function_like_impl::setAllResultAttrDicts(this->getOperation(), attributes); + } + void setAllResultAttrs(ArrayAttr attributes) { + assert(attributes.size() == getNumResults()); + this->getOperation()->setAttr(function_like_impl::getResultDictAttrName(), + attributes); + } /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. @@ -422,25 +463,12 @@ public: Attribute removeResultAttr(unsigned index, Identifier name); protected: - /// Returns the attribute entry name for the set of argument attributes at - /// 'index'. - static StringRef getArgAttrName(unsigned index, SmallVectorImpl &out) { - return ::mlir::impl::getArgAttrName(index, out); - } - /// Returns the dictionary attribute corresponding to the argument at 'index'. /// If there are no argument attributes at 'index', a null attribute is /// returned. DictionaryAttr getArgAttrDict(unsigned index) { assert(index < getNumArguments() && "invalid argument number"); - return ::mlir::impl::getArgAttrDict(this->getOperation(), index); - } - - /// Returns the attribute entry name for the set of result attributes at - /// 'index'. - static StringRef getResultAttrName(unsigned index, - SmallVectorImpl &out) { - return ::mlir::impl::getResultAttrName(index, out); + return function_like_impl::getArgAttrDict(this->getOperation(), index); } /// Returns the dictionary attribute corresponding to the result at 'index'. @@ -448,7 +476,7 @@ protected: /// returned. DictionaryAttr getResultAttrDict(unsigned index) { assert(index < getNumResults() && "invalid result number"); - return ::mlir::impl::getResultAttrDict(this->getOperation(), index); + return function_like_impl::getResultAttrDict(this->getOperation(), index); } /// Hook for concrete classes to verify that the type attribute respects @@ -475,9 +503,7 @@ LogicalResult FunctionLike::verifyBody() { template LogicalResult FunctionLike::verifyTrait(Operation *op) { - MLIRContext *ctx = op->getContext(); auto funcOp = cast(op); - if (!funcOp.isTypeAttrValid()) return funcOp.emitOpError("requires a type attribute '") << getTypeAttrName() << '\''; @@ -485,35 +511,69 @@ LogicalResult FunctionLike::verifyTrait(Operation *op) { if (failed(funcOp.verifyType())) return failure(); - for (unsigned i = 0, e = funcOp.getNumArguments(); i != e; ++i) { - // Verify that all of the argument attributes are dialect attributes, i.e. - // that they contain a dialect prefix in their name. Call the dialect, if - // registered, to verify the attributes themselves. - for (auto attr : funcOp.getArgAttrs(i)) { - if (!attr.first.strref().contains('.')) - return funcOp.emitOpError("arguments may only have dialect attributes"); - auto dialectNamePair = attr.first.strref().split('.'); - if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) { - if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, - /*argIndex=*/i, attr))) - return failure(); + if (ArrayAttr allArgAttrs = funcOp.getAllArgAttrs()) { + unsigned numArgs = funcOp.getNumArguments(); + if (allArgAttrs.size() != numArgs) { + return funcOp.emitOpError() + << "expects argument attribute array `" + << function_like_impl::getArgDictAttrName() + << "` to have the same number of elements as the number of " + "function arguments, got " + << allArgAttrs.size() << ", but expected " << numArgs; + } + for (unsigned i = 0; i != numArgs; ++i) { + DictionaryAttr argAttrs = allArgAttrs[i].dyn_cast(); + if (!argAttrs) { + return funcOp.emitOpError() << "expects argument attribute dictionary " + "to be a DictionaryAttr, but got `" + << allArgAttrs[i] << "`"; + } + + // Verify that all of the argument attributes are dialect attributes, i.e. + // that they contain a dialect prefix in their name. Call the dialect, if + // registered, to verify the attributes themselves. + for (auto attr : argAttrs) { + if (!attr.first.strref().contains('.')) + return funcOp.emitOpError( + "arguments may only have dialect attributes"); + if (Dialect *dialect = attr.first.getDialect()) { + if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, + /*argIndex=*/i, attr))) + return failure(); + } } } } + if (ArrayAttr allResultAttrs = funcOp.getAllResultAttrs()) { + unsigned numResults = funcOp.getNumResults(); + if (allResultAttrs.size() != numResults) { + return funcOp.emitOpError() + << "expects result attribute array `" + << function_like_impl::getResultDictAttrName() + << "` to have the same number of elements as the number of " + "function results, got " + << allResultAttrs.size() << ", but expected " << numResults; + } + for (unsigned i = 0; i != numResults; ++i) { + DictionaryAttr resultAttrs = allResultAttrs[i].dyn_cast(); + if (!resultAttrs) { + return funcOp.emitOpError() << "expects result attribute dictionary " + "to be a DictionaryAttr, but got `" + << allResultAttrs[i] << "`"; + } - for (unsigned i = 0, e = funcOp.getNumResults(); i != e; ++i) { - // Verify that all of the result attributes are dialect attributes, i.e. - // that they contain a dialect prefix in their name. Call the dialect, if - // registered, to verify the attributes themselves. - for (auto attr : funcOp.getResultAttrs(i)) { - if (!attr.first.strref().contains('.')) - return funcOp.emitOpError("results may only have dialect attributes"); - auto dialectNamePair = attr.first.strref().split('.'); - if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) { - if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, - /*resultIndex=*/i, - attr))) - return failure(); + // Verify that all of the result attributes are dialect attributes, i.e. + // that they contain a dialect prefix in their name. Call the dialect, if + // registered, to verify the attributes themselves. + for (auto attr : resultAttrs) { + if (!attr.first.strref().contains('.')) + return funcOp.emitOpError("results may only have dialect attributes"); + if (Dialect *dialect = attr.first.getDialect()) { + if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, + /*resultIndex=*/i, + attr))) + return failure(); + } } } } @@ -551,7 +611,7 @@ Block *FunctionLike::addBlock() { template void FunctionLike::setType(FunctionType newType) { - ::mlir::impl::setFunctionType(this->getOperation(), newType); + function_like_impl::setFunctionType(this->getOperation(), newType); } //===----------------------------------------------------------------------===// @@ -563,45 +623,19 @@ template void FunctionLike::setArgAttrs( unsigned index, ArrayRef attributes) { assert(index < getNumArguments() && "invalid argument number"); - SmallString<8> nameOut; - getArgAttrName(index, nameOut); - Operation *op = this->getOperation(); - if (attributes.empty()) - return (void)op->removeAttr(nameOut); - op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes)); + return function_like_impl::detail::setArgResAttrDict( + op, function_like_impl::getArgDictAttrName(), getNumArguments(), index, + DictionaryAttr::get(op->getContext(), attributes)); } template void FunctionLike::setArgAttrs(unsigned index, DictionaryAttr attributes) { - assert(index < getNumArguments() && "invalid argument number"); - SmallString<8> nameOut; - if (!attributes || attributes.empty()) - this->getOperation()->removeAttr(getArgAttrName(index, nameOut)); - else - return this->getOperation()->setAttr(getArgAttrName(index, nameOut), - attributes); -} - -template -void FunctionLike::setAllArgAttrs( - ArrayRef attributes) { - assert(attributes.size() == getNumArguments()); - NamedAttrList attrs = this->getOperation()->getAttrs(); - - // Instead of calling setArgAttrs() multiple times, which rebuild the - // attribute dictionary every time, build a new list of attributes for the - // operation so that we rebuild the attribute dictionary in one shot. - SmallString<8> argAttrName; - for (unsigned i = 0, e = attributes.size(); i != e; ++i) { - StringRef attrName = getArgAttrName(i, argAttrName); - if (!attributes[i] || attributes[i].empty()) - attrs.erase(attrName); - else - attrs.set(attrName, attributes[i]); - } - this->getOperation()->setAttrs(attrs); + Operation *op = this->getOperation(); + return function_like_impl::detail::setArgResAttrDict( + op, function_like_impl::getArgDictAttrName(), getNumArguments(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); } /// If the an attribute exists with the specified name, change it to the new @@ -640,45 +674,20 @@ template void FunctionLike::setResultAttrs( unsigned index, ArrayRef attributes) { assert(index < getNumResults() && "invalid result number"); - SmallString<8> nameOut; - getResultAttrName(index, nameOut); - - if (attributes.empty()) - return (void)this->getOperation()->removeAttr(nameOut); Operation *op = this->getOperation(); - op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes)); + return function_like_impl::detail::setArgResAttrDict( + op, function_like_impl::getResultDictAttrName(), getNumResults(), index, + DictionaryAttr::get(op->getContext(), attributes)); } template void FunctionLike::setResultAttrs(unsigned index, DictionaryAttr attributes) { assert(index < getNumResults() && "invalid result number"); - SmallString<8> nameOut; - if (!attributes || attributes.empty()) - this->getOperation()->removeAttr(getResultAttrName(index, nameOut)); - else - this->getOperation()->setAttr(getResultAttrName(index, nameOut), - attributes); -} - -template -void FunctionLike::setAllResultAttrs( - ArrayRef attributes) { - assert(attributes.size() == getNumResults()); - NamedAttrList attrs = this->getOperation()->getAttrs(); - - // Instead of calling setResultAttrs() multiple times, which rebuild the - // attribute dictionary every time, build a new list of attributes for the - // operation so that we rebuild the attribute dictionary in one shot. - SmallString<8> resultAttrName; - for (unsigned i = 0, e = attributes.size(); i != e; ++i) { - StringRef attrName = getResultAttrName(i, resultAttrName); - if (!attributes[i] || attributes[i].empty()) - attrs.erase(attrName); - else - attrs.set(attrName, attributes[i]); - } - this->getOperation()->setAttrs(attrs); + Operation *op = this->getOperation(); + return function_like_impl::detail::setArgResAttrDict( + op, function_like_impl::getResultDictAttrName(), getNumResults(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); } /// If the an attribute exists with the specified name, change it to the new diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 0833953..67f699a 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -58,7 +58,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, SmallVector attributes; for (const auto &attr : gpuFuncOp->getAttrs()) { if (attr.first == SymbolTable::getSymbolAttrName() || - attr.first == impl::getTypeAttrName() || + attr.first == function_like_impl::getTypeAttrName() || attr.first == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) continue; attributes.push_back(attr); diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 2066deb..fa4bbff 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -195,7 +195,7 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter, rewriter.getFunctionType(signatureConverter.getConvertedTypes(), llvm::None)); for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.first == impl::getTypeAttrName() || + if (namedAttr.first == function_like_impl::getTypeAttrName() || namedAttr.first == SymbolTable::getSymbolAttrName()) continue; newFuncOp->setAttr(namedAttr.first, namedAttr.second); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 5f94804..3949cd4 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1211,8 +1211,10 @@ static void filterFuncAttributes(ArrayRef attrs, SmallVectorImpl &result) { for (const auto &attr : attrs) { if (attr.first == SymbolTable::getSymbolAttrName() || - attr.first == impl::getTypeAttrName() || attr.first == "std.varargs" || - (filterArgAttrs && impl::isArgAttrName(attr.first.strref()))) + attr.first == function_like_impl::getTypeAttrName() || + attr.first == "std.varargs" || + (filterArgAttrs && + attr.first == function_like_impl::getArgDictAttrName())) continue; result.push_back(attr); } @@ -1395,19 +1397,19 @@ protected: SmallVector attributes; filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true, attributes); - for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { - auto attr = impl::getArgAttrDict(funcOp, i); - if (!attr) - continue; - - auto mapping = result.getInputMapping(i); - assert(mapping.hasValue() && "unexpected deletion of function argument"); - - SmallString<8> name; - for (size_t j = 0; j < mapping->size; ++j) { - impl::getArgAttrName(mapping->inputNo + j, name); - attributes.push_back(rewriter.getNamedAttr(name, attr)); + if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { + SmallVector newArgAttrs( + llvmType.cast().getNumParams()); + for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { + auto mapping = result.getInputMapping(i); + assert(mapping.hasValue() && + "unexpected deletion of function argument"); + for (size_t j = 0; j < mapping->size; ++j) + newArgAttrs[mapping->inputNo + j] = argAttrDicts[i]; } + attributes.push_back( + rewriter.getNamedAttr(function_like_impl::getArgDictAttrName(), + rewriter.getArrayAttr(newArgAttrs))); } // Create an LLVM function, use external linkage by default until MLIR diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 1fa687f8..1f081d8 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -599,9 +599,9 @@ parseLaunchFuncOperands(OpAsmParser &parser, return success(); SmallVector argAttrs; bool isVariadic = false; - return impl::parseFunctionArgumentList(parser, /*allowAttributes=*/false, - /*allowVariadic=*/false, argNames, - argTypes, argAttrs, isVariadic); + return function_like_impl::parseFunctionArgumentList( + parser, /*allowAttributes=*/false, + /*allowVariadic=*/false, argNames, argTypes, argAttrs, isVariadic); } static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, @@ -717,7 +717,7 @@ static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { return failure(); auto signatureLocation = parser.getCurrentLocation(); - if (failed(impl::parseFunctionSignature( + if (failed(function_like_impl::parseFunctionSignature( parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, isVariadic, resultTypes, resultAttrs))) return failure(); @@ -756,7 +756,8 @@ static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { // Parse attributes. if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); - mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); + function_like_impl::addArgAndResultAttrs(builder, result, argAttrs, + resultAttrs); // Parse the region. If no argument names were provided, take all names // (including those of attributions) from the entry block. @@ -781,33 +782,22 @@ static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) { p.printSymbolName(op.getName()); FunctionType type = op.getType(); - impl::printFunctionSignature(p, op.getOperation(), type.getInputs(), - /*isVariadic=*/false, type.getResults()); + function_like_impl::printFunctionSignature( + p, op.getOperation(), type.getInputs(), + /*isVariadic=*/false, type.getResults()); printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions()); printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions()); if (op.isKernel()) p << ' ' << op.getKernelKeyword(); - impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(), - type.getNumResults(), - {op.getNumWorkgroupAttributionsAttrName(), - GPUDialect::getKernelFuncAttrName()}); + function_like_impl::printFunctionAttributes( + p, op.getOperation(), type.getNumInputs(), type.getNumResults(), + {op.getNumWorkgroupAttributionsAttrName(), + GPUDialect::getKernelFuncAttrName()}); p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); } -void GPUFuncOp::setType(FunctionType newType) { - auto oldType = getType(); - assert(newType.getNumResults() == oldType.getNumResults() && - "unimplemented: changes to the number of results"); - - SmallVector nameBuf; - for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) - (*this)->removeAttr(getArgAttrName(i, nameBuf)); - - (*this)->setAttr(getTypeAttrName(), TypeAttr::get(newType)); -} - /// Hook for FunctionLike verifier. LogicalResult GPUFuncOp::verifyType() { Type type = getTypeAttr().getValue(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index e1ad37e..12e6ccc 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1732,21 +1732,19 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, if (argAttrs.empty()) return; - unsigned numInputs = type.cast().getNumParams(); - assert(numInputs == argAttrs.size() && + assert(type.cast().getNumParams() == argAttrs.size() && "expected as many argument attribute lists as arguments"); - SmallString<8> argAttrName; - for (unsigned i = 0; i < numInputs; ++i) - if (DictionaryAttr argDict = argAttrs[i]) - result.addAttribute(getArgAttrName(i, argAttrName), argDict); + function_like_impl::addArgAndResultAttrs(builder, result, argAttrs, + /*resultAttrs=*/llvm::None); } // Builds an LLVM function type from the given lists of input and output types. // Returns a null type if any of the types provided are non-LLVM types, or if // there is more than one output type. -static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc, - ArrayRef inputs, ArrayRef outputs, - impl::VariadicFlag variadicFlag) { +static Type +buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc, + ArrayRef inputs, ArrayRef outputs, + function_like_impl::VariadicFlag variadicFlag) { Builder &b = parser.getBuilder(); if (outputs.size() > 1) { parser.emitError(loc, "failed to construct function type: expected zero or " @@ -1803,22 +1801,23 @@ static ParseResult parseLLVMFuncOp(OpAsmParser &parser, auto signatureLocation = parser.getCurrentLocation(); if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), result.attributes) || - impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs, - argTypes, argAttrs, isVariadic, resultTypes, - resultAttrs)) + function_like_impl::parseFunctionSignature( + parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs, + isVariadic, resultTypes, resultAttrs)) return failure(); auto type = buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, - impl::VariadicFlag(isVariadic)); + function_like_impl::VariadicFlag(isVariadic)); if (!type) return failure(); - result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(function_like_impl::getTypeAttrName(), + TypeAttr::get(type)); if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); - impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs, - resultAttrs); + function_like_impl::addArgAndResultAttrs(parser.getBuilder(), result, + argAttrs, resultAttrs); auto *body = result.addRegion(); OptionalParseResult parseResult = parser.parseOptionalRegion( @@ -1846,9 +1845,10 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { if (!returnType.isa()) resTypes.push_back(returnType); - impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes); - impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(), - {getLinkageAttrName()}); + function_like_impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), + resTypes); + function_like_impl::printFunctionAttributes( + p, op, argTypes.size(), resTypes.size(), {getLinkageAttrName()}); // Print the body if this is not an external function. Region &body = op.body(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 4ca7da6..04b6353 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -99,7 +99,7 @@ struct FunctionNonEntryBlockConversion : public ConversionPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.startRootUpdate(op); - Region ®ion = mlir::impl::getFunctionBody(op); + Region ®ion = function_like_impl::getFunctionBody(op); SmallVector conversions; for (Block &block : llvm::drop_begin(region, 1)) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 31a92c3..c74528c 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1783,13 +1783,14 @@ static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) { // Parse the function signature. bool isVariadic = false; - if (impl::parseFunctionSignature(parser, /*allowVariadic=*/false, entryArgs, - argTypes, argAttrs, isVariadic, resultTypes, - resultAttrs)) + if (function_like_impl::parseFunctionSignature( + parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, + isVariadic, resultTypes, resultAttrs)) return failure(); auto fnType = builder.getFunctionType(argTypes, resultTypes); - state.addAttribute(impl::getTypeAttrName(), TypeAttr::get(fnType)); + state.addAttribute(function_like_impl::getTypeAttrName(), + TypeAttr::get(fnType)); // Parse the optional function control keyword. spirv::FunctionControl fnControl; @@ -1803,7 +1804,8 @@ static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) { // Add the attributes to the function arguments. assert(argAttrs.size() == argTypes.size()); assert(resultAttrs.size() == resultTypes.size()); - impl::addArgAndResultAttrs(builder, state, argAttrs, resultAttrs); + function_like_impl::addArgAndResultAttrs(builder, state, argAttrs, + resultAttrs); // Parse the optional function body. auto *body = state.addRegion(); @@ -1817,11 +1819,12 @@ static void print(spirv::FuncOp fnOp, OpAsmPrinter &printer) { printer << spirv::FuncOp::getOperationName() << " "; printer.printSymbolName(fnOp.sym_name()); auto fnType = fnOp.getType(); - impl::printFunctionSignature(printer, fnOp, fnType.getInputs(), - /*isVariadic=*/false, fnType.getResults()); + function_like_impl::printFunctionSignature(printer, fnOp, fnType.getInputs(), + /*isVariadic=*/false, + fnType.getResults()); printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control()) << "\""; - impl::printFunctionAttributes( + function_like_impl::printFunctionAttributes( printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(), {spirv::attributeName()}); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 5a59021..6e807a7 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -582,7 +582,7 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.first != impl::getTypeAttrName() && + if (namedAttr.first != function_like_impl::getTypeAttrName() && namedAttr.first != SymbolTable::getSymbolAttrName()) newFuncOp->setAttr(namedAttr.first, namedAttr.second); } diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp index e1706f2..728443e 100644 --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -106,27 +106,25 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - SmallString<8> argAttrName; - for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) - if (DictionaryAttr argDict = argAttrs[i]) - state.addAttribute(getArgAttrName(i, argAttrName), argDict); + function_like_impl::addArgAndResultAttrs(builder, state, argAttrs, + /*resultAttrs=*/llvm::None); } static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) { auto buildFuncType = [](Builder &builder, ArrayRef argTypes, - ArrayRef results, impl::VariadicFlag, - std::string &) { + ArrayRef results, + function_like_impl::VariadicFlag, std::string &) { return builder.getFunctionType(argTypes, results); }; - return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/false, - buildFuncType); + return function_like_impl::parseFunctionLikeOp( + parser, result, /*allowVariadic=*/false, buildFuncType); } static void print(FuncOp op, OpAsmPrinter &p) { FunctionType fnType = op.getType(); - impl::printFunctionLikeOp(p, op, fnType.getInputs(), /*isVariadic=*/false, - fnType.getResults()); + function_like_impl::printFunctionLikeOp( + p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); } static LogicalResult verify(FuncOp op) { @@ -170,30 +168,39 @@ void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { /// to cloned sub-values with the corresponding value that is copied, and adds /// those mappings to the mapper. FuncOp FuncOp::clone(BlockAndValueMapping &mapper) { - FunctionType newType = getType(); + // Create the new function. + FuncOp newFunc = cast(getOperation()->cloneWithoutRegions()); // If the function has a body, then the user might be deleting arguments to // the function by specifying them in the mapper. If so, we don't add the // argument to the input type vector. - bool isExternalFn = isExternal(); - if (!isExternalFn) { - SmallVector inputTypes; - inputTypes.reserve(newType.getNumInputs()); - for (unsigned i = 0, e = getNumArguments(); i != e; ++i) + if (!isExternal()) { + FunctionType oldType = getType(); + + unsigned oldNumArgs = oldType.getNumInputs(); + SmallVector newInputs; + newInputs.reserve(oldNumArgs); + for (unsigned i = 0; i != oldNumArgs; ++i) if (!mapper.contains(getArgument(i))) - inputTypes.push_back(newType.getInput(i)); - newType = FunctionType::get(getContext(), inputTypes, newType.getResults()); + newInputs.push_back(oldType.getInput(i)); + + /// If any of the arguments were dropped, update the type and drop any + /// necessary argument attributes. + if (newInputs.size() != oldNumArgs) { + newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, + oldType.getResults())); + + if (ArrayAttr argAttrs = getAllArgAttrs()) { + SmallVector newArgAttrs; + newArgAttrs.reserve(newInputs.size()); + for (unsigned i = 0; i != oldNumArgs; ++i) + if (!mapper.contains(getArgument(i))) + newArgAttrs.push_back(argAttrs[i]); + newFunc.setAllArgAttrs(newArgAttrs); + } + } } - // Create the new function. - FuncOp newFunc = cast(getOperation()->cloneWithoutRegions()); - newFunc.setType(newType); - - /// Set the argument attributes for arguments that aren't being replaced. - for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i) - if (isExternalFn || !mapper.contains(getArgument(i))) - newFunc.setArgAttrs(destI++, getArgAttrs(i)); - /// Clone the current function into the new one and return it. cloneInto(newFunc, mapper); return newFunc; diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp index 4bec168..aadf545 100644 --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -13,7 +13,7 @@ using namespace mlir; -ParseResult mlir::impl::parseFunctionArgumentList( +ParseResult mlir::function_like_impl::parseFunctionArgumentList( OpAsmParser &parser, bool allowAttributes, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, @@ -125,7 +125,7 @@ parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, /// indicates whether functions with variadic arguments are supported. The /// trailing arguments are populated by this function with names, types and /// attributes of the arguments and those of the results. -ParseResult mlir::impl::parseFunctionSignature( +ParseResult mlir::function_like_impl::parseFunctionSignature( OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, @@ -140,29 +140,53 @@ ParseResult mlir::impl::parseFunctionSignature( return success(); } -void mlir::impl::addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef argAttrs, - ArrayRef resultAttrs) { - // Add the attributes to the function arguments. - SmallString<8> attrNameBuf; - for (unsigned i = 0, e = argAttrs.size(); i != e; ++i) - if (!argAttrs[i].empty()) - result.addAttribute(getArgAttrName(i, attrNameBuf), - builder.getDictionaryAttr(argAttrs[i])); +/// Implementation of `addArgAndResultAttrs` that is attribute list type +/// agnostic. +template +static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result, + ArrayRef argAttrs, + ArrayRef resultAttrs, + AttrArrayBuildFnT &&buildAttrArrayFn) { + auto nonEmptyAttrsFn = [](const AttrListT &attrs) { return !attrs.empty(); }; + // Add the attributes to the function arguments. + if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) { + ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs)); + result.addAttribute(function_like_impl::getArgDictAttrName(), attrDicts); + } // Add the attributes to the function results. - for (unsigned i = 0, e = resultAttrs.size(); i != e; ++i) - if (!resultAttrs[i].empty()) - result.addAttribute(getResultAttrName(i, attrNameBuf), - builder.getDictionaryAttr(resultAttrs[i])); + if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) { + ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs)); + result.addAttribute(function_like_impl::getResultDictAttrName(), attrDicts); + } +} + +void mlir::function_like_impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, ArrayRef argAttrs, + ArrayRef resultAttrs) { + auto buildFn = [](ArrayRef attrs) { + return ArrayRef(attrs.data(), attrs.size()); + }; + addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn); +} +void mlir::function_like_impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, ArrayRef argAttrs, + ArrayRef resultAttrs) { + MLIRContext *context = builder.getContext(); + auto buildFn = [=](ArrayRef attrs) { + return llvm::to_vector<8>( + llvm::map_range(attrs, [=](const NamedAttrList &attrList) -> Attribute { + return attrList.getDictionary(context); + })); + }; + addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn); } /// Parser implementation for function-like operations. Uses `funcTypeBuilder` /// to construct the custom function type given lists of input and output types. -ParseResult -mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, - bool allowVariadic, - mlir::impl::FuncTypeBuilder funcTypeBuilder) { +ParseResult mlir::function_like_impl::parseFunctionLikeOp( + OpAsmParser &parser, OperationState &result, bool allowVariadic, + FuncTypeBuilder funcTypeBuilder) { SmallVector entryArgs; SmallVector argAttrs; SmallVector resultAttrs; @@ -187,13 +211,14 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, return failure(); std::string errorMessage; - if (auto type = funcTypeBuilder(builder, argTypes, resultTypes, - impl::VariadicFlag(isVariadic), errorMessage)) - result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); - else + Type type = funcTypeBuilder(builder, argTypes, resultTypes, + VariadicFlag(isVariadic), errorMessage); + if (!type) { return parser.emitError(signatureLocation) << "failed to construct function type" << (errorMessage.empty() ? "" : ": ") << errorMessage; + } + result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); // If function attributes are present, parse them. NamedAttrList parsedAttributes; @@ -236,35 +261,38 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, return success(); } -// Print a function result list. +/// Print a function result list. The provided `attrs` must either be null, or +/// contain a set of DictionaryAttrs of the same arity as `types`. static void printFunctionResultList(OpAsmPrinter &p, ArrayRef types, - ArrayRef> attrs) { + ArrayAttr attrs) { assert(!types.empty() && "Should not be called for empty result list."); + assert((!attrs || attrs.size() == types.size()) && + "Invalid number of attributes."); + auto &os = p.getStream(); - bool needsParens = - types.size() > 1 || types[0].isa() || !attrs[0].empty(); + bool needsParens = types.size() > 1 || types[0].isa() || + (attrs && !attrs[0].cast().empty()); if (needsParens) os << '('; - llvm::interleaveComma( - llvm::zip(types, attrs), os, - [&](const std::tuple> &t) { - p.printType(std::get<0>(t)); - p.printOptionalAttrDict(std::get<1>(t)); - }); + llvm::interleaveComma(llvm::seq(0, types.size()), os, [&](size_t i) { + p.printType(types[i]); + if (attrs) + p.printOptionalAttrDict(attrs[i].cast().getValue()); + }); if (needsParens) os << ')'; } /// Print the signature of the function-like operation `op`. Assumes `op` has /// the FunctionLike trait and passed the verification. -void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op, - ArrayRef argTypes, - bool isVariadic, - ArrayRef resultTypes) { +void mlir::function_like_impl::printFunctionSignature( + OpAsmPrinter &p, Operation *op, ArrayRef argTypes, bool isVariadic, + ArrayRef resultTypes) { Region &body = op->getRegion(0); bool isExternal = body.empty(); p << '('; + ArrayAttr argAttrs = op->getAttrOfType(getArgDictAttrName()); for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { if (i > 0) p << ", "; @@ -275,7 +303,8 @@ void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op, } p.printType(argTypes[i]); - p.printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i)); + if (argAttrs) + p.printOptionalAttrDict(argAttrs[i].cast().getValue()); } if (isVariadic) { @@ -288,9 +317,7 @@ void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op, if (!resultTypes.empty()) { p.getStream() << " -> "; - SmallVector, 4> resultAttrs; - for (int i = 0, e = resultTypes.size(); i < e; ++i) - resultAttrs.push_back(::mlir::impl::getResultAttrs(op, i)); + auto resultAttrs = op->getAttrOfType(getResultDictAttrName()); printFunctionResultList(p, resultTypes, resultAttrs); } } @@ -300,39 +327,25 @@ void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op, /// function-like operation internally are not printed. Nothing is printed /// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and /// passed the verification. -void mlir::impl::printFunctionAttributes(OpAsmPrinter &p, Operation *op, - unsigned numInputs, - unsigned numResults, - ArrayRef elided) { +void mlir::function_like_impl::printFunctionAttributes( + OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, + ArrayRef elided) { // Print out function attributes, if present. SmallVector ignoredAttrs = { - ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()}; + ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(), + getArgDictAttrName(), getResultDictAttrName()}; ignoredAttrs.append(elided.begin(), elided.end()); - SmallString<8> attrNameBuf; - - // Ignore any argument attributes. - std::vector> argAttrStorage; - for (unsigned i = 0; i != numInputs; ++i) - if (op->getAttr(getArgAttrName(i, attrNameBuf))) - argAttrStorage.emplace_back(attrNameBuf); - ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end()); - - // Ignore any result attributes. - std::vector> resultAttrStorage; - for (unsigned i = 0; i != numResults; ++i) - if (op->getAttr(getResultAttrName(i, attrNameBuf))) - resultAttrStorage.emplace_back(attrNameBuf); - ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end()); - p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); } /// Printer implementation for function-like operations. Accepts lists of /// argument and result types to use while printing. -void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op, - ArrayRef argTypes, bool isVariadic, - ArrayRef resultTypes) { +void mlir::function_like_impl::printFunctionLikeOp(OpAsmPrinter &p, + Operation *op, + ArrayRef argTypes, + bool isVariadic, + ArrayRef resultTypes) { // Print the operation and the function name. auto funcName = op->getAttrOfType(SymbolTable::getSymbolAttrName()) diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp index 347ea15..2538271 100644 --- a/mlir/lib/IR/FunctionSupport.cpp +++ b/mlir/lib/IR/FunctionSupport.cpp @@ -31,103 +31,199 @@ inline void iterateIndicesExcept(unsigned totalIndices, // Function Arguments and Results. //===----------------------------------------------------------------------===// -void mlir::impl::eraseFunctionArguments(Operation *op, - ArrayRef argIndices, - unsigned originalNumArgs, - Type newType) { +static bool isEmptyAttrDict(Attribute attr) { + return attr.cast().empty(); +} + +DictionaryAttr mlir::function_like_impl::getArgAttrDict(Operation *op, + unsigned index) { + ArrayAttr attrs = op->getAttrOfType(getArgDictAttrName()); + DictionaryAttr argAttrs = + attrs ? attrs[index].cast() : DictionaryAttr(); + return (argAttrs && !argAttrs.empty()) ? argAttrs : DictionaryAttr(); +} + +DictionaryAttr mlir::function_like_impl::getResultAttrDict(Operation *op, + unsigned index) { + ArrayAttr attrs = op->getAttrOfType(getResultDictAttrName()); + DictionaryAttr resAttrs = + attrs ? attrs[index].cast() : DictionaryAttr(); + return (resAttrs && !resAttrs.empty()) ? resAttrs : DictionaryAttr(); +} + +void mlir::function_like_impl::detail::setArgResAttrDict( + Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index, + DictionaryAttr attrs) { + ArrayAttr allAttrs = op->getAttrOfType(attrName); + if (!allAttrs) { + if (attrs.empty()) + return; + + // If this attribute is not empty, we need to create a new attribute array. + SmallVector newAttrs(numTotalIndices, + DictionaryAttr::get(op->getContext())); + newAttrs[index] = attrs; + op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); + return; + } + // Check to see if the attribute is different from what we already have. + if (allAttrs[index] == attrs) + return; + + // If it is, check to see if the attribute array would now contain only empty + // dictionaries. + ArrayRef rawAttrArray = allAttrs.getValue(); + if (attrs.empty() && + llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) && + llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) { + op->removeAttr(attrName); + return; + } + + // Otherwise, create a new attribute array with the updated dictionary. + SmallVector newAttrs(rawAttrArray.begin(), rawAttrArray.end()); + newAttrs[index] = attrs; + op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); +} + +/// Set all of the argument or result attribute dictionaries for a function. +static void setAllArgResAttrDicts(Operation *op, StringRef attrName, + ArrayRef attrs) { + if (llvm::all_of(attrs, isEmptyAttrDict)) + op->removeAttr(attrName); + else + op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs)); +} + +void mlir::function_like_impl::setAllArgAttrDicts( + Operation *op, ArrayRef attrs) { + setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} +void mlir::function_like_impl::setAllArgAttrDicts(Operation *op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, getArgDictAttrName(), + llvm::to_vector<8>(wrappedAttrs)); +} + +void mlir::function_like_impl::setAllResultAttrDicts( + Operation *op, ArrayRef attrs) { + setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} +void mlir::function_like_impl::setAllResultAttrDicts( + Operation *op, ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, getResultDictAttrName(), + llvm::to_vector<8>(wrappedAttrs)); +} + +void mlir::function_like_impl::eraseFunctionArguments( + Operation *op, ArrayRef argIndices, unsigned originalNumArgs, + Type newType) { // There are 3 things that need to be updated: // - Function type. // - Arg attrs. // - Block arguments of entry block. Block &entry = op->getRegion(0).front(); - SmallString<8> nameBuf; - - // Collect arg attrs to set. - SmallVector newArgAttrs; - iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { - newArgAttrs.emplace_back(getArgAttrDict(op, i)); - }); - - // Remove any arg attrs that are no longer needed. - for (unsigned i = newArgAttrs.size(), e = originalNumArgs; i < e; ++i) - op->removeAttr(getArgAttrName(i, nameBuf)); - - // Set the function type. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); - // Set the new arg attrs, or remove them if empty. - for (unsigned i = 0, e = newArgAttrs.size(); i != e; ++i) { - auto nameAttr = getArgAttrName(i, nameBuf); - if (newArgAttrs[i] && !newArgAttrs[i].empty()) - op->setAttr(nameAttr, newArgAttrs[i]); - else - op->removeAttr(nameAttr); + // Update the argument attributes of the function. + if (auto argAttrs = op->getAttrOfType(getArgDictAttrName())) { + SmallVector newArgAttrs; + newArgAttrs.reserve(argAttrs.size()); + iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { + newArgAttrs.emplace_back(argAttrs[i].cast()); + }); + setAllArgAttrDicts(op, newArgAttrs); } - // Update the entry block's arguments. + // Update the function type and any entry block arguments. + op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); entry.eraseArguments(argIndices); } -void mlir::impl::eraseFunctionResults(Operation *op, - ArrayRef resultIndices, - unsigned originalNumResults, - Type newType) { +void mlir::function_like_impl::eraseFunctionResults( + Operation *op, ArrayRef resultIndices, + unsigned originalNumResults, Type newType) { // There are 2 things that need to be updated: // - Function type. // - Result attrs. - SmallString<8> nameBuf; - - // Collect result attrs to set. - SmallVector newResultAttrs; - iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { - newResultAttrs.emplace_back(getResultAttrDict(op, i)); - }); - // Remove any result attrs that are no longer needed. - for (unsigned i = newResultAttrs.size(), e = originalNumResults; i < e; ++i) - op->removeAttr(getResultAttrName(i, nameBuf)); + // Update the result attributes of the function. + if (auto resAttrs = op->getAttrOfType(getResultDictAttrName())) { + SmallVector newResultAttrs; + newResultAttrs.reserve(resAttrs.size()); + iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { + newResultAttrs.emplace_back(resAttrs[i].cast()); + }); + setAllResultAttrDicts(op, newResultAttrs); + } - // Set the function type. + // Update the function type. op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); - - // Set the new result attrs, or remove them if empty. - for (unsigned i = 0, e = newResultAttrs.size(); i != e; ++i) { - auto nameAttr = getResultAttrName(i, nameBuf); - if (newResultAttrs[i] && !newResultAttrs[i].empty()) - op->setAttr(nameAttr, newResultAttrs[i]); - else - op->removeAttr(nameAttr); - } } //===----------------------------------------------------------------------===// // Function type signature. //===----------------------------------------------------------------------===// -FunctionType mlir::impl::getFunctionType(Operation *op) { +FunctionType mlir::function_like_impl::getFunctionType(Operation *op) { assert(op->hasTrait()); - return op->getAttrOfType(mlir::impl::getTypeAttrName()) + return op->getAttrOfType(getTypeAttrName()) .getValue() .cast(); } -void mlir::impl::setFunctionType(Operation *op, FunctionType newType) { +void mlir::function_like_impl::setFunctionType(Operation *op, + FunctionType newType) { assert(op->hasTrait()); - SmallVector nameBuf; FunctionType oldType = getFunctionType(op); - - for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) - op->removeAttr(getArgAttrName(i, nameBuf)); - for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e; i++) - op->removeAttr(getResultAttrName(i, nameBuf)); op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + + // Functor used to update the argument and result attributes of the function. + auto updateAttrFn = [&](StringRef attrName, unsigned oldCount, + unsigned newCount, auto setAttrFn) { + if (oldCount == newCount) + return; + // The new type has no arguments/results, just drop the attribute. + if (newCount == 0) { + op->removeAttr(attrName); + return; + } + ArrayAttr attrs = op->getAttrOfType(attrName); + if (!attrs) + return; + + // The new type has less arguments/results, take the first N attributes. + if (newCount < oldCount) + return setAttrFn(op, attrs.getValue().take_front(newCount)); + + // Otherwise, the new type has more arguments/results. Initialize the new + // arguments/results with empty attributes. + SmallVector newAttrs(attrs.begin(), attrs.end()); + newAttrs.resize(newCount); + setAttrFn(op, newAttrs); + }; + + // Update the argument and result attributes. + updateAttrFn(function_like_impl::getArgDictAttrName(), oldType.getNumInputs(), + newType.getNumInputs(), [&](Operation *op, auto &&attrs) { + setAllArgAttrDicts(op, attrs); + }); + updateAttrFn( + function_like_impl::getResultDictAttrName(), oldType.getNumResults(), + newType.getNumResults(), + [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); }); } //===----------------------------------------------------------------------===// // Function body. //===----------------------------------------------------------------------===// -Region &mlir::impl::getFunctionBody(Operation *op) { +Region &mlir::function_like_impl::getFunctionBody(Operation *op) { assert(op->hasTrait()); return op->getRegion(0); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c8bb22e..00b006c 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2628,15 +2628,15 @@ struct FunctionLikeSignatureConversion : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - FunctionType type = mlir::impl::getFunctionType(op); + FunctionType type = function_like_impl::getFunctionType(op); // Convert the original function types. TypeConverter::SignatureConversion result(type.getNumInputs()); SmallVector newResults; if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) || failed(typeConverter->convertTypes(type.getResults(), newResults)) || - failed(rewriter.convertRegionTypes(&mlir::impl::getFunctionBody(op), - *typeConverter, &result))) + failed(rewriter.convertRegionTypes( + &function_like_impl::getFunctionBody(op), *typeConverter, &result))) return failure(); // Update the function signature in-place. @@ -2644,7 +2644,7 @@ struct FunctionLikeSignatureConversion : public ConversionPattern { result.getConvertedTypes(), newResults); rewriter.updateRootInPlace( - op, [&] { mlir::impl::setFunctionType(op, newType); }); + op, [&] { function_like_impl::setFunctionType(op, newType); }); return success(); } diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir index ab32af2..e52acf6 100644 --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -35,7 +35,7 @@ module { // CHECK: attributes {xxx = {yyy = 42 : i64}} "llvm.func"() ({ }) {sym_name = "qux", type = !llvm.func, i64)>, - arg0 = {llvm.noalias = true}, xxx = {yyy = 42}} : () -> () + arg_attrs = [{llvm.noalias = true}, {}], xxx = {yyy = 42}} : () -> () // CHECK: llvm.func @roundtrip1() llvm.func @roundtrip1() diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir index c2ceefe..6b15a07 100644 --- a/mlir/test/IR/invalid-func-op.mlir +++ b/mlir/test/IR/invalid-func-op.mlir @@ -94,3 +94,22 @@ func private @invalid_symbol_name_attr() attributes { sym_name = "x" } // expected-error@+1 {{'type' is an inferred attribute and should not be specified in the explicit attribute dictionary}} func private @invalid_symbol_type_attr() attributes { type = "x" } +// ----- + +// expected-error@+1 {{argument attribute array `arg_attrs` to have the same number of elements as the number of function arguments}} +func private @invalid_arg_attrs() attributes { arg_attrs = [{}] } + +// ----- + +// expected-error@+1 {{expects argument attribute dictionary to be a DictionaryAttr, but got `10 : i64`}} +func private @invalid_arg_attrs(i32) attributes { arg_attrs = [10] } + +// ----- + +// expected-error@+1 {{result attribute array `res_attrs` to have the same number of elements as the number of function results}} +func private @invalid_res_attrs() attributes { res_attrs = [{}] } + +// ----- + +// expected-error@+1 {{expects result attribute dictionary to be a DictionaryAttr, but got `10 : i64`}} +func private @invalid_res_attrs() -> i32 attributes { res_attrs = [10] } diff --git a/mlir/test/IR/test-func-set-type.mlir b/mlir/test/IR/test-func-set-type.mlir index 05a1393..42f56ae 100644 --- a/mlir/test/IR/test-func-set-type.mlir +++ b/mlir/test/IR/test-func-set-type.mlir @@ -9,7 +9,6 @@ // Test case: The setType call needs to erase some arg attrs. // CHECK: func private @erase_arg(f32 {test.A}) -// CHECK-NOT: attributes{{.*arg[0-9]}} func private @t(f32) func private @erase_arg(%arg0: f32 {test.A}, %arg1: f32 {test.B}) attributes {test.set_type_from = @t} @@ -19,7 +18,6 @@ attributes {test.set_type_from = @t} // Test case: The setType call needs to erase some result attrs. // CHECK: func private @erase_result() -> (f32 {test.A}) -// CHECK-NOT: attributes{{.*result[0-9]}} func private @t() -> (f32) func private @erase_result() -> (f32 {test.A}, f32 {test.B}) attributes {test.set_type_from = @t} -- 2.7.4