clc: Declare LLVMContexts on the stack
authorJason Ekstrand <jason.ekstrand@collabora.com>
Wed, 13 Apr 2022 21:52:17 +0000 (16:52 -0500)
committerMarge Bot <emma+marge@anholt.net>
Thu, 14 Apr 2022 21:19:56 +0000 (21:19 +0000)
This prevents more use-after-free errors.  Passing them around using
std::unique_ptr ensures that the LLVMContext gets destroyed but doesn't
ensure destruction order.  Declaring it on the stack ensures that the
context doesn't get destroyed until right before the the function
returns which is after any other LLVM stuff is destroyed.

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Reviewed-by: Icecream95 <ixn@disroot.org>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15937>

src/compiler/clc/clc_helpers.cpp

index 7e4ae62..01ef8a6 100644 (file)
@@ -747,16 +747,11 @@ clc_free_kernels_info(const struct clc_kernel_info *kernels,
    free((void *)kernels);
 }
 
-static std::pair<std::unique_ptr<::llvm::Module>, std::unique_ptr<LLVMContext>>
-clc_compile_to_llvm_module(const struct clc_compile_args *args,
+static std::unique_ptr<::llvm::Module>
+clc_compile_to_llvm_module(LLVMContext &llvm_ctx,
+                           const struct clc_compile_args *args,
                            const struct clc_logger *logger)
 {
-   clc_initialize_llvm();
-
-   std::unique_ptr<LLVMContext> llvm_ctx { new LLVMContext };
-   llvm_ctx->setDiagnosticHandlerCallBack(llvm_log_handler,
-                                          const_cast<clc_logger *>(logger));
-
    std::string diag_log_str;
    raw_string_ostream diag_log_stream { diag_log_str };
 
@@ -878,14 +873,14 @@ clc_compile_to_llvm_module(const struct clc_compile_args *args,
            ::llvm::MemoryBuffer::getMemBufferCopy(std::string(args->source.value)).release());
 
    // Compile the code
-   clang::EmitLLVMOnlyAction act(llvm_ctx.get());
+   clang::EmitLLVMOnlyAction act(&llvm_ctx);
    if (!c->ExecuteAction(act)) {
       clc_error(logger, "%sError executing LLVM compilation action.\n",
                 diag_log_str.c_str());
       return {};
    }
 
-   return { act.takeModule(), std::move(llvm_ctx) };
+   return act.takeModule();
 }
 
 static SPIRV::VersionNumber
@@ -906,7 +901,7 @@ spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version)
 
 static int
 llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod,
-                  std::unique_ptr<LLVMContext> context,
+                  LLVMContext &context,
                   const struct clc_compile_args *args,
                   const struct clc_logger *logger,
                   struct clc_binary *out_spirv)
@@ -969,13 +964,19 @@ clc_c_to_spir(const struct clc_compile_args *args,
               const struct clc_logger *logger,
               struct clc_binary *out_spir)
 {
-   auto pair = clc_compile_to_llvm_module(args, logger);
-   if (!pair.first)
+   clc_initialize_llvm();
+
+   LLVMContext llvm_ctx;
+   llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
+                                         const_cast<clc_logger *>(logger));
+
+   auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
+   if (!mod)
       return -1;
 
    ::llvm::SmallVector<char, 0> buffer;
    ::llvm::BitcodeWriter writer(buffer);
-   writer.writeModule(*pair.first);
+   writer.writeModule(*mod);
 
    out_spir->size = buffer.size_in_bytes();
    out_spir->data = malloc(out_spir->size);
@@ -989,10 +990,16 @@ clc_c_to_spirv(const struct clc_compile_args *args,
                const struct clc_logger *logger,
                struct clc_binary *out_spirv)
 {
-   auto pair = clc_compile_to_llvm_module(args, logger);
-   if (!pair.first)
+   clc_initialize_llvm();
+
+   LLVMContext llvm_ctx;
+   llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
+                                         const_cast<clc_logger *>(logger));
+
+   auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
+   if (!mod)
       return -1;
-   return llvm_mod_to_spirv(std::move(pair.first), std::move(pair.second), args, logger, out_spirv);
+   return llvm_mod_to_spirv(std::move(mod), llvm_ctx, args, logger, out_spirv);
 }
 
 int
@@ -1002,13 +1009,16 @@ clc_spir_to_spirv(const struct clc_binary *in_spir,
 {
    clc_initialize_llvm();
 
-   std::unique_ptr<LLVMContext> llvm_ctx{ new LLVMContext };
+   LLVMContext llvm_ctx;
+   llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
+                                         const_cast<clc_logger *>(logger));
+
    ::llvm::StringRef spir_ref(static_cast<const char*>(in_spir->data), in_spir->size);
-   auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, "<spir>"), *llvm_ctx);
+   auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, "<spir>"), llvm_ctx);
    if (!mod)
       return -1;
 
-   return llvm_mod_to_spirv(std::move(mod.get()), std::move(llvm_ctx), NULL, logger, out_spirv);
+   return llvm_mod_to_spirv(std::move(mod.get()), llvm_ctx, NULL, logger, out_spirv);
 }
 
 class SPIRVMessageConsumer {