static constexpr const char *kMcuMemHostRegisterPtr = "mcuMemHostRegisterPtr";
static constexpr const char *kCubinAnnotation = "nvvm.cubin";
-static constexpr const char *kCubinGetterAnnotation = "nvvm.cubingetter";
-static constexpr const char *kCubinGetterSuffix = "_cubin";
static constexpr const char *kCubinStorageSuffix = "_cubin_cst";
namespace {
Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
Value *generateKernelNameConstant(FuncOp kernelFunction, Location &loc,
OpBuilder &builder);
- FuncOp generateCubinAccessor(FuncOp kernelFunc, StringAttr blob);
void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
public:
initializeCachedTypes();
getModule().walk([this](mlir::gpu::LaunchFuncOp op) {
- auto gpuModule =
- getModule().lookupSymbol<ModuleOp>(op.getKernelModuleName());
- auto kernelFunc = gpuModule.lookupSymbol<FuncOp>(op.kernel());
- auto cubinAttr = kernelFunc.getAttrOfType<StringAttr>(kCubinAnnotation);
- if (!cubinAttr)
- return signalPassFailure();
- FuncOp getter = generateCubinAccessor(kernelFunc, cubinAttr);
-
- // Store the name of the getter on the function for easier lookup and
- // remove the original CUBIN annotation.
- kernelFunc.setAttr(
- kCubinGetterAnnotation,
- SymbolRefAttr::get(getter.getName(), getter.getContext()));
- kernelFunc.removeAttr(kCubinAnnotation);
-
translateGpuLaunchCalls(op);
});
llvmDialect);
}
-// Inserts a global constant string containing `blob` into the grand-parent
-// module of `kernelFunc` and generates the function that returns the address of
-// the first character of this string.
-FuncOp GpuLaunchFuncToCudaCallsPass::generateCubinAccessor(FuncOp kernelFunc,
- StringAttr blob) {
- Location loc = kernelFunc.getLoc();
- SmallString<128> nameBuffer(kernelFunc.getName());
- ModuleOp module = getModule();
- assert(kernelFunc.getParentOp() &&
- kernelFunc.getParentOp()->getParentOp() == module &&
- "expected one level of module nesting");
-
- // Insert the getter function just after the GPU kernel module containing
- // `kernelFunc`.
- OpBuilder moduleBuilder(module.getBody());
- moduleBuilder.setInsertionPointAfter(kernelFunc.getParentOp());
- auto getterType = moduleBuilder.getFunctionType(
- llvm::None, LLVM::LLVMType::getInt8PtrTy(llvmDialect));
- nameBuffer.append(kCubinGetterSuffix);
- auto result = moduleBuilder.create<FuncOp>(
- loc, StringRef(nameBuffer), getterType, ArrayRef<NamedAttribute>());
- Block *entryBlock = result.addEntryBlock();
-
- // Drop the getter suffix before appending the storage suffix.
- nameBuffer.resize(kernelFunc.getName().size());
- nameBuffer.append(kCubinStorageSuffix);
-
- // Obtain the address of the first character of the global string containing
- // the cubin and return from the getter.
- OpBuilder builder(entryBlock);
- Value *startPtr = LLVM::createGlobalString(
- loc, builder, StringRef(nameBuffer), blob.getValue(), llvmDialect);
- builder.create<LLVM::ReturnOp>(loc, startPtr);
- return result;
-}
-
// Emits LLVM IR to launch a kernel function. Expects the module that contains
// the compiled kernel function as a cubin in the 'nvvm.cubin' attribute of the
// kernel function in the IR.
auto zero = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
builder.getI32IntegerAttr(0));
- // Emit a call to the cubin getter to retrieve a pointer to the data that
- // represents the cubin at runtime.
- // TODO(herhut): This should rather be a static global once supported.
+ // Create an LLVM global with CUBIN extracted from the kernel annotation and
+ // obtain a pointer to the first byte in it.
auto kernelModule =
getModule().lookupSymbol<ModuleOp>(launchOp.getKernelModuleName());
assert(kernelModule && "expected a kernel module");
auto kernelFunction = kernelModule.lookupSymbol<FuncOp>(launchOp.kernel());
assert(kernelFunction && "expected a kernel function");
- auto cubinGetter =
- kernelFunction.getAttrOfType<SymbolRefAttr>(kCubinGetterAnnotation);
- if (!cubinGetter) {
- kernelFunction.emitError("missing ")
- << kCubinGetterAnnotation << " attribute.";
+ auto cubinAttr = kernelFunction.getAttrOfType<StringAttr>(kCubinAnnotation);
+ if (!cubinAttr) {
+ kernelFunction.emitOpError()
+ << "missing " << kCubinAnnotation << " attribute";
return signalPassFailure();
}
- auto data = builder.create<LLVM::CallOp>(
- loc, ArrayRef<Type>{getPointerType()}, cubinGetter, ArrayRef<Value *>{});
+ assert(kernelModule.getName() && "expected a named module");
+ SmallString<128> nameBuffer(*kernelModule.getName());
+ nameBuffer.append(kCubinStorageSuffix);
+ Value *data = LLVM::createGlobalString(
+ loc, builder, nameBuffer.str(), cubinAttr.getValue(), getLLVMDialect());
+
// Emit the load module call to load the module data. Error checking is done
// in the called helper function.
auto cuModule = allocatePointer(builder, loc);
FuncOp cuModuleLoad = getModule().lookupSymbol<FuncOp>(cuModuleLoadName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuModuleLoad),
- ArrayRef<Value *>{cuModule, data.getResult(0)});
+ ArrayRef<Value *>{cuModule, data});
// Get the function from the module. The name corresponds to the name of
// the kernel function.
auto cuOwningModuleRef =
attributes { gpu.kernel, nvvm.cubin = "CUBIN" }
}
-// CHECK: func @[[getter:.*]]() -> !llvm<"i8*">
-// CHECK: %[[addressof:.*]] = llvm.mlir.addressof @[[global]]
-// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index)
-// CHECK: %[[gep:.*]] = llvm.getelementptr %[[addressof]][%[[c0]], %[[c0]]]
-// CHECK-SAME: -> !llvm<"i8*">
-// CHECK: llvm.return %[[gep]] : !llvm<"i8*">
-
func @foo() {
%0 = "op"() : () -> (!llvm.float)
%1 = "op"() : () -> (!llvm<"float*">)
%cst = constant 8 : index
- // CHECK: [[cubin_ptr:%.*]] = llvm.call @[[getter]]
- // CHECK: [[module_ptr:%.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
- // CHECK: llvm.call @mcuModuleLoad([[module_ptr]], [[cubin_ptr]]) : (!llvm<"i8**">, !llvm<"i8*">) -> !llvm.i32
- // CHECK: [[func_ptr:%.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
- // CHECK: llvm.call @mcuModuleGetFunction([[func_ptr]], {{.*}}, {{.*}}) : (!llvm<"i8**">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm.i32
+ // CHECK: %[[addressof:.*]] = llvm.mlir.addressof @[[global]]
+ // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index)
+ // CHECK: %[[cubin_ptr:.*]] = llvm.getelementptr %[[addressof]][%[[c0]], %[[c0]]]
+ // CHECK-SAME: -> !llvm<"i8*">
+ // CHECK: %[[module_ptr:.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
+ // CHECK: llvm.call @mcuModuleLoad(%[[module_ptr]], %[[cubin_ptr]]) : (!llvm<"i8**">, !llvm<"i8*">) -> !llvm.i32
+ // CHECK: %[[func_ptr:.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
+ // CHECK: llvm.call @mcuModuleGetFunction(%[[func_ptr]], {{.*}}, {{.*}}) : (!llvm<"i8**">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm.i32
// CHECK: llvm.call @mcuGetStreamHelper
// CHECK: llvm.call @mcuLaunchKernel
// CHECK: llvm.call @mcuStreamSynchronize