From 46d9b0e431a890e4f130f6cd3e2e150b152f51f1 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Wed, 13 Apr 2022 16:52:17 -0500 Subject: [PATCH] clc: Declare LLVMContexts on the stack This prevents more use-after-free errors. Passing them around using std::unique_ptr ensures that the LLVMContext gets destroyed but doesn't ensure destruction order. Declaring it on the stack ensures that the context doesn't get destroyed until right before the the function returns which is after any other LLVM stuff is destroyed. Reviewed-by: Jesse Natalie Reviewed-by: Icecream95 Part-of: --- src/compiler/clc/clc_helpers.cpp | 50 ++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/src/compiler/clc/clc_helpers.cpp b/src/compiler/clc/clc_helpers.cpp index 7e4ae62..01ef8a6 100644 --- a/src/compiler/clc/clc_helpers.cpp +++ b/src/compiler/clc/clc_helpers.cpp @@ -747,16 +747,11 @@ clc_free_kernels_info(const struct clc_kernel_info *kernels, free((void *)kernels); } -static std::pair, std::unique_ptr> -clc_compile_to_llvm_module(const struct clc_compile_args *args, +static std::unique_ptr<::llvm::Module> +clc_compile_to_llvm_module(LLVMContext &llvm_ctx, + const struct clc_compile_args *args, const struct clc_logger *logger) { - clc_initialize_llvm(); - - std::unique_ptr llvm_ctx { new LLVMContext }; - llvm_ctx->setDiagnosticHandlerCallBack(llvm_log_handler, - const_cast(logger)); - std::string diag_log_str; raw_string_ostream diag_log_stream { diag_log_str }; @@ -878,14 +873,14 @@ clc_compile_to_llvm_module(const struct clc_compile_args *args, ::llvm::MemoryBuffer::getMemBufferCopy(std::string(args->source.value)).release()); // Compile the code - clang::EmitLLVMOnlyAction act(llvm_ctx.get()); + clang::EmitLLVMOnlyAction act(&llvm_ctx); if (!c->ExecuteAction(act)) { clc_error(logger, "%sError executing LLVM compilation action.\n", diag_log_str.c_str()); return {}; } - return { act.takeModule(), std::move(llvm_ctx) }; + return act.takeModule(); } static SPIRV::VersionNumber @@ -906,7 +901,7 @@ spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version) static int llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod, - std::unique_ptr context, + LLVMContext &context, const struct clc_compile_args *args, const struct clc_logger *logger, struct clc_binary *out_spirv) @@ -969,13 +964,19 @@ clc_c_to_spir(const struct clc_compile_args *args, const struct clc_logger *logger, struct clc_binary *out_spir) { - auto pair = clc_compile_to_llvm_module(args, logger); - if (!pair.first) + clc_initialize_llvm(); + + LLVMContext llvm_ctx; + llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler, + const_cast(logger)); + + auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger); + if (!mod) return -1; ::llvm::SmallVector buffer; ::llvm::BitcodeWriter writer(buffer); - writer.writeModule(*pair.first); + writer.writeModule(*mod); out_spir->size = buffer.size_in_bytes(); out_spir->data = malloc(out_spir->size); @@ -989,10 +990,16 @@ clc_c_to_spirv(const struct clc_compile_args *args, const struct clc_logger *logger, struct clc_binary *out_spirv) { - auto pair = clc_compile_to_llvm_module(args, logger); - if (!pair.first) + clc_initialize_llvm(); + + LLVMContext llvm_ctx; + llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler, + const_cast(logger)); + + auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger); + if (!mod) return -1; - return llvm_mod_to_spirv(std::move(pair.first), std::move(pair.second), args, logger, out_spirv); + return llvm_mod_to_spirv(std::move(mod), llvm_ctx, args, logger, out_spirv); } int @@ -1002,13 +1009,16 @@ clc_spir_to_spirv(const struct clc_binary *in_spir, { clc_initialize_llvm(); - std::unique_ptr llvm_ctx{ new LLVMContext }; + LLVMContext llvm_ctx; + llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler, + const_cast(logger)); + ::llvm::StringRef spir_ref(static_cast(in_spir->data), in_spir->size); - auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, ""), *llvm_ctx); + auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, ""), llvm_ctx); if (!mod) return -1; - return llvm_mod_to_spirv(std::move(mod.get()), std::move(llvm_ctx), NULL, logger, out_spirv); + return llvm_mod_to_spirv(std::move(mod.get()), llvm_ctx, NULL, logger, out_spirv); } class SPIRVMessageConsumer { -- 2.7.4