#include "llvm/Support/TargetSelect.h"
#include <cuda.h>
-#include <nvPTXCompiler.h>
using namespace mlir;
-static void emitNvptxError(const llvm::Twine &expr,
- nvPTXCompilerHandle compiler,
- nvPTXCompileResult result, Location loc) {
+static void emitCudaError(const llvm::Twine &expr, const char *buffer,
+ CUresult result, Location loc) {
const char *error;
- auto GetErrMsg = [](nvPTXCompileResult result) -> const char * {
- switch (result) {
- case NVPTXCOMPILE_SUCCESS:
- return "Success";
- case NVPTXCOMPILE_ERROR_INVALID_COMPILER_HANDLE:
- return "Invalid compiler handle";
- case NVPTXCOMPILE_ERROR_INVALID_INPUT:
- return "Invalid input";
- case NVPTXCOMPILE_ERROR_COMPILATION_FAILURE:
- return "Compilation failure";
- case NVPTXCOMPILE_ERROR_INTERNAL:
- return "Internal error";
- case NVPTXCOMPILE_ERROR_OUT_OF_MEMORY:
- return "Out of memory";
- case NVPTXCOMPILE_ERROR_COMPILER_INVOCATION_INCOMPLETE:
- return "Invocation incomplete";
- case NVPTXCOMPILE_ERROR_UNSUPPORTED_PTX_VERSION:
- return "Unsupported PTX version";
- }
- };
- size_t errorSize;
- auto status = nvPTXCompilerGetErrorLogSize(compiler, &errorSize);
- std::string error_log;
- if (status == NVPTXCOMPILE_SUCCESS) {
- error_log.resize(errorSize);
- status = nvPTXCompilerGetErrorLog(compiler, error_log.data());
- if (status != NVPTXCOMPILE_SUCCESS)
- error_log = "<failed to retrieve compiler error log>";
- }
+ cuGetErrorString(result, &error);
emitError(loc, expr.concat(" failed with error code ")
- .concat(llvm::Twine{GetErrMsg(result)})
+ .concat(llvm::Twine{error})
.concat("[")
- .concat(error_log)
+ .concat(buffer)
.concat("]"));
}
#define RETURN_ON_CUDA_ERROR(expr) \
do { \
if (auto status = (expr)) { \
- emitNvptxError(#expr, compiler, status, loc); \
- return {}; \
- } \
- } while (false)
-
-#define RETURN_ON_NVPTX_ERROR(expr) \
- do { \
- nvPTXCompileResult result = (expr); \
- if (result != NVPTXCOMPILE_SUCCESS) { \
- emitNvptxError(#expr, compiler, result, loc); \
+ emitCudaError(#expr, jitErrorBuffer, status, loc); \
return {}; \
} \
} while (false)
SerializeToCubinPass::serializeISA(const std::string &isa) {
Location loc = getOperation().getLoc();
char jitErrorBuffer[4096] = {0};
- nvPTXCompilerHandle compiler;
- nvPTXCompilerCreate(&compiler, isa.length(), isa.c_str());
-
- nvPTXCompilerCompile(compiler, 0, nullptr);
+ RETURN_ON_CUDA_ERROR(cuInit(0));
+
+ // Linking requires a device context.
+ CUdevice device;
+ RETURN_ON_CUDA_ERROR(cuDeviceGet(&device, 0));
+ CUcontext context;
+ RETURN_ON_CUDA_ERROR(cuCtxCreate(&context, 0, device));
+ CUlinkState linkState;
+
+ CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER,
+ CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
+ void *jitOptionsVals[] = {jitErrorBuffer,
+ reinterpret_cast<void *>(sizeof(jitErrorBuffer))};
+
+ RETURN_ON_CUDA_ERROR(cuLinkCreate(2, /* number of jit options */
+ jitOptions, /* jit options */
+ jitOptionsVals, /* jit option values */
+ &linkState));
+
+ auto kernelName = getOperation().getName().str();
+ RETURN_ON_CUDA_ERROR(cuLinkAddData(
+ linkState, CUjitInputType::CU_JIT_INPUT_PTX,
+ const_cast<void *>(static_cast<const void *>(isa.c_str())), isa.length(),
+ kernelName.c_str(), 0, /* number of jit options */
+ nullptr, /* jit options */
+ nullptr /* jit option values */
+ ));
+
+ void *cubinData;
size_t cubinSize;
- nvPTXCompilerGetCompiledProgramSize(compiler, &cubinSize);
+ RETURN_ON_CUDA_ERROR(cuLinkComplete(linkState, &cubinData, &cubinSize));
+
+ char *cubinAsChar = static_cast<char *>(cubinData);
+ auto result =
+ std::make_unique<std::vector<char>>(cubinAsChar, cubinAsChar + cubinSize);
- auto result = std::make_unique<std::vector<char>>(cubinSize);
- nvPTXCompilerGetCompiledProgram(compiler, result->data());
- nvPTXCompilerDestroy(&compiler);
+ // This will also destroy the cubin data.
+ RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState));
+ RETURN_ON_CUDA_ERROR(cuCtxDestroy(context));
return result;
}