std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
/// Parser implementation for function-like operations. Uses
/// `funcTypeBuilder` to construct the custom function type given lists of
-/// input and output types. The parser sets the `typeAttrName` attribute to the
-/// resulting function type. If `allowVariadic` is set, the parser will accept
+/// input and output types. If `allowVariadic` is set, the parser will accept
/// trailing ellipsis in the function signature and indicate to the builder
/// whether the function is variadic. If the builder returns a null type,
/// `result` will not contain the `type` attribute. The caller can then add a
/// type, report the error or delegate the reporting to the op's verifier.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
- bool allowVariadic, StringAttr typeAttrName,
+ bool allowVariadic,
FuncTypeBuilder funcTypeBuilder);
/// Printer implementation for function-like operations.
-void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
- StringRef typeAttrName);
+void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic);
/// Prints the signature of the function-like operation `op`. Assumes `op` has
/// is a FunctionOpInterface and has passed verification.
/// function-like operation internally are not printed. Nothing is printed
/// if all attributes are elided. Assumes `op` is a FunctionOpInterface and
/// has passed verification.
-void printFunctionAttributes(OpAsmPrinter &p, Operation *op,
+void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs,
+ unsigned numResults,
ArrayRef<StringRef> elided = {});
} // namespace function_interface_impl
#include "llvm/ADT/SmallString.h"
namespace mlir {
-class FunctionOpInterface;
namespace function_interface_impl {
+/// Return the name of the attribute used for function types.
+inline StringRef getTypeAttrName() { return "function_type"; }
+
/// Return the name of the attribute used for function argument attributes.
inline StringRef getArgDictAttrName() { return "arg_attrs"; }
}
/// Insert the specified arguments and update the function type attribute.
-void insertFunctionArguments(FunctionOpInterface op,
- ArrayRef<unsigned> argIndices, TypeRange argTypes,
+void insertFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
+ TypeRange argTypes,
ArrayRef<DictionaryAttr> argAttrs,
ArrayRef<Location> argLocs,
unsigned originalNumArgs, Type newType);
/// Insert the specified results and update the function type attribute.
-void insertFunctionResults(FunctionOpInterface op,
- ArrayRef<unsigned> resultIndices,
+void insertFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
TypeRange resultTypes,
ArrayRef<DictionaryAttr> resultAttrs,
unsigned originalNumResults, Type newType);
/// Erase the specified arguments and update the function type attribute.
-void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices,
+void eraseFunctionArguments(Operation *op, const BitVector &argIndices,
Type newType);
/// Erase the specified results and update the function type attribute.
-void eraseFunctionResults(FunctionOpInterface op,
- const BitVector &resultIndices, Type newType);
+void eraseFunctionResults(Operation *op, const BitVector &resultIndices,
+ Type newType);
/// Set a FunctionOpInterface operation's type signature.
-void setFunctionType(FunctionOpInterface op, Type newType);
+void setFunctionType(Operation *op, Type newType);
/// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any
/// types are inserted, `storage` is used to hold the new type list. The new
/// method on FunctionOpInterface::Trait.
template <typename ConcreteOp>
LogicalResult verifyTrait(ConcreteOp op) {
+ if (!op.getFunctionTypeAttr())
+ return op.emitOpError("requires a type attribute '")
+ << function_interface_impl::getTypeAttrName() << '\'';
+
if (failed(op.verifyType()))
return failure();
for each of the function results.
}];
let methods = [
- InterfaceMethod<[{
- Returns the type of the function.
- }],
- "::mlir::Type", "getFunctionType">,
- InterfaceMethod<[{
- Set the type of the function. This method should perform an unsafe
- modification to the function type; it should not update argument or
- result attributes.
- }],
- "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>,
InterfaceMethod<[{
Returns the function argument types based exclusively on
the type (to allow for this method may be called on function
ArrayRef<NamedAttribute> attrs, TypeRange inputTypes) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name),
+ state.addAttribute(function_interface_impl::getTypeAttrName(),
TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
// the derived operation, which should already have these defined
// (via ODS).
+ /// Returns the name of the attribute used for function types.
+ static StringRef getTypeAttrName() {
+ return function_interface_impl::getTypeAttrName();
+ }
+
/// Returns the name of the attribute used for function argument attributes.
static StringRef getArgDictAttrName() {
return function_interface_impl::getArgDictAttrName();
return function_interface_impl::getResultDictAttrName();
}
+ /// Return the attribute containing the type of this function.
+ TypeAttr getFunctionTypeAttr() {
+ return this->getOperation()->template getAttrOfType<TypeAttr>(
+ getTypeAttrName());
+ }
+
+ /// Return the type of this function.
+ Type getFunctionType() { return getFunctionTypeAttr().getValue(); }
+
//===------------------------------------------------------------------===//
// Argument and Result Handling
//===------------------------------------------------------------------===//
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
-static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs,
+static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
+ bool filterArgAndResAttrs,
SmallVectorImpl<NamedAttribute> &result) {
- for (const NamedAttribute &attr : func->getAttrs()) {
+ for (const auto &attr : attrs) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
- attr.getName() == func.getFunctionTypeAttrName() ||
+ attr.getName() == FunctionOpInterface::getTypeAttrName() ||
attr.getName() == "func.varargs" ||
(filterArgAndResAttrs &&
(attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
LLVM::LLVMFuncOp newFuncOp) {
auto type = funcOp.getFunctionType();
SmallVector<NamedAttribute, 4> attributes;
- filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
+ filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
+ attributes);
auto [wrapperFuncType, resultIsNowArg] =
typeConverter.convertFunctionTypeCWrapper(type);
if (resultIsNowArg)
assert(wrapperType && "unexpected type conversion failure");
SmallVector<NamedAttribute, 4> attributes;
- filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
+ filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
+ attributes);
if (resultIsNowArg)
prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
// Propagate argument/result attributes to all converted arguments/result
// obtained after converting a given original argument/result.
SmallVector<NamedAttribute, 4> attributes;
- filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/true, attributes);
+ filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
+ attributes);
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
assert(!resAttrDicts.empty() && "expected array to be non-empty");
auto newResAttrDicts =
SmallVector<NamedAttribute, 4> attributes;
for (const auto &attr : gpuFuncOp->getAttrs()) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
- attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
+ attr.getName() == FunctionOpInterface::getTypeAttrName() ||
attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
continue;
attributes.push_back(attr);
rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
std::nullopt));
for (const auto &namedAttr : funcOp->getAttrs()) {
- if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
+ if (namedAttr.getName() == FunctionOpInterface::getTypeAttrName() ||
namedAttr.getName() == SymbolTable::getSymbolAttrName())
continue;
newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
ArrayRef<DictionaryAttr> argAttrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+ state.addAttribute(FunctionOpInterface::getTypeAttrName(),
+ TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType);
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
- getFunctionTypeAttrName());
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
/// Check that the result type of async.func is not void and must be
ArrayRef<DictionaryAttr> argAttrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+ state.addAttribute(FunctionOpInterface::getTypeAttrName(),
+ TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType);
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
- getFunctionTypeAttrName());
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
/// Clone the internal blocks from this function into dest and all attributes
ArrayRef<NamedAttribute> attrs) {
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- result.addAttribute(getFunctionTypeAttrName(result.name),
- TypeAttr::get(type));
+ result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
result.addAttribute(getNumWorkgroupAttributionsAttrName(),
builder.getI64IntegerAttr(workgroupAttributions.size()));
result.addAttributes(attrs);
for (auto &arg : entryArgs)
argTypes.push_back(arg.type);
auto type = builder.getFunctionType(argTypes, resultTypes);
- result.addAttribute(getFunctionTypeAttrName(result.name),
- TypeAttr::get(type));
+ result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type));
function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
resultAttrs);
p << ' ' << getKernelKeyword();
function_interface_impl::printFunctionAttributes(
- p, *this,
+ p, *this, type.getNumInputs(), type.getNumResults(),
{getNumWorkgroupAttributionsAttrName(),
- GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName()});
+ GPUDialect::getKernelFuncAttrName()});
p << ' ';
p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
LogicalResult GPUFuncOp::verifyType() {
+ Type type = getFunctionTypeAttr().getValue();
+ if (!type.isa<FunctionType>())
+ return emitOpError("requires '" + getTypeAttrName() +
+ "' attribute of function type");
+
if (isKernel() && getFunctionType().getNumResults() != 0)
return emitOpError() << "expected void return type for kernel function";
function_interface_impl::VariadicFlag(isVariadic));
if (!type)
return failure();
- result.addAttribute(getFunctionTypeAttrName(result.name),
+ result.addAttribute(FunctionOpInterface::getTypeAttrName(),
TypeAttr::get(type));
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
function_interface_impl::printFunctionSignature(p, *this, argTypes,
isVarArg(), resTypes);
function_interface_impl::printFunctionAttributes(
- p, *this,
- {getFunctionTypeAttrName(), getLinkageAttrName(), getCConvAttrName()});
+ p, *this, argTypes.size(), resTypes.size(),
+ {getLinkageAttrName(), getCConvAttrName()});
// Print the body if this is not an external function.
Region &body = getBody();
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType);
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
- getFunctionTypeAttrName());
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType);
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void SubgraphOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
- getFunctionTypeAttrName());
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType);
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
- getFunctionTypeAttrName());
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
for (auto &arg : entryArgs)
argTypes.push_back(arg.type);
auto fnType = builder.getFunctionType(argTypes, resultTypes);
- result.addAttribute(getFunctionTypeAttrName(result.name),
+ result.addAttribute(FunctionOpInterface::getTypeAttrName(),
TypeAttr::get(fnType));
// Parse the optional function control keyword.
printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
<< "\"";
function_interface_impl::printFunctionAttributes(
- printer, *this,
- {spirv::attributeName<spirv::FunctionControl>(),
- getFunctionTypeAttrName(), getFunctionControlAttrName()});
+ printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
+ {spirv::attributeName<spirv::FunctionControl>()});
// Print the body if this is not an external function.
Region &body = this->getBody();
}
LogicalResult spirv::FuncOp::verifyType() {
+ auto type = getFunctionTypeAttr().getValue();
+ if (!type.isa<FunctionType>())
+ return emitOpError("requires '" + getTypeAttrName() +
+ "' attribute of function type");
if (getFunctionType().getNumResults() > 1)
return emitOpError("cannot have more than one result");
return success();
ArrayRef<NamedAttribute> attrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+ state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
builder.getAttr<spirv::FunctionControlAttr>(control));
state.attributes.append(attrs.begin(), attrs.end());
// Copy over all attributes other than the function name and type.
for (const auto &namedAttr : funcOp->getAttrs()) {
- if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
+ if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() &&
namedAttr.getName() != SymbolTable::getSymbolAttrName())
newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
}
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType);
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
- getFunctionTypeAttrName());
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
ParseResult mlir::function_interface_impl::parseFunctionOp(
OpAsmParser &parser, OperationState &result, bool allowVariadic,
- StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder) {
+ FuncTypeBuilder funcTypeBuilder) {
SmallVector<OpAsmParser::Argument> entryArgs;
SmallVector<DictionaryAttr> resultAttrs;
SmallVector<Type> resultTypes;
<< "failed to construct function type"
<< (errorMessage.empty() ? "" : ": ") << errorMessage;
}
- result.addAttribute(typeAttrName, TypeAttr::get(type));
+ result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
// If function attributes are present, parse them.
NamedAttrList parsedAttributes;
// dictionary.
for (StringRef disallowed :
{SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
- typeAttrName.getValue()}) {
+ getTypeAttrName()}) {
if (parsedAttributes.get(disallowed))
return parser.emitError(attributeDictLocation, "'")
<< disallowed
}
void mlir::function_interface_impl::printFunctionAttributes(
- OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
+ OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults,
+ ArrayRef<StringRef> elided) {
// Print out function attributes, if present.
- SmallVector<StringRef, 2> ignoredAttrs = {SymbolTable::getSymbolAttrName(),
- getArgDictAttrName(),
- getResultDictAttrName()};
+ SmallVector<StringRef, 2> ignoredAttrs = {
+ ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(),
+ getArgDictAttrName(), getResultDictAttrName()};
ignoredAttrs.append(elided.begin(), elided.end());
p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
FunctionOpInterface op,
- bool isVariadic,
- StringRef typeAttrName) {
+ bool isVariadic) {
// Print the operation and the function name.
auto funcName =
op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
ArrayRef<Type> argTypes = op.getArgumentTypes();
ArrayRef<Type> resultTypes = op.getResultTypes();
printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
- printFunctionAttributes(p, op, {visibilityAttrName, typeAttrName});
+ printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(),
+ {visibilityAttrName});
// Print the body if this is not an external function.
Region &body = op->getRegion(0);
if (!body.empty()) {
}
void mlir::function_interface_impl::insertFunctionArguments(
- FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
+ Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
unsigned originalNumArgs, Type newType) {
assert(argIndices.size() == argTypes.size());
}
// Update the function type and any entry block arguments.
- op.setFunctionTypeAttr(TypeAttr::get(newType));
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
}
void mlir::function_interface_impl::insertFunctionResults(
- FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
- TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
- unsigned originalNumResults, Type newType) {
+ Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
+ ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
+ Type newType) {
assert(resultIndices.size() == resultTypes.size());
assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
if (resultIndices.empty())
}
// Update the function type.
- op.setFunctionTypeAttr(TypeAttr::get(newType));
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
}
void mlir::function_interface_impl::eraseFunctionArguments(
- FunctionOpInterface op, const BitVector &argIndices, Type newType) {
+ Operation *op, const BitVector &argIndices, Type newType) {
// There are 3 things that need to be updated:
// - Function type.
// - Arg attrs.
}
// Update the function type and any entry block arguments.
- op.setFunctionTypeAttr(TypeAttr::get(newType));
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
entry.eraseArguments(argIndices);
}
void mlir::function_interface_impl::eraseFunctionResults(
- FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
+ Operation *op, const BitVector &resultIndices, Type newType) {
// There are 2 things that need to be updated:
// - Function type.
// - Result attrs.
}
// Update the function type.
- op.setFunctionTypeAttr(TypeAttr::get(newType));
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
}
TypeRange mlir::function_interface_impl::insertTypesInto(
// Function type signature.
//===----------------------------------------------------------------------===//
-void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op,
+void mlir::function_interface_impl::setFunctionType(Operation *op,
Type newType) {
- unsigned oldNumArgs = op.getNumArguments();
- unsigned oldNumResults = op.getNumResults();
- op.setFunctionTypeAttr(TypeAttr::get(newType));
- unsigned newNumArgs = op.getNumArguments();
- unsigned newNumResults = op.getNumResults();
+ FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
+ unsigned oldNumArgs = funcOp.getNumArguments();
+ unsigned oldNumResults = funcOp.getNumResults();
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+ unsigned newNumArgs = funcOp.getNumArguments();
+ unsigned newNumResults = funcOp.getNumResults();
// Functor used to update the argument and result attributes of the function.
auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,