[mlir][sparse][gpu] fixing broken literal names in cuda runner macros
authorKun Wu <kunww@google.com>
Thu, 1 Jun 2023 17:46:08 +0000 (17:46 +0000)
committerKun Wu <kunww@google.com>
Thu, 1 Jun 2023 17:52:58 +0000 (17:52 +0000)
Differential Revision: https://reviews.llvm.org/D151910

mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

index c7367a8..17be418 100644 (file)
@@ -232,10 +232,10 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
 
 // Some macro magic to get float/double alpha and beta on host.
 #define ALPHABETA(dtp, alpha, beta)                                            \
-  __nv_bfloat16(alpha##bf16) = 1.0f;                                           \
-  __nv_bfloat16(beta##bf16) = 1.0f;                                            \
-  __half(alpha##f16) = 1.0f;                                                   \
-  __half(beta##f16) = 1.0f;                                                    \
+  __nv_bfloat16(alpha##16bf) = 1.0f;                                           \
+  __nv_bfloat16(beta##16bf) = 1.0f;                                            \
+  __half(alpha##16f) = 1.0f;                                                   \
+  __half(beta##16f) = 1.0f;                                                    \
   float(alpha##f) = 1.0f;                                                      \
   float(beta##f) = 1.0f;                                                       \
   double(alpha##d) = 1.0;                                                      \
@@ -251,11 +251,9 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
   } else if (dtp == CUDA_R_32F || dtp == CUDA_C_32F) {                         \
     (alpha##p) = reinterpret_cast<void *>(&(alpha##f));                        \
     (beta##p) = reinterpret_cast<void *>(&(beta##f));                          \
-  } else if (dtp == CUDA_R_64F || dtp == CUDA_C_64F) {                         \
+  } else {                                                                     \
     (alpha##p) = reinterpret_cast<void *>(&(alpha##d));                        \
     (beta##p) = reinterpret_cast<void *>(&(beta##d));                          \
-  } else {                                                                     \
-    llvm_unreachable("Unsupported data type");                                 \
   }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
@@ -321,6 +319,7 @@ mgpuCreateCsr(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowPos,
   cusparseSpMatDescr_t mat = nullptr;
   auto pTp = static_cast<cusparseIndexType_t>(ptp);
   auto iTp = static_cast<cusparseIndexType_t>(itp);
+  auto dTp = static_cast<cudaDataType_t>(dtp);
   CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsr(&mat, rows, cols, nnz, rowPos,
                                              colIdxs, values, pTp, iTp,
                                              CUSPARSE_INDEX_BASE_ZERO, dtp))