From: Kun Wu Date: Thu, 1 Jun 2023 17:46:08 +0000 (+0000) Subject: [mlir][sparse][gpu] fixing broken literal names in cuda runner macros X-Git-Tag: upstream/17.0.6~6453 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=be6c5320059fcc6a86c775108ff440f7e53f84b6;p=platform%2Fupstream%2Fllvm.git [mlir][sparse][gpu] fixing broken literal names in cuda runner macros Differential Revision: https://reviews.llvm.org/D151910 --- diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp index c7367a8..17be418 100644 --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -232,10 +232,10 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) { // Some macro magic to get float/double alpha and beta on host. #define ALPHABETA(dtp, alpha, beta) \ - __nv_bfloat16(alpha##bf16) = 1.0f; \ - __nv_bfloat16(beta##bf16) = 1.0f; \ - __half(alpha##f16) = 1.0f; \ - __half(beta##f16) = 1.0f; \ + __nv_bfloat16(alpha##16bf) = 1.0f; \ + __nv_bfloat16(beta##16bf) = 1.0f; \ + __half(alpha##16f) = 1.0f; \ + __half(beta##16f) = 1.0f; \ float(alpha##f) = 1.0f; \ float(beta##f) = 1.0f; \ double(alpha##d) = 1.0; \ @@ -251,11 +251,9 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) { } else if (dtp == CUDA_R_32F || dtp == CUDA_C_32F) { \ (alpha##p) = reinterpret_cast(&(alpha##f)); \ (beta##p) = reinterpret_cast(&(beta##f)); \ - } else if (dtp == CUDA_R_64F || dtp == CUDA_C_64F) { \ + } else { \ (alpha##p) = reinterpret_cast(&(alpha##d)); \ (beta##p) = reinterpret_cast(&(beta##d)); \ - } else { \ - llvm_unreachable("Unsupported data type"); \ } extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * @@ -321,6 +319,7 @@ mgpuCreateCsr(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowPos, cusparseSpMatDescr_t mat = nullptr; auto pTp = static_cast(ptp); auto iTp = static_cast(itp); + auto dTp = static_cast(dtp); CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsr(&mat, rows, cols, nnz, rowPos, colIdxs, values, pTp, iTp, CUSPARSE_INDEX_BASE_ZERO, dtp))