Cache several common LLVMTypes in the LLVMDialect.
authorRiver Riddle <riverriddle@google.com>
Mon, 24 Jun 2019 19:03:44 +0000 (12:03 -0700)
committerjpienaar <jpienaar@google.com>
Mon, 24 Jun 2019 20:50:10 +0000 (13:50 -0700)
LLVM is not thread-safe which means that several of the 'get' methods for LLVMType must be double locked to ensure thread-safety. This cl adds static caching, i.e. no lookups or locking, for several simple LLVM types(i1, half, void, etc.). It also cleans up the implementation of the double locking that is required for some types. In the future we could add a form of dynamic caching to only need to lock one mutex in the best case, but that requires analysis on the memory overhead/vs time lost to taking two locks.

PiperOrigin-RevId: 254806747

mlir/include/mlir/LLVMIR/LLVMDialect.h
mlir/lib/LLVMIR/IR/LLVMDialect.cpp

index b8272f9..bd3286d 100644 (file)
@@ -123,8 +123,13 @@ public:
 private:
   friend LLVMDialect;
 
-  /// Get an LLVM type with a pre-existing llvm type.
+  /// Get an LLVMType with a pre-existing llvm type.
   static LLVMType get(MLIRContext *context, llvm::Type *llvmType);
+
+  /// Get an LLVMType with an llvm type that may cause changes to the underlying
+  /// llvm context when constructed.
+  static LLVMType getLocked(LLVMDialect *dialect,
+                            llvm::function_ref<llvm::Type *()> typeBuilder);
 };
 
 ///// Ops /////
index e27994d..513af8f 100644 (file)
@@ -756,6 +756,12 @@ struct LLVMDialectImpl {
   llvm::LLVMContext llvmContext;
   llvm::Module module;
 
+  /// A set of LLVMTypes that are cached on construction to avoid any lookups or
+  /// locking.
+  LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
+  LLVMType doubleTy, floatTy, halfTy;
+  LLVMType voidTy;
+
   /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not
   /// multi-threaded and requires locked access to prevent race conditions.
   llvm::sys::SmartMutex<true> mutex;
@@ -775,6 +781,22 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
 
   // Support unknown operations because not all LLVM operations are registered.
   allowUnknownOperations();
+
+  // Cache some of the common LLVM types to avoid the need for lookups/locking.
+  auto &llvmContext = impl->llvmContext;
+  /// Integer Types.
+  impl->int1Ty = LLVMType::get(context, llvm::Type::getInt1Ty(llvmContext));
+  impl->int8Ty = LLVMType::get(context, llvm::Type::getInt8Ty(llvmContext));
+  impl->int16Ty = LLVMType::get(context, llvm::Type::getInt16Ty(llvmContext));
+  impl->int32Ty = LLVMType::get(context, llvm::Type::getInt32Ty(llvmContext));
+  impl->int64Ty = LLVMType::get(context, llvm::Type::getInt64Ty(llvmContext));
+  impl->int128Ty = LLVMType::get(context, llvm::Type::getInt128Ty(llvmContext));
+  /// Float Types.
+  impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext));
+  impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext));
+  impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext));
+  /// Other Types.
+  impl->voidTy = LLVMType::get(context, llvm::Type::getVoidTy(llvmContext));
 }
 
 LLVMDialect::~LLVMDialect() {}
@@ -847,6 +869,15 @@ LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
   return Base::get(context, FIRST_LLVM_TYPE, llvmType);
 }
 
+/// Get an LLVMType with an llvm type that may cause changes to the underlying
+/// llvm context when constructed.
+LLVMType LLVMType::getLocked(LLVMDialect *dialect,
+                             llvm::function_ref<llvm::Type *()> typeBuilder) {
+  // Lock access to the llvm context and build the type.
+  llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
+  return get(dialect->getContext(), typeBuilder());
+}
+
 LLVMDialect &LLVMType::getDialect() {
   return static_cast<LLVMDialect &>(Type::getDialect());
 }
@@ -863,8 +894,9 @@ LLVMType LLVMType::getArrayElementType() {
 /// Pointer type utilities.
 LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
   // Lock access to the dialect as this may modify the LLVM context.
-  llvm::sys::SmartScopedLock<true> lock(getDialect().impl->mutex);
-  return get(getContext(), getUnderlyingType()->getPointerTo(addrSpace));
+  return getLocked(&getDialect(), [=] {
+    return getUnderlyingType()->getPointerTo(addrSpace);
+  });
 }
 LLVMType LLVMType::getPointerElementTy() {
   return get(getContext(), getUnderlyingType()->getPointerElementType());
@@ -877,33 +909,46 @@ LLVMType LLVMType::getStructElementType(unsigned i) {
 
 /// Utilities used to generate floating point types.
 LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
-  return get(dialect->getContext(),
-             llvm::Type::getDoubleTy(dialect->getLLVMContext()));
+  return dialect->impl->doubleTy;
 }
 LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
-  return get(dialect->getContext(),
-             llvm::Type::getFloatTy(dialect->getLLVMContext()));
+  return dialect->impl->floatTy;
 }
 LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
-  return get(dialect->getContext(),
-             llvm::Type::getHalfTy(dialect->getLLVMContext()));
+  return dialect->impl->halfTy;
 }
 
 /// Utilities used to generate integer types.
 LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
+  switch (numBits) {
+  case 1:
+    return dialect->impl->int1Ty;
+  case 8:
+    return dialect->impl->int8Ty;
+  case 16:
+    return dialect->impl->int16Ty;
+  case 32:
+    return dialect->impl->int32Ty;
+  case 64:
+    return dialect->impl->int64Ty;
+  case 128:
+    return dialect->impl->int128Ty;
+  default:
+    break;
+  }
+
   // Lock access to the dialect as this may modify the LLVM context.
-  llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
-  return get(dialect->getContext(),
-             llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits));
+  return getLocked(dialect, [=] {
+    return llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits);
+  });
 }
 
 /// Utilities used to generate other miscellaneous types.
 LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) {
   // Lock access to the dialect as this may modify the LLVM context.
-  llvm::sys::SmartScopedLock<true> lock(elementType.getDialect().impl->mutex);
-  return get(
-      elementType.getContext(),
-      llvm::ArrayType::get(elementType.getUnderlyingType(), numElements));
+  return getLocked(&elementType.getDialect(), [=] {
+    return llvm::ArrayType::get(elementType.getUnderlyingType(), numElements);
+  });
 }
 LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
                                  bool isVarArg) {
@@ -912,10 +957,10 @@ LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
     llvmParams.push_back(param.getUnderlyingType());
 
   // Lock access to the dialect as this may modify the LLVM context.
-  llvm::sys::SmartScopedLock<true> lock(result.getDialect().impl->mutex);
-  return get(result.getContext(),
-             llvm::FunctionType::get(result.getUnderlyingType(), llvmParams,
-                                     isVarArg));
+  return getLocked(&result.getDialect(), [=] {
+    return llvm::FunctionType::get(result.getUnderlyingType(), llvmParams,
+                                   isVarArg);
+  });
 }
 LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
                                ArrayRef<LLVMType> elements, bool isPacked) {
@@ -924,19 +969,17 @@ LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
     llvmElements.push_back(elt.getUnderlyingType());
 
   // Lock access to the dialect as this may modify the LLVM context.
-  llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
-  return get(
-      dialect->getContext(),
-      llvm::StructType::get(dialect->getLLVMContext(), llvmElements, isPacked));
+  return getLocked(dialect, [=] {
+    return llvm::StructType::get(dialect->getLLVMContext(), llvmElements,
+                                 isPacked);
+  });
 }
 LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
   // Lock access to the dialect as this may modify the LLVM context.
-  llvm::sys::SmartScopedLock<true> lock(elementType.getDialect().impl->mutex);
-  return get(
-      elementType.getContext(),
-      llvm::VectorType::get(elementType.getUnderlyingType(), numElements));
+  return getLocked(&elementType.getDialect(), [=] {
+    return llvm::VectorType::get(elementType.getUnderlyingType(), numElements);
+  });
 }
 LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
-  return get(dialect->getContext(),
-             llvm::Type::getVoidTy(dialect->getLLVMContext()));
+  return dialect->impl->voidTy;
 }