Add thread-safe utilities to LLVMType to allow constructing llvm types in a multi...
authorRiver Riddle <riverriddle@google.com>
Wed, 22 May 2019 21:56:07 +0000 (14:56 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 02:57:03 +0000 (19:57 -0700)
--

PiperOrigin-RevId: 249526233

mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/include/mlir/LLVMIR/LLVMDialect.h
mlir/include/mlir/LLVMIR/LLVMLowering.h
mlir/lib/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp

index a23828d..7a0e1a5 100644 (file)
@@ -62,22 +62,14 @@ Type linalg::convertLinalgType(Type t) {
   // Simple conversions.
   if (t.isa<IndexType>()) {
     int width = dialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
-    auto *integerTy = llvm::IntegerType::get(dialect->getLLVMContext(), width);
-    return LLVM::LLVMType::get(context, integerTy);
-  }
-  if (auto intTy = t.dyn_cast<IntegerType>()) {
-    int width = intTy.getWidth();
-    auto *integerTy = llvm::IntegerType::get(dialect->getLLVMContext(), width);
-    return LLVM::LLVMType::get(context, integerTy);
-  }
-  if (t.isF32()) {
-    auto *floatTy = llvm::Type::getFloatTy(dialect->getLLVMContext());
-    return LLVM::LLVMType::get(context, floatTy);
-  }
-  if (t.isF64()) {
-    auto *doubleTy = llvm::Type::getDoubleTy(dialect->getLLVMContext());
-    return LLVM::LLVMType::get(context, doubleTy);
+    return LLVM::LLVMType::getIntNTy(dialect, width);
   }
+  if (auto intTy = t.dyn_cast<IntegerType>())
+    return LLVM::LLVMType::getIntNTy(dialect, intTy.getWidth());
+  if (t.isF32())
+    return LLVM::LLVMType::getFloatTy(dialect);
+  if (t.isF64())
+    return LLVM::LLVMType::getDoubleTy(dialect);
 
   // Range descriptor contains the range bounds and the step as 64-bit integers.
   //
@@ -87,9 +79,8 @@ Type linalg::convertLinalgType(Type t) {
   //   int64_t step;
   // };
   if (auto rangeTy = t.dyn_cast<linalg::RangeType>()) {
-    auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext());
-    auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty);
-    return LLVM::LLVMType::get(context, structTy);
+    auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect);
+    return LLVM::LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
   }
 
   // View descriptor contains the pointer to the data buffer, followed by a
@@ -116,14 +107,12 @@ Type linalg::convertLinalgType(Type t) {
   //   int64_t strides[Rank];
   // };
   if (auto viewTy = t.dyn_cast<linalg::ViewType>()) {
-    auto *elemTy = linalg::convertLinalgType(viewTy.getElementType())
-                       .cast<LLVM::LLVMType>()
-                       .getUnderlyingType()
-                       ->getPointerTo();
-    auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext());
-    auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank());
-    auto *structTy = llvm::StructType::get(elemTy, int64Ty, arrayTy, arrayTy);
-    return LLVM::LLVMType::get(context, structTy);
+    auto elemTy = linalg::convertLinalgType(viewTy.getElementType())
+                      .cast<LLVM::LLVMType>()
+                      .getPointerTo();
+    auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect);
+    auto arrayTy = LLVM::LLVMType::getArrayTy(int64Ty, viewTy.getRank());
+    return LLVM::LLVMType::getStructTy(elemTy, int64Ty, arrayTy, arrayTy);
   }
 
   // All other types are kept as is.
@@ -217,11 +206,9 @@ public:
       if (type.hasStaticShape())
         return memref;
 
-      auto elementTy = LLVM::LLVMType::get(
-          type.getContext(), linalg::convertLinalgType(type.getElementType())
-                                 .cast<LLVM::LLVMType>()
-                                 .getUnderlyingType()
-                                 ->getPointerTo());
+      auto elementTy = linalg::convertLinalgType(type.getElementType())
+                           .cast<LLVM::LLVMType>()
+                           .getPointerTo();
       return intrinsics::extractvalue(elementTy, memref, pos(0));
     };
 
@@ -307,11 +294,9 @@ public:
     auto sliceOp = cast<linalg::SliceOp>(op);
     auto newViewDescriptorType =
         linalg::convertLinalgType(sliceOp.getViewType());
-    auto elementType = rewriter.getType<LLVM::LLVMType>(
-        linalg::convertLinalgType(sliceOp.getElementType())
-            .cast<LLVM::LLVMType>()
-            .getUnderlyingType()
-            ->getPointerTo());
+    auto elementType = linalg::convertLinalgType(sliceOp.getElementType())
+                           .cast<LLVM::LLVMType>()
+                           .getPointerTo();
     auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64));
 
     auto pos = [&rewriter](ArrayRef<int> values) {
index 98bd867..9f2d46c 100644 (file)
@@ -67,11 +67,9 @@ public:
     auto loadOp = cast<Op>(op);
     auto elementType =
         loadOp.getViewType().template cast<linalg::ViewType>().getElementType();
-    auto *llvmPtrType = linalg::convertLinalgType(elementType)
-                            .template cast<LLVM::LLVMType>()
-                            .getUnderlyingType()
-                            ->getPointerTo();
-    elementType = rewriter.getType<LLVM::LLVMType>(llvmPtrType);
+    elementType = linalg::convertLinalgType(elementType)
+                      .template cast<LLVM::LLVMType>()
+                      .getPointerTo();
     auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64));
 
     auto pos = [&rewriter](ArrayRef<int> values) {
index 7911649..5a0a901 100644 (file)
@@ -210,15 +210,11 @@ private:
 
     // Create a function declaration for printf, signature is `i32 (i8*, ...)`
     Builder builder(&module);
-    MLIRContext *context = module.getContext();
-    auto *llvmDialect =
+    auto *dialect =
         module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
-    auto &llvmModule = llvmDialect->getLLVMModule();
-    llvm::IRBuilder<> llvmBuilder(llvmModule.getContext());
 
-    auto llvmI32Ty = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(32));
-    auto llvmI8PtrTy =
-        LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(8)->getPointerTo());
+    auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
+    auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
     auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
     printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy);
     // It should be variadic, but we don't support it fully just yet.
index 170a5b3..b8272f9 100644 (file)
@@ -41,10 +41,12 @@ class LLVMContext;
 
 namespace mlir {
 namespace LLVM {
+class LLVMDialect;
 
 namespace detail {
 struct LLVMTypeStorage;
-}
+struct LLVMDialectImpl;
+} // namespace detail
 
 class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
                                              detail::LLVMTypeStorage> {
@@ -57,9 +59,72 @@ public:
 
   static bool kindof(unsigned kind) { return kind == LLVM_TYPE; }
 
-  static LLVMType get(MLIRContext *context, llvm::Type *llvmType);
-
+  LLVMDialect &getDialect();
   llvm::Type *getUnderlyingType() const;
+
+  /// Array type utilities.
+  LLVMType getArrayElementType();
+
+  /// Pointer type utilities.
+  LLVMType getPointerTo(unsigned addrSpace = 0);
+  LLVMType getPointerElementTy();
+
+  /// Struct type utilities.
+  LLVMType getStructElementType(unsigned i);
+
+  /// Utilities used to generate floating point types.
+  static LLVMType getDoubleTy(LLVMDialect *dialect);
+  static LLVMType getFloatTy(LLVMDialect *dialect);
+  static LLVMType getHalfTy(LLVMDialect *dialect);
+
+  /// Utilities used to generate integer types.
+  static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits);
+  static LLVMType getInt1Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/1);
+  }
+  static LLVMType getInt8Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/8);
+  }
+  static LLVMType getInt8PtrTy(LLVMDialect *dialect) {
+    return getInt8Ty(dialect).getPointerTo();
+  }
+  static LLVMType getInt16Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/16);
+  }
+  static LLVMType getInt32Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/32);
+  }
+  static LLVMType getInt64Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/64);
+  }
+
+  /// Utilities used to generate other miscellaneous types.
+  static LLVMType getArrayTy(LLVMType elementType, uint64_t numElements);
+  static LLVMType getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
+                                bool isVarArg);
+  static LLVMType getFunctionTy(LLVMType result, bool isVarArg) {
+    return getFunctionTy(result, llvm::None, isVarArg);
+  }
+  static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef<LLVMType> elements,
+                              bool isPacked = false);
+  static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) {
+    return getStructTy(dialect, llvm::None, isPacked);
+  }
+  template <typename... Args>
+  static typename std::enable_if<llvm::are_base_of<LLVMType, Args...>::value,
+                                 LLVMType>::type
+  getStructTy(LLVMType elt1, Args... elts) {
+    SmallVector<LLVMType, 8> fields({elt1, elts...});
+    return getStructTy(&elt1.getDialect(), fields);
+  }
+  static LLVMType getVectorTy(LLVMType elementType, unsigned numElements);
+  static LLVMType getVoidTy(LLVMDialect *dialect);
+
+private:
+  friend LLVMDialect;
+
+  /// Get an LLVM type with a pre-existing llvm type.
+  static LLVMType get(MLIRContext *context, llvm::Type *llvmType);
 };
 
 ///// Ops /////
@@ -69,10 +134,11 @@ public:
 class LLVMDialect : public Dialect {
 public:
   explicit LLVMDialect(MLIRContext *context);
+  ~LLVMDialect();
   static StringRef getDialectNamespace() { return "llvm"; }
 
-  llvm::LLVMContext &getLLVMContext() { return llvmContext; }
-  llvm::Module &getLLVMModule() { return module; }
+  llvm::LLVMContext &getLLVMContext();
+  llvm::Module &getLLVMModule();
 
   /// Parse a type registered to this dialect.
   Type parseType(StringRef tyData, Location loc) const override;
@@ -86,8 +152,9 @@ public:
                                            NamedAttribute argAttr) override;
 
 private:
-  llvm::LLVMContext llvmContext;
-  llvm::Module module;
+  friend LLVMType;
+
+  std::unique_ptr<detail::LLVMDialectImpl> impl;
 };
 
 } // end namespace LLVM
index c2bf040..9947f42 100644 (file)
@@ -31,11 +31,12 @@ class IntegerType;
 class LLVMContext;
 class Module;
 class Type;
-}
+} // namespace llvm
 
 namespace mlir {
 namespace LLVM {
 class LLVMDialect;
+class LLVMType;
 }
 
 /// Conversion from the Standard dialect to the LLVM IR dialect.  Provides hooks
@@ -55,6 +56,9 @@ public:
   /// Returns the LLVM context.
   llvm::LLVMContext &getLLVMContext();
 
+  /// Returns the LLVM dialect.
+  LLVM::LLVMDialect *getDialect() { return llvmDialect; }
+
 protected:
   /// Add a set of converters to the given pattern list. Store the module
   /// associated with the dialect for further type conversion.
@@ -119,13 +123,13 @@ private:
 
   // Get the LLVM representation of the index type based on the bitwidth of the
   // pointer as defined by the data layout of the module.
-  llvm::IntegerType *getIndexType();
+  LLVM::LLVMType getIndexType();
 
   // Wrap the given LLVM IR type into an LLVM IR dialect type.
   Type wrap(llvm::Type *llvmType);
 
-  // Extract an LLVM IR type from the LLVM IR dialect type.
-  llvm::Type *unwrap(Type type);
+  // Extract an LLVM IR dialect type.
+  LLVM::LLVMType unwrap(Type type);
 };
 
 /// Base class for operation conversions targeting the LLVM IR dialect. Provides
index fd197c0..950b1d4 100644 (file)
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Type.h"
+#include "llvm/Support/Mutex.h"
 #include "llvm/Support/SourceMgr.h"
 
 using namespace mlir;
 using namespace mlir::LLVM;
 
-namespace mlir {
-namespace LLVM {
-namespace detail {
-struct LLVMTypeStorage : public ::mlir::TypeStorage {
-  LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {}
-
-  // LLVM types are pointer-unique.
-  using KeyTy = llvm::Type *;
-  bool operator==(const KeyTy &key) const { return key == underlyingType; }
-
-  static LLVMTypeStorage *construct(TypeStorageAllocator &allocator,
-                                    llvm::Type *ty) {
-    return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty);
-  }
-
-  llvm::Type *underlyingType;
-};
-} // end namespace detail
-} // end namespace LLVM
-} // end namespace mlir
-
-LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
-  return Base::get(context, FIRST_LLVM_TYPE, llvmType);
-}
-
-llvm::Type *LLVMType::getUnderlyingType() const {
-  return getImpl()->underlyingType;
-}
-
 static void printLLVMBinaryOp(OpAsmPrinter *p, Operation *op) {
   // Fallback to the generic form if the op is not well-formed (may happen
   // during incomplete rewrites, and used for debugging).
@@ -161,14 +133,13 @@ static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) {
   // The result type is either i1 or a vector type <? x i1> if the inputs are
   // vectors.
   auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
-  llvm::Type *llvmResultType = llvm::Type::getInt1Ty(dialect->getLLVMContext());
+  auto resultType = LLVMType::getInt1Ty(dialect);
   auto argType = type.dyn_cast<LLVM::LLVMType>();
   if (!argType)
     return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type");
   if (argType.getUnderlyingType()->isVectorTy())
-    llvmResultType = llvm::VectorType::get(
-        llvmResultType, argType.getUnderlyingType()->getVectorNumElements());
-  auto resultType = builder.getType<LLVM::LLVMType>(llvmResultType);
+    resultType = LLVMType::getVectorTy(
+        resultType, argType.getUnderlyingType()->getVectorNumElements());
 
   result->attributes = attrs;
   result->addTypes({resultType});
@@ -180,9 +151,7 @@ static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) {
 //===----------------------------------------------------------------------===//
 
 static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) {
-  auto *llvmPtrTy = op.getType().cast<LLVM::LLVMType>().getUnderlyingType();
-  auto *llvmElemTy = llvm::cast<llvm::PointerType>(llvmPtrTy)->getElementType();
-  auto elemTy = LLVM::LLVMType::get(op.getContext(), llvmElemTy);
+  auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
 
   auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()},
                                   op.getContext());
@@ -291,13 +260,10 @@ static Type getLoadStoreElementType(OpAsmParser *parser, Type type,
   if (!llvmTy)
     return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
            nullptr;
-  auto *llvmPtrTy = dyn_cast<llvm::PointerType>(llvmTy.getUnderlyingType());
-  if (!llvmPtrTy)
+  if (!llvmTy.getUnderlyingType()->isPointerTy())
     return parser->emitError(trailingTypeLoc, "expected LLVM pointer type"),
            nullptr;
-  auto elemTy = LLVM::LLVMType::get(parser->getBuilder().getContext(),
-                                    llvmPtrTy->getElementType());
-  return elemTy;
+  return llvmTy.getPointerElementTy();
 }
 
 // <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
@@ -465,33 +431,28 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
     Builder &builder = parser->getBuilder();
     auto *llvmDialect =
         builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
-    llvm::Type *llvmResultType;
-    Type wrappedResultType;
+    LLVM::LLVMType llvmResultType;
     if (funcType.getNumResults() == 0) {
-      llvmResultType = llvm::Type::getVoidTy(llvmDialect->getLLVMContext());
-      wrappedResultType = builder.getType<LLVM::LLVMType>(llvmResultType);
+      llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
     } else {
-      wrappedResultType = funcType.getResult(0);
-      auto wrappedLLVMResultType = wrappedResultType.dyn_cast<LLVM::LLVMType>();
-      if (!wrappedLLVMResultType)
+      llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
+      if (!llvmResultType)
         return parser->emitError(trailingTypeLoc,
                                  "expected result to have LLVM type");
-      llvmResultType = wrappedLLVMResultType.getUnderlyingType();
     }
 
-    SmallVector<llvm::Type *, 8> argTypes;
+    SmallVector<LLVM::LLVMType, 8> argTypes;
     argTypes.reserve(funcType.getNumInputs());
     for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
       auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>();
       if (!argType)
         return parser->emitError(trailingTypeLoc,
                                  "expected LLVM types as inputs");
-      argTypes.push_back(argType.getUnderlyingType());
+      argTypes.push_back(argType);
     }
-    auto *llvmFuncType = llvm::FunctionType::get(llvmResultType, argTypes,
-                                                 /*isVarArg=*/false);
-    auto wrappedFuncType =
-        builder.getType<LLVM::LLVMType>(llvmFuncType->getPointerTo());
+    auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
+                                                      /*isVarArg=*/false);
+    auto wrappedFuncType = llvmFuncType.getPointerTo();
 
     auto funcArguments =
         ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
@@ -505,7 +466,7 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
                                 parser->getNameLoc(), result->operands))
       return failure();
 
-    result->addTypes(wrappedResultType);
+    result->addTypes(llvmResultType);
   }
 
   result->attributes = attrs;
@@ -544,7 +505,6 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser,
   // type by taking the element type, indexed by the position attribute for
   // stuctures.  Check the position index before accessing, it is supposed to be
   // in bounds.
-  llvm::Type *llvmContainerType = wrappedContainerType.getUnderlyingType();
   for (Attribute subAttr : positionArrayAttr) {
     auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
     if (!positionElementAttr)
@@ -552,27 +512,27 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser,
                                "expected an array of integer literals"),
              nullptr;
     int position = positionElementAttr.getInt();
+    auto *llvmContainerType = wrappedContainerType.getUnderlyingType();
     if (llvmContainerType->isArrayTy()) {
       if (position < 0 || static_cast<unsigned>(position) >=
                               llvmContainerType->getArrayNumElements())
         return parser->emitError(attributeLoc, "position out of bounds"),
                nullptr;
-      llvmContainerType = llvmContainerType->getArrayElementType();
+      wrappedContainerType = wrappedContainerType.getArrayElementType();
     } else if (llvmContainerType->isStructTy()) {
       if (position < 0 || static_cast<unsigned>(position) >=
                               llvmContainerType->getStructNumElements())
         return parser->emitError(attributeLoc, "position out of bounds"),
                nullptr;
-      llvmContainerType = llvmContainerType->getStructElementType(position);
+      wrappedContainerType =
+          wrappedContainerType.getStructElementType(position);
     } else {
       return parser->emitError(typeLoc,
                                "expected wrapped LLVM IR structure/array type"),
              nullptr;
     }
   }
-
-  Builder &builder = parser->getBuilder();
-  return builder.getType<LLVM::LLVMType>(llvmContainerType);
+  return wrappedContainerType;
 }
 
 // <operation> ::= `llvm.extractvalue` ssa-use
@@ -730,8 +690,7 @@ static ParseResult parseCondBrOp(OpAsmParser *parser, OperationState *result) {
   Builder &builder = parser->getBuilder();
   auto *llvmDialect =
       builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
-  auto i1Type = builder.getType<LLVM::LLVMType>(
-      llvm::Type::getInt1Ty(llvmDialect->getLLVMContext()));
+  auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect);
 
   if (parser->parseOperand(condition) || parser->parseComma() ||
       parser->parseSuccessorAndUseList(trueDest, trueOperands) ||
@@ -844,9 +803,26 @@ static ParseResult parseConstantOp(OpAsmParser *parser,
 // LLVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
 
+namespace mlir {
+namespace LLVM {
+namespace detail {
+struct LLVMDialectImpl {
+  LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {}
+
+  llvm::LLVMContext llvmContext;
+  llvm::Module module;
+
+  /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not
+  /// multi-threaded and requires locked access to prevent race conditions.
+  llvm::sys::SmartMutex<true> mutex;
+};
+} // end namespace detail
+} // end namespace LLVM
+} // end namespace mlir
+
 LLVMDialect::LLVMDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context),
-      module("LLVMDialectModule", llvmContext) {
+      impl(new detail::LLVMDialectImpl()) {
   addTypes<LLVMType>();
   addOperations<
 #define GET_OP_LIST
@@ -857,13 +833,21 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
   allowUnknownOperations();
 }
 
+LLVMDialect::~LLVMDialect() {}
+
 #define GET_OP_CLASSES
 #include "mlir/LLVMIR/LLVMOps.cpp.inc"
 
+llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; }
+llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
+
 /// Parse a type registered to this dialect.
 Type LLVMDialect::parseType(StringRef tyData, Location loc) const {
+  // LLVM is not thread-safe, so lock access to it.
+  llvm::sys::SmartScopedLock<true> lock(impl->mutex);
+
   llvm::SMDiagnostic errorMessage;
-  llvm::Type *type = llvm::parseType(tyData, errorMessage, module);
+  llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module);
   if (!type)
     return (getContext()->emitError(loc, errorMessage.getMessage()), nullptr);
   return LLVMType::get(getContext(), type);
@@ -889,3 +873,126 @@ LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function *func,
 }
 
 static DialectRegistration<LLVMDialect> llvmDialect;
+
+//===----------------------------------------------------------------------===//
+// LLVMType.
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+struct LLVMTypeStorage : public ::mlir::TypeStorage {
+  LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {}
+
+  // LLVM types are pointer-unique.
+  using KeyTy = llvm::Type *;
+  bool operator==(const KeyTy &key) const { return key == underlyingType; }
+
+  static LLVMTypeStorage *construct(TypeStorageAllocator &allocator,
+                                    llvm::Type *ty) {
+    return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty);
+  }
+
+  llvm::Type *underlyingType;
+};
+} // end namespace detail
+} // end namespace LLVM
+} // end namespace mlir
+
+LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
+  return Base::get(context, FIRST_LLVM_TYPE, llvmType);
+}
+
+LLVMDialect &LLVMType::getDialect() {
+  return static_cast<LLVMDialect &>(Type::getDialect());
+}
+
+llvm::Type *LLVMType::getUnderlyingType() const {
+  return getImpl()->underlyingType;
+}
+
+/// Array type utilities.
+LLVMType LLVMType::getArrayElementType() {
+  return get(getContext(), getUnderlyingType()->getArrayElementType());
+}
+
+/// Pointer type utilities.
+LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
+  // Lock access to the dialect as this may modify the LLVM context.
+  llvm::sys::SmartScopedLock<true> lock(getDialect().impl->mutex);
+  return get(getContext(), getUnderlyingType()->getPointerTo(addrSpace));
+}
+LLVMType LLVMType::getPointerElementTy() {
+  return get(getContext(), getUnderlyingType()->getPointerElementType());
+}
+
+/// Struct type utilities.
+LLVMType LLVMType::getStructElementType(unsigned i) {
+  return get(getContext(), getUnderlyingType()->getStructElementType(i));
+}
+
+/// Utilities used to generate floating point types.
+LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
+  return get(dialect->getContext(),
+             llvm::Type::getDoubleTy(dialect->getLLVMContext()));
+}
+LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
+  return get(dialect->getContext(),
+             llvm::Type::getFloatTy(dialect->getLLVMContext()));
+}
+LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
+  return get(dialect->getContext(),
+             llvm::Type::getHalfTy(dialect->getLLVMContext()));
+}
+
+/// Utilities used to generate integer types.
+LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
+  // Lock access to the dialect as this may modify the LLVM context.
+  llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
+  return get(dialect->getContext(),
+             llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits));
+}
+
+/// Utilities used to generate other miscellaneous types.
+LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) {
+  // Lock access to the dialect as this may modify the LLVM context.
+  llvm::sys::SmartScopedLock<true> lock(elementType.getDialect().impl->mutex);
+  return get(
+      elementType.getContext(),
+      llvm::ArrayType::get(elementType.getUnderlyingType(), numElements));
+}
+LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
+                                 bool isVarArg) {
+  SmallVector<llvm::Type *, 8> llvmParams;
+  for (auto param : params)
+    llvmParams.push_back(param.getUnderlyingType());
+
+  // Lock access to the dialect as this may modify the LLVM context.
+  llvm::sys::SmartScopedLock<true> lock(result.getDialect().impl->mutex);
+  return get(result.getContext(),
+             llvm::FunctionType::get(result.getUnderlyingType(), llvmParams,
+                                     isVarArg));
+}
+LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
+                               ArrayRef<LLVMType> elements, bool isPacked) {
+  SmallVector<llvm::Type *, 8> llvmElements;
+  for (auto elt : elements)
+    llvmElements.push_back(elt.getUnderlyingType());
+
+  // Lock access to the dialect as this may modify the LLVM context.
+  llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
+  return get(
+      dialect->getContext(),
+      llvm::StructType::get(dialect->getLLVMContext(), llvmElements, isPacked));
+}
+LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
+  // Lock access to the dialect as this may modify the LLVM context.
+  llvm::sys::SmartScopedLock<true> lock(elementType.getDialect().impl->mutex);
+  return get(
+      elementType.getContext(),
+      llvm::VectorType::get(elementType.getUnderlyingType(), numElements));
+}
+LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
+  return get(dialect->getContext(),
+             llvm::Type::getVoidTy(dialect->getLLVMContext()));
+}
index 36267e9..a2476dc 100644 (file)
@@ -45,46 +45,37 @@ llvm::LLVMContext &LLVMLowering::getLLVMContext() {
   return module->getContext();
 }
 
-// Wrap the given LLVM IR type into an LLVM IR dialect type.
-Type LLVMLowering::wrap(llvm::Type *llvmType) {
-  return LLVM::LLVMType::get(llvmDialect->getContext(), llvmType);
-}
-
 // Extract an LLVM IR type from the LLVM IR dialect type.
-llvm::Type *LLVMLowering::unwrap(Type type) {
+LLVM::LLVMType LLVMLowering::unwrap(Type type) {
   if (!type)
     return nullptr;
   auto *mlirContext = type.getContext();
   auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
   if (!wrappedLLVMType)
-    return mlirContext->emitError(UnknownLoc::get(mlirContext),
-                                  "conversion resulted in a non-LLVM type"),
-           nullptr;
-  return wrappedLLVMType.getUnderlyingType();
+    mlirContext->emitError(UnknownLoc::get(mlirContext),
+                           "conversion resulted in a non-LLVM type");
+  return wrappedLLVMType;
 }
 
-llvm::IntegerType *LLVMLowering::getIndexType() {
-  return llvm::IntegerType::get(llvmDialect->getLLVMContext(),
-                                module->getDataLayout().getPointerSizeInBits());
+LLVM::LLVMType LLVMLowering::getIndexType() {
+  return LLVM::LLVMType::getIntNTy(
+      llvmDialect, module->getDataLayout().getPointerSizeInBits());
 }
 
-Type LLVMLowering::convertIndexType(IndexType type) {
-  return wrap(getIndexType());
-}
+Type LLVMLowering::convertIndexType(IndexType type) { return getIndexType(); }
 
 Type LLVMLowering::convertIntegerType(IntegerType type) {
-  return wrap(
-      llvm::Type::getIntNTy(llvmDialect->getLLVMContext(), type.getWidth()));
+  return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
 }
 
 Type LLVMLowering::convertFloatType(FloatType type) {
   switch (type.getKind()) {
   case mlir::StandardTypes::F32:
-    return wrap(llvm::Type::getFloatTy(llvmDialect->getLLVMContext()));
+    return LLVM::LLVMType::getFloatTy(llvmDialect);
   case mlir::StandardTypes::F64:
-    return wrap(llvm::Type::getDoubleTy(llvmDialect->getLLVMContext()));
+    return LLVM::LLVMType::getDoubleTy(llvmDialect);
   case mlir::StandardTypes::F16:
-    return wrap(llvm::Type::getHalfTy(llvmDialect->getLLVMContext()));
+    return LLVM::LLVMType::getHalfTy(llvmDialect);
   case mlir::StandardTypes::BF16: {
     auto *mlirContext = llvmDialect->getContext();
     return mlirContext->emitError(UnknownLoc::get(mlirContext),
@@ -102,7 +93,7 @@ Type LLVMLowering::convertFloatType(FloatType type) {
 // they are into an LLVM StructType in their order of appearance.
 Type LLVMLowering::convertFunctionType(FunctionType type) {
   // Convert argument types one by one and check for errors.
-  SmallVector<llvm::Type *, 8> argTypes;
+  SmallVector<LLVM::LLVMType, 8> argTypes;
   for (auto t : type.getInputs()) {
     auto converted = convertType(t);
     if (!converted)
@@ -113,14 +104,14 @@ Type LLVMLowering::convertFunctionType(FunctionType type) {
   // If function does not return anything, create the void result type,
   // if it returns on element, convert it, otherwise pack the result types into
   // a struct.
-  llvm::Type *resultType =
+  LLVM::LLVMType resultType =
       type.getNumResults() == 0
-          ? llvm::Type::getVoidTy(llvmDialect->getLLVMContext())
+          ? LLVM::LLVMType::getVoidTy(llvmDialect)
           : unwrap(packFunctionResults(type.getResults()));
   if (!resultType)
     return {};
-  return wrap(llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false)
-                  ->getPointerTo());
+  return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false)
+      .getPointerTo();
 }
 
 // Convert a MemRef to an LLVM type. If the memref is statically-shaped, then
@@ -129,21 +120,21 @@ Type LLVMLowering::convertFunctionType(FunctionType type) {
 // pointer to the elemental type of the MemRef and the following N elements are
 // values of the Index type, one for each of N dynamic dimensions of the MemRef.
 Type LLVMLowering::convertMemRefType(MemRefType type) {
-  llvm::Type *elementType = unwrap(convertType(type.getElementType()));
+  LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
   if (!elementType)
     return {};
-  auto ptrType = elementType->getPointerTo();
+  auto ptrType = elementType.getPointerTo();
 
   // Extra value for the memory space.
   unsigned numDynamicSizes = type.getNumDynamicDims();
   // If memref is statically-shaped we return the underlying pointer type.
-  if (numDynamicSizes == 0) {
-    return wrap(ptrType);
-  }
-  SmallVector<llvm::Type *, 8> types(numDynamicSizes + 1, getIndexType());
+  if (numDynamicSizes == 0)
+    return ptrType;
+
+  SmallVector<LLVM::LLVMType, 8> types(numDynamicSizes + 1, getIndexType());
   types.front() = ptrType;
 
-  return wrap(llvm::StructType::get(llvmDialect->getLLVMContext(), types));
+  return LLVM::LLVMType::getStructTy(llvmDialect, types);
 }
 
 // Convert a 1D vector type to an LLVM vector type.
@@ -155,9 +146,9 @@ Type LLVMLowering::convertVectorType(VectorType type) {
     return {};
   }
 
-  llvm::Type *elementType = unwrap(convertType(type.getElementType()));
+  LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
   return elementType
-             ? wrap(llvm::VectorType::get(elementType, type.getShape().front()))
+             ? LLVM::LLVMType::getVectorTy(elementType, type.getShape().front())
              : Type();
 }
 
@@ -189,8 +180,7 @@ static Type getMemRefElementPtrType(MemRefType t, LLVMLowering &lowering) {
   auto converted = lowering.convertType(elementType);
   if (!converted)
     return {};
-  llvm::Type *llvmType = converted.cast<LLVM::LLVMType>().getUnderlyingType();
-  return LLVM::LLVMType::get(t.getContext(), llvmType->getPointerTo());
+  return converted.cast<LLVM::LLVMType>().getPointerTo();
 }
 
 LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
@@ -226,15 +216,13 @@ public:
   // Get the MLIR type wrapping the LLVM integer type whose bit width is defined
   // by the pointer size used in the LLVM module.
   LLVM::LLVMType getIndexType() const {
-    llvm::Type *llvmType = llvm::Type::getIntNTy(
-        getContext(), getModule().getDataLayout().getPointerSizeInBits());
-    return LLVM::LLVMType::get(dialect.getContext(), llvmType);
+    return LLVM::LLVMType::getIntNTy(
+        &dialect, getModule().getDataLayout().getPointerSizeInBits());
   }
 
   // Get the MLIR type wrapping the LLVM i8* type.
   LLVM::LLVMType getVoidPtrType() const {
-    return LLVM::LLVMType::get(dialect.getContext(),
-                               llvm::Type::getInt8PtrTy(getContext()));
+    return LLVM::LLVMType::getInt8PtrTy(&dialect);
   }
 
   // Create an LLVM IR pseudo-operation defining the given index constant.
@@ -478,10 +466,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
                                   cumulativeSize)
             .getResult(0);
     auto structElementType = lowering.convertType(elementType);
-    auto elementPtrType = LLVM::LLVMType::get(
-        op->getContext(), structElementType.cast<LLVM::LLVMType>()
-                              .getUnderlyingType()
-                              ->getPointerTo());
+    auto elementPtrType =
+        structElementType.cast<LLVM::LLVMType>().getPointerTo();
     allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType,
                                                  ArrayRef<Value *>(allocated));
 
@@ -530,14 +516,9 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
       op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
     }
 
-    auto *type =
-        operands[0]->getType().cast<LLVM::LLVMType>().getUnderlyingType();
-    auto hasStaticShape = type->isPointerTy();
-    Type elementPtrType =
-        (hasStaticShape)
-            ? rewriter.getType<LLVM::LLVMType>(type)
-            : rewriter.getType<LLVM::LLVMType>(
-                  cast<llvm::StructType>(type)->getStructElementType(0));
+    auto type = operands[0]->getType().cast<LLVM::LLVMType>();
+    auto hasStaticShape = type.getUnderlyingType()->isPointerTy();
+    Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
     Value *bufferPtr = extractMemRefElementPtr(
         rewriter, op->getLoc(), operands[0], elementPtrType, hasStaticShape);
     Value *casted = rewriter.create<LLVM::BitcastOp>(
@@ -964,10 +945,6 @@ Type LLVMLowering::convertType(Type t) {
   return {};
 }
 
-static llvm::Type *unwrapType(Type type) {
-  return type.cast<LLVM::LLVMType>().getUnderlyingType();
-}
-
 // Create an LLVM IR structure type if there is more than one result.
 Type LLVMLowering::packFunctionResults(ArrayRef<Type> types) {
   assert(!types.empty() && "expected non-empty list of type");
@@ -975,18 +952,16 @@ Type LLVMLowering::packFunctionResults(ArrayRef<Type> types) {
   if (types.size() == 1)
     return convertType(types.front());
 
-  SmallVector<llvm::Type *, 8> resultTypes;
+  SmallVector<LLVM::LLVMType, 8> resultTypes;
   resultTypes.reserve(types.size());
   for (auto t : types) {
-    Type converted = convertType(t);
+    auto converted = convertType(t).dyn_cast<LLVM::LLVMType>();
     if (!converted)
       return {};
-    resultTypes.push_back(unwrapType(converted));
+    resultTypes.push_back(converted);
   }
 
-  return LLVM::LLVMType::get(
-      llvmDialect->getContext(),
-      llvm::StructType::get(llvmDialect->getLLVMContext(), resultTypes));
+  return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
 }
 
 // Convert function signatures using the stored LLVM IR module.
index ef762ff..8c2bdb7 100644 (file)
@@ -64,12 +64,10 @@ using llvm_select = ValueBuilder<LLVM::SelectOp>;
 using icmp = ValueBuilder<LLVM::ICmpOp>;
 
 template <typename T>
-static llvm::Type *getPtrToElementType(T containerType,
-                                       LLVMLowering &lowering) {
+static LLVMType getPtrToElementType(T containerType, LLVMLowering &lowering) {
   return lowering.convertType(containerType.getElementType())
       .template cast<LLVMType>()
-      .getUnderlyingType()
-      ->getPointerTo();
+      .getPointerTo();
 }
 
 // Convert the given type to the LLVM IR Dialect type.  The following
@@ -82,9 +80,8 @@ static llvm::Type *getPtrToElementType(T containerType,
 //     containing the respective dynamic values.
 static Type convertLinalgType(Type t, LLVMLowering &lowering) {
   auto *context = t.getContext();
-  auto *int64Ty = lowering.convertType(IntegerType::get(64, context))
-                      .cast<LLVM::LLVMType>()
-                      .getUnderlyingType();
+  auto int64Ty = lowering.convertType(IntegerType::get(64, context))
+                     .cast<LLVM::LLVMType>();
 
   // A buffer descriptor contains the pointer to a flat region of storage and
   // the size of the region.
@@ -95,9 +92,8 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
   //   int64_t size;
   // };
   if (auto bufferType = t.dyn_cast<BufferType>()) {
-    auto *ptrTy = getPtrToElementType(bufferType, lowering);
-    auto *structTy = llvm::StructType::get(ptrTy, int64Ty);
-    return LLVMType::get(context, structTy);
+    auto ptrTy = getPtrToElementType(bufferType, lowering);
+    return LLVMType::getStructTy(ptrTy, int64Ty);
   }
 
   // Range descriptor contains the range bounds and the step as 64-bit integers.
@@ -107,10 +103,8 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
   //   int64_t max;
   //   int64_t step;
   // };
-  if (t.isa<RangeType>()) {
-    auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty);
-    return LLVMType::get(context, structTy);
-  }
+  if (t.isa<RangeType>())
+    return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
 
   // View descriptor contains the pointer to the data buffer, followed by a
   // 64-bit integer containing the distance between the beginning of the buffer
@@ -136,10 +130,9 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
   //   int64_t strides[Rank];
   // };
   if (auto viewType = t.dyn_cast<ViewType>()) {
-    auto *ptrTy = getPtrToElementType(viewType, lowering);
-    auto *arrayTy = llvm::ArrayType::get(int64Ty, viewType.getRank());
-    auto *structTy = llvm::StructType::get(ptrTy, int64Ty, arrayTy, arrayTy);
-    return LLVMType::get(context, structTy);
+    auto ptrTy = getPtrToElementType(viewType, lowering);
+    auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank());
+    return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy);
   }
 
   return Type();
@@ -165,9 +158,8 @@ public:
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
     auto indexType = IndexType::get(op->getContext());
-    auto voidPtrTy = LLVM::LLVMType::get(
-        op->getContext(),
-        llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
+    auto voidPtrTy =
+        LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
     auto int64Ty = lowering.convertType(operands[0]->getType());
     // Insert the `malloc` declaration if it is not already present.
     auto *module = op->getFunction()->getModule();
@@ -187,8 +179,8 @@ public:
                     llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
     else
       elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
-    auto elementPtrType = rewriter.getType<LLVMType>(getPtrToElementType(
-        allocOp.getResult()->getType().cast<BufferType>(), lowering));
+    auto elementPtrType = getPtrToElementType(
+        allocOp.getResult()->getType().cast<BufferType>(), lowering);
     auto bufferDescriptorType =
         convertLinalgType(allocOp.getResult()->getType(), lowering);
 
@@ -221,9 +213,8 @@ public:
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
-    auto voidPtrTy = LLVM::LLVMType::get(
-        op->getContext(),
-        llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
+    auto voidPtrTy =
+        LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
     // Insert the `free` declaration if it is not already present.
     auto *module = op->getFunction()->getModule();
     Function *freeFunc = module->getNamedFunction("free");
@@ -235,8 +226,8 @@ public:
 
     // Get MLIR types for extracting element pointer.
     auto deallocOp = cast<BufferDeallocOp>(op);
-    auto elementPtrTy = rewriter.getType<LLVMType>(getPtrToElementType(
-        deallocOp.getOperand()->getType().cast<BufferType>(), lowering));
+    auto elementPtrTy = getPtrToElementType(
+        deallocOp.getOperand()->getType().cast<BufferType>(), lowering);
 
     // Emit MLIR for buffer_dealloc.
     edsc::ScopedContext context(rewriter, op->getLoc());
@@ -298,8 +289,7 @@ public:
                        ArrayRef<Value *> indices,
                        PatternRewriter &rewriter) const {
     auto loadOp = cast<Op>(op);
-    auto elementTy = rewriter.getType<LLVMType>(
-        getPtrToElementType(loadOp.getViewType(), lowering));
+    auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering);
     auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
     auto pos = [&rewriter](ArrayRef<int> values) {
       return positionAttr(rewriter, values);
@@ -425,8 +415,7 @@ public:
     // Helper function to obtain the ptr of the given `view`.
     auto getViewPtr = [pos, &rewriter, this](ViewType type,
                                              Value *view) -> Value * {
-      auto elementPtrTy =
-          rewriter.getType<LLVMType>(getPtrToElementType(type, lowering));
+      auto elementPtrTy = getPtrToElementType(type, lowering);
       return extractvalue(elementPtrTy, view, pos(0));
     };
 
@@ -512,8 +501,7 @@ public:
                PatternRewriter &rewriter) const override {
     auto viewOp = cast<ViewOp>(op);
     auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
-    auto elementTy = rewriter.getType<LLVMType>(
-        getPtrToElementType(viewOp.getViewType(), lowering));
+    auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering);
     auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
 
     auto pos = [&rewriter](ArrayRef<int> values) {