From 5953d12b959a2905bd2f32b7429d3196ea274db2 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 22 May 2019 14:56:07 -0700 Subject: [PATCH] Add thread-safe utilities to LLVMType to allow constructing llvm types in a multi-threaded environment. The LLVMContext is not thread-safe and directly constructing a raw llvm::Type can create situations where the LLVMContext is modified by multiple threads at the same time. -- PiperOrigin-RevId: 249526233 --- .../Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp | 57 ++--- .../Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp | 8 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 10 +- mlir/include/mlir/LLVMIR/LLVMDialect.h | 81 ++++++- mlir/include/mlir/LLVMIR/LLVMLowering.h | 12 +- mlir/lib/LLVMIR/IR/LLVMDialect.cpp | 239 +++++++++++++++------ .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 103 ++++----- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 56 ++--- 8 files changed, 343 insertions(+), 223 deletions(-) diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index a23828d..7a0e1a5 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -62,22 +62,14 @@ Type linalg::convertLinalgType(Type t) { // Simple conversions. if (t.isa()) { int width = dialect->getLLVMModule().getDataLayout().getPointerSizeInBits(); - auto *integerTy = llvm::IntegerType::get(dialect->getLLVMContext(), width); - return LLVM::LLVMType::get(context, integerTy); - } - if (auto intTy = t.dyn_cast()) { - int width = intTy.getWidth(); - auto *integerTy = llvm::IntegerType::get(dialect->getLLVMContext(), width); - return LLVM::LLVMType::get(context, integerTy); - } - if (t.isF32()) { - auto *floatTy = llvm::Type::getFloatTy(dialect->getLLVMContext()); - return LLVM::LLVMType::get(context, floatTy); - } - if (t.isF64()) { - auto *doubleTy = llvm::Type::getDoubleTy(dialect->getLLVMContext()); - return LLVM::LLVMType::get(context, doubleTy); + return LLVM::LLVMType::getIntNTy(dialect, width); } + if (auto intTy = t.dyn_cast()) + return LLVM::LLVMType::getIntNTy(dialect, intTy.getWidth()); + if (t.isF32()) + return LLVM::LLVMType::getFloatTy(dialect); + if (t.isF64()) + return LLVM::LLVMType::getDoubleTy(dialect); // Range descriptor contains the range bounds and the step as 64-bit integers. // @@ -87,9 +79,8 @@ Type linalg::convertLinalgType(Type t) { // int64_t step; // }; if (auto rangeTy = t.dyn_cast()) { - auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext()); - auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty); - return LLVM::LLVMType::get(context, structTy); + auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect); + return LLVM::LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); } // View descriptor contains the pointer to the data buffer, followed by a @@ -116,14 +107,12 @@ Type linalg::convertLinalgType(Type t) { // int64_t strides[Rank]; // }; if (auto viewTy = t.dyn_cast()) { - auto *elemTy = linalg::convertLinalgType(viewTy.getElementType()) - .cast() - .getUnderlyingType() - ->getPointerTo(); - auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext()); - auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank()); - auto *structTy = llvm::StructType::get(elemTy, int64Ty, arrayTy, arrayTy); - return LLVM::LLVMType::get(context, structTy); + auto elemTy = linalg::convertLinalgType(viewTy.getElementType()) + .cast() + .getPointerTo(); + auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect); + auto arrayTy = LLVM::LLVMType::getArrayTy(int64Ty, viewTy.getRank()); + return LLVM::LLVMType::getStructTy(elemTy, int64Ty, arrayTy, arrayTy); } // All other types are kept as is. @@ -217,11 +206,9 @@ public: if (type.hasStaticShape()) return memref; - auto elementTy = LLVM::LLVMType::get( - type.getContext(), linalg::convertLinalgType(type.getElementType()) - .cast() - .getUnderlyingType() - ->getPointerTo()); + auto elementTy = linalg::convertLinalgType(type.getElementType()) + .cast() + .getPointerTo(); return intrinsics::extractvalue(elementTy, memref, pos(0)); }; @@ -307,11 +294,9 @@ public: auto sliceOp = cast(op); auto newViewDescriptorType = linalg::convertLinalgType(sliceOp.getViewType()); - auto elementType = rewriter.getType( - linalg::convertLinalgType(sliceOp.getElementType()) - .cast() - .getUnderlyingType() - ->getPointerTo()); + auto elementType = linalg::convertLinalgType(sliceOp.getElementType()) + .cast() + .getPointerTo(); auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64)); auto pos = [&rewriter](ArrayRef values) { diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 98bd867..9f2d46c 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -67,11 +67,9 @@ public: auto loadOp = cast(op); auto elementType = loadOp.getViewType().template cast().getElementType(); - auto *llvmPtrType = linalg::convertLinalgType(elementType) - .template cast() - .getUnderlyingType() - ->getPointerTo(); - elementType = rewriter.getType(llvmPtrType); + elementType = linalg::convertLinalgType(elementType) + .template cast() + .getPointerTo(); auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64)); auto pos = [&rewriter](ArrayRef values) { diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 7911649..5a0a901 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -210,15 +210,11 @@ private: // Create a function declaration for printf, signature is `i32 (i8*, ...)` Builder builder(&module); - MLIRContext *context = module.getContext(); - auto *llvmDialect = + auto *dialect = module.getContext()->getRegisteredDialect(); - auto &llvmModule = llvmDialect->getLLVMModule(); - llvm::IRBuilder<> llvmBuilder(llvmModule.getContext()); - auto llvmI32Ty = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(32)); - auto llvmI8PtrTy = - LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(8)->getPointerTo()); + auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo(); auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty}); printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy); // It should be variadic, but we don't support it fully just yet. diff --git a/mlir/include/mlir/LLVMIR/LLVMDialect.h b/mlir/include/mlir/LLVMIR/LLVMDialect.h index 170a5b3..b8272f9 100644 --- a/mlir/include/mlir/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/LLVMIR/LLVMDialect.h @@ -41,10 +41,12 @@ class LLVMContext; namespace mlir { namespace LLVM { +class LLVMDialect; namespace detail { struct LLVMTypeStorage; -} +struct LLVMDialectImpl; +} // namespace detail class LLVMType : public mlir::Type::TypeBase { @@ -57,9 +59,72 @@ public: static bool kindof(unsigned kind) { return kind == LLVM_TYPE; } - static LLVMType get(MLIRContext *context, llvm::Type *llvmType); - + LLVMDialect &getDialect(); llvm::Type *getUnderlyingType() const; + + /// Array type utilities. + LLVMType getArrayElementType(); + + /// Pointer type utilities. + LLVMType getPointerTo(unsigned addrSpace = 0); + LLVMType getPointerElementTy(); + + /// Struct type utilities. + LLVMType getStructElementType(unsigned i); + + /// Utilities used to generate floating point types. + static LLVMType getDoubleTy(LLVMDialect *dialect); + static LLVMType getFloatTy(LLVMDialect *dialect); + static LLVMType getHalfTy(LLVMDialect *dialect); + + /// Utilities used to generate integer types. + static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits); + static LLVMType getInt1Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/1); + } + static LLVMType getInt8Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/8); + } + static LLVMType getInt8PtrTy(LLVMDialect *dialect) { + return getInt8Ty(dialect).getPointerTo(); + } + static LLVMType getInt16Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/16); + } + static LLVMType getInt32Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/32); + } + static LLVMType getInt64Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/64); + } + + /// Utilities used to generate other miscellaneous types. + static LLVMType getArrayTy(LLVMType elementType, uint64_t numElements); + static LLVMType getFunctionTy(LLVMType result, ArrayRef params, + bool isVarArg); + static LLVMType getFunctionTy(LLVMType result, bool isVarArg) { + return getFunctionTy(result, llvm::None, isVarArg); + } + static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef elements, + bool isPacked = false); + static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) { + return getStructTy(dialect, llvm::None, isPacked); + } + template + static typename std::enable_if::value, + LLVMType>::type + getStructTy(LLVMType elt1, Args... elts) { + SmallVector fields({elt1, elts...}); + return getStructTy(&elt1.getDialect(), fields); + } + static LLVMType getVectorTy(LLVMType elementType, unsigned numElements); + static LLVMType getVoidTy(LLVMDialect *dialect); + +private: + friend LLVMDialect; + + /// Get an LLVM type with a pre-existing llvm type. + static LLVMType get(MLIRContext *context, llvm::Type *llvmType); }; ///// Ops ///// @@ -69,10 +134,11 @@ public: class LLVMDialect : public Dialect { public: explicit LLVMDialect(MLIRContext *context); + ~LLVMDialect(); static StringRef getDialectNamespace() { return "llvm"; } - llvm::LLVMContext &getLLVMContext() { return llvmContext; } - llvm::Module &getLLVMModule() { return module; } + llvm::LLVMContext &getLLVMContext(); + llvm::Module &getLLVMModule(); /// Parse a type registered to this dialect. Type parseType(StringRef tyData, Location loc) const override; @@ -86,8 +152,9 @@ public: NamedAttribute argAttr) override; private: - llvm::LLVMContext llvmContext; - llvm::Module module; + friend LLVMType; + + std::unique_ptr impl; }; } // end namespace LLVM diff --git a/mlir/include/mlir/LLVMIR/LLVMLowering.h b/mlir/include/mlir/LLVMIR/LLVMLowering.h index c2bf040..9947f42 100644 --- a/mlir/include/mlir/LLVMIR/LLVMLowering.h +++ b/mlir/include/mlir/LLVMIR/LLVMLowering.h @@ -31,11 +31,12 @@ class IntegerType; class LLVMContext; class Module; class Type; -} +} // namespace llvm namespace mlir { namespace LLVM { class LLVMDialect; +class LLVMType; } /// Conversion from the Standard dialect to the LLVM IR dialect. Provides hooks @@ -55,6 +56,9 @@ public: /// Returns the LLVM context. llvm::LLVMContext &getLLVMContext(); + /// Returns the LLVM dialect. + LLVM::LLVMDialect *getDialect() { return llvmDialect; } + protected: /// Add a set of converters to the given pattern list. Store the module /// associated with the dialect for further type conversion. @@ -119,13 +123,13 @@ private: // Get the LLVM representation of the index type based on the bitwidth of the // pointer as defined by the data layout of the module. - llvm::IntegerType *getIndexType(); + LLVM::LLVMType getIndexType(); // Wrap the given LLVM IR type into an LLVM IR dialect type. Type wrap(llvm::Type *llvmType); - // Extract an LLVM IR type from the LLVM IR dialect type. - llvm::Type *unwrap(Type type); + // Extract an LLVM IR dialect type. + LLVM::LLVMType unwrap(Type type); }; /// Base class for operation conversions targeting the LLVM IR dialect. Provides diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index fd197c0..950b1d4 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -29,40 +29,12 @@ #include "llvm/IR/Attributes.h" #include "llvm/IR/Function.h" #include "llvm/IR/Type.h" +#include "llvm/Support/Mutex.h" #include "llvm/Support/SourceMgr.h" using namespace mlir; using namespace mlir::LLVM; -namespace mlir { -namespace LLVM { -namespace detail { -struct LLVMTypeStorage : public ::mlir::TypeStorage { - LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {} - - // LLVM types are pointer-unique. - using KeyTy = llvm::Type *; - bool operator==(const KeyTy &key) const { return key == underlyingType; } - - static LLVMTypeStorage *construct(TypeStorageAllocator &allocator, - llvm::Type *ty) { - return new (allocator.allocate()) LLVMTypeStorage(ty); - } - - llvm::Type *underlyingType; -}; -} // end namespace detail -} // end namespace LLVM -} // end namespace mlir - -LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) { - return Base::get(context, FIRST_LLVM_TYPE, llvmType); -} - -llvm::Type *LLVMType::getUnderlyingType() const { - return getImpl()->underlyingType; -} - static void printLLVMBinaryOp(OpAsmPrinter *p, Operation *op) { // Fallback to the generic form if the op is not well-formed (may happen // during incomplete rewrites, and used for debugging). @@ -161,14 +133,13 @@ static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) { // The result type is either i1 or a vector type if the inputs are // vectors. auto *dialect = builder.getContext()->getRegisteredDialect(); - llvm::Type *llvmResultType = llvm::Type::getInt1Ty(dialect->getLLVMContext()); + auto resultType = LLVMType::getInt1Ty(dialect); auto argType = type.dyn_cast(); if (!argType) return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"); if (argType.getUnderlyingType()->isVectorTy()) - llvmResultType = llvm::VectorType::get( - llvmResultType, argType.getUnderlyingType()->getVectorNumElements()); - auto resultType = builder.getType(llvmResultType); + resultType = LLVMType::getVectorTy( + resultType, argType.getUnderlyingType()->getVectorNumElements()); result->attributes = attrs; result->addTypes({resultType}); @@ -180,9 +151,7 @@ static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) { //===----------------------------------------------------------------------===// static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) { - auto *llvmPtrTy = op.getType().cast().getUnderlyingType(); - auto *llvmElemTy = llvm::cast(llvmPtrTy)->getElementType(); - auto elemTy = LLVM::LLVMType::get(op.getContext(), llvmElemTy); + auto elemTy = op.getType().cast().getPointerElementTy(); auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()}, op.getContext()); @@ -291,13 +260,10 @@ static Type getLoadStoreElementType(OpAsmParser *parser, Type type, if (!llvmTy) return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"), nullptr; - auto *llvmPtrTy = dyn_cast(llvmTy.getUnderlyingType()); - if (!llvmPtrTy) + if (!llvmTy.getUnderlyingType()->isPointerTy()) return parser->emitError(trailingTypeLoc, "expected LLVM pointer type"), nullptr; - auto elemTy = LLVM::LLVMType::get(parser->getBuilder().getContext(), - llvmPtrTy->getElementType()); - return elemTy; + return llvmTy.getPointerElementTy(); } // ::= `llvm.load` ssa-use attribute-dict? `:` type @@ -465,33 +431,28 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) { Builder &builder = parser->getBuilder(); auto *llvmDialect = builder.getContext()->getRegisteredDialect(); - llvm::Type *llvmResultType; - Type wrappedResultType; + LLVM::LLVMType llvmResultType; if (funcType.getNumResults() == 0) { - llvmResultType = llvm::Type::getVoidTy(llvmDialect->getLLVMContext()); - wrappedResultType = builder.getType(llvmResultType); + llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect); } else { - wrappedResultType = funcType.getResult(0); - auto wrappedLLVMResultType = wrappedResultType.dyn_cast(); - if (!wrappedLLVMResultType) + llvmResultType = funcType.getResult(0).dyn_cast(); + if (!llvmResultType) return parser->emitError(trailingTypeLoc, "expected result to have LLVM type"); - llvmResultType = wrappedLLVMResultType.getUnderlyingType(); } - SmallVector argTypes; + SmallVector argTypes; argTypes.reserve(funcType.getNumInputs()); for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { auto argType = funcType.getInput(i).dyn_cast(); if (!argType) return parser->emitError(trailingTypeLoc, "expected LLVM types as inputs"); - argTypes.push_back(argType.getUnderlyingType()); + argTypes.push_back(argType); } - auto *llvmFuncType = llvm::FunctionType::get(llvmResultType, argTypes, - /*isVarArg=*/false); - auto wrappedFuncType = - builder.getType(llvmFuncType->getPointerTo()); + auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes, + /*isVarArg=*/false); + auto wrappedFuncType = llvmFuncType.getPointerTo(); auto funcArguments = ArrayRef(operands).drop_front(); @@ -505,7 +466,7 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) { parser->getNameLoc(), result->operands)) return failure(); - result->addTypes(wrappedResultType); + result->addTypes(llvmResultType); } result->attributes = attrs; @@ -544,7 +505,6 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser, // type by taking the element type, indexed by the position attribute for // stuctures. Check the position index before accessing, it is supposed to be // in bounds. - llvm::Type *llvmContainerType = wrappedContainerType.getUnderlyingType(); for (Attribute subAttr : positionArrayAttr) { auto positionElementAttr = subAttr.dyn_cast(); if (!positionElementAttr) @@ -552,27 +512,27 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser, "expected an array of integer literals"), nullptr; int position = positionElementAttr.getInt(); + auto *llvmContainerType = wrappedContainerType.getUnderlyingType(); if (llvmContainerType->isArrayTy()) { if (position < 0 || static_cast(position) >= llvmContainerType->getArrayNumElements()) return parser->emitError(attributeLoc, "position out of bounds"), nullptr; - llvmContainerType = llvmContainerType->getArrayElementType(); + wrappedContainerType = wrappedContainerType.getArrayElementType(); } else if (llvmContainerType->isStructTy()) { if (position < 0 || static_cast(position) >= llvmContainerType->getStructNumElements()) return parser->emitError(attributeLoc, "position out of bounds"), nullptr; - llvmContainerType = llvmContainerType->getStructElementType(position); + wrappedContainerType = + wrappedContainerType.getStructElementType(position); } else { return parser->emitError(typeLoc, "expected wrapped LLVM IR structure/array type"), nullptr; } } - - Builder &builder = parser->getBuilder(); - return builder.getType(llvmContainerType); + return wrappedContainerType; } // ::= `llvm.extractvalue` ssa-use @@ -730,8 +690,7 @@ static ParseResult parseCondBrOp(OpAsmParser *parser, OperationState *result) { Builder &builder = parser->getBuilder(); auto *llvmDialect = builder.getContext()->getRegisteredDialect(); - auto i1Type = builder.getType( - llvm::Type::getInt1Ty(llvmDialect->getLLVMContext())); + auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect); if (parser->parseOperand(condition) || parser->parseComma() || parser->parseSuccessorAndUseList(trueDest, trueOperands) || @@ -844,9 +803,26 @@ static ParseResult parseConstantOp(OpAsmParser *parser, // LLVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// +namespace mlir { +namespace LLVM { +namespace detail { +struct LLVMDialectImpl { + LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {} + + llvm::LLVMContext llvmContext; + llvm::Module module; + + /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not + /// multi-threaded and requires locked access to prevent race conditions. + llvm::sys::SmartMutex mutex; +}; +} // end namespace detail +} // end namespace LLVM +} // end namespace mlir + LLVMDialect::LLVMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context), - module("LLVMDialectModule", llvmContext) { + impl(new detail::LLVMDialectImpl()) { addTypes(); addOperations< #define GET_OP_LIST @@ -857,13 +833,21 @@ LLVMDialect::LLVMDialect(MLIRContext *context) allowUnknownOperations(); } +LLVMDialect::~LLVMDialect() {} + #define GET_OP_CLASSES #include "mlir/LLVMIR/LLVMOps.cpp.inc" +llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; } +llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; } + /// Parse a type registered to this dialect. Type LLVMDialect::parseType(StringRef tyData, Location loc) const { + // LLVM is not thread-safe, so lock access to it. + llvm::sys::SmartScopedLock lock(impl->mutex); + llvm::SMDiagnostic errorMessage; - llvm::Type *type = llvm::parseType(tyData, errorMessage, module); + llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module); if (!type) return (getContext()->emitError(loc, errorMessage.getMessage()), nullptr); return LLVMType::get(getContext(), type); @@ -889,3 +873,126 @@ LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function *func, } static DialectRegistration llvmDialect; + +//===----------------------------------------------------------------------===// +// LLVMType. +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace LLVM { +namespace detail { +struct LLVMTypeStorage : public ::mlir::TypeStorage { + LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {} + + // LLVM types are pointer-unique. + using KeyTy = llvm::Type *; + bool operator==(const KeyTy &key) const { return key == underlyingType; } + + static LLVMTypeStorage *construct(TypeStorageAllocator &allocator, + llvm::Type *ty) { + return new (allocator.allocate()) LLVMTypeStorage(ty); + } + + llvm::Type *underlyingType; +}; +} // end namespace detail +} // end namespace LLVM +} // end namespace mlir + +LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) { + return Base::get(context, FIRST_LLVM_TYPE, llvmType); +} + +LLVMDialect &LLVMType::getDialect() { + return static_cast(Type::getDialect()); +} + +llvm::Type *LLVMType::getUnderlyingType() const { + return getImpl()->underlyingType; +} + +/// Array type utilities. +LLVMType LLVMType::getArrayElementType() { + return get(getContext(), getUnderlyingType()->getArrayElementType()); +} + +/// Pointer type utilities. +LLVMType LLVMType::getPointerTo(unsigned addrSpace) { + // Lock access to the dialect as this may modify the LLVM context. + llvm::sys::SmartScopedLock lock(getDialect().impl->mutex); + return get(getContext(), getUnderlyingType()->getPointerTo(addrSpace)); +} +LLVMType LLVMType::getPointerElementTy() { + return get(getContext(), getUnderlyingType()->getPointerElementType()); +} + +/// Struct type utilities. +LLVMType LLVMType::getStructElementType(unsigned i) { + return get(getContext(), getUnderlyingType()->getStructElementType(i)); +} + +/// Utilities used to generate floating point types. +LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) { + return get(dialect->getContext(), + llvm::Type::getDoubleTy(dialect->getLLVMContext())); +} +LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) { + return get(dialect->getContext(), + llvm::Type::getFloatTy(dialect->getLLVMContext())); +} +LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) { + return get(dialect->getContext(), + llvm::Type::getHalfTy(dialect->getLLVMContext())); +} + +/// Utilities used to generate integer types. +LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) { + // Lock access to the dialect as this may modify the LLVM context. + llvm::sys::SmartScopedLock lock(dialect->impl->mutex); + return get(dialect->getContext(), + llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits)); +} + +/// Utilities used to generate other miscellaneous types. +LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) { + // Lock access to the dialect as this may modify the LLVM context. + llvm::sys::SmartScopedLock lock(elementType.getDialect().impl->mutex); + return get( + elementType.getContext(), + llvm::ArrayType::get(elementType.getUnderlyingType(), numElements)); +} +LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef params, + bool isVarArg) { + SmallVector llvmParams; + for (auto param : params) + llvmParams.push_back(param.getUnderlyingType()); + + // Lock access to the dialect as this may modify the LLVM context. + llvm::sys::SmartScopedLock lock(result.getDialect().impl->mutex); + return get(result.getContext(), + llvm::FunctionType::get(result.getUnderlyingType(), llvmParams, + isVarArg)); +} +LLVMType LLVMType::getStructTy(LLVMDialect *dialect, + ArrayRef elements, bool isPacked) { + SmallVector llvmElements; + for (auto elt : elements) + llvmElements.push_back(elt.getUnderlyingType()); + + // Lock access to the dialect as this may modify the LLVM context. + llvm::sys::SmartScopedLock lock(dialect->impl->mutex); + return get( + dialect->getContext(), + llvm::StructType::get(dialect->getLLVMContext(), llvmElements, isPacked)); +} +LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) { + // Lock access to the dialect as this may modify the LLVM context. + llvm::sys::SmartScopedLock lock(elementType.getDialect().impl->mutex); + return get( + elementType.getContext(), + llvm::VectorType::get(elementType.getUnderlyingType(), numElements)); +} +LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) { + return get(dialect->getContext(), + llvm::Type::getVoidTy(dialect->getLLVMContext())); +} diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 36267e9..a2476dc 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -45,46 +45,37 @@ llvm::LLVMContext &LLVMLowering::getLLVMContext() { return module->getContext(); } -// Wrap the given LLVM IR type into an LLVM IR dialect type. -Type LLVMLowering::wrap(llvm::Type *llvmType) { - return LLVM::LLVMType::get(llvmDialect->getContext(), llvmType); -} - // Extract an LLVM IR type from the LLVM IR dialect type. -llvm::Type *LLVMLowering::unwrap(Type type) { +LLVM::LLVMType LLVMLowering::unwrap(Type type) { if (!type) return nullptr; auto *mlirContext = type.getContext(); auto wrappedLLVMType = type.dyn_cast(); if (!wrappedLLVMType) - return mlirContext->emitError(UnknownLoc::get(mlirContext), - "conversion resulted in a non-LLVM type"), - nullptr; - return wrappedLLVMType.getUnderlyingType(); + mlirContext->emitError(UnknownLoc::get(mlirContext), + "conversion resulted in a non-LLVM type"); + return wrappedLLVMType; } -llvm::IntegerType *LLVMLowering::getIndexType() { - return llvm::IntegerType::get(llvmDialect->getLLVMContext(), - module->getDataLayout().getPointerSizeInBits()); +LLVM::LLVMType LLVMLowering::getIndexType() { + return LLVM::LLVMType::getIntNTy( + llvmDialect, module->getDataLayout().getPointerSizeInBits()); } -Type LLVMLowering::convertIndexType(IndexType type) { - return wrap(getIndexType()); -} +Type LLVMLowering::convertIndexType(IndexType type) { return getIndexType(); } Type LLVMLowering::convertIntegerType(IntegerType type) { - return wrap( - llvm::Type::getIntNTy(llvmDialect->getLLVMContext(), type.getWidth())); + return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth()); } Type LLVMLowering::convertFloatType(FloatType type) { switch (type.getKind()) { case mlir::StandardTypes::F32: - return wrap(llvm::Type::getFloatTy(llvmDialect->getLLVMContext())); + return LLVM::LLVMType::getFloatTy(llvmDialect); case mlir::StandardTypes::F64: - return wrap(llvm::Type::getDoubleTy(llvmDialect->getLLVMContext())); + return LLVM::LLVMType::getDoubleTy(llvmDialect); case mlir::StandardTypes::F16: - return wrap(llvm::Type::getHalfTy(llvmDialect->getLLVMContext())); + return LLVM::LLVMType::getHalfTy(llvmDialect); case mlir::StandardTypes::BF16: { auto *mlirContext = llvmDialect->getContext(); return mlirContext->emitError(UnknownLoc::get(mlirContext), @@ -102,7 +93,7 @@ Type LLVMLowering::convertFloatType(FloatType type) { // they are into an LLVM StructType in their order of appearance. Type LLVMLowering::convertFunctionType(FunctionType type) { // Convert argument types one by one and check for errors. - SmallVector argTypes; + SmallVector argTypes; for (auto t : type.getInputs()) { auto converted = convertType(t); if (!converted) @@ -113,14 +104,14 @@ Type LLVMLowering::convertFunctionType(FunctionType type) { // If function does not return anything, create the void result type, // if it returns on element, convert it, otherwise pack the result types into // a struct. - llvm::Type *resultType = + LLVM::LLVMType resultType = type.getNumResults() == 0 - ? llvm::Type::getVoidTy(llvmDialect->getLLVMContext()) + ? LLVM::LLVMType::getVoidTy(llvmDialect) : unwrap(packFunctionResults(type.getResults())); if (!resultType) return {}; - return wrap(llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false) - ->getPointerTo()); + return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false) + .getPointerTo(); } // Convert a MemRef to an LLVM type. If the memref is statically-shaped, then @@ -129,21 +120,21 @@ Type LLVMLowering::convertFunctionType(FunctionType type) { // pointer to the elemental type of the MemRef and the following N elements are // values of the Index type, one for each of N dynamic dimensions of the MemRef. Type LLVMLowering::convertMemRefType(MemRefType type) { - llvm::Type *elementType = unwrap(convertType(type.getElementType())); + LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; - auto ptrType = elementType->getPointerTo(); + auto ptrType = elementType.getPointerTo(); // Extra value for the memory space. unsigned numDynamicSizes = type.getNumDynamicDims(); // If memref is statically-shaped we return the underlying pointer type. - if (numDynamicSizes == 0) { - return wrap(ptrType); - } - SmallVector types(numDynamicSizes + 1, getIndexType()); + if (numDynamicSizes == 0) + return ptrType; + + SmallVector types(numDynamicSizes + 1, getIndexType()); types.front() = ptrType; - return wrap(llvm::StructType::get(llvmDialect->getLLVMContext(), types)); + return LLVM::LLVMType::getStructTy(llvmDialect, types); } // Convert a 1D vector type to an LLVM vector type. @@ -155,9 +146,9 @@ Type LLVMLowering::convertVectorType(VectorType type) { return {}; } - llvm::Type *elementType = unwrap(convertType(type.getElementType())); + LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); return elementType - ? wrap(llvm::VectorType::get(elementType, type.getShape().front())) + ? LLVM::LLVMType::getVectorTy(elementType, type.getShape().front()) : Type(); } @@ -189,8 +180,7 @@ static Type getMemRefElementPtrType(MemRefType t, LLVMLowering &lowering) { auto converted = lowering.convertType(elementType); if (!converted) return {}; - llvm::Type *llvmType = converted.cast().getUnderlyingType(); - return LLVM::LLVMType::get(t.getContext(), llvmType->getPointerTo()); + return converted.cast().getPointerTo(); } LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, @@ -226,15 +216,13 @@ public: // Get the MLIR type wrapping the LLVM integer type whose bit width is defined // by the pointer size used in the LLVM module. LLVM::LLVMType getIndexType() const { - llvm::Type *llvmType = llvm::Type::getIntNTy( - getContext(), getModule().getDataLayout().getPointerSizeInBits()); - return LLVM::LLVMType::get(dialect.getContext(), llvmType); + return LLVM::LLVMType::getIntNTy( + &dialect, getModule().getDataLayout().getPointerSizeInBits()); } // Get the MLIR type wrapping the LLVM i8* type. LLVM::LLVMType getVoidPtrType() const { - return LLVM::LLVMType::get(dialect.getContext(), - llvm::Type::getInt8PtrTy(getContext())); + return LLVM::LLVMType::getInt8PtrTy(&dialect); } // Create an LLVM IR pseudo-operation defining the given index constant. @@ -478,10 +466,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern { cumulativeSize) .getResult(0); auto structElementType = lowering.convertType(elementType); - auto elementPtrType = LLVM::LLVMType::get( - op->getContext(), structElementType.cast() - .getUnderlyingType() - ->getPointerTo()); + auto elementPtrType = + structElementType.cast().getPointerTo(); allocated = rewriter.create(op->getLoc(), elementPtrType, ArrayRef(allocated)); @@ -530,14 +516,9 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { op->getFunction()->getModule()->getFunctions().push_back(freeFunc); } - auto *type = - operands[0]->getType().cast().getUnderlyingType(); - auto hasStaticShape = type->isPointerTy(); - Type elementPtrType = - (hasStaticShape) - ? rewriter.getType(type) - : rewriter.getType( - cast(type)->getStructElementType(0)); + auto type = operands[0]->getType().cast(); + auto hasStaticShape = type.getUnderlyingType()->isPointerTy(); + Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0); Value *bufferPtr = extractMemRefElementPtr( rewriter, op->getLoc(), operands[0], elementPtrType, hasStaticShape); Value *casted = rewriter.create( @@ -964,10 +945,6 @@ Type LLVMLowering::convertType(Type t) { return {}; } -static llvm::Type *unwrapType(Type type) { - return type.cast().getUnderlyingType(); -} - // Create an LLVM IR structure type if there is more than one result. Type LLVMLowering::packFunctionResults(ArrayRef types) { assert(!types.empty() && "expected non-empty list of type"); @@ -975,18 +952,16 @@ Type LLVMLowering::packFunctionResults(ArrayRef types) { if (types.size() == 1) return convertType(types.front()); - SmallVector resultTypes; + SmallVector resultTypes; resultTypes.reserve(types.size()); for (auto t : types) { - Type converted = convertType(t); + auto converted = convertType(t).dyn_cast(); if (!converted) return {}; - resultTypes.push_back(unwrapType(converted)); + resultTypes.push_back(converted); } - return LLVM::LLVMType::get( - llvmDialect->getContext(), - llvm::StructType::get(llvmDialect->getLLVMContext(), resultTypes)); + return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes); } // Convert function signatures using the stored LLVM IR module. diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index ef762ff..8c2bdb7 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -64,12 +64,10 @@ using llvm_select = ValueBuilder; using icmp = ValueBuilder; template -static llvm::Type *getPtrToElementType(T containerType, - LLVMLowering &lowering) { +static LLVMType getPtrToElementType(T containerType, LLVMLowering &lowering) { return lowering.convertType(containerType.getElementType()) .template cast() - .getUnderlyingType() - ->getPointerTo(); + .getPointerTo(); } // Convert the given type to the LLVM IR Dialect type. The following @@ -82,9 +80,8 @@ static llvm::Type *getPtrToElementType(T containerType, // containing the respective dynamic values. static Type convertLinalgType(Type t, LLVMLowering &lowering) { auto *context = t.getContext(); - auto *int64Ty = lowering.convertType(IntegerType::get(64, context)) - .cast() - .getUnderlyingType(); + auto int64Ty = lowering.convertType(IntegerType::get(64, context)) + .cast(); // A buffer descriptor contains the pointer to a flat region of storage and // the size of the region. @@ -95,9 +92,8 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) { // int64_t size; // }; if (auto bufferType = t.dyn_cast()) { - auto *ptrTy = getPtrToElementType(bufferType, lowering); - auto *structTy = llvm::StructType::get(ptrTy, int64Ty); - return LLVMType::get(context, structTy); + auto ptrTy = getPtrToElementType(bufferType, lowering); + return LLVMType::getStructTy(ptrTy, int64Ty); } // Range descriptor contains the range bounds and the step as 64-bit integers. @@ -107,10 +103,8 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) { // int64_t max; // int64_t step; // }; - if (t.isa()) { - auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty); - return LLVMType::get(context, structTy); - } + if (t.isa()) + return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); // View descriptor contains the pointer to the data buffer, followed by a // 64-bit integer containing the distance between the beginning of the buffer @@ -136,10 +130,9 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) { // int64_t strides[Rank]; // }; if (auto viewType = t.dyn_cast()) { - auto *ptrTy = getPtrToElementType(viewType, lowering); - auto *arrayTy = llvm::ArrayType::get(int64Ty, viewType.getRank()); - auto *structTy = llvm::StructType::get(ptrTy, int64Ty, arrayTy, arrayTy); - return LLVMType::get(context, structTy); + auto ptrTy = getPtrToElementType(viewType, lowering); + auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank()); + return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy); } return Type(); @@ -165,9 +158,8 @@ public: void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { auto indexType = IndexType::get(op->getContext()); - auto voidPtrTy = LLVM::LLVMType::get( - op->getContext(), - llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo()); + auto voidPtrTy = + LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); auto int64Ty = lowering.convertType(operands[0]->getType()); // Insert the `malloc` declaration if it is not already present. auto *module = op->getFunction()->getModule(); @@ -187,8 +179,8 @@ public: llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8); else elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); - auto elementPtrType = rewriter.getType(getPtrToElementType( - allocOp.getResult()->getType().cast(), lowering)); + auto elementPtrType = getPtrToElementType( + allocOp.getResult()->getType().cast(), lowering); auto bufferDescriptorType = convertLinalgType(allocOp.getResult()->getType(), lowering); @@ -221,9 +213,8 @@ public: void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { - auto voidPtrTy = LLVM::LLVMType::get( - op->getContext(), - llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo()); + auto voidPtrTy = + LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); // Insert the `free` declaration if it is not already present. auto *module = op->getFunction()->getModule(); Function *freeFunc = module->getNamedFunction("free"); @@ -235,8 +226,8 @@ public: // Get MLIR types for extracting element pointer. auto deallocOp = cast(op); - auto elementPtrTy = rewriter.getType(getPtrToElementType( - deallocOp.getOperand()->getType().cast(), lowering)); + auto elementPtrTy = getPtrToElementType( + deallocOp.getOperand()->getType().cast(), lowering); // Emit MLIR for buffer_dealloc. edsc::ScopedContext context(rewriter, op->getLoc()); @@ -298,8 +289,7 @@ public: ArrayRef indices, PatternRewriter &rewriter) const { auto loadOp = cast(op); - auto elementTy = rewriter.getType( - getPtrToElementType(loadOp.getViewType(), lowering)); + auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); auto pos = [&rewriter](ArrayRef values) { return positionAttr(rewriter, values); @@ -425,8 +415,7 @@ public: // Helper function to obtain the ptr of the given `view`. auto getViewPtr = [pos, &rewriter, this](ViewType type, Value *view) -> Value * { - auto elementPtrTy = - rewriter.getType(getPtrToElementType(type, lowering)); + auto elementPtrTy = getPtrToElementType(type, lowering); return extractvalue(elementPtrTy, view, pos(0)); }; @@ -512,8 +501,7 @@ public: PatternRewriter &rewriter) const override { auto viewOp = cast(op); auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering); - auto elementTy = rewriter.getType( - getPtrToElementType(viewOp.getViewType(), lowering)); + auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); auto pos = [&rewriter](ArrayRef values) { -- 2.7.4