// correctly.
for (auto e : llvm::enumerate(funcTy.getInputs())) {
unsigned index = e.index();
- llvm::ArrayRef<mlir::NamedAttribute> attrs = func.getArgAttrs(index);
+ llvm::ArrayRef<mlir::NamedAttribute> attrs =
+ mlir::function_interface_impl::getArgAttrs(func, index);
for (mlir::NamedAttribute attr : attrs) {
savedAttrs.push_back({index, attr});
}
let arguments = (ins
SymbolNameAttr:$sym_name,
- TypeAttrOf<FunctionType>:$function_type
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs
);
let regions = (region AnyRegion:$body);
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false, buildFuncType);
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
- mlir::function_interface_impl::printFunctionOp(p, *this,
- /*isVariadic=*/false);
+ mlir::function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
}
//===----------------------------------------------------------------------===//
let arguments = (ins
SymbolNameAttr:$sym_name,
- TypeAttrOf<FunctionType>:$function_type
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs
);
let regions = (region AnyRegion:$body);
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false, buildFuncType);
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
- mlir::function_interface_impl::printFunctionOp(p, *this,
- /*isVariadic=*/false);
+ mlir::function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
}
//===----------------------------------------------------------------------===//
let arguments = (ins
SymbolNameAttr:$sym_name,
- TypeAttrOf<FunctionType>:$function_type
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs
);
let regions = (region AnyRegion:$body);
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false, buildFuncType);
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
- mlir::function_interface_impl::printFunctionOp(p, *this,
- /*isVariadic=*/false);
+ mlir::function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
}
/// Returns the region on the function operation that is callable.
let arguments = (ins
SymbolNameAttr:$sym_name,
- TypeAttrOf<FunctionType>:$function_type
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs
);
let regions = (region AnyRegion:$body);
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false, buildFuncType);
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
- mlir::function_interface_impl::printFunctionOp(p, *this,
- /*isVariadic=*/false);
+ mlir::function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
}
/// Returns the region on the function operation that is callable.
let arguments = (ins
SymbolNameAttr:$sym_name,
- TypeAttrOf<FunctionType>:$function_type
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs
);
let regions = (region AnyRegion:$body);
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false, buildFuncType);
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
- mlir::function_interface_impl::printFunctionOp(p, *this,
- /*isVariadic=*/false);
+ mlir::function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
}
/// Returns the region on the function operation that is callable.
let arguments = (ins
SymbolNameAttr:$sym_name,
- TypeAttrOf<FunctionType>:$function_type
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs
);
let regions = (region AnyRegion:$body);
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false, buildFuncType);
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
- mlir::function_interface_impl::printFunctionOp(p, *this,
- /*isVariadic=*/false);
+ mlir::function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
}
/// Returns the region on the function operation that is callable.
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
- OptionalAttr<StrAttr>:$sym_visibility);
+ OptionalAttr<StrAttr>:$sym_visibility,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs);
let regions = (region AnyRegion:$body);
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
- OptionalAttr<StrAttr>:$sym_visibility);
+ OptionalAttr<StrAttr>:$sym_visibility,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs);
let regions = (region AnyRegion:$body);
let builders = [OpBuilder<(ins
attribution.
}];
- let arguments = (ins TypeAttrOf<FunctionType>:$function_type);
+ let arguments = (ins TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs);
let regions = (region AnyRegion:$body);
let skipDefaultBuilders = 1;
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
OptionalAttr<FlatSymbolRefAttr>:$personality,
OptionalAttr<StrAttr>:$garbageCollector,
- OptionalAttr<ArrayAttr>:$passthrough
+ OptionalAttr<ArrayAttr>:$passthrough,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs
);
let regions = (region AnyRegion:$body);
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs,
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region AnyRegion:$body);
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs,
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region AnyRegion:$body);
let arguments = (ins
SymbolNameAttr:$sym_name,
- TypeAttrOf<FunctionType>:$function_type
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs
);
let regions = (region MinSizedRegion<1>:$body);
let arguments = (ins
TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs,
StrAttr:$sym_name,
SPIRV_FunctionControlAttr:$function_control
);
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs,
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region AnyRegion:$body);
/// with special names given by getResultAttrName, getArgumentAttrName.
void addArgAndResultAttrs(Builder &builder, OperationState &result,
ArrayRef<DictionaryAttr> argAttrs,
- ArrayRef<DictionaryAttr> resultAttrs);
+ ArrayRef<DictionaryAttr> resultAttrs,
+ StringAttr argAttrsName, StringAttr resAttrsName);
void addArgAndResultAttrs(Builder &builder, OperationState &result,
- ArrayRef<OpAsmParser::Argument> argAttrs,
- ArrayRef<DictionaryAttr> resultAttrs);
+ ArrayRef<OpAsmParser::Argument> args,
+ ArrayRef<DictionaryAttr> resultAttrs,
+ StringAttr argAttrsName, StringAttr resAttrsName);
/// Callback type for `parseFunctionOp`, the callback should produce the
/// type that will be associated with a function-like operation from lists of
/// Parser implementation for function-like operations. Uses
/// `funcTypeBuilder` to construct the custom function type given lists of
-/// input and output types. If `allowVariadic` is set, the parser will accept
+/// input and output types. The parser sets the `typeAttrName` attribute to the
+/// resulting function type. If `allowVariadic` is set, the parser will accept
/// trailing ellipsis in the function signature and indicate to the builder
/// whether the function is variadic. If the builder returns a null type,
/// `result` will not contain the `type` attribute. The caller can then add a
/// type, report the error or delegate the reporting to the op's verifier.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
- bool allowVariadic,
- FuncTypeBuilder funcTypeBuilder);
+ bool allowVariadic, StringAttr typeAttrName,
+ FuncTypeBuilder funcTypeBuilder,
+ StringAttr argAttrsName, StringAttr resAttrsName);
/// Printer implementation for function-like operations.
-void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic);
+void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
+ StringRef typeAttrName, StringAttr argAttrsName,
+ StringAttr resAttrsName);
/// Prints the signature of the function-like operation `op`. Assumes `op` has
/// is a FunctionOpInterface and has passed verification.
-void printFunctionSignature(OpAsmPrinter &p, Operation *op,
+void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
ArrayRef<Type> argTypes, bool isVariadic,
ArrayRef<Type> resultTypes);
/// function-like operation internally are not printed. Nothing is printed
/// if all attributes are elided. Assumes `op` is a FunctionOpInterface and
/// has passed verification.
-void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs,
- unsigned numResults,
+void printFunctionAttributes(OpAsmPrinter &p, Operation *op,
ArrayRef<StringRef> elided = {});
} // namespace function_interface_impl
#include "llvm/ADT/SmallString.h"
namespace mlir {
+class FunctionOpInterface;
namespace function_interface_impl {
-/// Return the name of the attribute used for function types.
-inline StringRef getTypeAttrName() { return "function_type"; }
-
-/// Return the name of the attribute used for function argument attributes.
-inline StringRef getArgDictAttrName() { return "arg_attrs"; }
-
-/// Return the name of the attribute used for function argument attributes.
-inline StringRef getResultDictAttrName() { return "res_attrs"; }
-
/// Returns the dictionary attribute corresponding to the argument at 'index'.
/// If there are no argument attributes at 'index', a null attribute is
/// returned.
-DictionaryAttr getArgAttrDict(Operation *op, unsigned index);
+DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index);
/// Returns the dictionary attribute corresponding to the result at 'index'.
/// If there are no result attributes at 'index', a null attribute is
/// returned.
-DictionaryAttr getResultAttrDict(Operation *op, unsigned index);
+DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index);
-namespace detail {
-/// Update the given index into an argument or result attribute dictionary.
-void setArgResAttrDict(Operation *op, StringRef attrName,
- unsigned numTotalIndices, unsigned index,
- DictionaryAttr attrs);
-} // namespace detail
+/// Return all of the attributes for the argument at 'index'.
+ArrayRef<NamedAttribute> getArgAttrs(FunctionOpInterface op, unsigned index);
+
+/// Return all of the attributes for the result at 'index'.
+ArrayRef<NamedAttribute> getResultAttrs(FunctionOpInterface op, unsigned index);
/// Set all of the argument or result attribute dictionaries for a function. The
/// size of `attrs` is expected to match the number of arguments/results of the
/// given `op`.
-void setAllArgAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
-void setAllArgAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
-void setAllResultAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
-void setAllResultAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
-
-/// Return all of the attributes for the argument at 'index'.
-inline ArrayRef<NamedAttribute> getArgAttrs(Operation *op, unsigned index) {
- auto argDict = getArgAttrDict(op, index);
- return argDict ? argDict.getValue() : std::nullopt;
-}
-
-/// Return all of the attributes for the result at 'index'.
-inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
- auto resultDict = getResultAttrDict(op, index);
- return resultDict ? resultDict.getValue() : std::nullopt;
-}
+void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs);
+void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs);
+void setAllResultAttrDicts(FunctionOpInterface op,
+ ArrayRef<DictionaryAttr> attrs);
+void setAllResultAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs);
/// Insert the specified arguments and update the function type attribute.
-void insertFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
- TypeRange argTypes,
+void insertFunctionArguments(FunctionOpInterface op,
+ ArrayRef<unsigned> argIndices, TypeRange argTypes,
ArrayRef<DictionaryAttr> argAttrs,
ArrayRef<Location> argLocs,
unsigned originalNumArgs, Type newType);
/// Insert the specified results and update the function type attribute.
-void insertFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
+void insertFunctionResults(FunctionOpInterface op,
+ ArrayRef<unsigned> resultIndices,
TypeRange resultTypes,
ArrayRef<DictionaryAttr> resultAttrs,
unsigned originalNumResults, Type newType);
/// Erase the specified arguments and update the function type attribute.
-void eraseFunctionArguments(Operation *op, const BitVector &argIndices,
+void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices,
Type newType);
/// Erase the specified results and update the function type attribute.
-void eraseFunctionResults(Operation *op, const BitVector &resultIndices,
- Type newType);
+void eraseFunctionResults(FunctionOpInterface op,
+ const BitVector &resultIndices, Type newType);
/// Set a FunctionOpInterface operation's type signature.
-void setFunctionType(Operation *op, Type newType);
+void setFunctionType(FunctionOpInterface op, Type newType);
/// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any
/// types are inserted, `storage` is used to hold the new type list. The new
//===----------------------------------------------------------------------===//
/// Set the attributes held by the argument at 'index'.
-template <typename ConcreteType>
-void setArgAttrs(ConcreteType op, unsigned index,
- ArrayRef<NamedAttribute> attributes) {
- assert(index < op.getNumArguments() && "invalid argument number");
- return detail::setArgResAttrDict(
- op, getArgDictAttrName(), op.getNumArguments(), index,
- DictionaryAttr::get(op->getContext(), attributes));
-}
-template <typename ConcreteType>
-void setArgAttrs(ConcreteType op, unsigned index, DictionaryAttr attributes) {
- return detail::setArgResAttrDict(
- op, getArgDictAttrName(), op.getNumArguments(), index,
- attributes ? attributes : DictionaryAttr::get(op->getContext()));
-}
+void setArgAttrs(FunctionOpInterface op, unsigned index,
+ ArrayRef<NamedAttribute> attributes);
+void setArgAttrs(FunctionOpInterface op, unsigned index,
+ DictionaryAttr attributes);
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
//===----------------------------------------------------------------------===//
/// Set the attributes held by the result at 'index'.
-template <typename ConcreteType>
-void setResultAttrs(ConcreteType op, unsigned index,
- ArrayRef<NamedAttribute> attributes) {
- assert(index < op.getNumResults() && "invalid result number");
- return detail::setArgResAttrDict(
- op, getResultDictAttrName(), op.getNumResults(), index,
- DictionaryAttr::get(op->getContext(), attributes));
-}
-
-template <typename ConcreteType>
-void setResultAttrs(ConcreteType op, unsigned index,
- DictionaryAttr attributes) {
- assert(index < op.getNumResults() && "invalid result number");
- return detail::setArgResAttrDict(
- op, getResultDictAttrName(), op.getNumResults(), index,
- attributes ? attributes : DictionaryAttr::get(op->getContext()));
-}
+void setResultAttrs(FunctionOpInterface op, unsigned index,
+ ArrayRef<NamedAttribute> attributes);
+void setResultAttrs(FunctionOpInterface op, unsigned index,
+ DictionaryAttr attributes);
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
/// method on FunctionOpInterface::Trait.
template <typename ConcreteOp>
LogicalResult verifyTrait(ConcreteOp op) {
- if (!op.getFunctionTypeAttr())
- return op.emitOpError("requires a type attribute '")
- << function_interface_impl::getTypeAttrName() << '\'';
-
if (failed(op.verifyType()))
return failure();
unsigned numArgs = op.getNumArguments();
if (allArgAttrs.size() != numArgs) {
return op.emitOpError()
- << "expects argument attribute array `" << getArgDictAttrName()
- << "` to have the same number of elements as the number of "
- "function arguments, got "
+ << "expects argument attribute array to have the same number of "
+ "elements as the number of function arguments, got "
<< allArgAttrs.size() << ", but expected " << numArgs;
}
for (unsigned i = 0; i != numArgs; ++i) {
unsigned numResults = op.getNumResults();
if (allResultAttrs.size() != numResults) {
return op.emitOpError()
- << "expects result attribute array `" << getResultDictAttrName()
- << "` to have the same number of elements as the number of "
- "function results, got "
+ << "expects result attribute array to have the same number of "
+ "elements as the number of function results, got "
<< allResultAttrs.size() << ", but expected " << numResults;
}
for (unsigned i = 0; i != numResults; ++i) {
}];
let methods = [
InterfaceMethod<[{
+ Returns the type of the function.
+ }],
+ "::mlir::Type", "getFunctionType">,
+ InterfaceMethod<[{
+ Set the type of the function. This method should perform an unsafe
+ modification to the function type; it should not update argument or
+ result attributes.
+ }],
+ "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>,
+
+ InterfaceMethod<[{
+ Get the array of argument attribute dictionaries. The method should return
+ an array attribute containing only dictionary attributes equal in number
+ to the number of function arguments. Alternatively, the method can return
+ null to indicate that the function has no argument attributes.
+ }],
+ "::mlir::ArrayAttr", "getArgAttrsAttr">,
+ InterfaceMethod<[{
+ Get the array of result attribute dictionaries. The method should return
+ an array attribute containing only dictionary attributes equal in number
+ to the number of function results. Alternatively, the method can return
+ null to indicate that the function has no result attributes.
+ }],
+ "::mlir::ArrayAttr", "getResAttrsAttr">,
+ InterfaceMethod<[{
+ Set the array of argument attribute dictionaries.
+ }],
+ "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>,
+ InterfaceMethod<[{
+ Set the array of result attribute dictionaries.
+ }],
+ "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>,
+ InterfaceMethod<[{
+ Remove the array of argument attribute dictionaries. This is the same as
+ setting all argument attributes to an empty dictionary. The method should
+ return the removed attribute.
+ }],
+ "::mlir::Attribute", "removeArgAttrsAttr">,
+ InterfaceMethod<[{
+ Remove the array of result attribute dictionaries. This is the same as
+ setting all result attributes to an empty dictionary. The method should
+ return the removed attribute.
+ }],
+ "::mlir::Attribute", "removeResAttrsAttr">,
+
+ InterfaceMethod<[{
Returns the function argument types based exclusively on
the type (to allow for this method may be called on function
declarations).
ArrayRef<NamedAttribute> attrs, TypeRange inputTypes) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- state.addAttribute(function_interface_impl::getTypeAttrName(),
+ state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name),
TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
function_interface_impl::setFunctionType(this->getOperation(), newType);
}
- // FIXME: These functions should be removed in favor of just forwarding to
- // the derived operation, which should already have these defined
- // (via ODS).
-
- /// Returns the name of the attribute used for function types.
- static StringRef getTypeAttrName() {
- return function_interface_impl::getTypeAttrName();
- }
-
- /// Returns the name of the attribute used for function argument attributes.
- static StringRef getArgDictAttrName() {
- return function_interface_impl::getArgDictAttrName();
- }
-
- /// Returns the name of the attribute used for function argument attributes.
- static StringRef getResultDictAttrName() {
- return function_interface_impl::getResultDictAttrName();
- }
-
- /// Return the attribute containing the type of this function.
- TypeAttr getFunctionTypeAttr() {
- return this->getOperation()->template getAttrOfType<TypeAttr>(
- getTypeAttrName());
- }
-
- /// Return the type of this function.
- Type getFunctionType() { return getFunctionTypeAttr().getValue(); }
-
//===------------------------------------------------------------------===//
// Argument and Result Handling
//===------------------------------------------------------------------===//
/// Return an ArrayAttr containing all argument attribute dictionaries of
/// this function, or nullptr if no arguments have attributes.
- ArrayAttr getAllArgAttrs() {
- return this->getOperation()->template getAttrOfType<ArrayAttr>(
- getArgDictAttrName());
- }
+ ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); }
+
/// Return all argument attributes of this function.
void getAllArgAttrs(SmallVectorImpl<DictionaryAttr> &result) {
if (ArrayAttr argAttrs = getAllArgAttrs()) {
}
void setAllArgAttrs(ArrayAttr attributes) {
assert(attributes.size() == $_op.getNumArguments());
- this->getOperation()->setAttr(getArgDictAttrName(), attributes);
+ $_op.setArgAttrsAttr(attributes);
}
/// If the an attribute exists with the specified name, change it to the new
/// Return an ArrayAttr containing all result attribute dictionaries of this
/// function, or nullptr if no result have attributes.
- ArrayAttr getAllResultAttrs() {
- return this->getOperation()->template getAttrOfType<ArrayAttr>(
- getResultDictAttrName());
- }
+ ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); }
+
/// Return all result attributes of this function.
void getAllResultAttrs(SmallVectorImpl<DictionaryAttr> &result) {
if (ArrayAttr argAttrs = getAllResultAttrs()) {
}
void setAllResultAttrs(ArrayAttr attributes) {
assert(attributes.size() == $_op.getNumResults());
- this->getOperation()->setAttr(getResultDictAttrName(), attributes);
+ $_op.setResAttrsAttr(attributes);
}
/// If the an attribute exists with the specified name, change it to the new
}
def IndexListArrayAttr :
TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;
+def DictArrayAttr :
+ TypedArrayAttrBase<DictionaryAttr, "Array of dictionary attributes">;
// Attributes containing symbol references.
def SymbolRefAttr : Attr<CPred<"$_self.isa<::mlir::SymbolRefAttr>()">,
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
-static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
- bool filterArgAndResAttrs,
+static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs,
SmallVectorImpl<NamedAttribute> &result) {
- for (const auto &attr : attrs) {
+ for (const NamedAttribute &attr : func->getAttrs()) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
- attr.getName() == FunctionOpInterface::getTypeAttrName() ||
+ attr.getName() == func.getFunctionTypeAttrName() ||
attr.getName() == "func.varargs" ||
(filterArgAndResAttrs &&
- (attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
- attr.getName() == FunctionOpInterface::getResultDictAttrName())))
+ (attr.getName() == func.getArgAttrsAttrName() ||
+ attr.getName() == func.getResAttrsAttrName())))
continue;
result.push_back(attr);
}
static void
prependResAttrsToArgAttrs(OpBuilder &builder,
SmallVectorImpl<NamedAttribute> &attributes,
- size_t numArguments) {
+ func::FuncOp func) {
+ size_t numArguments = func.getNumArguments();
auto allAttrs = SmallVector<Attribute>(
numArguments + 1, DictionaryAttr::get(builder.getContext()));
NamedAttribute *argAttrs = nullptr;
for (auto *it = attributes.begin(); it != attributes.end();) {
- if (it->getName() == FunctionOpInterface::getArgDictAttrName()) {
+ if (it->getName() == func.getArgAttrsAttrName()) {
auto arrayAttrs = it->getValue().cast<ArrayAttr>();
assert(arrayAttrs.size() == numArguments &&
"Number of arg attrs and args should match");
std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1);
argAttrs = it;
- } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) {
+ } else if (it->getName() == func.getResAttrsAttrName()) {
auto arrayAttrs = it->getValue().cast<ArrayAttr>();
assert(!arrayAttrs.empty() && "expected array to be non-empty");
allAttrs[0] = (arrayAttrs.size() == 1)
it++;
}
- auto newArgAttrs =
- builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
- builder.getArrayAttr(allAttrs));
+ auto newArgAttrs = builder.getNamedAttr(func.getArgAttrsAttrName(),
+ builder.getArrayAttr(allAttrs));
if (!argAttrs) {
attributes.emplace_back(newArgAttrs);
return;
LLVM::LLVMFuncOp newFuncOp) {
auto type = funcOp.getFunctionType();
SmallVector<NamedAttribute, 4> attributes;
- filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
- attributes);
+ filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
auto [wrapperFuncType, resultIsNowArg] =
typeConverter.convertFunctionTypeCWrapper(type);
if (resultIsNowArg)
- prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments());
+ prependResAttrsToArgAttrs(rewriter, attributes, funcOp);
auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false,
assert(wrapperType && "unexpected type conversion failure");
SmallVector<NamedAttribute, 4> attributes;
- filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
- attributes);
+ filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
if (resultIsNowArg)
- prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
+ prependResAttrsToArgAttrs(builder, attributes, funcOp);
// Create the auxiliary function.
auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
// Propagate argument/result attributes to all converted arguments/result
// obtained after converting a given original argument/result.
SmallVector<NamedAttribute, 4> attributes;
- filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
- attributes);
+ filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/true, attributes);
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
assert(!resAttrDicts.empty() && "expected array to be non-empty");
auto newResAttrDicts =
? resAttrDicts
: rewriter.getArrayAttr(
{wrapAsStructAttrs(rewriter, resAttrDicts)});
- attributes.push_back(rewriter.getNamedAttr(
- FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
+ attributes.push_back(
+ rewriter.getNamedAttr(funcOp.getResAttrsAttrName(), newResAttrDicts));
}
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
SmallVector<Attribute, 4> newArgAttrs(
newArgAttrs[mapping->inputNo + j] =
DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
}
- attributes.push_back(
- rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
- rewriter.getArrayAttr(newArgAttrs)));
+ attributes.push_back(rewriter.getNamedAttr(
+ funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(newArgAttrs)));
}
for (const auto &pair : llvm::enumerate(attributes)) {
if (pair.value().getName() == "llvm.linkage") {
SmallVector<NamedAttribute, 4> attributes;
for (const auto &attr : gpuFuncOp->getAttrs()) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
- attr.getName() == FunctionOpInterface::getTypeAttrName() ||
+ attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
continue;
attributes.push_back(attr);
rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
std::nullopt));
for (const auto &namedAttr : funcOp->getAttrs()) {
- if (namedAttr.getName() == FunctionOpInterface::getTypeAttrName() ||
+ if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
namedAttr.getName() == SymbolTable::getSymbolAttrName())
continue;
newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
ArrayRef<DictionaryAttr> argAttrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- state.addAttribute(FunctionOpInterface::getTypeAttrName(),
- TypeAttr::get(type));
+ state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
- function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
- /*resultAttrs=*/std::nullopt);
+ function_interface_impl::addArgAndResultAttrs(
+ builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
+ getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
}
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false, buildFuncType);
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
}
/// Check that the result type of async.func is not void and must be
ArrayRef<DictionaryAttr> argAttrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- state.addAttribute(FunctionOpInterface::getTypeAttrName(),
- TypeAttr::get(type));
+ state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
- function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
- /*resultAttrs=*/std::nullopt);
+ function_interface_impl::addArgAndResultAttrs(
+ builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
+ getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
}
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false, buildFuncType);
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
}
/// Clone the internal blocks from this function into dest and all attributes
ArrayRef<NamedAttribute> attrs) {
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+ result.addAttribute(getFunctionTypeAttrName(result.name),
+ TypeAttr::get(type));
result.addAttribute(getNumWorkgroupAttributionsAttrName(),
builder.getI64IntegerAttr(workgroupAttributions.size()));
result.addAttributes(attrs);
for (auto &arg : entryArgs)
argTypes.push_back(arg.type);
auto type = builder.getFunctionType(argTypes, resultTypes);
- result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type));
+ result.addAttribute(getFunctionTypeAttrName(result.name),
+ TypeAttr::get(type));
- function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
- resultAttrs);
+ function_interface_impl::addArgAndResultAttrs(
+ builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+ getResAttrsAttrName(result.name));
// Parse workgroup memory attributions.
if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
p << ' ' << getKernelKeyword();
function_interface_impl::printFunctionAttributes(
- p, *this, type.getNumInputs(), type.getNumResults(),
+ p, *this,
{getNumWorkgroupAttributionsAttrName(),
- GPUDialect::getKernelFuncAttrName()});
+ GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName()});
p << ' ';
p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
LogicalResult GPUFuncOp::verifyType() {
- Type type = getFunctionTypeAttr().getValue();
- if (!type.isa<FunctionType>())
- return emitOpError("requires '" + getTypeAttrName() +
- "' attribute of function type");
-
if (isKernel() && getFunctionType().getNumResults() != 0)
return emitOpError() << "expected void return type for kernel function";
assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
"expected as many argument attribute lists as arguments");
- function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs,
- /*resultAttrs=*/std::nullopt);
+ function_interface_impl::addArgAndResultAttrs(
+ builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
// Builds an LLVM function type from the given lists of input and output types.
function_interface_impl::VariadicFlag(isVariadic));
if (!type)
return failure();
- result.addAttribute(FunctionOpInterface::getTypeAttrName(),
+ result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(type));
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
return failure();
- function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result,
- entryArgs, resultAttrs);
+ function_interface_impl::addArgAndResultAttrs(
+ parser.getBuilder(), result, entryArgs, resultAttrs,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
auto *body = result.addRegion();
OptionalParseResult parseResult =
function_interface_impl::printFunctionSignature(p, *this, argTypes,
isVarArg(), resTypes);
function_interface_impl::printFunctionAttributes(
- p, *this, argTypes.size(), resTypes.size(),
- {getLinkageAttrName(), getCConvAttrName()});
+ p, *this,
+ {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
+ getLinkageAttrName(), getCConvAttrName()});
// Print the body if this is not an external function.
Region &body = getBody();
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false, buildFuncType);
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
}
//===----------------------------------------------------------------------===//
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false, buildFuncType);
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void SubgraphOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
}
//===----------------------------------------------------------------------===//
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false, buildFuncType);
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
}
//===----------------------------------------------------------------------===//
for (auto &arg : entryArgs)
argTypes.push_back(arg.type);
auto fnType = builder.getFunctionType(argTypes, resultTypes);
- result.addAttribute(FunctionOpInterface::getTypeAttrName(),
+ result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(fnType));
// Parse the optional function control keyword.
// Add the attributes to the function arguments.
assert(resultAttrs.size() == resultTypes.size());
- function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
- resultAttrs);
+ function_interface_impl::addArgAndResultAttrs(
+ builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+ getResAttrsAttrName(result.name));
// Parse the optional function body.
auto *body = result.addRegion();
printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
<< "\"";
function_interface_impl::printFunctionAttributes(
- printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
- {spirv::attributeName<spirv::FunctionControl>()});
+ printer, *this,
+ {spirv::attributeName<spirv::FunctionControl>(),
+ getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
+ getFunctionControlAttrName()});
// Print the body if this is not an external function.
Region &body = this->getBody();
}
LogicalResult spirv::FuncOp::verifyType() {
- auto type = getFunctionTypeAttr().getValue();
- if (!type.isa<FunctionType>())
- return emitOpError("requires '" + getTypeAttrName() +
- "' attribute of function type");
if (getFunctionType().getNumResults() > 1)
return emitOpError("cannot have more than one result");
return success();
ArrayRef<NamedAttribute> attrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+ state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
builder.getAttr<spirv::FunctionControlAttr>(control));
state.attributes.append(attrs.begin(), attrs.end());
// Copy over all attributes other than the function name and type.
for (const auto &namedAttr : funcOp->getAttrs()) {
- if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() &&
+ if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
namedAttr.getName() != SymbolTable::getSymbolAttrName())
newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
}
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
- function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
- /*resultAttrs=*/std::nullopt);
+ function_interface_impl::addArgAndResultAttrs(
+ builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
+ getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
}
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false, buildFuncType);
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
}
//===----------------------------------------------------------------------===//
return parser.parseRParen();
}
-ParseResult mlir::function_interface_impl::parseFunctionSignature(
+ParseResult function_interface_impl::parseFunctionSignature(
OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
SmallVectorImpl<Type> &resultTypes,
return success();
}
-void mlir::function_interface_impl::addArgAndResultAttrs(
+void function_interface_impl::addArgAndResultAttrs(
Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
- ArrayRef<DictionaryAttr> resultAttrs) {
+ ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
+ StringAttr resAttrsName) {
auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
return attrs && !attrs.empty();
};
// Add the attributes to the function arguments.
if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
- result.addAttribute(function_interface_impl::getArgDictAttrName(),
- getArrayAttr(argAttrs));
+ result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
// Add the attributes to the function results.
if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
- result.addAttribute(function_interface_impl::getResultDictAttrName(),
- getArrayAttr(resultAttrs));
+ result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
}
-void mlir::function_interface_impl::addArgAndResultAttrs(
+void function_interface_impl::addArgAndResultAttrs(
Builder &builder, OperationState &result,
- ArrayRef<OpAsmParser::Argument> args,
- ArrayRef<DictionaryAttr> resultAttrs) {
+ ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
+ StringAttr argAttrsName, StringAttr resAttrsName) {
SmallVector<DictionaryAttr> argAttrs;
for (const auto &arg : args)
argAttrs.push_back(arg.attrs);
- addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
+ addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
+ resAttrsName);
}
-ParseResult mlir::function_interface_impl::parseFunctionOp(
+ParseResult function_interface_impl::parseFunctionOp(
OpAsmParser &parser, OperationState &result, bool allowVariadic,
- FuncTypeBuilder funcTypeBuilder) {
+ StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
+ StringAttr argAttrsName, StringAttr resAttrsName) {
SmallVector<OpAsmParser::Argument> entryArgs;
SmallVector<DictionaryAttr> resultAttrs;
SmallVector<Type> resultTypes;
<< "failed to construct function type"
<< (errorMessage.empty() ? "" : ": ") << errorMessage;
}
- result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+ result.addAttribute(typeAttrName, TypeAttr::get(type));
// If function attributes are present, parse them.
NamedAttrList parsedAttributes;
// dictionary.
for (StringRef disallowed :
{SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
- getTypeAttrName()}) {
+ typeAttrName.getValue()}) {
if (parsedAttributes.get(disallowed))
return parser.emitError(attributeDictLocation, "'")
<< disallowed
// Add the attributes to the function arguments.
assert(resultAttrs.size() == resultTypes.size());
- addArgAndResultAttrs(builder, result, entryArgs, resultAttrs);
+ addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName,
+ resAttrsName);
// Parse the optional function body. The printer will not print the body if
// its empty, so disallow parsing of empty body in the parser.
os << ')';
}
-void mlir::function_interface_impl::printFunctionSignature(
- OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
- ArrayRef<Type> resultTypes) {
+void function_interface_impl::printFunctionSignature(
+ OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
+ bool isVariadic, ArrayRef<Type> resultTypes) {
Region &body = op->getRegion(0);
bool isExternal = body.empty();
p << '(';
- ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+ ArrayAttr argAttrs = op.getArgAttrsAttr();
for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
if (i > 0)
p << ", ";
if (!resultTypes.empty()) {
p.getStream() << " -> ";
- auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+ auto resultAttrs = op.getResAttrsAttr();
printFunctionResultList(p, resultTypes, resultAttrs);
}
}
-void mlir::function_interface_impl::printFunctionAttributes(
- OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults,
- ArrayRef<StringRef> elided) {
+void function_interface_impl::printFunctionAttributes(
+ OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
// Print out function attributes, if present.
- SmallVector<StringRef, 2> ignoredAttrs = {
- ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(),
- getArgDictAttrName(), getResultDictAttrName()};
+ SmallVector<StringRef, 8> ignoredAttrs = {SymbolTable::getSymbolAttrName()};
ignoredAttrs.append(elided.begin(), elided.end());
p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
}
-void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
- FunctionOpInterface op,
- bool isVariadic) {
+void function_interface_impl::printFunctionOp(
+ OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
+ StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) {
// Print the operation and the function name.
auto funcName =
op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
ArrayRef<Type> argTypes = op.getArgumentTypes();
ArrayRef<Type> resultTypes = op.getResultTypes();
printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
- printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(),
- {visibilityAttrName});
+ printFunctionAttributes(
+ p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName});
// Print the body if this is not an external function.
Region &body = op->getRegion(0);
if (!body.empty()) {
return attr.cast<DictionaryAttr>().empty();
}
-DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op,
- unsigned index) {
- ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op,
+ unsigned index) {
+ ArrayAttr attrs = op.getArgAttrsAttr();
DictionaryAttr argAttrs =
attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
return argAttrs;
}
DictionaryAttr
-mlir::function_interface_impl::getResultAttrDict(Operation *op,
- unsigned index) {
- ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+function_interface_impl::getResultAttrDict(FunctionOpInterface op,
+ unsigned index) {
+ ArrayAttr attrs = op.getResAttrsAttr();
DictionaryAttr resAttrs =
attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
return resAttrs;
}
-void mlir::function_interface_impl::detail::setArgResAttrDict(
- Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
- DictionaryAttr attrs) {
- ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
+ArrayRef<NamedAttribute>
+function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) {
+ auto argDict = getArgAttrDict(op, index);
+ return argDict ? argDict.getValue() : std::nullopt;
+}
+
+ArrayRef<NamedAttribute>
+function_interface_impl::getResultAttrs(FunctionOpInterface op,
+ unsigned index) {
+ auto resultDict = getResultAttrDict(op, index);
+ return resultDict ? resultDict.getValue() : std::nullopt;
+}
+
+/// Get either the argument or result attributes array.
+template <bool isArg>
+static ArrayAttr getArgResAttrs(FunctionOpInterface op) {
+ if constexpr (isArg)
+ return op.getArgAttrsAttr();
+ else
+ return op.getResAttrsAttr();
+}
+
+/// Set either the argument or result attributes array.
+template <bool isArg>
+static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) {
+ if constexpr (isArg)
+ op.setArgAttrsAttr(attrs);
+ else
+ op.setResAttrsAttr(attrs);
+}
+
+/// Erase either the argument or result attributes array.
+template <bool isArg>
+static void removeArgResAttrs(FunctionOpInterface op) {
+ if constexpr (isArg)
+ op.removeArgAttrsAttr();
+ else
+ op.removeResAttrsAttr();
+}
+
+/// Set all of the argument or result attribute dictionaries for a function.
+template <bool isArg>
+static void setAllArgResAttrDicts(FunctionOpInterface op,
+ ArrayRef<Attribute> attrs) {
+ if (llvm::all_of(attrs, isEmptyAttrDict))
+ removeArgResAttrs<isArg>(op);
+ else
+ setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs));
+}
+
+void function_interface_impl::setAllArgAttrDicts(
+ FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
+ setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+}
+
+void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op,
+ ArrayRef<Attribute> attrs) {
+ auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
+ return !attr ? DictionaryAttr::get(op->getContext()) : attr;
+ });
+ setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(wrappedAttrs));
+}
+
+void function_interface_impl::setAllResultAttrDicts(
+ FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
+ setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+}
+
+void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op,
+ ArrayRef<Attribute> attrs) {
+ auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
+ return !attr ? DictionaryAttr::get(op->getContext()) : attr;
+ });
+ setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(wrappedAttrs));
+}
+
+/// Update the given index into an argument or result attribute dictionary.
+template <bool isArg>
+static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
+ unsigned index, DictionaryAttr attrs) {
+ ArrayAttr allAttrs = getArgResAttrs<isArg>(op);
if (!allAttrs) {
if (attrs.empty())
return;
SmallVector<Attribute, 8> newAttrs(numTotalIndices,
DictionaryAttr::get(op->getContext()));
newAttrs[index] = attrs;
- op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
+ setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
return;
}
// Check to see if the attribute is different from what we already have.
ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
if (attrs.empty() &&
llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
- llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
- op->removeAttr(attrName);
- return;
- }
+ llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict))
+ return removeArgResAttrs<isArg>(op);
// Otherwise, create a new attribute array with the updated dictionary.
SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
newAttrs[index] = attrs;
- op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
+ setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
}
-/// Set all of the argument or result attribute dictionaries for a function.
-static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
- ArrayRef<Attribute> attrs) {
- if (llvm::all_of(attrs, isEmptyAttrDict))
- op->removeAttr(attrName);
- else
- op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
+void function_interface_impl::setArgAttrs(FunctionOpInterface op,
+ unsigned index,
+ ArrayRef<NamedAttribute> attributes) {
+ assert(index < op.getNumArguments() && "invalid argument number");
+ return setArgResAttrDict</*isArg=*/true>(
+ op, op.getNumArguments(), index,
+ DictionaryAttr::get(op->getContext(), attributes));
}
-void mlir::function_interface_impl::setAllArgAttrDicts(
- Operation *op, ArrayRef<DictionaryAttr> attrs) {
- setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
-}
-void mlir::function_interface_impl::setAllArgAttrDicts(
- Operation *op, ArrayRef<Attribute> attrs) {
- auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
- return !attr ? DictionaryAttr::get(op->getContext()) : attr;
- });
- setAllArgResAttrDicts(op, getArgDictAttrName(),
- llvm::to_vector<8>(wrappedAttrs));
+void function_interface_impl::setArgAttrs(FunctionOpInterface op,
+ unsigned index,
+ DictionaryAttr attributes) {
+ return setArgResAttrDict</*isArg=*/true>(
+ op, op.getNumArguments(), index,
+ attributes ? attributes : DictionaryAttr::get(op->getContext()));
}
-void mlir::function_interface_impl::setAllResultAttrDicts(
- Operation *op, ArrayRef<DictionaryAttr> attrs) {
- setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+void function_interface_impl::setResultAttrs(
+ FunctionOpInterface op, unsigned index,
+ ArrayRef<NamedAttribute> attributes) {
+ assert(index < op.getNumResults() && "invalid result number");
+ return setArgResAttrDict</*isArg=*/false>(
+ op, op.getNumResults(), index,
+ DictionaryAttr::get(op->getContext(), attributes));
}
-void mlir::function_interface_impl::setAllResultAttrDicts(
- Operation *op, ArrayRef<Attribute> attrs) {
- auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
- return !attr ? DictionaryAttr::get(op->getContext()) : attr;
- });
- setAllArgResAttrDicts(op, getResultDictAttrName(),
- llvm::to_vector<8>(wrappedAttrs));
+
+void function_interface_impl::setResultAttrs(FunctionOpInterface op,
+ unsigned index,
+ DictionaryAttr attributes) {
+ assert(index < op.getNumResults() && "invalid result number");
+ return setArgResAttrDict</*isArg=*/false>(
+ op, op.getNumResults(), index,
+ attributes ? attributes : DictionaryAttr::get(op->getContext()));
}
-void mlir::function_interface_impl::insertFunctionArguments(
- Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
+void function_interface_impl::insertFunctionArguments(
+ FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
unsigned originalNumArgs, Type newType) {
assert(argIndices.size() == argTypes.size());
Block &entry = op->getRegion(0).front();
// Update the argument attributes of the function.
- auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+ ArrayAttr oldArgAttrs = op.getArgAttrsAttr();
if (oldArgAttrs || !argAttrs.empty()) {
SmallVector<DictionaryAttr, 4> newArgAttrs;
newArgAttrs.reserve(originalNumArgs + argIndices.size());
}
// Update the function type and any entry block arguments.
- op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+ op.setFunctionTypeAttr(TypeAttr::get(newType));
for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
}
-void mlir::function_interface_impl::insertFunctionResults(
- Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
- ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
- Type newType) {
+void function_interface_impl::insertFunctionResults(
+ FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
+ TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
+ unsigned originalNumResults, Type newType) {
assert(resultIndices.size() == resultTypes.size());
assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
if (resultIndices.empty())
// - Result attrs.
// Update the result attributes of the function.
- auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+ ArrayAttr oldResultAttrs = op.getResAttrsAttr();
if (oldResultAttrs || !resultAttrs.empty()) {
SmallVector<DictionaryAttr, 4> newResultAttrs;
newResultAttrs.reserve(originalNumResults + resultIndices.size());
}
// Update the function type.
- op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+ op.setFunctionTypeAttr(TypeAttr::get(newType));
}
-void mlir::function_interface_impl::eraseFunctionArguments(
- Operation *op, const BitVector &argIndices, Type newType) {
+void function_interface_impl::eraseFunctionArguments(
+ FunctionOpInterface op, const BitVector &argIndices, Type newType) {
// There are 3 things that need to be updated:
// - Function type.
// - Arg attrs.
Block &entry = op->getRegion(0).front();
// Update the argument attributes of the function.
- if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
+ if (ArrayAttr argAttrs = op.getArgAttrsAttr()) {
SmallVector<DictionaryAttr, 4> newArgAttrs;
newArgAttrs.reserve(argAttrs.size());
for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
}
// Update the function type and any entry block arguments.
- op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+ op.setFunctionTypeAttr(TypeAttr::get(newType));
entry.eraseArguments(argIndices);
}
-void mlir::function_interface_impl::eraseFunctionResults(
- Operation *op, const BitVector &resultIndices, Type newType) {
+void function_interface_impl::eraseFunctionResults(
+ FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
// There are 2 things that need to be updated:
// - Function type.
// - Result attrs.
// Update the result attributes of the function.
- if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
+ if (ArrayAttr resAttrs = op.getResAttrsAttr()) {
SmallVector<DictionaryAttr, 4> newResultAttrs;
newResultAttrs.reserve(resAttrs.size());
for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
}
// Update the function type.
- op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+ op.setFunctionTypeAttr(TypeAttr::get(newType));
}
-TypeRange mlir::function_interface_impl::insertTypesInto(
+TypeRange function_interface_impl::insertTypesInto(
TypeRange oldTypes, ArrayRef<unsigned> indices, TypeRange newTypes,
SmallVectorImpl<Type> &storage) {
assert(indices.size() == newTypes.size() &&
return storage;
}
-TypeRange mlir::function_interface_impl::filterTypesOut(
+TypeRange function_interface_impl::filterTypesOut(
TypeRange types, const BitVector &indices, SmallVectorImpl<Type> &storage) {
if (indices.none())
return types;
// Function type signature.
//===----------------------------------------------------------------------===//
-void mlir::function_interface_impl::setFunctionType(Operation *op,
- Type newType) {
- FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
- unsigned oldNumArgs = funcOp.getNumArguments();
- unsigned oldNumResults = funcOp.getNumResults();
- op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
- unsigned newNumArgs = funcOp.getNumArguments();
- unsigned newNumResults = funcOp.getNumResults();
+void function_interface_impl::setFunctionType(FunctionOpInterface op,
+ Type newType) {
+ unsigned oldNumArgs = op.getNumArguments();
+ unsigned oldNumResults = op.getNumResults();
+ op.setFunctionTypeAttr(TypeAttr::get(newType));
+ unsigned newNumArgs = op.getNumArguments();
+ unsigned newNumResults = op.getNumResults();
// Functor used to update the argument and result attributes of the function.
- auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
- unsigned newCount, auto setAttrFn) {
+ auto emptyDict = DictionaryAttr::get(op.getContext());
+ auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) {
+ constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>;
+
if (oldCount == newCount)
return;
// The new type has no arguments/results, just drop the attribute.
- if (newCount == 0) {
- op->removeAttr(attrName);
- return;
- }
- ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
+ if (newCount == 0)
+ return removeArgResAttrs<isArgVal>(op);
+ ArrayAttr attrs = getArgResAttrs<isArgVal>(op);
if (!attrs)
return;
// The new type has less arguments/results, take the first N attributes.
if (newCount < oldCount)
- return setAttrFn(op, attrs.getValue().take_front(newCount));
+ return setAllArgResAttrDicts<isArgVal>(
+ op, attrs.getValue().take_front(newCount));
// Otherwise, the new type has more arguments/results. Initialize the new
- // arguments/results with empty attributes.
+ // arguments/results with empty dictionary attributes.
SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
- newAttrs.resize(newCount);
- setAttrFn(op, newAttrs);
+ newAttrs.resize(newCount, emptyDict);
+ setAllArgResAttrDicts<isArgVal>(op, newAttrs);
};
// Update the argument and result attributes.
- updateAttrFn(
- getArgDictAttrName(), oldNumArgs, newNumArgs,
- [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); });
- updateAttrFn(
- getResultDictAttrName(), oldNumResults, newNumResults,
- [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
+ updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs);
+ updateAttrFn(std::false_type{}, oldNumResults, newNumResults);
}
// -----
-// expected-error@+1 {{argument attribute array `arg_attrs` to have the same number of elements as the number of function arguments}}
+// expected-error@+1 {{argument attribute array to have the same number of elements as the number of function arguments}}
func.func private @invalid_arg_attrs() attributes { arg_attrs = [{}] }
// -----
-// expected-error@+1 {{expects argument attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
-func.func private @invalid_arg_attrs(i32) attributes { arg_attrs = [10] }
-// -----
-
-// expected-error@+1 {{result attribute array `res_attrs` to have the same number of elements as the number of function results}}
+// expected-error@+1 {{result attribute array to have the same number of elements as the number of function results}}
func.func private @invalid_res_attrs() attributes { res_attrs = [{}] }
-
-// -----
-
-// expected-error@+1 {{expects result attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
-func.func private @invalid_res_attrs() -> i32 attributes { res_attrs = [10] }