From fbc253fe81da4e1d6bfa2519e01e03f21d8c40a8 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Tue, 6 Dec 2022 11:28:47 -0800 Subject: [PATCH] [mlir] FunctionOpInterface: make get/setFunctionType interface methods This patch removes the concept of a `function_type`-named type attribute as a requirement for implementors of FunctionOpInterface. Instead, this type should be provided through two interface methods, `getFunctionType` and `setFunctionTypeAttr` (*Attr because functions may use different concrete function types), which should be automatically implemented by ODS for ops that define a `$function_type` attribute. This also allows FunctionOpInterface to materialize function types if they don't carry them in an attribute, for example. Importantly, all the function "helper" still accept an attribute name to use in parsing and printing functions, for example. Reviewed By: rriddle, lattner Differential Revision: https://reviews.llvm.org/D139447 --- mlir/examples/toy/Ch2/mlir/Dialect.cpp | 3 +- mlir/examples/toy/Ch3/mlir/Dialect.cpp | 3 +- mlir/examples/toy/Ch4/mlir/Dialect.cpp | 3 +- mlir/examples/toy/Ch5/mlir/Dialect.cpp | 3 +- mlir/examples/toy/Ch6/mlir/Dialect.cpp | 3 +- mlir/examples/toy/Ch7/mlir/Dialect.cpp | 3 +- mlir/include/mlir/IR/FunctionImplementation.h | 11 ++++---- mlir/include/mlir/IR/FunctionInterfaces.h | 23 ++++++--------- mlir/include/mlir/IR/FunctionInterfaces.td | 26 ++++++++--------- mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 16 ++++------- mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 2 +- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 2 +- mlir/lib/Dialect/Async/IR/Async.cpp | 9 +++--- mlir/lib/Dialect/Func/IR/FuncOps.cpp | 9 +++--- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 15 ++++------ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 6 ++-- mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp | 12 +++++--- mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp | 6 ++-- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 13 ++++----- .../Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 +- mlir/lib/Dialect/Shape/IR/Shape.cpp | 6 ++-- mlir/lib/IR/FunctionImplementation.cpp | 21 +++++++------- mlir/lib/IR/FunctionInterfaces.cpp | 33 +++++++++++----------- 23 files changed, 113 insertions(+), 117 deletions(-) diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp index dbc1efb..ac12c5c 100644 --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -211,7 +211,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(mlir::OpAsmPrinter &p) { diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp index 50e2dfc..75cb57e 100644 --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -198,7 +198,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(mlir::OpAsmPrinter &p) { diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index 0a6195b..2d5a369 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -287,7 +287,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(mlir::OpAsmPrinter &p) { diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index f236a1f..280bf31 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -287,7 +287,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(mlir::OpAsmPrinter &p) { diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp index f236a1f..280bf31 100644 --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -287,7 +287,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(mlir::OpAsmPrinter &p) { diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index cc66a5d..b0d2130 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -314,7 +314,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(mlir::OpAsmPrinter &p) { diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h index 5265f78..f4c0cc0 100644 --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -69,17 +69,19 @@ Type getFunctionType(Builder &builder, ArrayRef argAttrs, /// Parser implementation for function-like operations. Uses /// `funcTypeBuilder` to construct the custom function type given lists of -/// input and output types. If `allowVariadic` is set, the parser will accept +/// input and output types. The parser sets the `typeAttrName` attribute to the +/// resulting function type. If `allowVariadic` is set, the parser will accept /// trailing ellipsis in the function signature and indicate to the builder /// whether the function is variadic. If the builder returns a null type, /// `result` will not contain the `type` attribute. The caller can then add a /// type, report the error or delegate the reporting to the op's verifier. ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, - bool allowVariadic, + bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder); /// Printer implementation for function-like operations. -void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic); +void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, + StringRef typeAttrName); /// Prints the signature of the function-like operation `op`. Assumes `op` has /// is a FunctionOpInterface and has passed verification. @@ -92,8 +94,7 @@ void printFunctionSignature(OpAsmPrinter &p, Operation *op, /// function-like operation internally are not printed. Nothing is printed /// if all attributes are elided. Assumes `op` is a FunctionOpInterface and /// has passed verification. -void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, - unsigned numResults, +void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef elided = {}); } // namespace function_interface_impl diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h index 23fd884..bc2ec47 100644 --- a/mlir/include/mlir/IR/FunctionInterfaces.h +++ b/mlir/include/mlir/IR/FunctionInterfaces.h @@ -22,12 +22,10 @@ #include "llvm/ADT/SmallString.h" namespace mlir { +class FunctionOpInterface; namespace function_interface_impl { -/// Return the name of the attribute used for function types. -inline StringRef getTypeAttrName() { return "function_type"; } - /// Return the name of the attribute used for function argument attributes. inline StringRef getArgDictAttrName() { return "arg_attrs"; } @@ -72,28 +70,29 @@ inline ArrayRef getResultAttrs(Operation *op, unsigned index) { } /// Insert the specified arguments and update the function type attribute. -void insertFunctionArguments(Operation *op, ArrayRef argIndices, - TypeRange argTypes, +void insertFunctionArguments(FunctionOpInterface op, + ArrayRef argIndices, TypeRange argTypes, ArrayRef argAttrs, ArrayRef argLocs, unsigned originalNumArgs, Type newType); /// Insert the specified results and update the function type attribute. -void insertFunctionResults(Operation *op, ArrayRef resultIndices, +void insertFunctionResults(FunctionOpInterface op, + ArrayRef resultIndices, TypeRange resultTypes, ArrayRef resultAttrs, unsigned originalNumResults, Type newType); /// Erase the specified arguments and update the function type attribute. -void eraseFunctionArguments(Operation *op, const BitVector &argIndices, +void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices, Type newType); /// Erase the specified results and update the function type attribute. -void eraseFunctionResults(Operation *op, const BitVector &resultIndices, - Type newType); +void eraseFunctionResults(FunctionOpInterface op, + const BitVector &resultIndices, Type newType); /// Set a FunctionOpInterface operation's type signature. -void setFunctionType(Operation *op, Type newType); +void setFunctionType(FunctionOpInterface op, Type newType); /// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any /// types are inserted, `storage` is used to hold the new type list. The new @@ -207,10 +206,6 @@ Attribute removeResultAttr(ConcreteType op, unsigned index, StringAttr name) { /// method on FunctionOpInterface::Trait. template LogicalResult verifyTrait(ConcreteOp op) { - if (!op.getFunctionTypeAttr()) - return op.emitOpError("requires a type attribute '") - << function_interface_impl::getTypeAttrName() << '\''; - if (failed(op.verifyType())) return failure(); diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td index c56129e..e86057a 100644 --- a/mlir/include/mlir/IR/FunctionInterfaces.td +++ b/mlir/include/mlir/IR/FunctionInterfaces.td @@ -50,6 +50,16 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { }]; let methods = [ InterfaceMethod<[{ + Returns the type of the function. + }], + "::mlir::Type", "getFunctionType">, + InterfaceMethod<[{ + Set the type of the function. This method should perform an unsafe + modification to the function type; it should not update argument or + result attributes. + }], + "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>, + InterfaceMethod<[{ Returns the function argument types based exclusively on the type (to allow for this method may be called on function declarations). @@ -139,7 +149,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { ArrayRef attrs, TypeRange inputTypes) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(function_interface_impl::getTypeAttrName(), + state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); @@ -244,11 +254,6 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { // the derived operation, which should already have these defined // (via ODS). - /// Returns the name of the attribute used for function types. - static StringRef getTypeAttrName() { - return function_interface_impl::getTypeAttrName(); - } - /// Returns the name of the attribute used for function argument attributes. static StringRef getArgDictAttrName() { return function_interface_impl::getArgDictAttrName(); @@ -259,15 +264,6 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { return function_interface_impl::getResultDictAttrName(); } - /// Return the attribute containing the type of this function. - TypeAttr getFunctionTypeAttr() { - return this->getOperation()->template getAttrOfType( - getTypeAttrName()); - } - - /// Return the type of this function. - Type getFunctionType() { return getFunctionTypeAttr().getValue(); } - //===------------------------------------------------------------------===// // Argument and Result Handling //===------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index d0e82de..9f522aaa 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -59,12 +59,11 @@ using namespace mlir; /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. -static void filterFuncAttributes(ArrayRef attrs, - bool filterArgAndResAttrs, +static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs, SmallVectorImpl &result) { - for (const auto &attr : attrs) { + for (const NamedAttribute &attr : func->getAttrs()) { if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == FunctionOpInterface::getTypeAttrName() || + attr.getName() == func.getFunctionTypeAttrName() || attr.getName() == "func.varargs" || (filterArgAndResAttrs && (attr.getName() == FunctionOpInterface::getArgDictAttrName() || @@ -138,8 +137,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, LLVM::LLVMFuncOp newFuncOp) { auto type = funcOp.getFunctionType(); SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, - attributes); + filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes); auto [wrapperFuncType, resultIsNowArg] = typeConverter.convertFunctionTypeCWrapper(type); if (resultIsNowArg) @@ -204,8 +202,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, assert(wrapperType && "unexpected type conversion failure"); SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, - attributes); + filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes); if (resultIsNowArg) prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments()); @@ -304,8 +301,7 @@ protected: // Propagate argument/result attributes to all converted arguments/result // obtained after converting a given original argument/result. SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true, - attributes); + filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/true, attributes); if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { assert(!resAttrDicts.empty() && "expected array to be non-empty"); auto newResAttrDicts = diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 85001d5..48effe2 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -60,7 +60,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, SmallVector attributes; for (const auto &attr : gpuFuncOp->getAttrs()) { if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == FunctionOpInterface::getTypeAttrName() || + attr.getName() == gpuFuncOp.getFunctionTypeAttrName() || attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) continue; attributes.push_back(attr); diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 119b1d3..2a83895 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -226,7 +226,7 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter, rewriter.getFunctionType(signatureConverter.getConvertedTypes(), std::nullopt)); for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.getName() == FunctionOpInterface::getTypeAttrName() || + if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() || namedAttr.getName() == SymbolTable::getSymbolAttrName()) continue; newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index e0772b4..064bf52 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -332,8 +332,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, ArrayRef argAttrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(FunctionOpInterface::getTypeAttrName(), - TypeAttr::get(type)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); @@ -352,11 +351,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, + getFunctionTypeAttrName()); } /// Check that the result type of async.func is not void and must be diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp index 961cf2e..fc9bd11 100644 --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -244,8 +244,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, ArrayRef argAttrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(FunctionOpInterface::getTypeAttrName(), - TypeAttr::get(type)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); @@ -263,11 +262,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, + getFunctionTypeAttrName()); } /// Clone the internal blocks from this function into dest and all attributes diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 7f73a65..80db646 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -859,7 +859,8 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result, ArrayRef attrs) { result.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(type)); result.addAttribute(getNumWorkgroupAttributionsAttrName(), builder.getI64IntegerAttr(workgroupAttributions.size())); result.addAttributes(attrs); @@ -930,7 +931,8 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { for (auto &arg : entryArgs) argTypes.push_back(arg.type); auto type = builder.getFunctionType(argTypes, resultTypes); - result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(type)); function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs, resultAttrs); @@ -992,19 +994,14 @@ void GPUFuncOp::print(OpAsmPrinter &p) { p << ' ' << getKernelKeyword(); function_interface_impl::printFunctionAttributes( - p, *this, type.getNumInputs(), type.getNumResults(), + p, *this, {getNumWorkgroupAttributionsAttrName(), - GPUDialect::getKernelFuncAttrName()}); + GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName()}); p << ' '; p.printRegion(getBody(), /*printEntryBlockArgs=*/false); } LogicalResult GPUFuncOp::verifyType() { - Type type = getFunctionTypeAttr().getValue(); - if (!type.isa()) - return emitOpError("requires '" + getTypeAttrName() + - "' attribute of function type"); - if (isKernel() && getFunctionType().getNumResults() != 0) return emitOpError() << "expected void return type for kernel function"; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index f114acd..6b428a1 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2090,7 +2090,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { function_interface_impl::VariadicFlag(isVariadic)); if (!type) return failure(); - result.addAttribute(FunctionOpInterface::getTypeAttrName(), + result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(type)); if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) @@ -2130,8 +2130,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) { function_interface_impl::printFunctionSignature(p, *this, argTypes, isVarArg(), resTypes); function_interface_impl::printFunctionAttributes( - p, *this, argTypes.size(), resTypes.size(), - {getLinkageAttrName(), getCConvAttrName()}); + p, *this, + {getFunctionTypeAttrName(), getLinkageAttrName(), getCConvAttrName()}); // Print the body if this is not an external function. Region &body = getBody(); diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp index 2f1e4b9..27c6130 100644 --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -152,11 +152,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, + getFunctionTypeAttrName()); } //===----------------------------------------------------------------------===// @@ -313,11 +315,13 @@ ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void SubgraphOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, + getFunctionTypeAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp index e8a61ef..28fc4db 100644 --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -220,11 +220,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, + getFunctionTypeAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 52ad8ad..3ce3913 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2382,7 +2382,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) { for (auto &arg : entryArgs) argTypes.push_back(arg.type); auto fnType = builder.getFunctionType(argTypes, resultTypes); - result.addAttribute(FunctionOpInterface::getTypeAttrName(), + result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(fnType)); // Parse the optional function control keyword. @@ -2417,8 +2417,9 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) { printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl()) << "\""; function_interface_impl::printFunctionAttributes( - printer, *this, fnType.getNumInputs(), fnType.getNumResults(), - {spirv::attributeName()}); + printer, *this, + {spirv::attributeName(), + getFunctionTypeAttrName(), getFunctionControlAttrName()}); // Print the body if this is not an external function. Region &body = this->getBody(); @@ -2430,10 +2431,6 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) { } LogicalResult spirv::FuncOp::verifyType() { - auto type = getFunctionTypeAttr().getValue(); - if (!type.isa()) - return emitOpError("requires '" + getTypeAttrName() + - "' attribute of function type"); if (getFunctionType().getNumResults() > 1) return emitOpError("cannot have more than one result"); return success(); @@ -2473,7 +2470,7 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state, ArrayRef attrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.addAttribute(spirv::attributeName(), builder.getAttr(control)); state.attributes.append(attrs.begin(), attrs.end()); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 2772c01..62e3a3d 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -531,7 +531,7 @@ FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() && + if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && namedAttr.getName() != SymbolTable::getSymbolAttrName()) newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 8c89ec8..30c5f56 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1311,11 +1311,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, + getFunctionTypeAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp index 9481e4a..af692be 100644 --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -163,7 +163,7 @@ void mlir::function_interface_impl::addArgAndResultAttrs( ParseResult mlir::function_interface_impl::parseFunctionOp( OpAsmParser &parser, OperationState &result, bool allowVariadic, - FuncTypeBuilder funcTypeBuilder) { + StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder) { SmallVector entryArgs; SmallVector resultAttrs; SmallVector resultTypes; @@ -197,7 +197,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp( << "failed to construct function type" << (errorMessage.empty() ? "" : ": ") << errorMessage; } - result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(typeAttrName, TypeAttr::get(type)); // If function attributes are present, parse them. NamedAttrList parsedAttributes; @@ -209,7 +209,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp( // dictionary. for (StringRef disallowed : {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), - getTypeAttrName()}) { + typeAttrName.getValue()}) { if (parsedAttributes.get(disallowed)) return parser.emitError(attributeDictLocation, "'") << disallowed @@ -301,12 +301,11 @@ void mlir::function_interface_impl::printFunctionSignature( } void mlir::function_interface_impl::printFunctionAttributes( - OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, - ArrayRef elided) { + OpAsmPrinter &p, Operation *op, ArrayRef elided) { // Print out function attributes, if present. - SmallVector ignoredAttrs = { - ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(), - getArgDictAttrName(), getResultDictAttrName()}; + SmallVector ignoredAttrs = {SymbolTable::getSymbolAttrName(), + getArgDictAttrName(), + getResultDictAttrName()}; ignoredAttrs.append(elided.begin(), elided.end()); p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); @@ -314,7 +313,8 @@ void mlir::function_interface_impl::printFunctionAttributes( void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, - bool isVariadic) { + bool isVariadic, + StringRef typeAttrName) { // Print the operation and the function name. auto funcName = op->getAttrOfType(SymbolTable::getSymbolAttrName()) @@ -329,8 +329,7 @@ void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p, ArrayRef argTypes = op.getArgumentTypes(); ArrayRef resultTypes = op.getResultTypes(); printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); - printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(), - {visibilityAttrName}); + printFunctionAttributes(p, op, {visibilityAttrName, typeAttrName}); // Print the body if this is not an external function. Region &body = op->getRegion(0); if (!body.empty()) { diff --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp index 3331aef..9ba8303 100644 --- a/mlir/lib/IR/FunctionInterfaces.cpp +++ b/mlir/lib/IR/FunctionInterfaces.cpp @@ -112,7 +112,7 @@ void mlir::function_interface_impl::setAllResultAttrDicts( } void mlir::function_interface_impl::insertFunctionArguments( - Operation *op, ArrayRef argIndices, TypeRange argTypes, + FunctionOpInterface op, ArrayRef argIndices, TypeRange argTypes, ArrayRef argAttrs, ArrayRef argLocs, unsigned originalNumArgs, Type newType) { assert(argIndices.size() == argTypes.size()); @@ -152,15 +152,15 @@ void mlir::function_interface_impl::insertFunctionArguments( } // Update the function type and any entry block arguments. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); for (unsigned i = 0, e = argIndices.size(); i < e; ++i) entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]); } void mlir::function_interface_impl::insertFunctionResults( - Operation *op, ArrayRef resultIndices, TypeRange resultTypes, - ArrayRef resultAttrs, unsigned originalNumResults, - Type newType) { + FunctionOpInterface op, ArrayRef resultIndices, + TypeRange resultTypes, ArrayRef resultAttrs, + unsigned originalNumResults, Type newType) { assert(resultIndices.size() == resultTypes.size()); assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty()); if (resultIndices.empty()) @@ -196,11 +196,11 @@ void mlir::function_interface_impl::insertFunctionResults( } // Update the function type. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); } void mlir::function_interface_impl::eraseFunctionArguments( - Operation *op, const BitVector &argIndices, Type newType) { + FunctionOpInterface op, const BitVector &argIndices, Type newType) { // There are 3 things that need to be updated: // - Function type. // - Arg attrs. @@ -218,12 +218,12 @@ void mlir::function_interface_impl::eraseFunctionArguments( } // Update the function type and any entry block arguments. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); entry.eraseArguments(argIndices); } void mlir::function_interface_impl::eraseFunctionResults( - Operation *op, const BitVector &resultIndices, Type newType) { + FunctionOpInterface op, const BitVector &resultIndices, Type newType) { // There are 2 things that need to be updated: // - Function type. // - Result attrs. @@ -239,7 +239,7 @@ void mlir::function_interface_impl::eraseFunctionResults( } // Update the function type. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); } TypeRange mlir::function_interface_impl::insertTypesInto( @@ -276,14 +276,13 @@ TypeRange mlir::function_interface_impl::filterTypesOut( // Function type signature. //===----------------------------------------------------------------------===// -void mlir::function_interface_impl::setFunctionType(Operation *op, +void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op, Type newType) { - FunctionOpInterface funcOp = cast(op); - unsigned oldNumArgs = funcOp.getNumArguments(); - unsigned oldNumResults = funcOp.getNumResults(); - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); - unsigned newNumArgs = funcOp.getNumArguments(); - unsigned newNumResults = funcOp.getNumResults(); + unsigned oldNumArgs = op.getNumArguments(); + unsigned oldNumResults = op.getNumResults(); + op.setFunctionTypeAttr(TypeAttr::get(newType)); + unsigned newNumArgs = op.getNumArguments(); + unsigned newNumResults = op.getNumResults(); // Functor used to update the argument and result attributes of the function. auto updateAttrFn = [&](StringRef attrName, unsigned oldCount, -- 2.7.4