class ModuleOp;
template <typename T>
class OperationPass;
+class Pass;
-#define GEN_PASS_DECL_CONVERTVULKANLAUNCHFUNCTOVULKANCALLS
+#define GEN_PASS_DECL_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS
#define GEN_PASS_DECL_CONVERTGPULAUNCHFUNCTOVULKANLAUNCHFUNC
#include "mlir/Conversion/Passes.h.inc"
-std::unique_ptr<OperationPass<ModuleOp>>
-createConvertVulkanLaunchFuncToVulkanCallsPass();
-
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createConvertGpuLaunchFuncToVulkanLaunchFuncPass();
let dependentDialects = ["spirv::SPIRVDialect"];
}
-def ConvertVulkanLaunchFuncToVulkanCalls
+def ConvertVulkanLaunchFuncToVulkanCallsPass
: Pass<"launch-func-to-vulkan", "ModuleOp"> {
let summary = "Convert vulkanLaunch external call to Vulkan runtime external "
"calls";
let description = [{
This pass is only intended for the mlir-vulkan-runner.
}];
- let constructor = "mlir::createConvertVulkanLaunchFuncToVulkanCallsPass()";
+
+ let options = [
+ Option<"useOpaquePointers", "use-opaque-pointers", "bool",
+ /*default=*/"false", "Generate LLVM IR using opaque pointers "
+ "instead of typed pointers">
+ ];
+
let dependentDialects = ["LLVM::LLVMDialect"];
}
static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
+static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types";
static constexpr const char *kVulkanLaunch = "vulkanLaunch";
namespace {
vulkanLaunchCallOp->setAttr(kSPIRVEntryPointAttrName,
launchOp.getKernelName());
+ // Add MemRef element types before they're lost when lowering to LLVM.
+ SmallVector<Type> elementTypes;
+ for (Type type : llvm::drop_begin(launchOp.getOperandTypes(),
+ gpu::LaunchOp::kNumConfigOperands)) {
+ // The below cast always succeeds as it has already been verified in
+ // 'declareVulkanLaunchFunc' that these are MemRefs with compatible element
+ // types.
+ elementTypes.push_back(type.cast<MemRefType>().getElementType());
+ }
+ vulkanLaunchCallOp->setAttr(kSPIRVElementTypesAttrName,
+ builder.getTypeArrayAttr(elementTypes));
+
launchOp.erase();
}
#include "llvm/Support/FormatVariadic.h"
namespace mlir {
-#define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLS
+#define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
+static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types";
static constexpr const char *kVulkanLaunch = "vulkanLaunch";
namespace {
/// * deinitVulkan -- deinitializes vulkan runtime
///
class VulkanLaunchFuncToVulkanCallsPass
- : public impl::ConvertVulkanLaunchFuncToVulkanCallsBase<
+ : public impl::ConvertVulkanLaunchFuncToVulkanCallsPassBase<
VulkanLaunchFuncToVulkanCallsPass> {
private:
void initializeCachedTypes() {
llvmFloatType = Float32Type::get(&getContext());
llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
- llvmPointerType =
- LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
+ if (useOpaquePointers)
+ llvmPointerType = LLVM::LLVMPointerType::get(&getContext());
+ else
+ llvmPointerType =
+ LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
llvmInt32Type = IntegerType::get(&getContext(), 32);
llvmInt64Type = IntegerType::get(&getContext(), 64);
}
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
- auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType);
+ auto llvmPtrToElementType = useOpaquePointers
+ ? llvmPointerType
+ : LLVM::LLVMPointerType::get(elemenType);
auto llvmArrayRankElementSizeType =
LLVM::LLVMArrayType::get(getInt64Type(), rank);
/// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
- /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
- LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
- uint32_t &rank, Type &type);
+ /// Deduces a rank from the given 'launchCallArg`.
+ LogicalResult deduceMemRefRank(Value launchCallArg, uint32_t &rank);
/// Returns a string representation from the given `type`.
StringRef stringifyType(Type type) {
}
public:
+ using Base::Base;
+
void runOnOperation() override;
private:
Type llvmInt32Type;
Type llvmInt64Type;
+ struct SPIRVAttributes {
+ StringAttr blob;
+ StringAttr entryPoint;
+ SmallVector<Type> elementTypes;
+ };
+
// TODO: Use an associative array to support multiple vulkan launch calls.
- std::pair<StringAttr, StringAttr> spirvAttributes;
+ SPIRVAttributes spirvAttributes;
/// The number of vulkan launch configuration operands, placed at the leading
/// positions of the operand list.
static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
return signalPassFailure();
}
- spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
+ auto spirvElementTypesAttr =
+ vulkanLaunchCallOp->getAttrOfType<ArrayAttr>(kSPIRVElementTypesAttrName);
+ if (!spirvElementTypesAttr) {
+ vulkanLaunchCallOp.emitError()
+ << "missing " << kSPIRVElementTypesAttrName << " attribute";
+ return signalPassFailure();
+ }
+ if (llvm::any_of(spirvElementTypesAttr,
+ [](Attribute attr) { return !isa<TypeAttr>(attr); })) {
+ vulkanLaunchCallOp.emitError()
+ << "expected " << spirvElementTypesAttr << " to be an array of types";
+ return signalPassFailure();
+ }
+
+ spirvAttributes.blob = spirvBlobAttr;
+ spirvAttributes.entryPoint = spirvEntryPointNameAttr;
+ spirvAttributes.elementTypes =
+ llvm::to_vector(spirvElementTypesAttr.getAsValueRange<mlir::TypeAttr>());
}
void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
Value descriptorSet =
builder.create<LLVM::ConstantOp>(loc, getInt32Type(), 0);
- for (const auto &en :
+ for (auto [index, ptrToMemRefDescriptor] :
llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
kVulkanLaunchNumConfigOperands))) {
// Create LLVM constant for the descriptor binding index.
Value descriptorBinding =
- builder.create<LLVM::ConstantOp>(loc, getInt32Type(), en.index());
+ builder.create<LLVM::ConstantOp>(loc, getInt32Type(), index);
+
+ if (index >= spirvAttributes.elementTypes.size()) {
+ cInterfaceVulkanLaunchCallOp.emitError()
+ << kSPIRVElementTypesAttrName << " missing element type for "
+ << ptrToMemRefDescriptor;
+ return signalPassFailure();
+ }
- auto ptrToMemRefDescriptor = en.value();
uint32_t rank = 0;
- Type type;
- if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
+ Type type = spirvAttributes.elementTypes[index];
+ if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
cInterfaceVulkanLaunchCallOp.emitError()
<< "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
return signalPassFailure();
llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
// Special case for fp16 type. Since it is not a supported type in C we use
// int16_t and bitcast the descriptor.
- if (type.isa<Float16Type>()) {
+ if (!useOpaquePointers && type.isa<Float16Type>()) {
auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16));
ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
}
}
-LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
- Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) {
- auto llvmPtrDescriptorTy =
- ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
- if (!llvmPtrDescriptorTy)
+LogicalResult
+VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg,
+ uint32_t &rank) {
+ // Deduce the rank from the type used to allocate the lowered MemRef.
+ auto alloca = launchCallArg.getDefiningOp<LLVM::AllocaOp>();
+ if (!alloca)
return failure();
- auto llvmDescriptorTy =
- llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>();
+ LLVM::LLVMStructType llvmDescriptorTy;
+ if (std::optional<Type> elementType = alloca.getElemType()) {
+ llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(*elementType);
+ } else {
+ // This case is only possible if we are not using opaque pointers
+ // since opaque pointer producing allocas require an element type.
+ llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(
+ alloca.getRes().getType().getElementType());
+ }
+
// template <typename Elem, size_t Rank>
// struct {
// Elem *allocated;
if (!llvmDescriptorTy)
return failure();
- type = llvmDescriptorTy.getBody()[0]
- .cast<LLVM::LLVMPointerType>()
- .getElementType();
if (llvmDescriptorTy.getBody().size() == 3) {
rank = 0;
return success();
auto fnType = LLVM::LLVMFunctionType::get(
getVoidType(),
{getPointerType(), getInt32Type(), getInt32Type(),
- LLVM::LLVMPointerType::get(getMemRefType(i, type))},
+ useOpaquePointers
+ ? llvmPointerType
+ : LLVM::LLVMPointerType::get(getMemRefType(i, type))},
/*isVarArg=*/false);
builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
}
std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
shaderName, LLVM::Linkage::Internal,
- /*TODO:useOpaquePointers=*/false);
+ useOpaquePointers);
}
void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
// Create LLVM global with SPIR-V binary data, so we can pass a pointer with
// that data to runtime call.
Value ptrToSPIRVBinary = LLVM::createGlobalString(
- loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
- LLVM::Linkage::Internal, /*TODO:useOpaquePointers=*/false);
+ loc, builder, kSPIRVBinary, spirvAttributes.blob.getValue(),
+ LLVM::Linkage::Internal, useOpaquePointers);
// Create LLVM constant for the size of SPIR-V binary shader.
Value binarySize = builder.create<LLVM::ConstantOp>(
- loc, getInt32Type(), spirvAttributes.first.getValue().size());
+ loc, getInt32Type(), spirvAttributes.blob.getValue().size());
// Create call to `bindMemRef` for each memref operand.
createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
// Create LLVM global with entry point name.
Value entryPointName = createEntryPointNameConstant(
- spirvAttributes.second.getValue(), loc, builder);
+ spirvAttributes.entryPoint.getValue(), loc, builder);
// Create call to `setEntryPoint` runtime function with the given pointer to
// entry point name.
builder.create<LLVM::CallOp>(loc, TypeRange(), kSetEntryPoint,
cInterfaceVulkanLaunchCallOp.erase();
}
-
-std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
-mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
- return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
-}
-// RUN: mlir-opt %s -launch-func-to-vulkan | FileCheck %s
+// RUN: mlir-opt %s -launch-func-to-vulkan='use-opaque-pointers=1' | FileCheck %s
// CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name
// CHECK: llvm.mlir.global internal constant @SPIRV_BIN
-// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr<i8>
+// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr
// CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
// CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
// CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
-// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
-// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> ()
+// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr) -> ()
+// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr, !llvm.ptr, i32) -> ()
// CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
// CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
-// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
-// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> ()
-// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr, !llvm.ptr) -> ()
+// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i64, i64, i64) -> ()
+// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> ()
+// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> ()
-// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)
+// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr, i32, i32, !llvm.ptr)
module attributes {gpu.container_module} {
- llvm.func @malloc(i64) -> !llvm.ptr<i8>
+ llvm.func @malloc(i64) -> !llvm.ptr
llvm.func @foo() {
%0 = llvm.mlir.constant(12 : index) : i64
- %1 = llvm.mlir.null : !llvm.ptr<f32>
+ %1 = llvm.mlir.null : !llvm.ptr
%2 = llvm.mlir.constant(1 : index) : i64
- %3 = llvm.getelementptr %1[%2] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
- %4 = llvm.ptrtoint %3 : !llvm.ptr<f32> to i64
+ %3 = llvm.getelementptr %1[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ %4 = llvm.ptrtoint %3 : !llvm.ptr to i64
%5 = llvm.mul %0, %4 : i64
- %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr<i8>
- %7 = llvm.bitcast %6 : !llvm.ptr<i8> to !llvm.ptr<f32>
- %8 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %9 = llvm.insertvalue %7, %8[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %10 = llvm.insertvalue %7, %9[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr
+ %8 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %9 = llvm.insertvalue %6, %8[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %10 = llvm.insertvalue %6, %9[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%11 = llvm.mlir.constant(0 : index) : i64
- %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%13 = llvm.mlir.constant(1 : index) : i64
- %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%16 = llvm.mlir.constant(1 : index) : i64
- %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_entry_point = "kernel"}
- : (i64, i64, i64, !llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64) -> ()
+ %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_element_types = [f32], spirv_entry_point = "kernel"}
+ : (i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> ()
llvm.return
}
- llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr<f32>, %arg7: !llvm.ptr<f32>, %arg8: i64, %arg9: i64, %arg10: i64) {
- %0 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr, %arg7: !llvm.ptr, %arg8: i64, %arg9: i64, %arg10: i64) {
+ %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%6 = llvm.mlir.constant(1 : index) : i64
- %7 = llvm.alloca %6 x !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
- llvm.store %5, %7 : !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
- llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
+ %7 = llvm.alloca %6 x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr
+ llvm.store %5, %7 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr
+ llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr) -> ()
llvm.return
}
- llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>)
+ llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr)
}
// CHECK: %[[resource:.*]] = memref.alloc() : memref<12xf32>
// CHECK: %[[index:.*]] = arith.constant 1 : index
-// CHECK: call @vulkanLaunch(%[[index]], %[[index]], %[[index]], %[[resource]]) {spirv_blob = "{{.*}}", spirv_entry_point = "kernel"}
+// CHECK: call @vulkanLaunch(%[[index]], %[[index]], %[[index]], %[[resource]]) {spirv_blob = "{{.*}}", spirv_element_types = [f32], spirv_entry_point = "kernel"}
module attributes {gpu.container_module} {
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]> {
--- /dev/null
+// RUN: mlir-opt %s -launch-func-to-vulkan='use-opaque-pointers=0' | FileCheck %s
+
+// CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name
+// CHECK: llvm.mlir.global internal constant @SPIRV_BIN
+// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr<i8>
+// CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
+// CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
+// CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
+// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
+// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> ()
+// CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
+// CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
+// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> ()
+// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
+
+// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)
+
+module attributes {gpu.container_module} {
+ llvm.func @malloc(i64) -> !llvm.ptr<i8>
+ llvm.func @foo() {
+ %0 = llvm.mlir.constant(12 : index) : i64
+ %1 = llvm.mlir.null : !llvm.ptr<f32>
+ %2 = llvm.mlir.constant(1 : index) : i64
+ %3 = llvm.getelementptr %1[%2] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+ %4 = llvm.ptrtoint %3 : !llvm.ptr<f32> to i64
+ %5 = llvm.mul %0, %4 : i64
+ %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr<i8>
+ %7 = llvm.bitcast %6 : !llvm.ptr<i8> to !llvm.ptr<f32>
+ %8 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %9 = llvm.insertvalue %7, %8[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %10 = llvm.insertvalue %7, %9[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %11 = llvm.mlir.constant(0 : index) : i64
+ %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %13 = llvm.mlir.constant(1 : index) : i64
+ %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %16 = llvm.mlir.constant(1 : index) : i64
+ %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_element_types = [f32], spirv_entry_point = "kernel"}
+ : (i64, i64, i64, !llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64) -> ()
+ llvm.return
+ }
+ llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr<f32>, %arg7: !llvm.ptr<f32>, %arg8: i64, %arg9: i64, %arg10: i64) {
+ %0 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %6 = llvm.mlir.constant(1 : index) : i64
+ %7 = llvm.alloca %6 x !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
+ llvm.store %5, %7 : !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
+ llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
+ llvm.return
+ }
+ llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>)
+}
if (options.spirvWebGPUPrepare)
modulePM.addPass(spirv::createSPIRVWebGPUPreparePass());
+ auto enableOpaquePointers = [](auto passOption) {
+ passOption.useOpaquePointers = true;
+ return passOption;
+ };
+
passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
- passManager.addPass(createFinalizeMemRefToLLVMConversionPass());
- passManager.addPass(createConvertVectorToLLVMPass());
+ passManager.addPass(createFinalizeMemRefToLLVMConversionPass(
+ enableOpaquePointers(FinalizeMemRefToLLVMConversionPassOptions{})));
+ passManager.addPass(createConvertVectorToLLVMPass(
+ enableOpaquePointers(ConvertVectorToLLVMPassOptions{})));
passManager.nest<func::FuncOp>().addPass(LLVM::createRequestCWrappersPass());
ConvertFuncToLLVMPassOptions funcToLLVMOptions{};
funcToLLVMOptions.indexBitwidth =
DataLayout(module).getTypeSizeInBits(IndexType::get(module.getContext()));
- passManager.addPass(createConvertFuncToLLVMPass(funcToLLVMOptions));
+ passManager.addPass(
+ createConvertFuncToLLVMPass(enableOpaquePointers(funcToLLVMOptions)));
passManager.addPass(createReconcileUnrealizedCastsPass());
- passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
+ passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass(
+ enableOpaquePointers(ConvertVulkanLaunchFuncToVulkanCallsPassOptions{})));
return passManager.run(module);
}