Revert "[mlir] FunctionOpInterface: make get/setFunctionType interface methods"
authorDavid Spickett <david.spickett@linaro.org>
Fri, 9 Dec 2022 15:21:28 +0000 (15:21 +0000)
committerDavid Spickett <david.spickett@linaro.org>
Fri, 9 Dec 2022 15:36:48 +0000 (15:36 +0000)
and "[mlir] Fix examples build"

This reverts commit fbc253fe81da4e1d6bfa2519e01e03f21d8c40a8 and
96cf183bccd7d1c3083f169a89a6af1f263b3aae.

Which I missed in the first revert in f3379feabe38fd3711b13ffcf6de4aab03b7ccdc.

23 files changed:
mlir/examples/toy/Ch2/mlir/Dialect.cpp
mlir/examples/toy/Ch3/mlir/Dialect.cpp
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch5/mlir/Dialect.cpp
mlir/examples/toy/Ch6/mlir/Dialect.cpp
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/include/mlir/IR/FunctionImplementation.h
mlir/include/mlir/IR/FunctionInterfaces.h
mlir/include/mlir/IR/FunctionInterfaces.td
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Func/IR/FuncOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/IR/FunctionImplementation.cpp
mlir/lib/IR/FunctionInterfaces.cpp

index 8c36fafc2f001142ec58b569cd8674dd98943175..dbc1efb3d06beb0e3bcb3d74482fd2a6c67eb6bd 100644 (file)
@@ -211,9 +211,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType,
-      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {
index 6bf140487420fffbe33204456bdf1ddeb47b66c5..50e2dfc7f4a3e82b5f2af592cf083f875e53dc8c 100644 (file)
@@ -198,9 +198,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType,
-      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {
index 8343a1cb5fbc328891fc9c0a556432ea41e7073b..0a6195b12d5d48cdb9fca1ca6c1e9dc2d2f671e3 100644 (file)
@@ -287,9 +287,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType,
-      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {
index dde12f51c351e515211aada324e08533cf80b96a..f236a1ffe0e5afd142aa534c5a0a6a954f7c3692 100644 (file)
@@ -287,9 +287,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType,
-      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {
index dde12f51c351e515211aada324e08533cf80b96a..f236a1ffe0e5afd142aa534c5a0a6a954f7c3692 100644 (file)
@@ -287,9 +287,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType,
-      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {
index 3413e57d37b30abd679e1785d5173cb72c3fd569..cc66a5d44b5f4c4271eddc4916f22a8ae7f9e48f 100644 (file)
@@ -314,9 +314,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType,
-      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {
index f4c0cc03050fea637942df6f7132762b77e62dcf..5265f781d1a7707cf31f7e537d28bebec09bb1dc 100644 (file)
@@ -69,19 +69,17 @@ Type getFunctionType(Builder &builder, ArrayRef<OpAsmParser::Argument> argAttrs,
 
 /// Parser implementation for function-like operations.  Uses
 /// `funcTypeBuilder` to construct the custom function type given lists of
-/// input and output types. The parser sets the `typeAttrName` attribute to the
-/// resulting function type. If `allowVariadic` is set, the parser will accept
+/// input and output types.  If `allowVariadic` is set, the parser will accept
 /// trailing ellipsis in the function signature and indicate to the builder
 /// whether the function is variadic.  If the builder returns a null type,
 /// `result` will not contain the `type` attribute.  The caller can then add a
 /// type, report the error or delegate the reporting to the op's verifier.
 ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
-                            bool allowVariadic, StringAttr typeAttrName,
+                            bool allowVariadic,
                             FuncTypeBuilder funcTypeBuilder);
 
 /// Printer implementation for function-like operations.
-void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
-                     StringRef typeAttrName);
+void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic);
 
 /// Prints the signature of the function-like operation `op`. Assumes `op` has
 /// is a FunctionOpInterface and has passed verification.
@@ -94,7 +92,8 @@ void printFunctionSignature(OpAsmPrinter &p, Operation *op,
 /// function-like operation internally are not printed. Nothing is printed
 /// if all attributes are elided. Assumes `op` is a FunctionOpInterface and
 /// has passed verification.
-void printFunctionAttributes(OpAsmPrinter &p, Operation *op,
+void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs,
+                             unsigned numResults,
                              ArrayRef<StringRef> elided = {});
 
 } // namespace function_interface_impl
index bc2ec4751c5824ff9324f9f39366c11925d965d1..23fd884d97f142c71a688897406802aa2fe478af 100644 (file)
 #include "llvm/ADT/SmallString.h"
 
 namespace mlir {
-class FunctionOpInterface;
 
 namespace function_interface_impl {
 
+/// Return the name of the attribute used for function types.
+inline StringRef getTypeAttrName() { return "function_type"; }
+
 /// Return the name of the attribute used for function argument attributes.
 inline StringRef getArgDictAttrName() { return "arg_attrs"; }
 
@@ -70,29 +72,28 @@ inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
 }
 
 /// Insert the specified arguments and update the function type attribute.
-void insertFunctionArguments(FunctionOpInterface op,
-                             ArrayRef<unsigned> argIndices, TypeRange argTypes,
+void insertFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
+                             TypeRange argTypes,
                              ArrayRef<DictionaryAttr> argAttrs,
                              ArrayRef<Location> argLocs,
                              unsigned originalNumArgs, Type newType);
 
 /// Insert the specified results and update the function type attribute.
-void insertFunctionResults(FunctionOpInterface op,
-                           ArrayRef<unsigned> resultIndices,
+void insertFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
                            TypeRange resultTypes,
                            ArrayRef<DictionaryAttr> resultAttrs,
                            unsigned originalNumResults, Type newType);
 
 /// Erase the specified arguments and update the function type attribute.
-void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices,
+void eraseFunctionArguments(Operation *op, const BitVector &argIndices,
                             Type newType);
 
 /// Erase the specified results and update the function type attribute.
-void eraseFunctionResults(FunctionOpInterface op,
-                          const BitVector &resultIndices, Type newType);
+void eraseFunctionResults(Operation *op, const BitVector &resultIndices,
+                          Type newType);
 
 /// Set a FunctionOpInterface operation's type signature.
-void setFunctionType(FunctionOpInterface op, Type newType);
+void setFunctionType(Operation *op, Type newType);
 
 /// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any
 /// types are inserted, `storage` is used to hold the new type list. The new
@@ -206,6 +207,10 @@ Attribute removeResultAttr(ConcreteType op, unsigned index, StringAttr name) {
 /// method on FunctionOpInterface::Trait.
 template <typename ConcreteOp>
 LogicalResult verifyTrait(ConcreteOp op) {
+  if (!op.getFunctionTypeAttr())
+    return op.emitOpError("requires a type attribute '")
+           << function_interface_impl::getTypeAttrName() << '\'';
+
   if (failed(op.verifyType()))
     return failure();
 
index e86057aa7ec2f3e2f11477ecf6b5b31b1fb8bbeb..c56129ea895d962886eab28d35d2eef41f43d8e6 100644 (file)
@@ -49,16 +49,6 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
           for each of the function results.
   }];
   let methods = [
-    InterfaceMethod<[{
-      Returns the type of the function.
-    }],
-    "::mlir::Type", "getFunctionType">,
-    InterfaceMethod<[{
-      Set the type of the function. This method should perform an unsafe
-      modification to the function type; it should not update argument or
-      result attributes.
-    }],
-    "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>,
     InterfaceMethod<[{
       Returns the function argument types based exclusively on
       the type (to allow for this method may be called on function
@@ -149,7 +139,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
         ArrayRef<NamedAttribute> attrs, TypeRange inputTypes) {
       state.addAttribute(SymbolTable::getSymbolAttrName(),
                         builder.getStringAttr(name));
-      state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name),
+      state.addAttribute(function_interface_impl::getTypeAttrName(),
                         TypeAttr::get(type));
       state.attributes.append(attrs.begin(), attrs.end());
 
@@ -254,6 +244,11 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
     // the derived operation, which should already have these defined
     // (via ODS).
 
+    /// Returns the name of the attribute used for function types.
+    static StringRef getTypeAttrName() {
+      return function_interface_impl::getTypeAttrName();
+    }
+
     /// Returns the name of the attribute used for function argument attributes.
     static StringRef getArgDictAttrName() {
       return function_interface_impl::getArgDictAttrName();
@@ -264,6 +259,15 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
       return function_interface_impl::getResultDictAttrName();
     }
 
+    /// Return the attribute containing the type of this function.
+    TypeAttr getFunctionTypeAttr() {
+      return this->getOperation()->template getAttrOfType<TypeAttr>(
+          getTypeAttrName());
+    }
+
+    /// Return the type of this function.
+    Type getFunctionType() { return getFunctionTypeAttr().getValue(); }
+
     //===------------------------------------------------------------------===//
     // Argument and Result Handling
     //===------------------------------------------------------------------===//
index 9f522aaa49f920a99cfdd42e59a0e97fa73b71be..d0e82de839c0c819ffa1eb2901acfbaae345ff73 100644 (file)
@@ -59,11 +59,12 @@ using namespace mlir;
 /// Only retain those attributes that are not constructed by
 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
 /// attributes.
-static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs,
+static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
+                                 bool filterArgAndResAttrs,
                                  SmallVectorImpl<NamedAttribute> &result) {
-  for (const NamedAttribute &attr : func->getAttrs()) {
+  for (const auto &attr : attrs) {
     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
-        attr.getName() == func.getFunctionTypeAttrName() ||
+        attr.getName() == FunctionOpInterface::getTypeAttrName() ||
         attr.getName() == "func.varargs" ||
         (filterArgAndResAttrs &&
          (attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
@@ -137,7 +138,8 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
                                    LLVM::LLVMFuncOp newFuncOp) {
   auto type = funcOp.getFunctionType();
   SmallVector<NamedAttribute, 4> attributes;
-  filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
+  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
+                       attributes);
   auto [wrapperFuncType, resultIsNowArg] =
       typeConverter.convertFunctionTypeCWrapper(type);
   if (resultIsNowArg)
@@ -202,7 +204,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   assert(wrapperType && "unexpected type conversion failure");
 
   SmallVector<NamedAttribute, 4> attributes;
-  filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
+  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
+                       attributes);
 
   if (resultIsNowArg)
     prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
@@ -301,7 +304,8 @@ protected:
     // Propagate argument/result attributes to all converted arguments/result
     // obtained after converting a given original argument/result.
     SmallVector<NamedAttribute, 4> attributes;
-    filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/true, attributes);
+    filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
+                         attributes);
     if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
       assert(!resAttrDicts.empty() && "expected array to be non-empty");
       auto newResAttrDicts =
index 48effe24f674edaba6d0cce073db10483736a0c1..85001d54d093d6937969c5bc0ccd0766c1b72986 100644 (file)
@@ -60,7 +60,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   SmallVector<NamedAttribute, 4> attributes;
   for (const auto &attr : gpuFuncOp->getAttrs()) {
     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
-        attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
+        attr.getName() == FunctionOpInterface::getTypeAttrName() ||
         attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
       continue;
     attributes.push_back(attr);
index 2a8389598f36a9e0fb73590807b053821553c608..119b1d3dea91e8921a917a25637344f13b9868e0 100644 (file)
@@ -226,7 +226,7 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
                                std::nullopt));
   for (const auto &namedAttr : funcOp->getAttrs()) {
-    if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
+    if (namedAttr.getName() == FunctionOpInterface::getTypeAttrName() ||
         namedAttr.getName() == SymbolTable::getSymbolAttrName())
       continue;
     newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
index 064bf525db238e0505447f8b6979f86c6d773097..e0772b4dd90bc7ad7e6f47193685ef5f00557996 100644 (file)
@@ -332,7 +332,8 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
                    ArrayRef<DictionaryAttr> argAttrs) {
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
-  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+  state.addAttribute(FunctionOpInterface::getTypeAttrName(),
+                     TypeAttr::get(type));
 
   state.attributes.append(attrs.begin(), attrs.end());
   state.addRegion();
@@ -351,13 +352,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
 }
 
 /// Check that the result type of async.func is not void and must be
index fc9bd115e2223172b8ef0de699de4beb1c00be7f..961cf2eb36e35d04954193c8eac0278974a25ec7 100644 (file)
@@ -244,7 +244,8 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
                    ArrayRef<DictionaryAttr> argAttrs) {
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
-  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+  state.addAttribute(FunctionOpInterface::getTypeAttrName(),
+                     TypeAttr::get(type));
   state.attributes.append(attrs.begin(), attrs.end());
   state.addRegion();
 
@@ -262,13 +263,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
 }
 
 /// Clone the internal blocks from this function into dest and all attributes
index 80db6461ecc55fce3ef80258373f276efc045f9b..7f73a651d0e9bf6835ede29c48de87eab5ac5e32 100644 (file)
@@ -859,8 +859,7 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
                       ArrayRef<NamedAttribute> attrs) {
   result.addAttribute(SymbolTable::getSymbolAttrName(),
                       builder.getStringAttr(name));
-  result.addAttribute(getFunctionTypeAttrName(result.name),
-                      TypeAttr::get(type));
+  result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
   result.addAttribute(getNumWorkgroupAttributionsAttrName(),
                       builder.getI64IntegerAttr(workgroupAttributions.size()));
   result.addAttributes(attrs);
@@ -931,8 +930,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   for (auto &arg : entryArgs)
     argTypes.push_back(arg.type);
   auto type = builder.getFunctionType(argTypes, resultTypes);
-  result.addAttribute(getFunctionTypeAttrName(result.name),
-                      TypeAttr::get(type));
+  result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type));
 
   function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
                                                 resultAttrs);
@@ -994,14 +992,19 @@ void GPUFuncOp::print(OpAsmPrinter &p) {
     p << ' ' << getKernelKeyword();
 
   function_interface_impl::printFunctionAttributes(
-      p, *this,
+      p, *this, type.getNumInputs(), type.getNumResults(),
       {getNumWorkgroupAttributionsAttrName(),
-       GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName()});
+       GPUDialect::getKernelFuncAttrName()});
   p << ' ';
   p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
 }
 
 LogicalResult GPUFuncOp::verifyType() {
+  Type type = getFunctionTypeAttr().getValue();
+  if (!type.isa<FunctionType>())
+    return emitOpError("requires '" + getTypeAttrName() +
+                       "' attribute of function type");
+
   if (isKernel() && getFunctionType().getNumResults() != 0)
     return emitOpError() << "expected void return type for kernel function";
 
index 028f25ec3c1d2d1c68d6ad209e3b5846ee04fcbd..1087bcf5d32f479f162390e0229e67d0dd57a7f9 100644 (file)
@@ -2090,7 +2090,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
                             function_interface_impl::VariadicFlag(isVariadic));
   if (!type)
     return failure();
-  result.addAttribute(getFunctionTypeAttrName(result.name),
+  result.addAttribute(FunctionOpInterface::getTypeAttrName(),
                       TypeAttr::get(type));
 
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
@@ -2130,8 +2130,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
   function_interface_impl::printFunctionSignature(p, *this, argTypes,
                                                   isVarArg(), resTypes);
   function_interface_impl::printFunctionAttributes(
-      p, *this,
-      {getFunctionTypeAttrName(), getLinkageAttrName(), getCConvAttrName()});
+      p, *this, argTypes.size(), resTypes.size(),
+      {getLinkageAttrName(), getCConvAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = getBody();
index 27c613088df56fab5f2a232b0784c19815e666eb..2f1e4b93a6ac30d5952d64cafa1ed805822cf096 100644 (file)
@@ -152,13 +152,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
 }
 
 //===----------------------------------------------------------------------===//
@@ -315,13 +313,11 @@ ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void SubgraphOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
 }
 
 //===----------------------------------------------------------------------===//
index 28fc4dbf97aea62785af098ec3e405c757fb95ba..e8a61ef4c6a4d94793cba7ecb3ae24e2000526c2 100644 (file)
@@ -220,13 +220,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
 }
 
 //===----------------------------------------------------------------------===//
index 3ce3913f2814be8a8b6a508b40b5c479ad0953e3..52ad8ad5fe7c7b8067eef9b543d394e32ab7bcb2 100644 (file)
@@ -2382,7 +2382,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
   for (auto &arg : entryArgs)
     argTypes.push_back(arg.type);
   auto fnType = builder.getFunctionType(argTypes, resultTypes);
-  result.addAttribute(getFunctionTypeAttrName(result.name),
+  result.addAttribute(FunctionOpInterface::getTypeAttrName(),
                       TypeAttr::get(fnType));
 
   // Parse the optional function control keyword.
@@ -2417,9 +2417,8 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
   printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
           << "\"";
   function_interface_impl::printFunctionAttributes(
-      printer, *this,
-      {spirv::attributeName<spirv::FunctionControl>(),
-       getFunctionTypeAttrName(), getFunctionControlAttrName()});
+      printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
+      {spirv::attributeName<spirv::FunctionControl>()});
 
   // Print the body if this is not an external function.
   Region &body = this->getBody();
@@ -2431,6 +2430,10 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
 }
 
 LogicalResult spirv::FuncOp::verifyType() {
+  auto type = getFunctionTypeAttr().getValue();
+  if (!type.isa<FunctionType>())
+    return emitOpError("requires '" + getTypeAttrName() +
+                       "' attribute of function type");
   if (getFunctionType().getNumResults() > 1)
     return emitOpError("cannot have more than one result");
   return success();
@@ -2470,7 +2473,7 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
                           ArrayRef<NamedAttribute> attrs) {
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
-  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+  state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
   state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
                      builder.getAttr<spirv::FunctionControlAttr>(control));
   state.attributes.append(attrs.begin(), attrs.end());
index 62e3a3de0f11b03cecaeada21aa419c818629cb4..2772c0150cda7079669ad33b188e68753fb729c1 100644 (file)
@@ -531,7 +531,7 @@ FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
 
   // Copy over all attributes other than the function name and type.
   for (const auto &namedAttr : funcOp->getAttrs()) {
-    if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
+    if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() &&
         namedAttr.getName() != SymbolTable::getSymbolAttrName())
       newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
   }
index 30c5f56f13886fbb969df09693df3d3869dbbd30..8c89ec8bba6cfd5167a244681602dd40a392c21d 100644 (file)
@@ -1311,13 +1311,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
 }
 
 //===----------------------------------------------------------------------===//
index af692befb0fde65415dfac89dc2e0b4d14c3b511..9481e4ae8175b0468e4c0245ebfcce9d0d3b4970 100644 (file)
@@ -163,7 +163,7 @@ void mlir::function_interface_impl::addArgAndResultAttrs(
 
 ParseResult mlir::function_interface_impl::parseFunctionOp(
     OpAsmParser &parser, OperationState &result, bool allowVariadic,
-    StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder) {
+    FuncTypeBuilder funcTypeBuilder) {
   SmallVector<OpAsmParser::Argument> entryArgs;
   SmallVector<DictionaryAttr> resultAttrs;
   SmallVector<Type> resultTypes;
@@ -197,7 +197,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
            << "failed to construct function type"
            << (errorMessage.empty() ? "" : ": ") << errorMessage;
   }
-  result.addAttribute(typeAttrName, TypeAttr::get(type));
+  result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
 
   // If function attributes are present, parse them.
   NamedAttrList parsedAttributes;
@@ -209,7 +209,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
   // dictionary.
   for (StringRef disallowed :
        {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
-        typeAttrName.getValue()}) {
+        getTypeAttrName()}) {
     if (parsedAttributes.get(disallowed))
       return parser.emitError(attributeDictLocation, "'")
              << disallowed
@@ -301,11 +301,12 @@ void mlir::function_interface_impl::printFunctionSignature(
 }
 
 void mlir::function_interface_impl::printFunctionAttributes(
-    OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
+    OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults,
+    ArrayRef<StringRef> elided) {
   // Print out function attributes, if present.
-  SmallVector<StringRef, 2> ignoredAttrs = {SymbolTable::getSymbolAttrName(),
-                                            getArgDictAttrName(),
-                                            getResultDictAttrName()};
+  SmallVector<StringRef, 2> ignoredAttrs = {
+      ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(),
+      getArgDictAttrName(), getResultDictAttrName()};
   ignoredAttrs.append(elided.begin(), elided.end());
 
   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
@@ -313,8 +314,7 @@ void mlir::function_interface_impl::printFunctionAttributes(
 
 void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
                                                     FunctionOpInterface op,
-                                                    bool isVariadic,
-                                                    StringRef typeAttrName) {
+                                                    bool isVariadic) {
   // Print the operation and the function name.
   auto funcName =
       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
@@ -329,7 +329,8 @@ void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
   ArrayRef<Type> argTypes = op.getArgumentTypes();
   ArrayRef<Type> resultTypes = op.getResultTypes();
   printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
-  printFunctionAttributes(p, op, {visibilityAttrName, typeAttrName});
+  printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(),
+                          {visibilityAttrName});
   // Print the body if this is not an external function.
   Region &body = op->getRegion(0);
   if (!body.empty()) {
index 9ba830366056c9c0533d46a9bd014db625dcab7f..3331aefc76d13be406b83cf2b396c10bba271472 100644 (file)
@@ -112,7 +112,7 @@ void mlir::function_interface_impl::setAllResultAttrDicts(
 }
 
 void mlir::function_interface_impl::insertFunctionArguments(
-    FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
+    Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
     ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
     unsigned originalNumArgs, Type newType) {
   assert(argIndices.size() == argTypes.size());
@@ -152,15 +152,15 @@ void mlir::function_interface_impl::insertFunctionArguments(
   }
 
   // Update the function type and any entry block arguments.
-  op.setFunctionTypeAttr(TypeAttr::get(newType));
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
   for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
     entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
 }
 
 void mlir::function_interface_impl::insertFunctionResults(
-    FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
-    TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
-    unsigned originalNumResults, Type newType) {
+    Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
+    ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
+    Type newType) {
   assert(resultIndices.size() == resultTypes.size());
   assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
   if (resultIndices.empty())
@@ -196,11 +196,11 @@ void mlir::function_interface_impl::insertFunctionResults(
   }
 
   // Update the function type.
-  op.setFunctionTypeAttr(TypeAttr::get(newType));
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
 }
 
 void mlir::function_interface_impl::eraseFunctionArguments(
-    FunctionOpInterface op, const BitVector &argIndices, Type newType) {
+    Operation *op, const BitVector &argIndices, Type newType) {
   // There are 3 things that need to be updated:
   // - Function type.
   // - Arg attrs.
@@ -218,12 +218,12 @@ void mlir::function_interface_impl::eraseFunctionArguments(
   }
 
   // Update the function type and any entry block arguments.
-  op.setFunctionTypeAttr(TypeAttr::get(newType));
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
   entry.eraseArguments(argIndices);
 }
 
 void mlir::function_interface_impl::eraseFunctionResults(
-    FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
+    Operation *op, const BitVector &resultIndices, Type newType) {
   // There are 2 things that need to be updated:
   // - Function type.
   // - Result attrs.
@@ -239,7 +239,7 @@ void mlir::function_interface_impl::eraseFunctionResults(
   }
 
   // Update the function type.
-  op.setFunctionTypeAttr(TypeAttr::get(newType));
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
 }
 
 TypeRange mlir::function_interface_impl::insertTypesInto(
@@ -276,13 +276,14 @@ TypeRange mlir::function_interface_impl::filterTypesOut(
 // Function type signature.
 //===----------------------------------------------------------------------===//
 
-void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op,
+void mlir::function_interface_impl::setFunctionType(Operation *op,
                                                     Type newType) {
-  unsigned oldNumArgs = op.getNumArguments();
-  unsigned oldNumResults = op.getNumResults();
-  op.setFunctionTypeAttr(TypeAttr::get(newType));
-  unsigned newNumArgs = op.getNumArguments();
-  unsigned newNumResults = op.getNumResults();
+  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
+  unsigned oldNumArgs = funcOp.getNumArguments();
+  unsigned oldNumResults = funcOp.getNumResults();
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+  unsigned newNumArgs = funcOp.getNumArguments();
+  unsigned newNumResults = funcOp.getNumResults();
 
   // Functor used to update the argument and result attributes of the function.
   auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,