[mlir][gpu] Fix leaked stream and module when lowering gpu.launch_func to runtime...
authorChristian Sigg <csigg@google.com>
Thu, 29 Oct 2020 07:17:27 +0000 (08:17 +0100)
committerChristian Sigg <csigg@google.com>
Thu, 29 Oct 2020 07:40:51 +0000 (08:40 +0100)
Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D90370

mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp

index ae112d5..a8a416d 100644 (file)
@@ -88,6 +88,8 @@ protected:
       "mgpuModuleLoad",
       llvmPointerType /* void *module */,
       {llvmPointerType /* void *cubin */}};
+  FunctionCallBuilder moduleUnloadCallBuilder = {
+      "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
   FunctionCallBuilder moduleGetFunctionCallBuilder = {
       "mgpuModuleGetFunction",
       llvmPointerType /* void *function */,
@@ -490,6 +492,8 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
        kernelParams,                /* kernel params */
        nullpointer /* extra */});
   streamSynchronizeCallBuilder.create(loc, rewriter, stream.getResult(0));
+  streamDestroyCallBuilder.create(loc, rewriter, stream.getResult(0));
+  moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
 
   rewriter.eraseOp(op);
   return success();
index 908604e..c639368 100644 (file)
@@ -48,4 +48,6 @@ module attributes {gpu.container_module} {
   // CHECK-SAME: [[C8]], [[C8]], [[C8]], [[C0_I32]], [[STREAM]],
   // CHECK-SAME: [[PARAMS]], [[EXTRA_PARAMS]])
   // CHECK: llvm.call @mgpuStreamSynchronize
+  // CHECK: llvm.call @mgpuStreamDestroy
+  // CHECK: llvm.call @mgpuModuleUnload
 }
index a32c37d..5556700 100644 (file)
@@ -47,6 +47,10 @@ extern "C" CUmodule mgpuModuleLoad(void *data) {
   return module;
 }
 
+extern "C" void mgpuModuleUnload(CUmodule module) {
+  CUDA_REPORT_IF_ERROR(cuModuleUnload(module));
+}
+
 extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) {
   CUfunction function = nullptr;
   CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name));
index 999b80c..d600a3f 100644 (file)
@@ -46,6 +46,10 @@ extern "C" hipModule_t mgpuModuleLoad(void *data) {
   return module;
 }
 
+extern "C" void mgpuModuleUnload(hipModule_t module) {
+  HIP_REPORT_IF_ERROR(hipModuleUnload(module));
+}
+
 extern "C" hipFunction_t mgpuModuleGetFunction(hipModule_t module,
                                                const char *name) {
   hipFunction_t function = nullptr;