Part of https://discourse.llvm.org/t/rfc-switching-the-llvm-dialect-and-dialect-lowerings-to-opaque-pointers/68179
This patch adds the new pass option 'use-opaque-pointers' to `-launch-func-to-vulkan` instructing the pass to emit LLVM opaque pointers instead of typed pointers.
Note that the pass as it was previously implemented relied on the fact LLVM pointers carried an element type. The passed used this information to deduce both the rank of a "lowered-to-llvm" MemRef as well as the element type. Since the element type when using LLVM opaque pointers is completely erased it is not possible to deduce the element type.
I therefore added a new attribute that is attached to the `vulkanLaunch` call alongside the binary blob and entry point name by the `-convert-gpu-launch-to-vulkan-launch` pass. It simply attaches a type array specifying the element types of each memref. This way the `-launch-func-to-vulkan` can simply read out the element type from the attribute.
The rank can still be deduced from the auto-generated C interface from `FinalizeMemRefToLLVM`. This is admittedly a bit fragile but I was not sure whether it was worth the effort to also add a rank array attribute.
As a last step, the use of opaque-pointers in `mlir-vulkan-runners` codegen pipeline was also enabled, since all covnersion passes used fully support it.
Differential Revision: https://reviews.llvm.org/D144460
class ModuleOp;
template <typename T>
class OperationPass;
+class Pass;
-#define GEN_PASS_DECL_CONVERTVULKANLAUNCHFUNCTOVULKANCALLS
+#define GEN_PASS_DECL_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS
#define GEN_PASS_DECL_CONVERTGPULAUNCHFUNCTOVULKANLAUNCHFUNC
#include "mlir/Conversion/Passes.h.inc"
-std::unique_ptr<OperationPass<ModuleOp>>
-createConvertVulkanLaunchFuncToVulkanCallsPass();
-
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createConvertGpuLaunchFuncToVulkanLaunchFuncPass();
let dependentDialects = ["spirv::SPIRVDialect"];
}
-def ConvertVulkanLaunchFuncToVulkanCalls
+def ConvertVulkanLaunchFuncToVulkanCallsPass
: Pass<"launch-func-to-vulkan", "ModuleOp"> {
let summary = "Convert vulkanLaunch external call to Vulkan runtime external "
"calls";
let description = [{
This pass is only intended for the mlir-vulkan-runner.
}];
- let constructor = "mlir::createConvertVulkanLaunchFuncToVulkanCallsPass()";
+
+ let options = [
+ Option<"useOpaquePointers", "use-opaque-pointers", "bool",
+ /*default=*/"false", "Generate LLVM IR using opaque pointers "
+ "instead of typed pointers">
+ ];
+
let dependentDialects = ["LLVM::LLVMDialect"];
}
static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
+static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types";
static constexpr const char *kVulkanLaunch = "vulkanLaunch";
namespace {
vulkanLaunchCallOp->setAttr(kSPIRVEntryPointAttrName,
launchOp.getKernelName());
+ // Add MemRef element types before they're lost when lowering to LLVM.
+ SmallVector<Type> elementTypes;
+ for (Type type : llvm::drop_begin(launchOp.getOperandTypes(),
+ gpu::LaunchOp::kNumConfigOperands)) {
+ // The below cast always succeeds as it has already been verified in
+ // 'declareVulkanLaunchFunc' that these are MemRefs with compatible element
+ // types.
+ elementTypes.push_back(type.cast<MemRefType>().getElementType());
+ }
+ vulkanLaunchCallOp->setAttr(kSPIRVElementTypesAttrName,
+ builder.getTypeArrayAttr(elementTypes));
+
launchOp.erase();
}
#include "llvm/Support/FormatVariadic.h"
namespace mlir {
-#define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLS
+#define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
+static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types";
static constexpr const char *kVulkanLaunch = "vulkanLaunch";
namespace {
/// * deinitVulkan -- deinitializes vulkan runtime
///
class VulkanLaunchFuncToVulkanCallsPass
- : public impl::ConvertVulkanLaunchFuncToVulkanCallsBase<
+ : public impl::ConvertVulkanLaunchFuncToVulkanCallsPassBase<
VulkanLaunchFuncToVulkanCallsPass> {
private:
void initializeCachedTypes() {
llvmFloatType = Float32Type::get(&getContext());
llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
- llvmPointerType =
- LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
+ if (useOpaquePointers)
+ llvmPointerType = LLVM::LLVMPointerType::get(&getContext());
+ else
+ llvmPointerType =
+ LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
llvmInt32Type = IntegerType::get(&getContext(), 32);
llvmInt64Type = IntegerType::get(&getContext(), 64);
}
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
- auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType);
+ auto llvmPtrToElementType = useOpaquePointers
+ ? llvmPointerType
+ : LLVM::LLVMPointerType::get(elemenType);
auto llvmArrayRankElementSizeType =
LLVM::LLVMArrayType::get(getInt64Type(), rank);
/// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
- /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
- LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
- uint32_t &rank, Type &type);
+ /// Deduces a rank from the given 'launchCallArg`.
+ LogicalResult deduceMemRefRank(Value launchCallArg, uint32_t &rank);
/// Returns a string representation from the given `type`.
StringRef stringifyType(Type type) {
}
public:
+ using Base::Base;
+
void runOnOperation() override;
private:
Type llvmInt32Type;
Type llvmInt64Type;
+ struct SPIRVAttributes {
+ StringAttr blob;
+ StringAttr entryPoint;
+ SmallVector<Type> elementTypes;
+ };
+
// TODO: Use an associative array to support multiple vulkan launch calls.
- std::pair<StringAttr, StringAttr> spirvAttributes;
+ SPIRVAttributes spirvAttributes;
/// The number of vulkan launch configuration operands, placed at the leading
/// positions of the operand list.
static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
return signalPassFailure();
}
- spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
+ auto spirvElementTypesAttr =
+ vulkanLaunchCallOp->getAttrOfType<ArrayAttr>(kSPIRVElementTypesAttrName);
+ if (!spirvElementTypesAttr) {
+ vulkanLaunchCallOp.emitError()
+ << "missing " << kSPIRVElementTypesAttrName << " attribute";
+ return signalPassFailure();
+ }
+ if (llvm::any_of(spirvElementTypesAttr,
+ [](Attribute attr) { return !isa<TypeAttr>(attr); })) {
+ vulkanLaunchCallOp.emitError()
+ << "expected " << spirvElementTypesAttr << " to be an array of types";
+ return signalPassFailure();
+ }
+
+ spirvAttributes.blob = spirvBlobAttr;
+ spirvAttributes.entryPoint = spirvEntryPointNameAttr;
+ spirvAttributes.elementTypes =
+ llvm::to_vector(spirvElementTypesAttr.getAsValueRange<mlir::TypeAttr>());
}
void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
Value descriptorSet =
builder.create<LLVM::ConstantOp>(loc, getInt32Type(), 0);
- for (const auto &en :
+ for (auto [index, ptrToMemRefDescriptor] :
llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
kVulkanLaunchNumConfigOperands))) {
// Create LLVM constant for the descriptor binding index.
Value descriptorBinding =
- builder.create<LLVM::ConstantOp>(loc, getInt32Type(), en.index());
+ builder.create<LLVM::ConstantOp>(loc, getInt32Type(), index);
+
+ if (index >= spirvAttributes.elementTypes.size()) {
+ cInterfaceVulkanLaunchCallOp.emitError()
+ << kSPIRVElementTypesAttrName << " missing element type for "
+ << ptrToMemRefDescriptor;
+ return signalPassFailure();
+ }
- auto ptrToMemRefDescriptor = en.value();
uint32_t rank = 0;
- Type type;
- if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
+ Type type = spirvAttributes.elementTypes[index];
+ if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
cInterfaceVulkanLaunchCallOp.emitError()
<< "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
return signalPassFailure();
llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
// Special case for fp16 type. Since it is not a supported type in C we use
// int16_t and bitcast the descriptor.
- if (type.isa<Float16Type>()) {
+ if (!useOpaquePointers && type.isa<Float16Type>()) {
auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16));
ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
}
}
-LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
- Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) {
- auto llvmPtrDescriptorTy =
- ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
- if (!llvmPtrDescriptorTy)
+LogicalResult
+VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg,
+ uint32_t &rank) {
+ // Deduce the rank from the type used to allocate the lowered MemRef.
+ auto alloca = launchCallArg.getDefiningOp<LLVM::AllocaOp>();
+ if (!alloca)
return failure();
- auto llvmDescriptorTy =
- llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>();
+ LLVM::LLVMStructType llvmDescriptorTy;
+ if (std::optional<Type> elementType = alloca.getElemType()) {
+ llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(*elementType);
+ } else {
+ // This case is only possible if we are not using opaque pointers
+ // since opaque pointer producing allocas require an element type.
+ llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(
+ alloca.getRes().getType().getElementType());
+ }
+
// template <typename Elem, size_t Rank>
// struct {
// Elem *allocated;
if (!llvmDescriptorTy)
return failure();
- type = llvmDescriptorTy.getBody()[0]
- .cast<LLVM::LLVMPointerType>()
- .getElementType();
if (llvmDescriptorTy.getBody().size() == 3) {
rank = 0;
return success();
auto fnType = LLVM::LLVMFunctionType::get(
getVoidType(),
{getPointerType(), getInt32Type(), getInt32Type(),
- LLVM::LLVMPointerType::get(getMemRefType(i, type))},
+ useOpaquePointers
+ ? llvmPointerType
+ : LLVM::LLVMPointerType::get(getMemRefType(i, type))},
/*isVarArg=*/false);
builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
}
std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
shaderName, LLVM::Linkage::Internal,
- /*TODO:useOpaquePointers=*/false);
+ useOpaquePointers);
}
void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
// Create LLVM global with SPIR-V binary data, so we can pass a pointer with
// that data to runtime call.
Value ptrToSPIRVBinary = LLVM::createGlobalString(
- loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
- LLVM::Linkage::Internal, /*TODO:useOpaquePointers=*/false);
+ loc, builder, kSPIRVBinary, spirvAttributes.blob.getValue(),
+ LLVM::Linkage::Internal, useOpaquePointers);
// Create LLVM constant for the size of SPIR-V binary shader.
Value binarySize = builder.create<LLVM::ConstantOp>(
- loc, getInt32Type(), spirvAttributes.first.getValue().size());
+ loc, getInt32Type(), spirvAttributes.blob.getValue().size());
// Create call to `bindMemRef` for each memref operand.
createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
// Create LLVM global with entry point name.
Value entryPointName = createEntryPointNameConstant(
- spirvAttributes.second.getValue(), loc, builder);
+ spirvAttributes.entryPoint.getValue(), loc, builder);
// Create call to `setEntryPoint` runtime function with the given pointer to
// entry point name.
builder.create<LLVM::CallOp>(loc, TypeRange(), kSetEntryPoint,
cInterfaceVulkanLaunchCallOp.erase();
}
-
-std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
-mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
- return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
-}
-// RUN: mlir-opt %s -launch-func-to-vulkan | FileCheck %s
+// RUN: mlir-opt %s -launch-func-to-vulkan='use-opaque-pointers=1' | FileCheck %s
// CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name
// CHECK: llvm.mlir.global internal constant @SPIRV_BIN
-// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr<i8>
+// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr
// CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
// CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
// CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
-// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
-// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> ()
+// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr) -> ()
+// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr, !llvm.ptr, i32) -> ()
// CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
// CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
-// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
-// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> ()
-// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr, !llvm.ptr) -> ()
+// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i64, i64, i64) -> ()
+// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> ()
+// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> ()
-// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)
+// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr, i32, i32, !llvm.ptr)
module attributes {gpu.container_module} {
- llvm.func @malloc(i64) -> !llvm.ptr<i8>
+ llvm.func @malloc(i64) -> !llvm.ptr
llvm.func @foo() {
%0 = llvm.mlir.constant(12 : index) : i64
- %1 = llvm.mlir.null : !llvm.ptr<f32>
+ %1 = llvm.mlir.null : !llvm.ptr
%2 = llvm.mlir.constant(1 : index) : i64
- %3 = llvm.getelementptr %1[%2] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
- %4 = llvm.ptrtoint %3 : !llvm.ptr<f32> to i64
+ %3 = llvm.getelementptr %1[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ %4 = llvm.ptrtoint %3 : !llvm.ptr to i64
%5 = llvm.mul %0, %4 : i64
- %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr<i8>
- %7 = llvm.bitcast %6 : !llvm.ptr<i8> to !llvm.ptr<f32>
- %8 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %9 = llvm.insertvalue %7, %8[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %10 = llvm.insertvalue %7, %9[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr
+ %8 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %9 = llvm.insertvalue %6, %8[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %10 = llvm.insertvalue %6, %9[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%11 = llvm.mlir.constant(0 : index) : i64
- %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%13 = llvm.mlir.constant(1 : index) : i64
- %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%16 = llvm.mlir.constant(1 : index) : i64
- %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_entry_point = "kernel"}
- : (i64, i64, i64, !llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64) -> ()
+ %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_element_types = [f32], spirv_entry_point = "kernel"}
+ : (i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> ()
llvm.return
}
- llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr<f32>, %arg7: !llvm.ptr<f32>, %arg8: i64, %arg9: i64, %arg10: i64) {
- %0 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr, %arg7: !llvm.ptr, %arg8: i64, %arg9: i64, %arg10: i64) {
+ %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%6 = llvm.mlir.constant(1 : index) : i64
- %7 = llvm.alloca %6 x !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
- llvm.store %5, %7 : !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
- llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
+ %7 = llvm.alloca %6 x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr
+ llvm.store %5, %7 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr
+ llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr) -> ()
llvm.return
}
- llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>)
+ llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr)
}
// CHECK: %[[resource:.*]] = memref.alloc() : memref<12xf32>
// CHECK: %[[index:.*]] = arith.constant 1 : index
-// CHECK: call @vulkanLaunch(%[[index]], %[[index]], %[[index]], %[[resource]]) {spirv_blob = "{{.*}}", spirv_entry_point = "kernel"}
+// CHECK: call @vulkanLaunch(%[[index]], %[[index]], %[[index]], %[[resource]]) {spirv_blob = "{{.*}}", spirv_element_types = [f32], spirv_entry_point = "kernel"}
module attributes {gpu.container_module} {
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]> {
--- /dev/null
+// RUN: mlir-opt %s -launch-func-to-vulkan='use-opaque-pointers=0' | FileCheck %s
+
+// CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name
+// CHECK: llvm.mlir.global internal constant @SPIRV_BIN
+// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr<i8>
+// CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
+// CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
+// CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
+// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
+// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> ()
+// CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
+// CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
+// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> ()
+// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
+
+// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)
+
+module attributes {gpu.container_module} {
+ llvm.func @malloc(i64) -> !llvm.ptr<i8>
+ llvm.func @foo() {
+ %0 = llvm.mlir.constant(12 : index) : i64
+ %1 = llvm.mlir.null : !llvm.ptr<f32>
+ %2 = llvm.mlir.constant(1 : index) : i64
+ %3 = llvm.getelementptr %1[%2] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+ %4 = llvm.ptrtoint %3 : !llvm.ptr<f32> to i64
+ %5 = llvm.mul %0, %4 : i64
+ %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr<i8>
+ %7 = llvm.bitcast %6 : !llvm.ptr<i8> to !llvm.ptr<f32>
+ %8 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %9 = llvm.insertvalue %7, %8[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %10 = llvm.insertvalue %7, %9[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %11 = llvm.mlir.constant(0 : index) : i64
+ %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %13 = llvm.mlir.constant(1 : index) : i64
+ %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %16 = llvm.mlir.constant(1 : index) : i64
+ %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_element_types = [f32], spirv_entry_point = "kernel"}
+ : (i64, i64, i64, !llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64) -> ()
+ llvm.return
+ }
+ llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr<f32>, %arg7: !llvm.ptr<f32>, %arg8: i64, %arg9: i64, %arg10: i64) {
+ %0 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %6 = llvm.mlir.constant(1 : index) : i64
+ %7 = llvm.alloca %6 x !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
+ llvm.store %5, %7 : !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
+ llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
+ llvm.return
+ }
+ llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>)
+}
if (options.spirvWebGPUPrepare)
modulePM.addPass(spirv::createSPIRVWebGPUPreparePass());
+ auto enableOpaquePointers = [](auto passOption) {
+ passOption.useOpaquePointers = true;
+ return passOption;
+ };
+
passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
- passManager.addPass(createFinalizeMemRefToLLVMConversionPass());
- passManager.addPass(createConvertVectorToLLVMPass());
+ passManager.addPass(createFinalizeMemRefToLLVMConversionPass(
+ enableOpaquePointers(FinalizeMemRefToLLVMConversionPassOptions{})));
+ passManager.addPass(createConvertVectorToLLVMPass(
+ enableOpaquePointers(ConvertVectorToLLVMPassOptions{})));
passManager.nest<func::FuncOp>().addPass(LLVM::createRequestCWrappersPass());
ConvertFuncToLLVMPassOptions funcToLLVMOptions{};
funcToLLVMOptions.indexBitwidth =
DataLayout(module).getTypeSizeInBits(IndexType::get(module.getContext()));
- passManager.addPass(createConvertFuncToLLVMPass(funcToLLVMOptions));
+ passManager.addPass(
+ createConvertFuncToLLVMPass(enableOpaquePointers(funcToLLVMOptions)));
passManager.addPass(createReconcileUnrealizedCastsPass());
- passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
+ passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass(
+ enableOpaquePointers(ConvertVulkanLaunchFuncToVulkanCallsPassOptions{})));
return passManager.run(module);
}