Enable Half, BFloat16, and Complex dtypes for coo-coo sparse matmul [CUDA] (#59980)
authorIvan Yashchuk <ivan.yashchuk@aalto.fi>
Mon, 30 Aug 2021 22:03:15 +0000 (15:03 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 22:06:25 +0000 (15:06 -0700)
Summary:
This PR enables Half, BFloat16, ComplexFloat, and ComplexDouble support for matrix-matrix multiplication of COO sparse matrices.
The change is applied only to CUDA 11+ builds.

`cusparseSpGEMM` also supports `CUDA_C_16F` (complex float16) and `CUDA_C_16BF` (complex bfloat16). PyTorch also supports the complex float16 dtype (`ScalarType::ComplexHalf`), but there is no convenient dispatch, so this dtype is omitted in this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59980

Reviewed By: ngimel

Differential Revision: D29699456

Pulled By: cpuhrsch

fbshipit-source-id: 407ae53392acb2f92396a62a57cbaeb0fe6e950b

aten/src/ATen/cuda/CUDADataType.h [new file with mode: 0644]
aten/src/ATen/native/sparse/cuda/SparseMatMul.cu
test/test_sparse.py
torch/testing/_internal/common_cuda.py
torch/utils/hipify/cuda_to_hip_mappings.py

diff --git a/aten/src/ATen/cuda/CUDADataType.h b/aten/src/ATen/cuda/CUDADataType.h
new file mode 100644 (file)
index 0000000..71c9af9
--- /dev/null
@@ -0,0 +1,61 @@
+#pragma once
+
+#include <c10/core/ScalarType.h>
+
+#include <cuda.h>
+#include <library_types.h>
+
+namespace at {
+namespace cuda {
+
+template <typename scalar_t>
+cudaDataType getCudaDataType() {
+  TORCH_INTERNAL_ASSERT(false, "Cannot convert type ", typeid(scalar_t).name(), " to cudaDataType.")
+}
+
+template<> cudaDataType getCudaDataType<at::Half>() {
+  return CUDA_R_16F;
+}
+template<> cudaDataType getCudaDataType<float>() {
+  return CUDA_R_32F;
+}
+template<> cudaDataType getCudaDataType<double>() {
+  return CUDA_R_64F;
+}
+template<> cudaDataType getCudaDataType<c10::complex<c10::Half>>() {
+  return CUDA_C_16F;
+}
+template<> cudaDataType getCudaDataType<c10::complex<float>>() {
+  return CUDA_C_32F;
+}
+template<> cudaDataType getCudaDataType<c10::complex<double>>() {
+  return CUDA_C_64F;
+}
+
+// HIP doesn't define integral types
+#ifndef __HIP_PLATFORM_HCC__
+template<> cudaDataType getCudaDataType<uint8_t>() {
+  return CUDA_R_8U;
+}
+template<> cudaDataType getCudaDataType<int8_t>() {
+  return CUDA_R_8I;
+}
+template<> cudaDataType getCudaDataType<int>() {
+  return CUDA_R_32I;
+}
+#endif
+
+#if !defined(__HIP_PLATFORM_HCC__) && defined(CUDA_VERSION) && CUDA_VERSION >= 11000
+template<> cudaDataType getCudaDataType<int16_t>() {
+  return CUDA_R_16I;
+}
+template<> cudaDataType getCudaDataType<int64_t>() {
+  return CUDA_R_64I;
+}
+template<> cudaDataType getCudaDataType<at::BFloat16>() {
+  return CUDA_R_16BF;
+}
+#endif
+
+} // namespace cuda
+} // namespace at
index d5f31a1..a08c93d 100644 (file)
@@ -16,6 +16,7 @@
 #include <THC/THCThrustAllocator.cuh>
 
 #include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDADataType.h>
 #include <ATen/cuda/CUDAUtils.h>
 #include <cusparse.h>
 #include <ATen/native/sparse/cuda/SparseCUDABlas.h>
@@ -118,14 +119,7 @@ struct csrMatrixRef {
         nnz_{nnz},
         size_{size} {
     #if IS_CUSPARSE11_AVAILABLE()
-      cudaDataType cuda_data_type;
-      if ( std::is_same<float, scalar_t>::value ) {
-        cuda_data_type = CUDA_R_32F;
-      } else if ( std::is_same<double, scalar_t>::value) {
-        cuda_data_type = CUDA_R_64F;
-      } else {
-        TORCH_CHECK(false, "Tensor types must be either float32 or float64");
-      }
+      cudaDataType cuda_data_type = at::cuda::getCudaDataType<scalar_t>();
       TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
         &description_,
         this->size(0),
@@ -192,8 +186,14 @@ struct CusparseMatrixMultiplyOp {
   cusparseSpGEMMDescr_t spgemmDesc;
 
   CusparseMatrixMultiplyOp() {
-    static_assert(std::is_same<float, scalar_t>::value || std::is_same<double, scalar_t>::value,
-      "cusparse csr sparse-sparse MM only supports data type of float and double.");
+    static_assert(
+      std::is_same<c10::Half, scalar_t>::value ||
+          std::is_same<c10::BFloat16, scalar_t>::value ||
+          std::is_same<float, scalar_t>::value ||
+          std::is_same<double, scalar_t>::value ||
+          std::is_same<c10::complex<float>, scalar_t>::value ||
+          std::is_same<c10::complex<double>, scalar_t>::value,
+      "cusparseSpGEMM only supports data type of half, bfloat16, float, double and complex float, double.");
     // SpGEMM Computation
     TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&spgemmDesc));
   }
@@ -212,14 +212,6 @@ struct CusparseMatrixMultiplyOp {
 
     const int B_num_cols = B.size(1);
 
-    cudaDataType computeType;
-    if ( std::is_same<float, scalar_t>::value ) {
-      computeType = CUDA_R_32F;
-    } else if ( std::is_same<double, scalar_t>::value) {
-      computeType = CUDA_R_64F;
-    } else {
-      TORCH_CHECK(false, "Tensor types must be either float32 or float64");
-    }
     csrOutput out({A.size(0), B.size(1)});
 
     out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
@@ -252,6 +244,16 @@ struct CusparseMatrixMultiplyOp {
     cusparseSpMatDescr_t matC = C.description_;
     //--------------------------------------------------------------------------
 
+    cudaDataType computeType = at::cuda::getCudaDataType<scalar_t>();
+
+    // If a specific GPU model does not provide native support for a given data type,
+    // the routine returns CUSPARSE_STATUS_ARCH_MISMATCH error
+    cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+    TORCH_CHECK(prop->major >= 5 && !((10*prop->major + prop->minor) < 53 && computeType == CUDA_R_16F),
+        "sparse_mm: CUDA Float16 requires compute capability >= 53 (current: ", prop->major, prop->minor, ")");
+    TORCH_CHECK(!(prop->major < 8 && computeType == CUDA_R_16BF),
+        "sparse_mm: CUDA BFloat16 requires compute capability >= 80 (current: ", prop->major, prop->minor, ")");
+
     // ask bufferSize1 bytes for external memory
     TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
         handle,
@@ -646,8 +648,14 @@ void sparse_sparse_matmul_cuda_kernel(
     const Tensor& mat1,
     const Tensor& mat2) {
 
-  static_assert(std::is_same<float, scalar_t>::value || std::is_same<double, scalar_t>::value,
-    "sparse_sparse_matmul_cuda_kernel only supports float and double value types");
+  static_assert(
+    std::is_same<c10::Half, scalar_t>::value ||
+        std::is_same<c10::BFloat16, scalar_t>::value ||
+        std::is_same<float, scalar_t>::value ||
+        std::is_same<double, scalar_t>::value ||
+        std::is_same<c10::complex<float>, scalar_t>::value ||
+        std::is_same<c10::complex<double>, scalar_t>::value,
+    "sparse_sparse_matmul_cuda_kernel only supports data type of half, bfloat16, float, double and complex float, double.");
 
   Tensor mat1_indices_ = mat1._indices().contiguous();
   Tensor mat1_values = mat1._values().contiguous();
@@ -775,9 +783,15 @@ Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) {
   auto output = at::native::empty_like(mat1_);
   output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
 
+#if IS_CUSPARSE11_AVAILABLE()
+  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
+    sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
+  });
+#else
   AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
     sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
   });
+#endif
   return output;
 }
 
index 333f29f..aaf045c 100644 (file)
@@ -12,8 +12,12 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm
 from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
 from numbers import Number
 from typing import Dict, Any
+from torch.testing import get_all_complex_dtypes, get_all_fp_dtypes
+from torch.testing._internal.common_cuda import \
+    (SM53OrLater, SM80OrLater, CUDA11OrLater)
 from torch.testing._internal.common_device_type import \
-    (instantiate_device_type_tests, ops, dtypes, dtypesIfCPU, onlyCPU, onlyCUDA, deviceCountAtLeast)
+    (instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
+     deviceCountAtLeast)
 from torch.testing._internal.common_methods_invocations import \
     (sparse_unary_ufuncs)
 
@@ -3217,8 +3221,13 @@ class TestSparse(TestCase):
     # TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA
     @skipIfRocm
     @coalescedonoff
-    @dtypes(torch.double)
-    @dtypesIfCPU(torch.double, torch.cdouble)
+    @dtypes(*get_all_complex_dtypes(),
+            *get_all_fp_dtypes(include_half=False, include_bfloat16=False))
+    @dtypesIfCUDA(*(get_all_complex_dtypes() if CUDA11OrLater else ()),
+                  *get_all_fp_dtypes(
+                      include_half=(CUDA11OrLater and SM53OrLater),
+                      include_bfloat16=(CUDA11OrLater and SM80OrLater)))
+    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
     def test_sparse_matmul(self, device, dtype, coalesced):
         """
         This function test `torch.sparse.mm` when both the mat1 and mat2 are sparse tensors.
@@ -3328,22 +3337,23 @@ class TestSparse(TestCase):
             r2 = torch.sparse.mm(a, b)
             self.assertEqual(r1, r2)
 
-            a.requires_grad_(True)
-            b.requires_grad_(True)
+            if dtype in [torch.double, torch.cdouble]:
+                a.requires_grad_(True)
+                b.requires_grad_(True)
 
-            # check autograd support on sparse matmul
-            def fn(D1, D2):
-                return torch.sparse.mm(D1, D2).to_dense()
+                # check autograd support on sparse matmul
+                def fn(D1, D2):
+                    return torch.sparse.mm(D1, D2).to_dense()
 
-            if a.is_cuda:
-                # For cuda, `nondet_tol` is set with `1e-5`
-                # This is because cuSparse sometimes returns approximate zero values like `~e-323`
-                # TODO: Check this cuSparse issue.
-                # This happens when you do chain multiplication `torch.sparse.mm` operations
-                gradcheck(fn, (a, b), check_sparse_nnz=True, nondet_tol=1e-5)
-            else:
-                gradcheck(fn, (a, b), check_sparse_nnz=True)
-            grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b)
+                if a.is_cuda:
+                    # For cuda, `nondet_tol` is set with `1e-5`
+                    # This is because cuSparse sometimes returns approximate zero values like `~e-323`
+                    # TODO: Check this cuSparse issue.
+                    # This happens when you do chain multiplication `torch.sparse.mm` operations
+                    gradcheck(fn, (a, b), check_sparse_nnz=True, nondet_tol=1e-5)
+                else:
+                    gradcheck(fn, (a, b), check_sparse_nnz=True)
+                grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b)
 
         def test_error_cases():
             def fn(sparse_dims, nnz, shape_a, shape_b):
index 5d0849b..36e7f8a 100644 (file)
@@ -20,6 +20,7 @@ CUDA11OrLater = torch.version.cuda and distutils.version.LooseVersion(torch.vers
 CUDA9 = torch.version.cuda and torch.version.cuda.startswith('9.')
 SM53OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)
 SM60OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0)
+SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)
 
 TEST_MAGMA = TEST_CUDA
 if TEST_CUDA:
index 558acc2..6b60516 100644 (file)
@@ -554,6 +554,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
         ),
         ("device_functions.h", ("hip/device_functions.h", CONV_INCLUDE, API_RUNTIME)),
         ("driver_types.h", ("hip/driver_types.h", CONV_INCLUDE, API_RUNTIME)),
+        ("library_types.h", ("hip/library_types.h", CONV_INCLUDE, API_RUNTIME)),
         ("cuComplex.h", ("hip/hip_complex.h", CONV_INCLUDE, API_RUNTIME)),
         ("cuda_fp16.h", ("hip/hip_fp16.h", CONV_INCLUDE, API_RUNTIME)),
         (
@@ -3786,21 +3787,21 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
             ),
         ),
         ("cudaDataType_t", ("hipDataType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("cudaDataType", ("hipDataType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_R_16F", ("hipR16F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_C_16F", ("hipC16F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_R_32F", ("hipR32F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_C_32F", ("hipC32F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_R_64F", ("hipR64F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_C_64F", ("hipC64F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_R_8I", ("hipR8I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_C_8I", ("hipC8I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_R_8U", ("hipR8U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_C_8U", ("hipC8U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_R_32I", ("hipR32I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_C_32I", ("hipC32I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_R_32U", ("hipR32U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
-        ("CUDA_C_32U", ("hipC32U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("cudaDataType", ("hipDataType", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_R_16F", ("HIP_R_16F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_C_16F", ("HIP_C_16F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_R_32F", ("HIP_R_32F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_C_32F", ("HIP_C_32F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_R_64F", ("HIP_R_64F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_C_64F", ("HIP_C_64F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_R_8I", ("HIP_R_8I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_C_8I", ("HIP_C_8I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_R_8U", ("HIP_R_8U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_C_8U", ("HIP_C_8U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_R_32I", ("HIP_R_32I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_C_32I", ("HIP_C_32I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_R_32U", ("HIP_R_32U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("CUDA_C_32U", ("HIP_C_32U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
         (
             "MAJOR_VERSION",
             ("hipLibraryMajorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED),