[mlir] Remove most uses of LLVMDialect::getModule
authorAlex Zinenko <zinenko@google.com>
Wed, 5 Aug 2020 22:52:10 +0000 (00:52 +0200)
committerAlex Zinenko <zinenko@google.com>
Thu, 6 Aug 2020 08:54:30 +0000 (10:54 +0200)
This prepares for the removal of llvm::Module and LLVMContext from the
mlir::LLVMDialect.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D85371

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

index 43a1587..1bf46b9 100644 (file)
@@ -118,8 +118,7 @@ public:
   unsigned getPointerBitwidth(unsigned addressSpace = 0);
 
 protected:
-  /// LLVM IR module used to parse/create types.
-  llvm::Module *module;
+  /// Pointer to the LLVM dialect.
   LLVM::LLVMDialect *llvmDialect;
 
 private:
@@ -400,9 +399,6 @@ public:
   /// Returns the LLVM IR context.
   llvm::LLVMContext &getContext() const;
 
-  /// Returns the LLVM IR module associated with the LLVM dialect.
-  llvm::Module &getModule() const;
-
   /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
   /// defined by the used type converter.
   LLVM::LLVMType getIndexType() const;
@@ -437,8 +433,8 @@ public:
                              ConversionPatternRewriter &rewriter) const;
 
   Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
-                   ValueRange indices, ConversionPatternRewriter &rewriter,
-                   llvm::Module &module) const;
+                   ValueRange indices,
+                   ConversionPatternRewriter &rewriter) const;
 
   /// Returns the type of a pointer to an element of the memref.
   Type getElementPtrType(MemRefType type) const;
index 4d99bf2..6b265e7 100644 (file)
@@ -25,6 +25,7 @@ def LLVM_Dialect : Dialect {
     llvm::LLVMContext &getLLVMContext();
     llvm::Module &getLLVMModule();
     llvm::sys::SmartMutex<true> &getLLVMContextMutex();
+    const llvm::DataLayout &getDataLayout();
 
   private:
     friend LLVMType;
index 14011e0..c5ecaf7 100644 (file)
@@ -66,12 +66,7 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
 private:
   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
 
-  llvm::LLVMContext &getLLVMContext() {
-    return getLLVMDialect()->getLLVMContext();
-  }
-
   void initializeCachedTypes() {
-    const llvm::Module &module = llvmDialect->getLLVMModule();
     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
     llvmPointerPointerType = llvmPointerType.getPointerTo();
@@ -79,7 +74,7 @@ private:
     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
     llvmIntPtrType = LLVM::LLVMType::getIntNTy(
-        llvmDialect, module.getDataLayout().getPointerSizeInBits());
+        llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits());
   }
 
   LLVM::LLVMType getVoidType() { return llvmVoidType; }
@@ -95,9 +90,9 @@ private:
   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
 
   LLVM::LLVMType getIntPtrType() {
-    const llvm::Module &module = getLLVMDialect()->getLLVMModule();
     return LLVM::LLVMType::getIntNTy(
-        getLLVMDialect(), module.getDataLayout().getPointerSizeInBits());
+        getLLVMDialect(),
+        getLLVMDialect()->getDataLayout().getPointerSizeInBits());
   }
 
   // Allocate a void pointer on the stack.
index e6527e0..c1a64bd 100644 (file)
@@ -59,10 +59,6 @@ class VulkanLaunchFuncToVulkanCallsPass
 private:
   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
 
-  llvm::LLVMContext &getLLVMContext() {
-    return getLLVMDialect()->getLLVMContext();
-  }
-
   void initializeCachedTypes() {
     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
     llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
index 9777071..e7c8770 100644 (file)
@@ -128,10 +128,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
     : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
       options(options) {
   assert(llvmDialect && "LLVM IR dialect is not registered");
-  module = &llvmDialect->getLLVMModule();
   if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
     this->options.indexBitwidth =
-        module->getDataLayout().getPointerSizeInBits();
+        llvmDialect->getDataLayout().getPointerSizeInBits();
 
   // Register conversions for the standard types.
   addConversion([&](ComplexType type) { return convertComplexType(type); });
@@ -196,7 +195,7 @@ MLIRContext &LLVMTypeConverter::getContext() {
 
 /// Get the LLVM context.
 llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
-  return module->getContext();
+  return llvmDialect->getLLVMContext();
 }
 
 LLVM::LLVMType LLVMTypeConverter::getIndexType() {
@@ -204,7 +203,7 @@ LLVM::LLVMType LLVMTypeConverter::getIndexType() {
 }
 
 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
-  return module->getDataLayout().getPointerSizeInBits(addressSpace);
+  return llvmDialect->getDataLayout().getPointerSizeInBits(addressSpace);
 }
 
 Type LLVMTypeConverter::convertIndexType(IndexType type) {
@@ -849,10 +848,6 @@ llvm::LLVMContext &ConvertToLLVMPattern::getContext() const {
   return typeConverter.getLLVMContext();
 }
 
-llvm::Module &ConvertToLLVMPattern::getModule() const {
-  return getDialect().getLLVMModule();
-}
-
 LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
   return typeConverter.getIndexType();
 }
@@ -910,10 +905,9 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
   return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
 }
 
-Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
-                                       Value memRefDesc, ValueRange indices,
-                                       ConversionPatternRewriter &rewriter,
-                                       llvm::Module &module) const {
+Value ConvertToLLVMPattern::getDataPtr(
+    Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
+    ConversionPatternRewriter &rewriter) const {
   LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
   int64_t offset;
   SmallVector<int64_t, 4> strides;
@@ -2451,7 +2445,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
     auto type = loadOp.getMemRefType();
 
     Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
-                               transformed.indices(), rewriter, getModule());
+                               transformed.indices(), rewriter);
     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
     return success();
   }
@@ -2469,7 +2463,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
     StoreOp::Adaptor transformed(operands);
 
     Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
-                               transformed.indices(), rewriter, getModule());
+                               transformed.indices(), rewriter);
     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
                                                dataPtr);
     return success();
@@ -2489,7 +2483,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
     auto type = prefetchOp.getMemRefType();
 
     Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
-                               transformed.indices(), rewriter, getModule());
+                               transformed.indices(), rewriter);
 
     // Replace with llvm.prefetch.
     auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
@@ -3086,7 +3080,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
     auto resultType = adaptor.value().getType();
     auto memRefType = atomicOp.getMemRefType();
     auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
-                              adaptor.indices(), rewriter, getModule());
+                              adaptor.indices(), rewriter);
     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
         op, resultType, *maybeKind, dataPtr, adaptor.value(),
         LLVM::AtomicOrdering::acq_rel);
@@ -3152,7 +3146,7 @@ struct GenericAtomicRMWOpLowering
     rewriter.setInsertionPointToEnd(initBlock);
     auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
     auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
-                              adaptor.indices(), rewriter, getModule());
+                              adaptor.indices(), rewriter);
     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
 
index 64d38e4..011143b 100644 (file)
@@ -131,7 +131,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
   LLVM::LLVMDialect *dialect = typeConverter.getDialect();
   align = LLVM::TypeToLLVMIRTranslator(dialect->getLLVMContext())
               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
-                                     dialect->getLLVMModule().getDataLayout());
+                                     dialect->getDataLayout());
   return success();
 }
 
@@ -1152,7 +1152,7 @@ public:
     //    address space 0.
     // TODO: support alignment when possible.
     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
-                               adaptor.indices(), rewriter, getModule());
+                               adaptor.indices(), rewriter);
     auto vecTy =
         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
     Value vectorDataPtr;
index 9c8556d..f699ef0 100644 (file)
@@ -103,7 +103,7 @@ public:
     // indices, so no need to calculat offset size in bytes again in
     // the MUBUF instruction.
     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
-                               adaptor.indices(), rewriter, getModule());
+                               adaptor.indices(), rewriter);
 
     // 1. Create and fill a <4 x i32> dwordConfig with:
     //    1st two elements holding the address of dataPtr.
index 47129d7..6a70af4 100644 (file)
@@ -1741,6 +1741,9 @@ llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
 llvm::sys::SmartMutex<true> &LLVMDialect::getLLVMContextMutex() {
   return impl->mutex;
 }
+const llvm::DataLayout &LLVMDialect::getDataLayout() {
+  return impl->module.getDataLayout();
+}
 
 /// Parse a type registered to this dialect.
 Type LLVMDialect::parseType(DialectAsmParser &parser) const {