[mlir] take MLIRContext instead of LLVMDialect in getters of LLVMType's
authorAlex Zinenko <zinenko@google.com>
Wed, 5 Aug 2020 22:52:20 +0000 (00:52 +0200)
committerAlex Zinenko <zinenko@google.com>
Thu, 6 Aug 2020 09:05:40 +0000 (11:05 +0200)
Historical modeling of the LLVM dialect types had been wrapping LLVM IR types
and therefore needed access to the instance of LLVMContext stored in the
LLVMDialect. The new modeling does not rely on that and only needs the
MLIRContext that is used for uniquing, similarly to other MLIR types. Change
LLVMType::get<Kind>Ty functions to take `MLIRContext *` instead of
`LLVMDialect *` as first argument. This brings the code base closer to
completely removing the dependence on LLVMContext from the LLVMDialect,
together with additional support for thread-safety of its use.

Depends On D85371

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D85372

19 files changed:
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp

index af4130c..74b32dc 100644 (file)
@@ -56,19 +56,15 @@ public:
     auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
     auto memRefShape = memRefType.getShape();
     auto loc = op->getLoc();
-    auto *llvmDialect =
-        op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
-    assert(llvmDialect && "expected llvm dialect to be registered");
 
     ModuleOp parentModule = op->getParentOfType<ModuleOp>();
 
     // Get a symbol reference to the printf function, inserting it if necessary.
-    auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect);
+    auto printfRef = getOrInsertPrintf(rewriter, parentModule);
     Value formatSpecifierCst = getOrCreateGlobalString(
-        loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule,
-        llvmDialect);
+        loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule);
     Value newLineCst = getOrCreateGlobalString(
-        loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect);
+        loc, rewriter, "nl", StringRef("\n\0", 2), parentModule);
 
     // Create a loop for each of the dimensions within the shape.
     SmallVector<Value, 4> loopIvs;
@@ -108,16 +104,15 @@ private:
   /// Return a symbol reference to the printf function, inserting it into the
   /// module if necessary.
   static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
-                                             ModuleOp module,
-                                             LLVM::LLVMDialect *llvmDialect) {
+                                             ModuleOp module) {
     auto *context = module.getContext();
     if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
       return SymbolRefAttr::get("printf", context);
 
     // Create a function declaration for printf, the signature is:
     //   * `i32 (i8*, ...)`
-    auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
-    auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+    auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context);
+    auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
     auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
                                                     /*isVarArg=*/true);
 
@@ -132,15 +127,14 @@ private:
   /// name, creating the string if necessary.
   static Value getOrCreateGlobalString(Location loc, OpBuilder &builder,
                                        StringRef name, StringRef value,
-                                       ModuleOp module,
-                                       LLVM::LLVMDialect *llvmDialect) {
+                                       ModuleOp module) {
     // Create the global at the entry of the module.
     LLVM::GlobalOp global;
     if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
       OpBuilder::InsertionGuard insertGuard(builder);
       builder.setInsertionPointToStart(module.getBody());
       auto type = LLVM::LLVMType::getArrayTy(
-          LLVM::LLVMType::getInt8Ty(llvmDialect), value.size());
+          LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size());
       global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
                                               LLVM::Linkage::Internal, name,
                                               builder.getStringAttr(value));
@@ -149,10 +143,10 @@ private:
     // Get the pointer to the first character in the global string.
     Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
     Value cst0 = builder.create<LLVM::ConstantOp>(
-        loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
+        loc, LLVM::LLVMType::getInt64Ty(builder.getContext()),
         builder.getIntegerAttr(builder.getIndexType(), 0));
     return builder.create<LLVM::GEPOp>(
-        loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
+        loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr,
         ArrayRef<Value>({cst0, cst0}));
   }
 };
index af4130c..74b32dc 100644 (file)
@@ -56,19 +56,15 @@ public:
     auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
     auto memRefShape = memRefType.getShape();
     auto loc = op->getLoc();
-    auto *llvmDialect =
-        op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
-    assert(llvmDialect && "expected llvm dialect to be registered");
 
     ModuleOp parentModule = op->getParentOfType<ModuleOp>();
 
     // Get a symbol reference to the printf function, inserting it if necessary.
-    auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect);
+    auto printfRef = getOrInsertPrintf(rewriter, parentModule);
     Value formatSpecifierCst = getOrCreateGlobalString(
-        loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule,
-        llvmDialect);
+        loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule);
     Value newLineCst = getOrCreateGlobalString(
-        loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect);
+        loc, rewriter, "nl", StringRef("\n\0", 2), parentModule);
 
     // Create a loop for each of the dimensions within the shape.
     SmallVector<Value, 4> loopIvs;
@@ -108,16 +104,15 @@ private:
   /// Return a symbol reference to the printf function, inserting it into the
   /// module if necessary.
   static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
-                                             ModuleOp module,
-                                             LLVM::LLVMDialect *llvmDialect) {
+                                             ModuleOp module) {
     auto *context = module.getContext();
     if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
       return SymbolRefAttr::get("printf", context);
 
     // Create a function declaration for printf, the signature is:
     //   * `i32 (i8*, ...)`
-    auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
-    auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+    auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context);
+    auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
     auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
                                                     /*isVarArg=*/true);
 
@@ -132,15 +127,14 @@ private:
   /// name, creating the string if necessary.
   static Value getOrCreateGlobalString(Location loc, OpBuilder &builder,
                                        StringRef name, StringRef value,
-                                       ModuleOp module,
-                                       LLVM::LLVMDialect *llvmDialect) {
+                                       ModuleOp module) {
     // Create the global at the entry of the module.
     LLVM::GlobalOp global;
     if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
       OpBuilder::InsertionGuard insertGuard(builder);
       builder.setInsertionPointToStart(module.getBody());
       auto type = LLVM::LLVMType::getArrayTy(
-          LLVM::LLVMType::getInt8Ty(llvmDialect), value.size());
+          LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size());
       global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
                                               LLVM::Linkage::Internal, name,
                                               builder.getStringAttr(value));
@@ -149,10 +143,10 @@ private:
     // Get the pointer to the first character in the global string.
     Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
     Value cst0 = builder.create<LLVM::ConstantOp>(
-        loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
+        loc, LLVM::LLVMType::getInt64Ty(builder.getContext()),
         builder.getIntegerAttr(builder.getIndexType(), 0));
     return builder.create<LLVM::GEPOp>(
-        loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
+        loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr,
         ArrayRef<Value>({cst0, cst0}));
   }
 };
index 2853ef6..1614a24 100644 (file)
@@ -59,8 +59,7 @@ struct LLVMDialectImpl;
 /// global and use it to compute the address of the first character in the
 /// string (operations inserted at the builder insertion point).
 Value createGlobalString(Location loc, OpBuilder &builder, StringRef name,
-                         StringRef value, LLVM::Linkage linkage,
-                         LLVM::LLVMDialect *llvmDialect);
+                         StringRef value, LLVM::Linkage linkage);
 
 /// LLVM requires some operations to be inside of a Module operation. This
 /// function confirms that the Operation has the desired properties.
index 6b265e7..fb4001f 100644 (file)
@@ -58,8 +58,7 @@ class LLVMI<int width>
          "$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy(" # width # ")">]>,
          "LLVM dialect " # width # "-bit integer">,
       BuildableType<
-        "::mlir::LLVM::LLVMType::getIntNTy("
-        "$_builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(),"
+        "::mlir::LLVM::LLVMType::getIntNTy($_builder.getContext(),"
          # width # ")">;
 
 def LLVMI1 : LLVMI<1>;
index 768d8db..a4a0db1 100644 (file)
@@ -151,8 +151,7 @@ def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>,
   let builders = [OpBuilder<
     "OpBuilder &b, OperationState &result, ICmpPredicate predicate, Value lhs, "
     "Value rhs", [{
-      LLVMDialect *dialect = &lhs.getType().cast<LLVMType>().getDialect();
-      build(b, result, LLVMType::getInt1Ty(dialect),
+      build(b, result, LLVMType::getInt1Ty(lhs.getType().getContext()),
             b.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
     }]>];
   let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
@@ -198,8 +197,7 @@ def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>,
   let builders = [OpBuilder<
     "OpBuilder &b, OperationState &result, FCmpPredicate predicate, Value lhs, "
     "Value rhs", [{
-      LLVMDialect *dialect = &lhs.getType().cast<LLVMType>().getDialect();
-      build(b, result, LLVMType::getInt1Ty(dialect),
+      build(b, result, LLVMType::getInt1Ty(lhs.getType().getContext()),
             b.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
     }]>];
   let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
index 7d7839c..0d3a6f3 100644 (file)
@@ -152,32 +152,32 @@ public:
   bool isStructTy();
 
   /// Utilities used to generate floating point types.
-  static LLVMType getDoubleTy(LLVMDialect *dialect);
-  static LLVMType getFloatTy(LLVMDialect *dialect);
-  static LLVMType getBFloatTy(LLVMDialect *dialect);
-  static LLVMType getHalfTy(LLVMDialect *dialect);
-  static LLVMType getFP128Ty(LLVMDialect *dialect);
-  static LLVMType getX86_FP80Ty(LLVMDialect *dialect);
+  static LLVMType getDoubleTy(MLIRContext *context);
+  static LLVMType getFloatTy(MLIRContext *context);
+  static LLVMType getBFloatTy(MLIRContext *context);
+  static LLVMType getHalfTy(MLIRContext *context);
+  static LLVMType getFP128Ty(MLIRContext *context);
+  static LLVMType getX86_FP80Ty(MLIRContext *context);
 
   /// Utilities used to generate integer types.
-  static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits);
-  static LLVMType getInt1Ty(LLVMDialect *dialect) {
-    return getIntNTy(dialect, /*numBits=*/1);
+  static LLVMType getIntNTy(MLIRContext *context, unsigned numBits);
+  static LLVMType getInt1Ty(MLIRContext *context) {
+    return getIntNTy(context, /*numBits=*/1);
   }
-  static LLVMType getInt8Ty(LLVMDialect *dialect) {
-    return getIntNTy(dialect, /*numBits=*/8);
+  static LLVMType getInt8Ty(MLIRContext *context) {
+    return getIntNTy(context, /*numBits=*/8);
   }
-  static LLVMType getInt8PtrTy(LLVMDialect *dialect) {
-    return getInt8Ty(dialect).getPointerTo();
+  static LLVMType getInt8PtrTy(MLIRContext *context) {
+    return getInt8Ty(context).getPointerTo();
   }
-  static LLVMType getInt16Ty(LLVMDialect *dialect) {
-    return getIntNTy(dialect, /*numBits=*/16);
+  static LLVMType getInt16Ty(MLIRContext *context) {
+    return getIntNTy(context, /*numBits=*/16);
   }
-  static LLVMType getInt32Ty(LLVMDialect *dialect) {
-    return getIntNTy(dialect, /*numBits=*/32);
+  static LLVMType getInt32Ty(MLIRContext *context) {
+    return getIntNTy(context, /*numBits=*/32);
   }
-  static LLVMType getInt64Ty(LLVMDialect *dialect) {
-    return getIntNTy(dialect, /*numBits=*/64);
+  static LLVMType getInt64Ty(MLIRContext *context) {
+    return getIntNTy(context, /*numBits=*/64);
   }
 
   /// Utilities used to generate other miscellaneous types.
@@ -187,33 +187,33 @@ public:
   static LLVMType getFunctionTy(LLVMType result, bool isVarArg) {
     return getFunctionTy(result, llvm::None, isVarArg);
   }
-  static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef<LLVMType> elements,
+  static LLVMType getStructTy(MLIRContext *context, ArrayRef<LLVMType> elements,
                               bool isPacked = false);
-  static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) {
-    return getStructTy(dialect, llvm::None, isPacked);
+  static LLVMType getStructTy(MLIRContext *context, bool isPacked = false) {
+    return getStructTy(context, llvm::None, isPacked);
   }
   template <typename... Args>
   static typename std::enable_if<llvm::are_base_of<LLVMType, Args...>::value,
                                  LLVMType>::type
   getStructTy(LLVMType elt1, Args... elts) {
     SmallVector<LLVMType, 8> fields({elt1, elts...});
-    return getStructTy(&elt1.getDialect(), fields);
+    return getStructTy(elt1.getContext(), fields);
   }
   static LLVMType getVectorTy(LLVMType elementType, unsigned numElements);
 
   /// Void type utilities.
-  static LLVMType getVoidTy(LLVMDialect *dialect);
+  static LLVMType getVoidTy(MLIRContext *context);
   bool isVoidTy();
 
   // Creation and setting of LLVM's identified struct types
-  static LLVMType createStructTy(LLVMDialect *dialect,
+  static LLVMType createStructTy(MLIRContext *context,
                                  ArrayRef<LLVMType> elements,
                                  Optional<StringRef> name,
                                  bool isPacked = false);
 
-  static LLVMType createStructTy(LLVMDialect *dialect,
+  static LLVMType createStructTy(MLIRContext *context,
                                  Optional<StringRef> name) {
-    return createStructTy(dialect, llvm::None, name);
+    return createStructTy(context, llvm::None, name);
   }
 
   static LLVMType createStructTy(ArrayRef<LLVMType> elements,
@@ -222,7 +222,7 @@ public:
     assert(!elements.empty() &&
            "This method may not be invoked with an empty list");
     LLVMType ele0 = elements.front();
-    return createStructTy(&ele0.getDialect(), elements, name, isPacked);
+    return createStructTy(ele0.getContext(), elements, name, isPacked);
   }
 
   template <typename... Args>
@@ -231,7 +231,7 @@ public:
   createStructTy(StringRef name, LLVMType elt1, Args... elts) {
     SmallVector<LLVMType, 8> fields({elt1, elts...});
     Optional<StringRef> opt_name(name);
-    return createStructTy(&elt1.getDialect(), fields, opt_name);
+    return createStructTy(elt1.getContext(), fields, opt_name);
   }
 
   static LLVMType setStructTyBody(LLVMType structType,
index c5ecaf7..e186a33 100644 (file)
@@ -67,14 +67,14 @@ private:
   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
 
   void initializeCachedTypes() {
-    llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
-    llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+    llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext());
+    llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext());
     llvmPointerPointerType = llvmPointerType.getPointerTo();
-    llvmInt8Type = LLVM::LLVMType::getInt8Ty(llvmDialect);
-    llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
-    llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
+    llvmInt8Type = LLVM::LLVMType::getInt8Ty(&getContext());
+    llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext());
+    llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext());
     llvmIntPtrType = LLVM::LLVMType::getIntNTy(
-        llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits());
+        &getContext(), llvmDialect->getDataLayout().getPointerSizeInBits());
   }
 
   LLVM::LLVMType getVoidType() { return llvmVoidType; }
@@ -91,7 +91,7 @@ private:
 
   LLVM::LLVMType getIntPtrType() {
     return LLVM::LLVMType::getIntNTy(
-        getLLVMDialect(),
+        &getContext(),
         getLLVMDialect()->getDataLayout().getPointerSizeInBits());
   }
 
@@ -340,7 +340,7 @@ Value GpuLaunchFuncToGpuRuntimeCallsPass::generateKernelNameConstant(
       std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
   return LLVM::createGlobalString(
       loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
-      LLVM::Linkage::Internal, llvmDialect);
+      LLVM::Linkage::Internal);
 }
 
 // Emits LLVM IR to launch a kernel function. Expects the module that contains
@@ -378,9 +378,9 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::translateGpuLaunchCalls(
 
   SmallString<128> nameBuffer(kernelModule.getName());
   nameBuffer.append(kGpuBinaryStorageSuffix);
-  Value data = LLVM::createGlobalString(
-      loc, builder, nameBuffer.str(), binaryAttr.getValue(),
-      LLVM::Linkage::Internal, getLLVMDialect());
+  Value data =
+      LLVM::createGlobalString(loc, builder, nameBuffer.str(),
+                               binaryAttr.getValue(), LLVM::Linkage::Internal);
 
   // Emit the load module call to load the module data. Error checking is done
   // in the called helper function.
index f6aede4..f15ccdc 100644 (file)
@@ -89,7 +89,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
       // Rewrite workgroup memory attributions to addresses of global buffers.
       rewriter.setInsertionPointToStart(&gpuFuncOp.front());
       unsigned numProperArguments = gpuFuncOp.getNumArguments();
-      auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect());
+      auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
 
       Value zero = nullptr;
       if (!workgroupBuffers.empty())
@@ -117,7 +117,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
       // Rewrite private memory attributions to alloca'ed buffers.
       unsigned numWorkgroupAttributions =
           gpuFuncOp.getNumWorkgroupAttributions();
-      auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
+      auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
       for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
         Value attribution = en.value();
         auto type = attribution.getType().cast<MemRefType>();
index 6d4b99a..9338105 100644 (file)
@@ -46,17 +46,17 @@ public:
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op->getLoc();
-    auto dialect = typeConverter.getDialect();
+    MLIRContext *context = rewriter.getContext();
     Value newOp;
     switch (dimensionToIndex(cast<Op>(op))) {
     case X:
-      newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+      newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(context));
       break;
     case Y:
-      newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+      newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(context));
       break;
     case Z:
-      newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+      newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(context));
       break;
     default:
       return failure();
@@ -64,10 +64,10 @@ public:
 
     if (indexBitwidth > 32) {
       newOp = rewriter.create<LLVM::SExtOp>(
-          loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
+          loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp);
     } else if (indexBitwidth < 32) {
       newOp = rewriter.create<LLVM::TruncOp>(
-          loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
+          loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp);
     }
 
     rewriter.replaceOp(op, {newOp});
index 58b5f1d..fc74382 100644 (file)
@@ -85,7 +85,7 @@ private:
       return operand;
 
     return rewriter.create<LLVM::FPExtOp>(
-        operand.getLoc(), LLVM::LLVMType::getFloatTy(&type.getDialect()),
+        operand.getLoc(), LLVM::LLVMType::getFloatTy(rewriter.getContext()),
         operand);
   }
 
index afb6d28..76c1668 100644 (file)
@@ -57,11 +57,11 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
     Location loc = op->getLoc();
     gpu::ShuffleOpAdaptor adaptor(operands);
 
-    auto dialect = typeConverter.getDialect();
     auto valueTy = adaptor.value().getType().cast<LLVM::LLVMType>();
-    auto int32Type = LLVM::LLVMType::getInt32Ty(dialect);
-    auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
-    auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy});
+    auto int32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
+    auto predTy = LLVM::LLVMType::getInt1Ty(rewriter.getContext());
+    auto resultTy =
+        LLVM::LLVMType::getStructTy(rewriter.getContext(), {valueTy, predTy});
 
     Value one = rewriter.create<LLVM::ConstantOp>(
         loc, int32Type, rewriter.getI32IntegerAttr(1));
index c1a64bd..377b4f0 100644 (file)
@@ -57,15 +57,12 @@ class VulkanLaunchFuncToVulkanCallsPass
     : public ConvertVulkanLaunchFuncToVulkanCallsBase<
           VulkanLaunchFuncToVulkanCallsPass> {
 private:
-  LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
-
   void initializeCachedTypes() {
-    llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
-    llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
-    llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
-    llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
-    llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
-    llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
+    llvmFloatType = LLVM::LLVMType::getFloatTy(&getContext());
+    llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext());
+    llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext());
+    llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext());
+    llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext());
   }
 
   LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
@@ -87,7 +84,7 @@ private:
     // `!llvm<"{ `element-type`*, `element-type`*, i64,
     // [`rank` x i64], [`rank` x i64]}">`.
     return LLVM::LLVMType::getStructTy(
-        llvmDialect,
+        &getContext(),
         {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
   }
@@ -153,7 +150,6 @@ public:
   void runOnOperation() override;
 
 private:
-  LLVM::LLVMDialect *llvmDialect;
   LLVM::LLVMType llvmFloatType;
   LLVM::LLVMType llvmVoidType;
   LLVM::LLVMType llvmPointerType;
@@ -245,7 +241,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
     // int16_t and bitcast the descriptor.
     if (type.isHalfTy()) {
       auto memRefTy =
-          getMemRefType(rank, LLVM::LLVMType::getInt16Ty(llvmDialect));
+          getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext()));
       ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
           loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor);
     }
@@ -324,15 +320,15 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
   }
 
   for (unsigned i = 1; i <= 3; i++) {
-    for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(llvmDialect),
-                                LLVM::LLVMType::getInt32Ty(llvmDialect),
-                                LLVM::LLVMType::getInt16Ty(llvmDialect),
-                                LLVM::LLVMType::getInt8Ty(llvmDialect),
-                                LLVM::LLVMType::getHalfTy(llvmDialect)}) {
+    for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(&getContext()),
+                                LLVM::LLVMType::getInt32Ty(&getContext()),
+                                LLVM::LLVMType::getInt16Ty(&getContext()),
+                                LLVM::LLVMType::getInt8Ty(&getContext()),
+                                LLVM::LLVMType::getHalfTy(&getContext())}) {
       std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
                            std::string(stringifyType(type));
       if (type.isHalfTy())
-        type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(llvmDialect));
+        type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(&getContext()));
       if (!module.lookupSymbol(fnName)) {
         auto fnType = LLVM::LLVMType::getFunctionTy(
             getVoidType(),
@@ -368,8 +364,7 @@ Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
 
   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
-                                  shaderName, LLVM::Linkage::Internal,
-                                  getLLVMDialect());
+                                  shaderName, LLVM::Linkage::Internal);
 }
 
 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
@@ -388,7 +383,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
   // that data to runtime call.
   Value ptrToSPIRVBinary = LLVM::createGlobalString(
       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
-      LLVM::Linkage::Internal, getLLVMDialect());
+      LLVM::Linkage::Internal);
 
   // Create LLVM constant for the size of SPIR-V binary shader.
   Value binarySize = builder.create<LLVM::ConstantOp>(
index 0c32628..024f2b1 100644 (file)
@@ -186,15 +186,15 @@ static Type convertStructTypePacked(spirv::StructType type,
       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
         return converter.convertType(elementType).cast<LLVM::LLVMType>();
       }));
-  return LLVM::LLVMType::getStructTy(converter.getDialect(), elementsVector,
+  return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
                                      /*isPacked=*/true);
 }
 
 /// Creates LLVM dialect constant with the given value.
 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
-                                 LLVMTypeConverter &converter, unsigned value) {
+                                 unsigned value) {
   return rewriter.create<LLVM::ConstantOp>(
-      loc, LLVM::LLVMType::getInt32Ty(converter.getDialect()),
+      loc, LLVM::LLVMType::getInt32Ty(rewriter.getContext()),
       rewriter.getIntegerAttr(rewriter.getI32Type(), value));
 }
 
@@ -1002,7 +1002,7 @@ public:
       return failure();
 
     Location loc = varOp.getLoc();
-    Value size = createI32ConstantOf(loc, rewriter, typeConverter, 1);
+    Value size = createI32ConstantOf(loc, rewriter, 1);
     if (!init) {
       rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
       return success();
index e7c8770..a5ecbe4 100644 (file)
@@ -199,7 +199,7 @@ llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
 }
 
 LLVM::LLVMType LLVMTypeConverter::getIndexType() {
-  return LLVM::LLVMType::getIntNTy(llvmDialect, getIndexTypeBitwidth());
+  return LLVM::LLVMType::getIntNTy(&getContext(), getIndexTypeBitwidth());
 }
 
 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
@@ -211,19 +211,19 @@ Type LLVMTypeConverter::convertIndexType(IndexType type) {
 }
 
 Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
-  return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
+  return LLVM::LLVMType::getIntNTy(&getContext(), type.getWidth());
 }
 
 Type LLVMTypeConverter::convertFloatType(FloatType type) {
   switch (type.getKind()) {
   case mlir::StandardTypes::F32:
-    return LLVM::LLVMType::getFloatTy(llvmDialect);
+    return LLVM::LLVMType::getFloatTy(&getContext());
   case mlir::StandardTypes::F64:
-    return LLVM::LLVMType::getDoubleTy(llvmDialect);
+    return LLVM::LLVMType::getDoubleTy(&getContext());
   case mlir::StandardTypes::F16:
-    return LLVM::LLVMType::getHalfTy(llvmDialect);
+    return LLVM::LLVMType::getHalfTy(&getContext());
   case mlir::StandardTypes::BF16: {
-    return LLVM::LLVMType::getBFloatTy(llvmDialect);
+    return LLVM::LLVMType::getBFloatTy(&getContext());
   }
   default:
     llvm_unreachable("non-float type in convertFloatType");
@@ -238,7 +238,7 @@ static constexpr unsigned kRealPosInComplexNumberStruct = 0;
 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
 Type LLVMTypeConverter::convertComplexType(ComplexType type) {
   auto elementType = convertType(type.getElementType()).cast<LLVM::LLVMType>();
-  return LLVM::LLVMType::getStructTy(llvmDialect, {elementType, elementType});
+  return LLVM::LLVMType::getStructTy(&getContext(), {elementType, elementType});
 }
 
 // Except for signatures, MLIR function types are converted into LLVM
@@ -274,7 +274,7 @@ LLVMTypeConverter::convertMemRefSignature(MemRefType type) {
 /// In signatures, unranked MemRef descriptors are expanded into a pair "rank,
 /// pointer to descriptor".
 SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
-  return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(llvmDialect)};
+  return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())};
 }
 
 // Function types are converted to LLVM Function types by recursively converting
@@ -307,7 +307,7 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
   // a struct.
   LLVM::LLVMType resultType =
       type.getNumResults() == 0
-          ? LLVM::LLVMType::getVoidTy(llvmDialect)
+          ? LLVM::LLVMType::getVoidTy(&getContext())
           : unwrap(packFunctionResults(type.getResults()));
   if (!resultType)
     return {};
@@ -331,7 +331,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
 
   LLVM::LLVMType resultType =
       type.getNumResults() == 0
-          ? LLVM::LLVMType::getVoidTy(llvmDialect)
+          ? LLVM::LLVMType::getVoidTy(&getContext())
           : unwrap(packFunctionResults(type.getResults()));
   if (!resultType)
     return {};
@@ -400,7 +400,7 @@ static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
 
 Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
   auto rankTy = getIndexType();
-  auto ptrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+  auto ptrTy = LLVM::LLVMType::getInt8PtrTy(&getContext());
   return LLVM::LLVMType::getStructTy(rankTy, ptrTy);
 }
 
@@ -853,11 +853,11 @@ LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
 }
 
 LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
-  return LLVM::LLVMType::getVoidTy(&getDialect());
+  return LLVM::LLVMType::getVoidTy(&typeConverter.getContext());
 }
 
 LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const {
-  return LLVM::LLVMType::getInt8PtrTy(&getDialect());
+  return LLVM::LLVMType::getInt8PtrTy(&typeConverter.getContext());
 }
 
 Value ConvertToLLVMPattern::createIndexConstant(
@@ -2025,9 +2025,10 @@ static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
                                          unrankedMemrefs, sizes);
 
   // Get frequently used types.
-  auto voidType = LLVM::LLVMType::getVoidTy(typeConverter.getDialect());
-  auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect());
-  auto i1Type = LLVM::LLVMType::getInt1Ty(typeConverter.getDialect());
+  MLIRContext *context = builder.getContext();
+  auto voidType = LLVM::LLVMType::getVoidTy(context);
+  auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(context);
+  auto i1Type = LLVM::LLVMType::getInt1Ty(context);
   LLVM::LLVMType indexType = typeConverter.getIndexType();
 
   // Find the malloc and free, or declare them if necessary.
@@ -3168,7 +3169,7 @@ struct GenericAtomicRMWOpLowering
     // Append the cmpxchg op to the end of the loop block.
     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
-    auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
+    auto boolType = LLVM::LLVMType::getInt1Ty(rewriter.getContext());
     auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
         loc, pairType, dataPtr, loopArgument, result, successOrdering,
@@ -3330,13 +3331,13 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
     resultTypes.push_back(converted);
   }
 
-  return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
+  return LLVM::LLVMType::getStructTy(&getContext(), resultTypes);
 }
 
 Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
                                                     OpBuilder &builder) {
   auto *context = builder.getContext();
-  auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect());
+  auto int64Ty = LLVM::LLVMType::getInt64Ty(builder.getContext());
   auto indexType = IndexType::get(context);
   // Alloca with proper alignment. We do not expect optimizations of this
   // alloca op and so we omit allocating at the entry block.
index 011143b..f5d66f2 100644 (file)
@@ -715,7 +715,7 @@ public:
 
     // Remaining extraction of element from 1-D LLVM vector
     auto position = positionAttrs.back().cast<IntegerAttr>();
-    auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
+    auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
     extracted =
         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
@@ -832,7 +832,7 @@ public:
     }
 
     // Insertion of an element into a 1-D LLVM vector.
-    auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
+    auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
     Value inserted = rewriter.create<LLVM::InsertElementOp>(
         loc, typeConverter.convertType(oneDVectorType), extracted,
@@ -1074,7 +1074,7 @@ public:
     if (failed(successStrides) || !isContiguous)
       return failure();
 
-    auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
+    auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
 
     // Create descriptor.
     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
@@ -1263,11 +1263,10 @@ private:
                  int64_t rank) const {
     Location loc = op->getLoc();
     if (rank == 0) {
-      if (value.getType() ==
-          LLVM::LLVMType::getInt1Ty(typeConverter.getDialect())) {
+      if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) {
         // Convert i1 (bool) to i32 so we can use the print_i32 method.
         // This avoids the need for a print_i1 method with an unclear ABI.
-        auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect());
+        auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
         auto trueVal = rewriter.create<ConstantOp>(
             loc, i32Type, rewriter.getI32IntegerAttr(1));
         auto falseVal = rewriter.create<ConstantOp>(
@@ -1303,8 +1302,8 @@ private:
   }
 
   // Helper for printer method declaration (first hit) and lookup.
-  static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
-                             StringRef name, ArrayRef<LLVM::LLVMType> params) {
+  static Operation *getPrint(Operation *op, StringRef name,
+                             ArrayRef<LLVM::LLVMType> params) {
     auto module = op->getParentOfType<ModuleOp>();
     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
     if (func)
@@ -1312,42 +1311,39 @@ private:
     OpBuilder moduleBuilder(module.getBodyRegion());
     return moduleBuilder.create<LLVM::LLVMFuncOp>(
         op->getLoc(), name,
-        LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
-                                      params, /*isVarArg=*/false));
+        LLVM::LLVMType::getFunctionTy(
+            LLVM::LLVMType::getVoidTy(op->getContext()), params,
+            /*isVarArg=*/false));
   }
 
   // Helpers for method names.
   Operation *getPrintI32(Operation *op) const {
-    LLVM::LLVMDialect *dialect = typeConverter.getDialect();
-    return getPrint(op, dialect, "print_i32",
-                    LLVM::LLVMType::getInt32Ty(dialect));
+    return getPrint(op, "print_i32",
+                    LLVM::LLVMType::getInt32Ty(op->getContext()));
   }
   Operation *getPrintI64(Operation *op) const {
-    LLVM::LLVMDialect *dialect = typeConverter.getDialect();
-    return getPrint(op, dialect, "print_i64",
-                    LLVM::LLVMType::getInt64Ty(dialect));
+    return getPrint(op, "print_i64",
+                    LLVM::LLVMType::getInt64Ty(op->getContext()));
   }
   Operation *getPrintFloat(Operation *op) const {
-    LLVM::LLVMDialect *dialect = typeConverter.getDialect();
-    return getPrint(op, dialect, "print_f32",
-                    LLVM::LLVMType::getFloatTy(dialect));
+    return getPrint(op, "print_f32",
+                    LLVM::LLVMType::getFloatTy(op->getContext()));
   }
   Operation *getPrintDouble(Operation *op) const {
-    LLVM::LLVMDialect *dialect = typeConverter.getDialect();
-    return getPrint(op, dialect, "print_f64",
-                    LLVM::LLVMType::getDoubleTy(dialect));
+    return getPrint(op, "print_f64",
+                    LLVM::LLVMType::getDoubleTy(op->getContext()));
   }
   Operation *getPrintOpen(Operation *op) const {
-    return getPrint(op, typeConverter.getDialect(), "print_open", {});
+    return getPrint(op, "print_open", {});
   }
   Operation *getPrintClose(Operation *op) const {
-    return getPrint(op, typeConverter.getDialect(), "print_close", {});
+    return getPrint(op, "print_close", {});
   }
   Operation *getPrintComma(Operation *op) const {
-    return getPrint(op, typeConverter.getDialect(), "print_comma", {});
+    return getPrint(op, "print_comma", {});
   }
   Operation *getPrintNewline(Operation *op) const {
-    return getPrint(op, typeConverter.getDialect(), "print_newline", {});
+    return getPrint(op, "print_newline", {});
   }
 };
 
index 6a70af4..876fd05 100644 (file)
@@ -101,8 +101,7 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
 
   // The result type is either i1 or a vector type <? x i1> if the inputs are
   // vectors.
-  auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
-  auto resultType = LLVMType::getInt1Ty(dialect);
+  auto resultType = LLVMType::getInt1Ty(builder.getContext());
   auto argType = type.dyn_cast<LLVM::LLVMType>();
   if (!argType)
     return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
@@ -393,11 +392,9 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
       return parser.emitError(trailingTypeLoc,
                               "expected function with 0 or 1 result");
 
-    auto *llvmDialect =
-        builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
     LLVM::LLVMType llvmResultType;
     if (funcType.getNumResults() == 0) {
-      llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
+      llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
     } else {
       llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
       if (!llvmResultType)
@@ -601,11 +598,9 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
                               "expected function with 0 or 1 result");
 
     Builder &builder = parser.getBuilder();
-    auto *llvmDialect =
-        builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
     LLVM::LLVMType llvmResultType;
     if (funcType.getNumResults() == 0) {
-      llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
+      llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
     } else {
       llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
       if (!llvmResultType)
@@ -1101,9 +1096,8 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
   if (types.empty()) {
     if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
       MLIRContext *context = parser.getBuilder().getContext();
-      auto *dialect = context->getRegisteredDialect<LLVMDialect>();
       auto arrayType = LLVM::LLVMType::getArrayTy(
-          LLVM::LLVMType::getInt8Ty(dialect), strAttr.getValue().size());
+          LLVM::LLVMType::getInt8Ty(context), strAttr.getValue().size());
       types.push_back(arrayType);
     } else {
       return parser.emitError(parser.getNameLoc(),
@@ -1265,14 +1259,8 @@ static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
     llvmInputs.push_back(llvmTy);
   }
 
-  // Get the dialect from the input type, if any exist.  Look it up in the
-  // context otherwise.
-  LLVMDialect *dialect =
-      llvmInputs.empty() ? b.getContext()->getRegisteredDialect<LLVMDialect>()
-                         : &llvmInputs.front().getDialect();
-
   // No output is denoted as "void" in LLVM type system.
-  LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect)
+  LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(b.getContext())
                                         : outputs.front().dyn_cast<LLVMType>();
   if (!llvmOutput) {
     parser.emitError(loc, "failed to construct function type: expected LLVM "
@@ -1605,8 +1593,7 @@ static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
       parser.resolveOperand(val, type, result.operands))
     return failure();
 
-  auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
-  auto boolType = LLVMType::getInt1Ty(dialect);
+  auto boolType = LLVMType::getInt1Ty(builder.getContext());
   auto resultType = LLVMType::getStructTy(type, boolType);
   result.addTypes(resultType);
 
@@ -1777,8 +1764,7 @@ LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
 
 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
                                      StringRef name, StringRef value,
-                                     LLVM::Linkage linkage,
-                                     LLVM::LLVMDialect *llvmDialect) {
+                                     LLVM::Linkage linkage) {
   assert(builder.getInsertionBlock() &&
          builder.getInsertionBlock()->getParentOp() &&
          "expected builder to point to a block constrained in an op");
@@ -1788,8 +1774,9 @@ Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
 
   // Create the global at the entry of the module.
   OpBuilder moduleBuilder(module.getBodyRegion());
-  auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
-                                         value.size());
+  MLIRContext *ctx = builder.getContext();
+  auto type =
+      LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(ctx), value.size());
   auto global = moduleBuilder.create<LLVM::GlobalOp>(
       loc, type, /*isConstant=*/true, linkage, name,
       builder.getStringAttr(value));
@@ -1797,10 +1784,9 @@ Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
   // Get the pointer to the first character in the global string.
   Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
   Value cst0 = builder.create<LLVM::ConstantOp>(
-      loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
+      loc, LLVM::LLVMType::getInt64Ty(ctx),
       builder.getIntegerAttr(builder.getIndexType(), 0));
-  return builder.create<LLVM::GEPOp>(loc,
-                                     LLVM::LLVMType::getInt8PtrTy(llvmDialect),
+  return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMType::getInt8PtrTy(ctx),
                                      globalPtr, ArrayRef<Value>({cst0, cst0}));
 }
 
index fa25f2d..f8cadeb 100644 (file)
@@ -127,35 +127,35 @@ bool LLVMType::isStructTy() { return isa<LLVMStructType>(); }
 //----------------------------------------------------------------------------//
 // Utilities used to generate floating point types.
 
-LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
-  return LLVMDoubleType::get(dialect->getContext());
+LLVMType LLVMType::getDoubleTy(MLIRContext *context) {
+  return LLVMDoubleType::get(context);
 }
 
-LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
-  return LLVMFloatType::get(dialect->getContext());
+LLVMType LLVMType::getFloatTy(MLIRContext *context) {
+  return LLVMFloatType::get(context);
 }
 
-LLVMType LLVMType::getBFloatTy(LLVMDialect *dialect) {
-  return LLVMBFloatType::get(dialect->getContext());
+LLVMType LLVMType::getBFloatTy(MLIRContext *context) {
+  return LLVMBFloatType::get(context);
 }
 
-LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
-  return LLVMHalfType::get(dialect->getContext());
+LLVMType LLVMType::getHalfTy(MLIRContext *context) {
+  return LLVMHalfType::get(context);
 }
 
-LLVMType LLVMType::getFP128Ty(LLVMDialect *dialect) {
-  return LLVMFP128Type::get(dialect->getContext());
+LLVMType LLVMType::getFP128Ty(MLIRContext *context) {
+  return LLVMFP128Type::get(context);
 }
 
-LLVMType LLVMType::getX86_FP80Ty(LLVMDialect *dialect) {
-  return LLVMX86FP80Type::get(dialect->getContext());
+LLVMType LLVMType::getX86_FP80Ty(MLIRContext *context) {
+  return LLVMX86FP80Type::get(context);
 }
 
 //----------------------------------------------------------------------------//
 // Utilities used to generate integer types.
 
-LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
-  return LLVMIntegerType::get(dialect->getContext(), numBits);
+LLVMType LLVMType::getIntNTy(MLIRContext *context, unsigned numBits) {
+  return LLVMIntegerType::get(context, numBits);
 }
 
 //----------------------------------------------------------------------------//
@@ -170,9 +170,9 @@ LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
   return LLVMFunctionType::get(result, params, isVarArg);
 }
 
-LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
+LLVMType LLVMType::getStructTy(MLIRContext *context,
                                ArrayRef<LLVMType> elements, bool isPacked) {
-  return LLVMStructType::getLiteral(dialect->getContext(), elements, isPacked);
+  return LLVMStructType::getLiteral(context, elements, isPacked);
 }
 
 LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
@@ -182,8 +182,8 @@ LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
 //----------------------------------------------------------------------------//
 // Void type utilities.
 
-LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
-  return LLVMVoidType::get(dialect->getContext());
+LLVMType LLVMType::getVoidTy(MLIRContext *context) {
+  return LLVMVoidType::get(context);
 }
 
 bool LLVMType::isVoidTy() { return isa<LLVMVoidType>(); }
@@ -191,7 +191,7 @@ bool LLVMType::isVoidTy() { return isa<LLVMVoidType>(); }
 //----------------------------------------------------------------------------//
 // Creation and setting of LLVM's identified struct types
 
-LLVMType LLVMType::createStructTy(LLVMDialect *dialect,
+LLVMType LLVMType::createStructTy(MLIRContext *context,
                                   ArrayRef<LLVMType> elements,
                                   Optional<StringRef> name, bool isPacked) {
   assert(name.hasValue() &&
@@ -200,8 +200,7 @@ LLVMType LLVMType::createStructTy(LLVMDialect *dialect,
   std::string stringName = stringNameBase.str();
   unsigned counter = 0;
   do {
-    auto type =
-        LLVMStructType::getIdentified(dialect->getContext(), stringName);
+    auto type = LLVMStructType::getIdentified(context, stringName);
     if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
       counter += 1;
       stringName =
index 9a694a5..9a09488 100644 (file)
@@ -41,12 +41,6 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
     p << " : " << op->getResultTypes();
 }
 
-static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
-  return parser.getBuilder()
-      .getContext()
-      ->getRegisteredDialect<LLVM::LLVMDialect>();
-}
-
 // <operation> ::=
 //     `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
 //      ({return_value_and_is_valid})? : result_type
@@ -69,7 +63,7 @@ static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
     break;
   }
 
-  auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
+  auto int32Ty = LLVM::LLVMType::getInt32Ty(parser.getBuilder().getContext());
   return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
                                 parser.getNameLoc(), result.operands);
 }
@@ -77,9 +71,9 @@ static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
 static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
                                          OperationState &result) {
-  auto llvmDialect = getLlvmDialect(parser);
-  auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
-  auto int1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
+  MLIRContext *context = parser.getBuilder().getContext();
+  auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
+  auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
 
   SmallVector<OpAsmParser::OperandType, 8> ops;
   Type type;
@@ -92,14 +86,14 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
 }
 
 static LogicalResult verify(MmaOp op) {
-  auto dialect = op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
-  auto f16Ty = LLVM::LLVMType::getHalfTy(dialect);
+  MLIRContext *context = op.getContext();
+  auto f16Ty = LLVM::LLVMType::getHalfTy(context);
   auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2);
-  auto f32Ty = LLVM::LLVMType::getFloatTy(dialect);
+  auto f32Ty = LLVM::LLVMType::getFloatTy(context);
   auto f16x2x4StructTy = LLVM::LLVMType::getStructTy(
-      dialect, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+      context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
   auto f32x8StructTy = LLVM::LLVMType::getStructTy(
-      dialect, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
+      context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
 
   SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
                                       op.getOperandTypes().end());
index f3771dd..47089b9 100644 (file)
@@ -34,12 +34,6 @@ using namespace ROCDL;
 // Parsing for ROCDL ops
 //===----------------------------------------------------------------------===//
 
-static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
-  return parser.getBuilder()
-      .getContext()
-      ->getRegisteredDialect<LLVM::LLVMDialect>();
-}
-
 // <operation> ::=
 //     `llvm.amdgcn.buffer.load.* %rsrc, %vindex, %offset, %glc, %slc :
 //     result_type`
@@ -51,8 +45,9 @@ static ParseResult parseROCDLMubufLoadOp(OpAsmParser &parser,
       parser.addTypeToList(type, result.types))
     return failure();
 
-  auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
-  auto int1Ty = LLVM::LLVMType::getInt1Ty(getLlvmDialect(parser));
+  MLIRContext *context = parser.getBuilder().getContext();
+  auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
+  auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
   auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4);
   return parser.resolveOperands(ops,
                                 {i32x4Ty, int32Ty, int32Ty, int1Ty, int1Ty},
@@ -69,8 +64,9 @@ static ParseResult parseROCDLMubufStoreOp(OpAsmParser &parser,
   if (parser.parseOperandList(ops, 6) || parser.parseColonType(type))
     return failure();
 
-  auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
-  auto int1Ty = LLVM::LLVMType::getInt1Ty(getLlvmDialect(parser));
+  MLIRContext *context = parser.getBuilder().getContext();
+  auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
+  auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
   auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4);
 
   if (parser.resolveOperands(ops,