From dbb9608de8f9da217e65ea960e8b7e909ef4c673 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 9 May 2019 05:40:54 -0700 Subject: [PATCH] Conversion to LLVM Dialect: integrate TypeConverter into LLVMLowering Historically, the conversion from standard and built-in types to the LLVM IR dialect types was performed by a dedicated class, TypeConverter. This class served to contain references to the LLVM IR dialect and to the LLVM IR Module to allow querying the data layout. Recently, the LLVMLowering class was introduced to make the conversion to the LLVM IR dialect extensible to other source dialects. This class also includes the references to the LLVM IR dialect and module. TypeConverter was extended with basic support for dialect-specific type conversion through callbacks. This is not sufficient in cases where dialect-specific types appear inside other types, such as function or container types. Integrate TypeConverter into LLVMLowering. Whenever a subtype needs to be converted during standard type conversion (e.g. an argument or a result of a FunctionType), the conversion will call to the virtual function `LLVMLowering::convertType`, which can be extended to support dialect-specific types. Provide a new LLVMOpConversion class that serves as a base class for all conversions to the LLVM IR dialect and gives them access to LLVMLowering for the purpose of type conversion. Update Linalg to LLVM IR lowering to use this class. -- PiperOrigin-RevId: 247407314 --- mlir/include/mlir/LLVMIR/LLVMLowering.h | 72 ++++- mlir/include/mlir/LLVMIR/Transforms.h | 4 - .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 332 +++++++-------------- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 92 +++--- 4 files changed, 214 insertions(+), 286 deletions(-) diff --git a/mlir/include/mlir/LLVMIR/LLVMLowering.h b/mlir/include/mlir/LLVMIR/LLVMLowering.h index 590d973..02bc816 100644 --- a/mlir/include/mlir/LLVMIR/LLVMLowering.h +++ b/mlir/include/mlir/LLVMIR/LLVMLowering.h @@ -27,7 +27,9 @@ #include "mlir/Transforms/DialectConversion.h" namespace llvm { +class IntegerType; class Module; +class Type; } namespace mlir { @@ -38,6 +40,17 @@ class LLVMDialect; /// Conversion from the Standard dialect to the LLVM IR dialect. Provides hooks /// for derived classes to extend the conversion. class LLVMLowering : public DialectConversion { +public: + /// Convert types to LLVM IR. This calls `convertAdditionalType` to convert + /// non-standard or non-builtin types. + Type convertType(Type t) override final; + + /// Convert a non-empty list of types to be returned from a function into a + /// supported LLVM IR type. In particular, if more than one values is + /// returned, create an LLVM IR structure type with elements that correspond + /// to each of the MLIR types converted with `convertType`. + Type packFunctionResults(ArrayRef types); + protected: /// Create a set of converters that live in the pass object by passing them a /// reference to the LLVM IR dialect. Store the module associated with the @@ -52,9 +65,6 @@ protected: return {}; }; - /// Convert standard and builtin types to LLVM IR. - Type convertType(Type t) override final; - /// Derived classes can override this function to convert custom types. It /// will be called by convertType if the default conversion from standard and /// builtin types fails. Derived classes can thus call convertType whenever @@ -73,6 +83,62 @@ protected: /// LLVM IR module used to parse/create types. llvm::Module *module; LLVM::LLVMDialect *llvmDialect; + +private: + Type convertStandardType(Type type); + + // Convert a function type. The arguments and results are converted one by + // one. Additionally, if the function returns more than one value, pack the + // results into an LLVM IR structure type so that the converted function type + // returns at most one result. + Type convertFunctionType(FunctionType type); + + // Convert the index type. Uses llvmModule data layout to create an integer + // of the pointer bitwidth. + Type convertIndexType(IndexType type); + + // Convert an integer type `i*` to `!llvm<"i*">`. + Type convertIntegerType(IntegerType type); + + // Convert a floating point type: `f16` to `!llvm.half`, `f32` to + // `!llvm.float` and `f64` to `!llvm.double`. `bf16` is not supported + // by LLVM. + Type convertFloatType(FloatType type); + + // Convert a memref type into an LLVM type that captures the relevant data. + // For statically-shaped memrefs, the resulting type is a pointer to the + // (converted) memref element type. For dynamically-shaped memrefs, the + // resulting type is an LLVM structure type that contains: + // 1. a pointer to the (converted) memref element type + // 2. as many index types as memref has dynamic dimensions. + Type convertMemRefType(MemRefType type); + + // Convert a 1D vector type into an LLVM vector type. + Type convertVectorType(VectorType type); + + // Get the LLVM representation of the index type based on the bitwidth of the + // pointer as defined by the data layout of the module. + llvm::IntegerType *getIndexType(); + + // Wrap the given LLVM IR type into an LLVM IR dialect type. + Type wrap(llvm::Type *llvmType); + + // Extract an LLVM IR type from the LLVM IR dialect type. + llvm::Type *unwrap(Type type); +}; + +/// Base class for operation conversions targeting the LLVM IR dialect. Provides +/// conversion patterns with an access to the containing LLVMLowering for the +/// purpose of type conversions. +class LLVMOpLowering : public DialectOpConversion { +public: + LLVMOpLowering(StringRef rootOpName, MLIRContext *context, + LLVMLowering &lowering); + +protected: + // Back-reference to the lowering class, used to call type and function + // conversions accounting for potential extensions. + LLVMLowering &lowering; }; } // namespace mlir diff --git a/mlir/include/mlir/LLVMIR/Transforms.h b/mlir/include/mlir/LLVMIR/Transforms.h index b021981..95244b8 100644 --- a/mlir/include/mlir/LLVMIR/Transforms.h +++ b/mlir/include/mlir/LLVMIR/Transforms.h @@ -44,10 +44,6 @@ namespace LLVM { /// another block as a successor more than once with different values, insert /// a new dummy block for LLVM PHI nodes to tell the sources apart. void ensureDistinctSuccessors(Module *m); - -/// Converts a type in either MLIR standard or builtin type into LLVMIR dialect -/// type. -Type convertToLLVMDialectType(Type t, llvm::Module &llvmModule); } // namespace LLVM } // namespace mlir diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 5f57ebc..893a063 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -40,173 +40,62 @@ using namespace mlir; -namespace { -// Type converter for the LLVM IR dialect. Converts MLIR standard and builtin -// types into equivalent LLVM IR dialect types. -class TypeConverter { -public: - // Convert one type `t ` and register it in the `llvmModule`. The latter may - // be used to extract information specific to the data layout. - // Dispatches to the private functions below based on the actual type. - static Type convert(Type t, llvm::Module &llvmModule); - - // Convert the element type of the memref `t` to to an LLVM type, get a - // pointer LLVM type pointing to the converted `t`, wrap it into the MLIR LLVM - // dialect type and return. - static Type getMemRefElementPtrType(MemRefType t, llvm::Module &llvmModule); - - // Convert a non-empty list of types to an LLVM IR dialect type wrapping an - // LLVM IR structure type, elements of which are formed by converting - // individual types in the given list. Register the type in the `llvmModule`. - // The module may be also used to query the data layout. - static Type pack(ArrayRef types, llvm::Module &llvmModule, - MLIRContext &context); - - // Convert a function signature type to the LLVM IR dialect. The outer - // function type remains `mlir::FunctionType`. Argument types are converted - // to LLVM IR using `typeConversionCallback` if provided and using - // `TypeConverter::convert` otherwise. If the function returns a single - // result, its type is converted. Otherwise, the types of results are packed - // into an LLVM IR structure type. - static FunctionType convertFunctionSignature( - FunctionType t, llvm::Module &llvmModule, - llvm::function_ref typeConversionCallback = {}); - -private: - // Construct a type converter. - explicit TypeConverter(llvm::Module &llvmModule, MLIRContext *context) - : module(llvmModule), llvmContext(llvmModule.getContext()), - builder(llvmModule.getContext()), mlirContext(context) {} - - // Convert a function type. The arguments and results are converted one by - // one. Additionally, if the function returns more than one value, pack the - // results into an LLVM IR structure type so that the converted function type - // returns at most one result. - Type convertFunctionType(FunctionType type); - - // Convert function type arguments and results without converting the - // function type itself. - FunctionType convertFunctionSignatureType( - FunctionType type, llvm::function_ref typeConversionCallback); - - // Convert the index type. Uses llvmModule data layout to create an integer - // of the pointer bitwidth. - Type convertIndexType(IndexType type); - - // Convert an integer type `i*` to `!llvm<"i*">`. - Type convertIntegerType(IntegerType type); - - // Convert a floating point type: `f16` to `!llvm.half`, `f32` to - // `!llvm.float` and `f64` to `!llvm.double`. `bf16` is not supported - // by LLVM. - Type convertFloatType(FloatType type); - - // Convert a memref type into an LLVM type that captures the relevant data. - // For statically-shaped memrefs, the resulting type is a pointer to the - // (converted) memref element type. For dynamically-shaped memrefs, the - // resulting type is an LLVM structure type that contains: - // 1. a pointer to the (converted) memref element type - // 2. as many index types as memref has dynamic dimensions. - Type convertMemRefType(MemRefType type); - - // Convert a 1D vector type into an LLVM vector type. - Type convertVectorType(VectorType type); - - // Convert a non-empty list of types into an LLVM structure type containing - // those types. If the list contains a single element, convert the element - // directly. - Type getPackedResultType(ArrayRef types); - - // Convert a type to the LLVM IR dialect. Returns a null type in case of - // error. - Type convertType(Type type); - - // Get the LLVM representation of the index type based on the bitwidth of the - // pointer as defined by the data layout of the module. - llvm::IntegerType *getIndexType(); - - // Wrap the given LLVM IR type into an LLVM IR dialect type. - Type wrap(llvm::Type *llvmType) { - return LLVM::LLVMType::get(mlirContext, llvmType); - } - - // Extract an LLVM IR type from the LLVM IR dialect type. - llvm::Type *unwrap(Type type) { - if (!type) - return nullptr; - auto wrappedLLVMType = type.dyn_cast(); - if (!wrappedLLVMType) - return mlirContext->emitError(UnknownLoc::get(mlirContext), - "conversion resulted in a non-LLVM type"), - nullptr; - return wrappedLLVMType.getUnderlyingType(); - } - - llvm::Module &module; - llvm::LLVMContext &llvmContext; - llvm::IRBuilder<> builder; +// Wrap the given LLVM IR type into an LLVM IR dialect type. +Type LLVMLowering::wrap(llvm::Type *llvmType) { + return LLVM::LLVMType::get(llvmDialect->getContext(), llvmType); +} - MLIRContext *mlirContext; -}; -} // end anonymous namespace +// Extract an LLVM IR type from the LLVM IR dialect type. +llvm::Type *LLVMLowering::unwrap(Type type) { + if (!type) + return nullptr; + auto *mlirContext = type.getContext(); + auto wrappedLLVMType = type.dyn_cast(); + if (!wrappedLLVMType) + return mlirContext->emitError(UnknownLoc::get(mlirContext), + "conversion resulted in a non-LLVM type"), + nullptr; + return wrappedLLVMType.getUnderlyingType(); +} -llvm::IntegerType *TypeConverter::getIndexType() { - return builder.getIntNTy(module.getDataLayout().getPointerSizeInBits()); +llvm::IntegerType *LLVMLowering::getIndexType() { + return llvm::IntegerType::get(llvmDialect->getLLVMContext(), + module->getDataLayout().getPointerSizeInBits()); } -Type TypeConverter::convertIndexType(IndexType type) { +Type LLVMLowering::convertIndexType(IndexType type) { return wrap(getIndexType()); } -Type TypeConverter::convertIntegerType(IntegerType type) { - return wrap(builder.getIntNTy(type.getWidth())); +Type LLVMLowering::convertIntegerType(IntegerType type) { + return wrap( + llvm::Type::getIntNTy(llvmDialect->getLLVMContext(), type.getWidth())); } -Type TypeConverter::convertFloatType(FloatType type) { +Type LLVMLowering::convertFloatType(FloatType type) { switch (type.getKind()) { case mlir::StandardTypes::F32: - return wrap(builder.getFloatTy()); + return wrap(llvm::Type::getFloatTy(llvmDialect->getLLVMContext())); case mlir::StandardTypes::F64: - return wrap(builder.getDoubleTy()); + return wrap(llvm::Type::getDoubleTy(llvmDialect->getLLVMContext())); case mlir::StandardTypes::F16: - return wrap(builder.getHalfTy()); - case mlir::StandardTypes::BF16: + return wrap(llvm::Type::getHalfTy(llvmDialect->getLLVMContext())); + case mlir::StandardTypes::BF16: { + auto *mlirContext = llvmDialect->getContext(); return mlirContext->emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"), Type(); + } default: llvm_unreachable("non-float type in convertFloatType"); } } -// If `types` has more than one type, pack them into an LLVM StructType, -// otherwise just convert the type. -Type TypeConverter::getPackedResultType(ArrayRef types) { - // We don't convert zero-valued functions to one-valued functions returning - // void yet. - assert(!types.empty() && "empty type list"); - - // Convert result types one by one and check for errors. - SmallVector resultTypes; - for (auto t : types) { - llvm::Type *converted = unwrap(convertType(t)); - if (!converted) - return {}; - resultTypes.push_back(converted); - } - - // LLVM does not support tuple returns. If there are more than 2 results, - // pack them into an LLVM struct type. - if (resultTypes.size() == 1) - return wrap(resultTypes.front()); - return wrap(llvm::StructType::get(llvmContext, resultTypes)); -} - // Function types are converted to LLVM Function types by recursively converting // argument and result types. If MLIR Function has zero results, the LLVM // Function has one VoidType result. If MLIR Function has more than one result, // they are into an LLVM StructType in their order of appearance. -Type TypeConverter::convertFunctionType(FunctionType type) { +Type LLVMLowering::convertFunctionType(FunctionType type) { // Convert argument types one by one and check for errors. SmallVector argTypes; for (auto t : type.getInputs()) { @@ -219,45 +108,22 @@ Type TypeConverter::convertFunctionType(FunctionType type) { // If function does not return anything, create the void result type, // if it returns on element, convert it, otherwise pack the result types into // a struct. - llvm::Type *resultType = type.getNumResults() == 0 - ? llvm::Type::getVoidTy(llvmContext) - : unwrap(getPackedResultType(type.getResults())); + llvm::Type *resultType = + type.getNumResults() == 0 + ? llvm::Type::getVoidTy(llvmDialect->getLLVMContext()) + : unwrap(packFunctionResults(type.getResults())); if (!resultType) return {}; return wrap(llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false) ->getPointerTo()); } -FunctionType TypeConverter::convertFunctionSignatureType( - FunctionType type, llvm::function_ref typeConversionCallback) { - if (!typeConversionCallback) - typeConversionCallback = [this](Type t) { return convertType(t); }; - - SmallVector argTypes; - for (auto t : type.getInputs()) { - auto converted = typeConversionCallback(t); - if (!converted) - return {}; - argTypes.push_back(converted); - } - - // If function does not return anything, return immediately. - if (type.getNumResults() == 0) - return FunctionType::get(argTypes, {}, mlirContext); - - // Otherwise pack the result types into a struct. - if (auto result = getPackedResultType(type.getResults())) - return FunctionType::get(argTypes, {result}, mlirContext); - - return {}; -} - // Convert a MemRef to an LLVM type. If the memref is statically-shaped, then // we return a pointer to the converted element type. Otherwise we return an // LLVM stucture type, where the first element of the structure type is a // pointer to the elemental type of the MemRef and the following N elements are // values of the Index type, one for each of N dynamic dimensions of the MemRef. -Type TypeConverter::convertMemRefType(MemRefType type) { +Type LLVMLowering::convertMemRefType(MemRefType type) { llvm::Type *elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; @@ -272,12 +138,13 @@ Type TypeConverter::convertMemRefType(MemRefType type) { SmallVector types(numDynamicSizes + 1, getIndexType()); types.front() = ptrType; - return wrap(llvm::StructType::get(llvmContext, types)); + return wrap(llvm::StructType::get(llvmDialect->getLLVMContext(), types)); } // Convert a 1D vector type to an LLVM vector type. -Type TypeConverter::convertVectorType(VectorType type) { +Type LLVMLowering::convertVectorType(VectorType type) { if (type.getRank() != 1) { + auto *mlirContext = llvmDialect->getContext(); mlirContext->emitError(UnknownLoc::get(mlirContext), "only 1D vectors are supported"); return {}; @@ -290,7 +157,7 @@ Type TypeConverter::convertVectorType(VectorType type) { } // Dispatch based on the actual type. Return null type on error. -Type TypeConverter::convertType(Type type) { +Type LLVMLowering::convertStandardType(Type type) { if (auto funcType = type.dyn_cast()) return convertFunctionType(funcType); if (auto intType = type.dyn_cast()) @@ -309,44 +176,36 @@ Type TypeConverter::convertType(Type type) { return {}; } -Type TypeConverter::convert(Type t, llvm::Module &module) { - return TypeConverter(module, t.getContext()).convertType(t); -} - -FunctionType TypeConverter::convertFunctionSignature( - FunctionType t, llvm::Module &module, - llvm::function_ref typeConversionCallback) { - return TypeConverter(module, t.getContext()) - .convertFunctionSignatureType(t, typeConversionCallback); -} - -Type TypeConverter::getMemRefElementPtrType(MemRefType t, - llvm::Module &module) { +// Convert the element type of the memref `t` to to an LLVM type using +// `lowering`, get a pointer LLVM type pointing to the converted `t`, wrap it +// into the MLIR LLVM dialect type and return. +static Type getMemRefElementPtrType(MemRefType t, LLVMLowering &lowering) { auto elementType = t.getElementType(); - auto converted = convert(elementType, module); + auto converted = lowering.convertType(elementType); if (!converted) return {}; llvm::Type *llvmType = converted.cast().getUnderlyingType(); return LLVM::LLVMType::get(t.getContext(), llvmType->getPointerTo()); } -Type TypeConverter::pack(ArrayRef types, llvm::Module &module, - MLIRContext &mlirContext) { - return TypeConverter(module, &mlirContext).getPackedResultType(types); -} +LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, + LLVMLowering &lowering_) + : DialectOpConversion(rootOpName, /*benefit=*/1, context), + lowering(lowering_) {} namespace { // Base class for Standard to LLVM IR op conversions. Matches the Op type // provided as template argument. Carries a reference to the LLVM dialect in // case it is necessary for rewriters. template -class LLVMLegalizationPattern : public DialectOpConversion { +class LLVMLegalizationPattern : public LLVMOpLowering { public: // Construct a conversion pattern. - explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect) - : DialectOpConversion(SourceOp::getOperationName(), 1, - dialect.getContext()), - dialect(dialect) {} + explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_, + LLVMLowering &lowering_) + : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(), + lowering_), + dialect(dialect_) {} // Match by type. PatternMatchResult match(Operation *op) const override { @@ -436,13 +295,11 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { unsigned numResults = op->getNumResults(); - auto *mlirContext = op->getContext(); Type packedType; if (numResults != 0) { packedType = - TypeConverter::pack(getTypes(op->getResults()), - this->dialect.getLLVMModule(), *mlirContext); + this->lowering.packFunctionResults(getTypes(op->getResults())); assert(packedType && "type conversion failed, such operation should not " "have been matched"); } @@ -461,8 +318,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { - auto type = TypeConverter::convert(op->getResult(i)->getType(), - this->dialect.getLLVMModule()); + auto type = this->lowering.convertType(op->getResult(i)->getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp.getOperation()->getResult(0), this->getIntegerArrayAttr(rewriter, i))); @@ -623,7 +479,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern { rewriter.getFunctionAttr(mallocFunc), cumulativeSize) .getResult(0); - auto structElementType = TypeConverter::convert(elementType, getModule()); + auto structElementType = lowering.convertType(elementType); auto elementPtrType = LLVM::LLVMType::get( op->getContext(), structElementType.cast() .getUnderlyingType() @@ -637,7 +493,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern { } // Create the MemRef descriptor. - auto structType = TypeConverter::convert(type, getModule()); + auto structType = lowering.convertType(type); Value *memRefDescriptor = rewriter.create( op->getLoc(), structType, ArrayRef{}); @@ -718,8 +574,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { auto sourceType = memRefCastOp.getOperand()->getType().cast(); // Copy the data buffer pointer. - auto elementTypePtr = - TypeConverter::getMemRefElementPtrType(targetType, getModule()); + auto elementTypePtr = getMemRefElementPtrType(targetType, lowering); Value *buffer = extractMemRefElementPtr(rewriter, op->getLoc(), operands[0], elementTypePtr, sourceType.hasStaticShape()); @@ -729,7 +584,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { } // Create the new MemRef descriptor. - auto structType = TypeConverter::convert(targetType, getModule()); + auto structType = lowering.convertType(targetType); Value *newDescriptor = rewriter.create( op->getLoc(), structType, ArrayRef{}); // Otherwise target type is dynamic memref, so create a proper descriptor. @@ -921,7 +776,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr, ArrayRef indices, FuncBuilder &rewriter, llvm::Module &module) const { - auto ptrType = TypeConverter::getMemRefElementPtrType(type, module); + auto ptrType = getMemRefElementPtrType(type, this->lowering); auto shape = type.getShape(); if (type.hasStaticShape()) { // NB: If memref was statically-shaped, dataPtr is pointer to raw data. @@ -944,8 +799,7 @@ struct LoadOpLowering : public LoadStoreOpLowering { Value *dataPtr = getDataPtr(op->getLoc(), type, operands.front(), operands.drop_front(), rewriter, getModule()); - auto elementType = - TypeConverter::convert(type.getElementType(), getModule()); + auto elementType = lowering.convertType(type.getElementType()); SmallVector results; results.push_back(rewriter.create( @@ -1018,9 +872,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern { // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. - auto *mlirContext = op->getContext(); - auto packedType = TypeConverter::pack( - getTypes(op->getOperands()), dialect.getLLVMModule(), *mlirContext); + auto packedType = lowering.packFunctionResults(getTypes(op->getOperands())); Value *packed = rewriter.create(op->getLoc(), packedType); for (unsigned i = 0; i < numArguments; ++i) { @@ -1119,7 +971,8 @@ LLVMLowering::initConverters(MLIRContext *mlirContext) { LoadOpLowering, MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering, SelectOpLowering, StoreOpLowering, SubFOpLowering, - SubIOpLowering, XOrOpLowering>::build(&converterStorage, *llvmDialect); + SubIOpLowering, XOrOpLowering>::build(&converterStorage, *llvmDialect, + *this); auto extraConverters = initAdditionalConverters(); converters.insert(extraConverters.begin(), extraConverters.end()); return converters; @@ -1127,7 +980,7 @@ LLVMLowering::initConverters(MLIRContext *mlirContext) { // Convert types using the stored LLVM IR module. Type LLVMLowering::convertType(Type t) { - if (auto result = TypeConverter::convert(t, *module)) + if (auto result = convertStandardType(t)) return result; if (auto result = convertAdditionalType(t)) return result; @@ -1138,16 +991,57 @@ Type LLVMLowering::convertType(Type t) { return {}; } +static llvm::Type *unwrapType(Type type) { + return type.cast().getUnderlyingType(); +} + +// Create an LLVM IR structure type if there is more than one result. +Type LLVMLowering::packFunctionResults(ArrayRef types) { + assert(!types.empty() && "expected non-empty list of type"); + + if (types.size() == 1) + return convertType(types.front()); + + SmallVector resultTypes; + resultTypes.reserve(types.size()); + for (auto t : types) { + Type converted = convertType(t); + if (!converted) + return {}; + resultTypes.push_back(unwrapType(converted)); + } + + return LLVM::LLVMType::get( + llvmDialect->getContext(), + llvm::StructType::get(llvmDialect->getLLVMContext(), resultTypes)); +} + // Convert function signatures using the stored LLVM IR module. FunctionType LLVMLowering::convertFunctionSignatureType( - FunctionType t, ArrayRef argAttrs, + FunctionType type, ArrayRef argAttrs, SmallVectorImpl &convertedArgAttrs) { convertedArgAttrs.reserve(argAttrs.size()); for (auto attr : argAttrs) convertedArgAttrs.push_back(attr); - return TypeConverter::convertFunctionSignature( - t, *module, [this](Type t) { return convertType(t); }); + + SmallVector argTypes; + for (auto t : type.getInputs()) { + auto converted = convertType(t); + if (!converted) + return {}; + argTypes.push_back(converted); + } + + // If function does not return anything, return immediately. + if (type.getNumResults() == 0) + return FunctionType::get(argTypes, {}, type.getContext()); + + // Otherwise pack the result types into a struct. + if (auto result = packFunctionResults(type.getResults())) + return FunctionType::get(argTypes, result, result.getContext()); + + return {}; } namespace { @@ -1184,7 +1078,3 @@ std::unique_ptr mlir::createStdToLLVMConverter() { static PassRegistration pass("lower-to-llvm", "Convert all functions to the LLVM IR dialect"); - -Type mlir::LLVM::convertToLLVMDialectType(Type t, llvm::Module &llvmModule) { - return TypeConverter::convert(t, llvmModule); -} diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 7463c71..108da6c 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -54,21 +54,10 @@ using add = ValueBuilder; using sub = ValueBuilder; using mul = ValueBuilder; -static llvm::Module *getLLVMModule(MLIRContext *context) { - auto *llvmDialect = - static_cast(context->getRegisteredDialect("llvm")); - if (!llvmDialect) { - context->emitError(UnknownLoc::get(context), - "LLVM IR dialect is not registered"); - return nullptr; - } - return &llvmDialect->getLLVMModule(); -} - template static llvm::Type *getPtrToElementType(T containerType, - llvm::Module &llvmModule) { - return convertToLLVMDialectType(containerType.getElementType(), llvmModule) + LLVMLowering &lowering) { + return lowering.convertType(containerType.getElementType()) .template cast() .getUnderlyingType() ->getPointerTo(); @@ -82,9 +71,11 @@ static llvm::Type *getPtrToElementType(T containerType, // - an F32 type is converted into an LLVM float type // - a Buffer, Range or View is converted into an LLVM structure type // containing the respective dynamic values. -static Type convertLinalgType(Type t, llvm::Module &llvmModule) { +static Type convertLinalgType(Type t, LLVMLowering &lowering) { auto *context = t.getContext(); - auto *int64Ty = llvm::Type::getInt64Ty(llvmModule.getContext()); + auto *int64Ty = lowering.convertType(IntegerType::get(64, context)) + .cast() + .getUnderlyingType(); // A buffer descriptor contains the pointer to a flat region of storage and // the size of the region. @@ -95,7 +86,7 @@ static Type convertLinalgType(Type t, llvm::Module &llvmModule) { // int64_t size; // }; if (auto bufferTy = t.dyn_cast()) { - auto *ptrTy = getPtrToElementType(bufferTy, llvmModule); + auto *ptrTy = getPtrToElementType(bufferTy, lowering); auto *structTy = llvm::StructType::get(ptrTy, int64Ty); return LLVMType::get(context, structTy); } @@ -136,7 +127,7 @@ static Type convertLinalgType(Type t, llvm::Module &llvmModule) { // int64_t strides[Rank]; // }; if (auto viewTy = t.dyn_cast()) { - auto *ptrTy = getPtrToElementType(viewTy, llvmModule); + auto *ptrTy = getPtrToElementType(viewTy, lowering); auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank()); auto *structTy = llvm::StructType::get(ptrTy, int64Ty, arrayTy, arrayTy); return LLVMType::get(context, structTy); @@ -157,36 +148,31 @@ static ArrayAttr makePositionAttr(FuncBuilder &builder, } // BufferSizeOp creates a new `index` value. -class BufferSizeOpConversion : public DialectOpConversion { +class BufferSizeOpConversion : public LLVMOpLowering { public: - explicit BufferSizeOpConversion(MLIRContext *context) - : DialectOpConversion(BufferSizeOp::getOperationName(), 1, context), - llvmModule(*getLLVMModule(context)) {} + BufferSizeOpConversion(MLIRContext *context, LLVMLowering &lowering_) + : LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {} SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto bufferSizeType = - convertToLLVMDialectType(operands[0]->getType(), llvmModule); + auto bufferSizeType = lowering.convertType(operands[0]->getType()); edsc::ScopedContext context(rewriter, op->getLoc()); return {extractvalue(bufferSizeType, operands[0], makePositionAttr(rewriter, 1))}; } - - llvm::Module &llvmModule; }; // RangeOp creates a new range descriptor. -class RangeOpConversion : public DialectOpConversion { +class RangeOpConversion : public LLVMOpLowering { public: - explicit RangeOpConversion(MLIRContext *context) - : DialectOpConversion(RangeOp::getOperationName(), 1, context), - llvmModule(*getLLVMModule(context)) {} + explicit RangeOpConversion(MLIRContext *context, LLVMLowering &lowering_) + : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {} SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { auto rangeOp = op->cast(); auto rangeDescriptorType = - convertLinalgType(rangeOp.getResult()->getType(), llvmModule); + convertLinalgType(rangeOp.getResult()->getType(), lowering); edsc::ScopedContext context(rewriter, op->getLoc()); @@ -201,24 +187,20 @@ public: return {desc}; } - - llvm::Module &llvmModule; }; -class SliceOpConversion : public DialectOpConversion { +class SliceOpConversion : public LLVMOpLowering { public: - explicit SliceOpConversion(MLIRContext *context) - : DialectOpConversion(SliceOp::getOperationName(), 1, context), - llvmModule(*getLLVMModule(context)) {} + explicit SliceOpConversion(MLIRContext *context, LLVMLowering &lowering_) + : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {} SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { auto sliceOp = op->cast(); auto viewDescriptorType = - convertLinalgType(sliceOp.getViewType(), llvmModule); + convertLinalgType(sliceOp.getViewType(), lowering); auto viewType = sliceOp.getBaseViewType(); - auto int64Ty = - convertToLLVMDialectType(rewriter.getIntegerType(64), llvmModule); + auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); // Helper function to create an integer array attribute out of a list of // values. @@ -229,7 +211,7 @@ public: auto getViewPtr = [pos, &rewriter, this](ViewType type, Value *view) -> Value * { auto elementPtrTy = - rewriter.getType(getPtrToElementType(type, llvmModule)); + rewriter.getType(getPtrToElementType(type, lowering)); return extractvalue(elementPtrTy, view, pos(0)); }; @@ -288,25 +270,20 @@ public: return {desc}; } - - llvm::Module &llvmModule; }; -class ViewOpConversion : public DialectOpConversion { +class ViewOpConversion : public LLVMOpLowering { public: - explicit ViewOpConversion(MLIRContext *context) - : DialectOpConversion(ViewOp::getOperationName(), 1, context), - llvmModule(*getLLVMModule(context)) {} + explicit ViewOpConversion(MLIRContext *context, LLVMLowering &lowering_) + : LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {} SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { auto viewOp = op->cast(); - auto viewDescriptorType = - convertLinalgType(viewOp.getViewType(), llvmModule); + auto viewDescriptorType = convertLinalgType(viewOp.getViewType(), lowering); auto elementType = rewriter.getType( - getPtrToElementType(viewOp.getViewType(), llvmModule)); - auto int64Ty = - convertToLLVMDialectType(rewriter.getIntegerType(64), llvmModule); + getPtrToElementType(viewOp.getViewType(), lowering)); + auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); auto pos = [&rewriter](ArrayRef values) { return makePositionAttr(rewriter, values); @@ -350,15 +327,13 @@ public: return {desc}; } - - llvm::Module &llvmModule; }; // DotOp creates a new range descriptor. -class DotOpConversion : public DialectOpConversion { +class DotOpConversion : public LLVMOpLowering { public: - explicit DotOpConversion(MLIRContext *context) - : DialectOpConversion(DotOp::getOperationName(), 1, context) {} + explicit DotOpConversion(MLIRContext *context, LLVMLowering &lowering_) + : LLVMOpLowering(DotOp::getOperationName(), context, lowering_) {} static StringRef libraryFunctionName() { return "linalg_dot"; } @@ -384,11 +359,12 @@ protected: return ConversionListBuilder< BufferSizeOpConversion, DotOpConversion, RangeOpConversion, SliceOpConversion, ViewOpConversion>::build(&converterStorage, - llvmDialect->getContext()); + llvmDialect->getContext(), + *this); } Type convertAdditionalType(Type t) override { - return convertLinalgType(t, *module); + return convertLinalgType(t, *this); } }; } // end anonymous namespace -- 2.7.4