fix compilation error with cuda 11 (#6213)
authorJamie Zhuang <lanchongyizu@gmail.com>
Wed, 5 Aug 2020 17:04:24 +0000 (01:04 +0800)
committerGitHub <noreply@github.com>
Wed, 5 Aug 2020 17:04:24 +0000 (10:04 -0700)
src/runtime/contrib/cublas/cublas.cc

index ff20445..467ae5f 100644 (file)
@@ -172,7 +172,11 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) {
   cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
   cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C;
   cublasLtMatmulDesc_t operationDesc = nullptr;
+#if CUDART_VERSION >= 11000
+  CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I));
+#elif
   CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I));
+#endif
   CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
                                                     &opTranspose, sizeof(opTranspose)));
   cublasOperation_t opTransA = BooleanToTranspose(transa);