[mlir][GPUToVulkan] Port conversion passes and `mlir-vulkan-runner` to opaque pointers
authorMarkus Böck <markus.boeck02@gmail.com>
Tue, 21 Feb 2023 09:43:18 +0000 (10:43 +0100)
committerMarkus Böck <markus.boeck02@gmail.com>
Fri, 24 Feb 2023 16:04:42 +0000 (17:04 +0100)
Part of https://discourse.llvm.org/t/rfc-switching-the-llvm-dialect-and-dialect-lowerings-to-opaque-pointers/68179

This patch adds the new pass option 'use-opaque-pointers' to `-launch-func-to-vulkan` instructing the pass to emit LLVM opaque pointers instead of typed pointers.

Note that the pass as it was previously implemented relied on the fact LLVM pointers carried an element type. The passed used this information to deduce both the rank of a "lowered-to-llvm" MemRef as well as the element type. Since the element type when using LLVM opaque pointers is completely erased it is not possible to deduce the element type.

I therefore added a new attribute that is attached to the `vulkanLaunch` call alongside the binary blob and entry point name by the `-convert-gpu-launch-to-vulkan-launch` pass. It simply attaches a type array specifying the element types of each memref. This way the `-launch-func-to-vulkan` can simply read out the element type from the attribute.
The rank can still be deduced from the auto-generated C interface from `FinalizeMemRefToLLVM`. This is admittedly a bit fragile but I was not sure whether it was worth the effort to also add a rank array attribute.

As a last step, the use of opaque-pointers in `mlir-vulkan-runners` codegen pipeline was also enabled, since all covnersion passes used fully support it.

Differential Revision: https://reviews.llvm.org/D144460

mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir [new file with mode: 0644]
mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp

index 5a528df18e6cb274aea7ffa32d0ce5736a3bbd2e..f69720328f2a42262ae27a5b0dfcadebf4a8dbf3 100644 (file)
@@ -23,14 +23,12 @@ namespace mlir {
 class ModuleOp;
 template <typename T>
 class OperationPass;
+class Pass;
 
-#define GEN_PASS_DECL_CONVERTVULKANLAUNCHFUNCTOVULKANCALLS
+#define GEN_PASS_DECL_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS
 #define GEN_PASS_DECL_CONVERTGPULAUNCHFUNCTOVULKANLAUNCHFUNC
 #include "mlir/Conversion/Passes.h.inc"
 
-std::unique_ptr<OperationPass<ModuleOp>>
-createConvertVulkanLaunchFuncToVulkanCallsPass();
-
 std::unique_ptr<OperationPass<mlir::ModuleOp>>
 createConvertGpuLaunchFuncToVulkanLaunchFuncPass();
 
index b069627298e1d3ceeafb11b604d00419fba52a26..70ecfba5ea49f26d529a2b1346783099af48406b 100644 (file)
@@ -492,14 +492,20 @@ def ConvertGpuLaunchFuncToVulkanLaunchFunc
   let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
-def ConvertVulkanLaunchFuncToVulkanCalls
+def ConvertVulkanLaunchFuncToVulkanCallsPass
     : Pass<"launch-func-to-vulkan", "ModuleOp"> {
   let summary = "Convert vulkanLaunch external call to Vulkan runtime external "
                 "calls";
   let description = [{
     This pass is only intended for the mlir-vulkan-runner.
   }];
-  let constructor = "mlir::createConvertVulkanLaunchFuncToVulkanCallsPass()";
+
+  let options = [
+     Option<"useOpaquePointers", "use-opaque-pointers", "bool",
+            /*default=*/"false", "Generate LLVM IR using opaque pointers "
+            "instead of typed pointers">
+  ];
+
   let dependentDialects = ["LLVM::LLVMDialect"];
 }
 
index b6aff16407495d60aeac7d6c1b219c23cf5b306b..e4ac64252acc430b5767782a5b4d5f7bb0bc25b8 100644 (file)
@@ -35,6 +35,7 @@ using namespace mlir;
 
 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
+static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types";
 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
 
 namespace {
@@ -189,6 +190,18 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
   vulkanLaunchCallOp->setAttr(kSPIRVEntryPointAttrName,
                               launchOp.getKernelName());
 
+  // Add MemRef element types before they're lost when lowering to LLVM.
+  SmallVector<Type> elementTypes;
+  for (Type type : llvm::drop_begin(launchOp.getOperandTypes(),
+                                    gpu::LaunchOp::kNumConfigOperands)) {
+    // The below cast always succeeds as it has already been verified in
+    // 'declareVulkanLaunchFunc' that these are MemRefs with compatible element
+    // types.
+    elementTypes.push_back(type.cast<MemRefType>().getElementType());
+  }
+  vulkanLaunchCallOp->setAttr(kSPIRVElementTypesAttrName,
+                              builder.getTypeArrayAttr(elementTypes));
+
   launchOp.erase();
 }
 
index 9fca86317ed115f822bf1aa47c4031fa1b3b9e5a..78d1f6790c859ae0c1fd97edd234566171f22f81 100644 (file)
@@ -25,7 +25,7 @@
 #include "llvm/Support/FormatVariadic.h"
 
 namespace mlir {
-#define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLS
+#define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS
 #include "mlir/Conversion/Passes.h.inc"
 } // namespace mlir
 
@@ -42,6 +42,7 @@ static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
+static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types";
 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
 
 namespace {
@@ -58,14 +59,17 @@ namespace {
 /// * deinitVulkan         -- deinitializes vulkan runtime
 ///
 class VulkanLaunchFuncToVulkanCallsPass
-    : public impl::ConvertVulkanLaunchFuncToVulkanCallsBase<
+    : public impl::ConvertVulkanLaunchFuncToVulkanCallsPassBase<
           VulkanLaunchFuncToVulkanCallsPass> {
 private:
   void initializeCachedTypes() {
     llvmFloatType = Float32Type::get(&getContext());
     llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
-    llvmPointerType =
-        LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
+    if (useOpaquePointers)
+      llvmPointerType = LLVM::LLVMPointerType::get(&getContext());
+    else
+      llvmPointerType =
+          LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
     llvmInt32Type = IntegerType::get(&getContext(), 32);
     llvmInt64Type = IntegerType::get(&getContext(), 64);
   }
@@ -81,7 +85,9 @@ private:
     //   int64_t sizes[Rank]; // omitted when rank == 0
     //   int64_t strides[Rank]; // omitted when rank == 0
     // };
-    auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType);
+    auto llvmPtrToElementType = useOpaquePointers
+                                    ? llvmPointerType
+                                    : LLVM::LLVMPointerType::get(elemenType);
     auto llvmArrayRankElementSizeType =
         LLVM::LLVMArrayType::get(getInt64Type(), rank);
 
@@ -131,9 +137,8 @@ private:
   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
 
-  /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
-  LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
-                                        uint32_t &rank, Type &type);
+  /// Deduces a rank from the given 'launchCallArg`.
+  LogicalResult deduceMemRefRank(Value launchCallArg, uint32_t &rank);
 
   /// Returns a string representation from the given `type`.
   StringRef stringifyType(Type type) {
@@ -154,6 +159,8 @@ private:
   }
 
 public:
+  using Base::Base;
+
   void runOnOperation() override;
 
 private:
@@ -163,8 +170,14 @@ private:
   Type llvmInt32Type;
   Type llvmInt64Type;
 
+  struct SPIRVAttributes {
+    StringAttr blob;
+    StringAttr entryPoint;
+    SmallVector<Type> elementTypes;
+  };
+
   // TODO: Use an associative array to support multiple vulkan launch calls.
-  std::pair<StringAttr, StringAttr> spirvAttributes;
+  SPIRVAttributes spirvAttributes;
   /// The number of vulkan launch configuration operands, placed at the leading
   /// positions of the operand list.
   static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
@@ -209,7 +222,24 @@ void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
     return signalPassFailure();
   }
 
-  spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
+  auto spirvElementTypesAttr =
+      vulkanLaunchCallOp->getAttrOfType<ArrayAttr>(kSPIRVElementTypesAttrName);
+  if (!spirvElementTypesAttr) {
+    vulkanLaunchCallOp.emitError()
+        << "missing " << kSPIRVElementTypesAttrName << " attribute";
+    return signalPassFailure();
+  }
+  if (llvm::any_of(spirvElementTypesAttr,
+                   [](Attribute attr) { return !isa<TypeAttr>(attr); })) {
+    vulkanLaunchCallOp.emitError()
+        << "expected " << spirvElementTypesAttr << " to be an array of types";
+    return signalPassFailure();
+  }
+
+  spirvAttributes.blob = spirvBlobAttr;
+  spirvAttributes.entryPoint = spirvEntryPointNameAttr;
+  spirvAttributes.elementTypes =
+      llvm::to_vector(spirvElementTypesAttr.getAsValueRange<mlir::TypeAttr>());
 }
 
 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
@@ -226,17 +256,23 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
   Value descriptorSet =
       builder.create<LLVM::ConstantOp>(loc, getInt32Type(), 0);
 
-  for (const auto &en :
+  for (auto [index, ptrToMemRefDescriptor] :
        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
            kVulkanLaunchNumConfigOperands))) {
     // Create LLVM constant for the descriptor binding index.
     Value descriptorBinding =
-        builder.create<LLVM::ConstantOp>(loc, getInt32Type(), en.index());
+        builder.create<LLVM::ConstantOp>(loc, getInt32Type(), index);
+
+    if (index >= spirvAttributes.elementTypes.size()) {
+      cInterfaceVulkanLaunchCallOp.emitError()
+          << kSPIRVElementTypesAttrName << " missing element type for "
+          << ptrToMemRefDescriptor;
+      return signalPassFailure();
+    }
 
-    auto ptrToMemRefDescriptor = en.value();
     uint32_t rank = 0;
-    Type type;
-    if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
+    Type type = spirvAttributes.elementTypes[index];
+    if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
       cInterfaceVulkanLaunchCallOp.emitError()
           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
       return signalPassFailure();
@@ -246,7 +282,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
         llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
     // Special case for fp16 type. Since it is not a supported type in C we use
     // int16_t and bitcast the descriptor.
-    if (type.isa<Float16Type>()) {
+    if (!useOpaquePointers && type.isa<Float16Type>()) {
       auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16));
       ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
           loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
@@ -259,15 +295,24 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
   }
 }
 
-LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
-    Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) {
-  auto llvmPtrDescriptorTy =
-      ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
-  if (!llvmPtrDescriptorTy)
+LogicalResult
+VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg,
+                                                    uint32_t &rank) {
+  // Deduce the rank from the type used to allocate the lowered MemRef.
+  auto alloca = launchCallArg.getDefiningOp<LLVM::AllocaOp>();
+  if (!alloca)
     return failure();
 
-  auto llvmDescriptorTy =
-      llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>();
+  LLVM::LLVMStructType llvmDescriptorTy;
+  if (std::optional<Type> elementType = alloca.getElemType()) {
+    llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(*elementType);
+  } else {
+    // This case is only possible if we are not using opaque pointers
+    // since opaque pointer producing allocas require an element type.
+    llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(
+        alloca.getRes().getType().getElementType());
+  }
+
   // template <typename Elem, size_t Rank>
   // struct {
   //   Elem *allocated;
@@ -279,9 +324,6 @@ LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
   if (!llvmDescriptorTy)
     return failure();
 
-  type = llvmDescriptorTy.getBody()[0]
-             .cast<LLVM::LLVMPointerType>()
-             .getElementType();
   if (llvmDescriptorTy.getBody().size() == 3) {
     rank = 0;
     return success();
@@ -339,7 +381,9 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
         auto fnType = LLVM::LLVMFunctionType::get(
             getVoidType(),
             {getPointerType(), getInt32Type(), getInt32Type(),
-             LLVM::LLVMPointerType::get(getMemRefType(i, type))},
+             useOpaquePointers
+                 ? llvmPointerType
+                 : LLVM::LLVMPointerType::get(getMemRefType(i, type))},
             /*isVarArg=*/false);
         builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
       }
@@ -368,7 +412,7 @@ Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
                                   shaderName, LLVM::Linkage::Internal,
-                                  /*TODO:useOpaquePointers=*/false);
+                                  useOpaquePointers);
 }
 
 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
@@ -385,12 +429,12 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
   // that data to runtime call.
   Value ptrToSPIRVBinary = LLVM::createGlobalString(
-      loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
-      LLVM::Linkage::Internal, /*TODO:useOpaquePointers=*/false);
+      loc, builder, kSPIRVBinary, spirvAttributes.blob.getValue(),
+      LLVM::Linkage::Internal, useOpaquePointers);
 
   // Create LLVM constant for the size of SPIR-V binary shader.
   Value binarySize = builder.create<LLVM::ConstantOp>(
-      loc, getInt32Type(), spirvAttributes.first.getValue().size());
+      loc, getInt32Type(), spirvAttributes.blob.getValue().size());
 
   // Create call to `bindMemRef` for each memref operand.
   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
@@ -402,7 +446,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
       ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
   // Create LLVM global with entry point name.
   Value entryPointName = createEntryPointNameConstant(
-      spirvAttributes.second.getValue(), loc, builder);
+      spirvAttributes.entryPoint.getValue(), loc, builder);
   // Create call to `setEntryPoint` runtime function with the given pointer to
   // entry point name.
   builder.create<LLVM::CallOp>(loc, TypeRange(), kSetEntryPoint,
@@ -428,8 +472,3 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
 
   cInterfaceVulkanLaunchCallOp.erase();
 }
-
-std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
-mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
-  return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
-}
index b99584ba126f0801eaab1586744563cf18309912..c77bf238c8ca212b6a7dadb4bf5e282c06f4db0e 100644 (file)
@@ -1,63 +1,62 @@
-// RUN: mlir-opt %s -launch-func-to-vulkan | FileCheck %s
+// RUN: mlir-opt %s -launch-func-to-vulkan='use-opaque-pointers=1' | FileCheck %s
 
 // CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name
 // CHECK: llvm.mlir.global internal constant @SPIRV_BIN
-// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr<i8>
+// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr
 // CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
 // CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
 // CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
-// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
-// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> ()
+// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr) -> ()
+// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr, !llvm.ptr, i32) -> ()
 // CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
 // CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
-// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
-// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> ()
-// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr, !llvm.ptr) -> ()
+// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i64, i64, i64) -> ()
+// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> ()
+// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> ()
 
-// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)
+// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr, i32, i32, !llvm.ptr)
 
 module attributes {gpu.container_module} {
-  llvm.func @malloc(i64) -> !llvm.ptr<i8>
+  llvm.func @malloc(i64) -> !llvm.ptr
   llvm.func @foo() {
     %0 = llvm.mlir.constant(12 : index) : i64
-    %1 = llvm.mlir.null : !llvm.ptr<f32>
+    %1 = llvm.mlir.null : !llvm.ptr
     %2 = llvm.mlir.constant(1 : index) : i64
-    %3 = llvm.getelementptr %1[%2] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-    %4 = llvm.ptrtoint %3 : !llvm.ptr<f32> to i64
+    %3 = llvm.getelementptr %1[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    %4 = llvm.ptrtoint %3 : !llvm.ptr to i64
     %5 = llvm.mul %0, %4 : i64
-    %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr<i8>
-    %7 = llvm.bitcast %6 : !llvm.ptr<i8> to !llvm.ptr<f32>
-    %8 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %9 = llvm.insertvalue %7, %8[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %10 = llvm.insertvalue %7, %9[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr
+    %8 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %9 = llvm.insertvalue %6, %8[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %10 = llvm.insertvalue %6, %9[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     %11 = llvm.mlir.constant(0 : index) : i64
-    %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     %13 = llvm.mlir.constant(1 : index) : i64
-    %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     %16 = llvm.mlir.constant(1 : index) : i64
-    %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_entry_point = "kernel"}
-    : (i64, i64, i64, !llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64) -> ()
+    %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_element_types = [f32], spirv_entry_point = "kernel"}
+    : (i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> ()
     llvm.return
   }
-  llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr<f32>, %arg7: !llvm.ptr<f32>, %arg8: i64, %arg9: i64, %arg10: i64) {
-    %0 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr, %arg7: !llvm.ptr, %arg8: i64, %arg9: i64, %arg10: i64) {
+    %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     %6 = llvm.mlir.constant(1 : index) : i64
-    %7 = llvm.alloca %6 x !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
-    llvm.store %5, %7 : !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
-    llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
+    %7 = llvm.alloca %6 x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr
+    llvm.store %5, %7 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr
+    llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr) -> ()
     llvm.return
   }
-  llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>)
+  llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr)
 }
index 3b176c3cf8346d259728f999d0b7f2548a94f6f2..13eb3a194df44498bb86e9c53fc46b6f7d5b9128 100644 (file)
@@ -2,7 +2,7 @@
 
 // CHECK: %[[resource:.*]] = memref.alloc() : memref<12xf32>
 // CHECK: %[[index:.*]] = arith.constant 1 : index
-// CHECK: call @vulkanLaunch(%[[index]], %[[index]], %[[index]], %[[resource]]) {spirv_blob = "{{.*}}", spirv_entry_point = "kernel"}
+// CHECK: call @vulkanLaunch(%[[index]], %[[index]], %[[index]], %[[resource]]) {spirv_blob = "{{.*}}", spirv_element_types = [f32], spirv_entry_point = "kernel"}
 
 module attributes {gpu.container_module} {
   spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]> {
diff --git a/mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir b/mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir
new file mode 100644 (file)
index 0000000..67bd640
--- /dev/null
@@ -0,0 +1,63 @@
+// RUN: mlir-opt %s -launch-func-to-vulkan='use-opaque-pointers=0' | FileCheck %s
+
+// CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name
+// CHECK: llvm.mlir.global internal constant @SPIRV_BIN
+// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr<i8>
+// CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
+// CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
+// CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
+// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
+// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> ()
+// CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
+// CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
+// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> ()
+// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
+
+// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)
+
+module attributes {gpu.container_module} {
+  llvm.func @malloc(i64) -> !llvm.ptr<i8>
+  llvm.func @foo() {
+    %0 = llvm.mlir.constant(12 : index) : i64
+    %1 = llvm.mlir.null : !llvm.ptr<f32>
+    %2 = llvm.mlir.constant(1 : index) : i64
+    %3 = llvm.getelementptr %1[%2] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+    %4 = llvm.ptrtoint %3 : !llvm.ptr<f32> to i64
+    %5 = llvm.mul %0, %4 : i64
+    %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr<i8>
+    %7 = llvm.bitcast %6 : !llvm.ptr<i8> to !llvm.ptr<f32>
+    %8 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %9 = llvm.insertvalue %7, %8[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %10 = llvm.insertvalue %7, %9[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %11 = llvm.mlir.constant(0 : index) : i64
+    %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %13 = llvm.mlir.constant(1 : index) : i64
+    %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %16 = llvm.mlir.constant(1 : index) : i64
+    %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_element_types = [f32], spirv_entry_point = "kernel"}
+    : (i64, i64, i64, !llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64) -> ()
+    llvm.return
+  }
+  llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr<f32>, %arg7: !llvm.ptr<f32>, %arg8: i64, %arg9: i64, %arg10: i64) {
+    %0 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %6 = llvm.mlir.constant(1 : index) : i64
+    %7 = llvm.alloca %6 x !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
+    llvm.store %5, %7 : !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
+    llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
+    llvm.return
+  }
+  llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>)
+}
index caec6439ec747c5253645b265344d8a56fee9305..d196902eb169ff16ea4c01536746b14c44ca6deb 100644 (file)
@@ -68,16 +68,25 @@ static LogicalResult runMLIRPasses(Operation *op,
   if (options.spirvWebGPUPrepare)
     modulePM.addPass(spirv::createSPIRVWebGPUPreparePass());
 
+  auto enableOpaquePointers = [](auto passOption) {
+    passOption.useOpaquePointers = true;
+    return passOption;
+  };
+
   passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
-  passManager.addPass(createFinalizeMemRefToLLVMConversionPass());
-  passManager.addPass(createConvertVectorToLLVMPass());
+  passManager.addPass(createFinalizeMemRefToLLVMConversionPass(
+      enableOpaquePointers(FinalizeMemRefToLLVMConversionPassOptions{})));
+  passManager.addPass(createConvertVectorToLLVMPass(
+      enableOpaquePointers(ConvertVectorToLLVMPassOptions{})));
   passManager.nest<func::FuncOp>().addPass(LLVM::createRequestCWrappersPass());
   ConvertFuncToLLVMPassOptions funcToLLVMOptions{};
   funcToLLVMOptions.indexBitwidth =
       DataLayout(module).getTypeSizeInBits(IndexType::get(module.getContext()));
-  passManager.addPass(createConvertFuncToLLVMPass(funcToLLVMOptions));
+  passManager.addPass(
+      createConvertFuncToLLVMPass(enableOpaquePointers(funcToLLVMOptions)));
   passManager.addPass(createReconcileUnrealizedCastsPass());
-  passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
+  passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass(
+      enableOpaquePointers(ConvertVulkanLaunchFuncToVulkanCallsPassOptions{})));
 
   return passManager.run(module);
 }