From d9592444cea11fcc9e8debe0f5eff331bdabbdc4 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Fri, 21 Oct 2022 11:08:41 -0700 Subject: [PATCH] [mlir][llvm] Move LLVMFunctionType to a TypeDef Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D136485 --- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 63 ---------------- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td | 60 +++++++++++++++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 1 - mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp | 59 +-------------- mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 105 +++++++++++++++++--------- mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h | 56 -------------- 6 files changed, 132 insertions(+), 212 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index da24d4d..237082b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -73,69 +73,6 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType); #undef DEFINE_TRIVIAL_LLVM_TYPE //===----------------------------------------------------------------------===// -// LLVMFunctionType. -//===----------------------------------------------------------------------===// - -/// LLVM dialect function type. It consists of a single return type (unlike MLIR -/// which can have multiple), a list of parameter types and can optionally be -/// variadic. -class LLVMFunctionType : public Type::TypeBase { -public: - /// Inherit base constructors. - using Base::Base; - using Base::getChecked; - - /// Checks if the given type can be used an argument in a function type. - static bool isValidArgumentType(Type type); - - /// Checks if the given type can be used as a result in a function type. - static bool isValidResultType(Type type); - - /// Returns whether the function is variadic. - bool isVarArg() const; - - /// Gets or creates an instance of LLVM dialect function in the same context - /// as the `result` type. - static LLVMFunctionType get(Type result, ArrayRef arguments, - bool isVarArg = false); - static LLVMFunctionType - getChecked(function_ref emitError, Type result, - ArrayRef arguments, bool isVarArg = false); - - /// Returns a clone of this function type with the given argument - /// and result types. - LLVMFunctionType clone(TypeRange inputs, TypeRange results) const; - - /// Returns the result type of the function. - Type getReturnType() const; - - /// Returns the result type of the function as an ArrayRef, enabling better - /// integration with generic MLIR utilities. - ArrayRef getReturnTypes() const; - - /// Returns the number of arguments to the function. - unsigned getNumParams(); - - /// Returns `i`-th argument of the function. Asserts on out-of-bounds. - Type getParamType(unsigned i); - - /// Returns a list of argument types of the function. - ArrayRef getParams() const; - ArrayRef params() { return getParams(); } - - /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verify(function_ref emitError, - Type result, ArrayRef arguments, bool); - - void walkImmediateSubElements(function_ref walkAttrsFn, - function_ref walkTypesFn) const; - Type replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const; -}; - -//===----------------------------------------------------------------------===// // LLVMPointerType. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td index b347406..6ddef17 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -58,4 +58,64 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [ }]; } +//===----------------------------------------------------------------------===// +// LLVMFunctionType +//===----------------------------------------------------------------------===// + +def LLVMFunctionType : LLVMType<"LLVMFunction", "func", [ + DeclareTypeInterfaceMethods]> { + let summary = "LLVM function type"; + let description = [{ + The `!llvm.func` is a function type. It consists of a single return type + (unlike MLIR which can have multiple), a list of parameter types and can + optionally be variadic. + + Example: + + ```mlir + !llvm.func + ``` + }]; + + let parameters = (ins "Type":$returnType, ArrayRefParameter<"Type">:$params, + "bool":$varArg); + let assemblyFormat = [{ + `<` custom($returnType) ` ` `(` + custom($params, $varArg) `>` + }]; + + let genVerifyDecl = 1; + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Type":$result, "ArrayRef":$arguments, + CArg<"bool", "false">:$isVarArg)> + ]; + + let extraClassDeclaration = [{ + /// Checks if the given type can be used an argument in a function type. + static bool isValidArgumentType(Type type); + + /// Checks if the given type can be used as a result in a function type. + static bool isValidResultType(Type type); + + /// Returns whether the function is variadic. + bool isVarArg() const { return getVarArg(); } + + /// Returns a clone of this function type with the given argument + /// and result types. + LLVMFunctionType clone(TypeRange inputs, TypeRange results) const; + + /// Returns the result type of the function as an ArrayRef, enabling better + /// integration with generic MLIR utilities. + ArrayRef getReturnTypes() const; + + /// Returns the number of arguments to the function. + unsigned getNumParams() const { return getParams().size(); } + + /// Returns `i`-th argument of the function. Asserts on out-of-bounds. + Type getParamType(unsigned i) { return getParams()[i]; } + }]; +} + #endif // LLVMTYPES_TD diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 72f37f9..71ba80d 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2573,7 +2573,6 @@ void LLVMDialect::initialize() { LLVMTokenType, LLVMLabelType, LLVMMetadataType, - LLVMFunctionType, LLVMPointerType, LLVMFixedVectorType, LLVMScalableVectorType, diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp index 9324774..566ef63 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -107,22 +107,6 @@ static void printVectorType(AsmPrinter &printer, TypeTy type) { printer << '>'; } -/// Prints a function type. -static void printFunctionType(AsmPrinter &printer, LLVMFunctionType funcType) { - printer << '<'; - dispatchPrint(printer, funcType.getReturnType()); - printer << " ("; - llvm::interleaveComma( - funcType.getParams(), printer.getStream(), - [&printer](Type subtype) { dispatchPrint(printer, subtype); }); - if (funcType.isVarArg()) { - if (funcType.getNumParams() != 0) - printer << ", "; - printer << "..."; - } - printer << ")>"; -} - /// Prints the given LLVM dialect type recursively. This leverages closedness of /// the LLVM dialect type system to avoid printing the dialect prefix /// repeatedly. For recursive structures, only prints the name of the structure @@ -171,7 +155,7 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) { return printStructType(printer, structType); if (auto funcType = type.dyn_cast()) - return printFunctionType(printer, funcType); + return funcType.print(printer); } //===----------------------------------------------------------------------===// @@ -180,45 +164,6 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) { static ParseResult dispatchParse(AsmParser &parser, Type &type); -/// Parses an LLVM dialect function type. -/// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>` -static LLVMFunctionType parseFunctionType(AsmParser &parser) { - SMLoc loc = parser.getCurrentLocation(); - Type returnType; - if (parser.parseLess() || dispatchParse(parser, returnType) || - parser.parseLParen()) - return LLVMFunctionType(); - - // Function type without arguments. - if (succeeded(parser.parseOptionalRParen())) { - if (succeeded(parser.parseGreater())) - return parser.getChecked(loc, returnType, llvm::None, - /*isVarArg=*/false); - return LLVMFunctionType(); - } - - // Parse arguments. - SmallVector argTypes; - do { - if (succeeded(parser.parseOptionalEllipsis())) { - if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) - return LLVMFunctionType(); - return parser.getChecked(loc, returnType, argTypes, - /*isVarArg=*/true); - } - - Type arg; - if (dispatchParse(parser, arg)) - return LLVMFunctionType(); - argTypes.push_back(arg); - } while (succeeded(parser.parseOptionalComma())); - - if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) - return LLVMFunctionType(); - return parser.getChecked(loc, returnType, argTypes, - /*isVarArg=*/false); -} - /// Parses an LLVM dialect pointer type. /// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>` /// | `ptr` (`<` integer `>`)? @@ -445,7 +390,7 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) { .Case("token", [&] { return LLVMTokenType::get(ctx); }) .Case("label", [&] { return LLVMLabelType::get(ctx); }) .Case("metadata", [&] { return LLVMMetadataType::get(ctx); }) - .Case("func", [&] { return parseFunctionType(parser); }) + .Case("func", [&] { return LLVMFunctionType::parse(parser); }) .Case("ptr", [&] { return parsePointerType(parser); }) .Case("vec", [&] { return parseVectorType(parser); }) .Case("array", [&] { return LLVMArrayType::parse(parser); }) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index 10170b0..f55d2ae 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -28,6 +28,70 @@ using namespace mlir::LLVM; constexpr const static unsigned kBitsInByte = 8; //===----------------------------------------------------------------------===// +// custom +//===----------------------------------------------------------------------===// + +static ParseResult parseFunctionTypes(AsmParser &p, + FailureOr> ¶ms, + FailureOr &isVarArg) { + params.emplace(); + isVarArg = false; + // `(` `)` + if (succeeded(p.parseOptionalRParen())) + return success(); + + // `(` `...` `)` + if (succeeded(p.parseOptionalEllipsis())) { + isVarArg = true; + return p.parseRParen(); + } + + // type (`,` type)* (`,` `...`)? + FailureOr type; + if (parsePrettyLLVMType(p, type)) + return failure(); + params->push_back(*type); + while (succeeded(p.parseOptionalComma())) { + if (succeeded(p.parseOptionalEllipsis())) { + isVarArg = true; + return p.parseRParen(); + } + if (parsePrettyLLVMType(p, type)) + return failure(); + params->push_back(*type); + } + return p.parseRParen(); +} + +static void printFunctionTypes(AsmPrinter &p, ArrayRef params, + bool isVarArg) { + llvm::interleaveComma(params, p, + [&](Type type) { printPrettyLLVMType(p, type); }); + if (isVarArg) { + if (!params.empty()) + p << ", "; + p << "..."; + } + p << ')'; +} + +//===----------------------------------------------------------------------===// +// ODS-Generated Definitions +//===----------------------------------------------------------------------===// + +/// These are unused for now. +/// TODO: Move over to these once more types have been migrated to TypeDef. +LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); +LLVM_ATTRIBUTE_UNUSED static LogicalResult +generatedTypePrinter(Type def, AsmPrinter &printer); + +#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc" + +//===----------------------------------------------------------------------===// // LLVMArrayType //===----------------------------------------------------------------------===// @@ -130,25 +194,8 @@ LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs, return get(results[0], llvm::to_vector(inputs), isVarArg()); } -Type LLVMFunctionType::getReturnType() const { - return getImpl()->getReturnType(); -} ArrayRef LLVMFunctionType::getReturnTypes() const { - return getImpl()->getReturnType(); -} - -unsigned LLVMFunctionType::getNumParams() { - return getImpl()->getArgumentTypes().size(); -} - -Type LLVMFunctionType::getParamType(unsigned i) { - return getImpl()->getArgumentTypes()[i]; -} - -bool LLVMFunctionType::isVarArg() const { return getImpl()->isVariadic(); } - -ArrayRef LLVMFunctionType::getParams() const { - return getImpl()->getArgumentTypes(); + return static_cast(getImpl())->returnType; } LogicalResult @@ -164,10 +211,14 @@ LLVMFunctionType::verify(function_ref emitError, return success(); } +//===----------------------------------------------------------------------===// +// SubElementTypeInterface + void LLVMFunctionType::walkImmediateSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) const { - for (Type type : llvm::concat(getReturnTypes(), getParams())) + walkTypesFn(getReturnType()); + for (Type type : getParams()) walkTypesFn(type); } @@ -1006,22 +1057,6 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) { } //===----------------------------------------------------------------------===// -// ODS-Generated Definitions -//===----------------------------------------------------------------------===// - -/// These are unused for now. -/// TODO: Move over to these once more types have been migrated to TypeDef. -LLVM_ATTRIBUTE_UNUSED static OptionalParseResult -generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); -LLVM_ATTRIBUTE_UNUSED static LogicalResult -generatedTypePrinter(Type def, AsmPrinter &printer); - -#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc" - -#define GET_TYPEDEF_CLASSES -#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc" - -//===----------------------------------------------------------------------===// // LLVMDialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h index 7ebfe8e..d13452f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h +++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h @@ -322,62 +322,6 @@ private: }; //===----------------------------------------------------------------------===// -// LLVMFunctionTypeStorage. -//===----------------------------------------------------------------------===// - -/// Type storage for LLVM dialect function types. These are uniqued using the -/// list of types they contain and the vararg bit. -struct LLVMFunctionTypeStorage : public TypeStorage { - using KeyTy = std::tuple, bool>; - - /// Construct a storage from the given components. The list is expected to be - /// allocated in the context. - LLVMFunctionTypeStorage(Type result, ArrayRef arguments, bool variadic) - : resultType(result), isVariadicFlag(variadic), - numArguments(arguments.size()), argumentTypes(arguments.data()) {} - - /// Hook into the type uniquing infrastructure. - static LLVMFunctionTypeStorage *construct(TypeStorageAllocator &allocator, - const KeyTy &key) { - return new (allocator.allocate()) - LLVMFunctionTypeStorage(std::get<0>(key), - allocator.copyInto(std::get<1>(key)), - std::get<2>(key)); - } - - static unsigned hashKey(const KeyTy &key) { - // LLVM doesn't like hashing bools in tuples. - return llvm::hash_combine(std::get<0>(key), std::get<1>(key), - static_cast(std::get<2>(key))); - } - - bool operator==(const KeyTy &key) const { - return std::make_tuple(getReturnType(), getArgumentTypes(), isVariadic()) == - key; - } - - /// Returns the list of function argument types. - ArrayRef getArgumentTypes() const { - return ArrayRef(argumentTypes, numArguments); - } - - /// Checks whether the function type is variadic. - bool isVariadic() const { return isVariadicFlag; } - - /// Returns the function result type. - const Type &getReturnType() const { return resultType; } - -private: - /// The result type of the function. - Type resultType; - /// Flag indicating if the function is variadic. - bool isVariadicFlag; - /// The argument types of the function. - unsigned numArguments; - const Type *argumentTypes; -}; - -//===----------------------------------------------------------------------===// // LLVMPointerTypeStorage. //===----------------------------------------------------------------------===// -- 2.7.4