--- /dev/null
+#pragma once
+
+#include <c10/core/ScalarType.h>
+
+#include <cuda.h>
+#include <library_types.h>
+
+namespace at {
+namespace cuda {
+
+template <typename scalar_t>
+cudaDataType getCudaDataType() {
+ TORCH_INTERNAL_ASSERT(false, "Cannot convert type ", typeid(scalar_t).name(), " to cudaDataType.")
+}
+
+template<> cudaDataType getCudaDataType<at::Half>() {
+ return CUDA_R_16F;
+}
+template<> cudaDataType getCudaDataType<float>() {
+ return CUDA_R_32F;
+}
+template<> cudaDataType getCudaDataType<double>() {
+ return CUDA_R_64F;
+}
+template<> cudaDataType getCudaDataType<c10::complex<c10::Half>>() {
+ return CUDA_C_16F;
+}
+template<> cudaDataType getCudaDataType<c10::complex<float>>() {
+ return CUDA_C_32F;
+}
+template<> cudaDataType getCudaDataType<c10::complex<double>>() {
+ return CUDA_C_64F;
+}
+
+// HIP doesn't define integral types
+#ifndef __HIP_PLATFORM_HCC__
+template<> cudaDataType getCudaDataType<uint8_t>() {
+ return CUDA_R_8U;
+}
+template<> cudaDataType getCudaDataType<int8_t>() {
+ return CUDA_R_8I;
+}
+template<> cudaDataType getCudaDataType<int>() {
+ return CUDA_R_32I;
+}
+#endif
+
+#if !defined(__HIP_PLATFORM_HCC__) && defined(CUDA_VERSION) && CUDA_VERSION >= 11000
+template<> cudaDataType getCudaDataType<int16_t>() {
+ return CUDA_R_16I;
+}
+template<> cudaDataType getCudaDataType<int64_t>() {
+ return CUDA_R_64I;
+}
+template<> cudaDataType getCudaDataType<at::BFloat16>() {
+ return CUDA_R_16BF;
+}
+#endif
+
+} // namespace cuda
+} // namespace at
#include <THC/THCThrustAllocator.cuh>
#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDADataType.h>
#include <ATen/cuda/CUDAUtils.h>
#include <cusparse.h>
#include <ATen/native/sparse/cuda/SparseCUDABlas.h>
nnz_{nnz},
size_{size} {
#if IS_CUSPARSE11_AVAILABLE()
- cudaDataType cuda_data_type;
- if ( std::is_same<float, scalar_t>::value ) {
- cuda_data_type = CUDA_R_32F;
- } else if ( std::is_same<double, scalar_t>::value) {
- cuda_data_type = CUDA_R_64F;
- } else {
- TORCH_CHECK(false, "Tensor types must be either float32 or float64");
- }
+ cudaDataType cuda_data_type = at::cuda::getCudaDataType<scalar_t>();
TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
&description_,
this->size(0),
cusparseSpGEMMDescr_t spgemmDesc;
CusparseMatrixMultiplyOp() {
- static_assert(std::is_same<float, scalar_t>::value || std::is_same<double, scalar_t>::value,
- "cusparse csr sparse-sparse MM only supports data type of float and double.");
+ static_assert(
+ std::is_same<c10::Half, scalar_t>::value ||
+ std::is_same<c10::BFloat16, scalar_t>::value ||
+ std::is_same<float, scalar_t>::value ||
+ std::is_same<double, scalar_t>::value ||
+ std::is_same<c10::complex<float>, scalar_t>::value ||
+ std::is_same<c10::complex<double>, scalar_t>::value,
+ "cusparseSpGEMM only supports data type of half, bfloat16, float, double and complex float, double.");
// SpGEMM Computation
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&spgemmDesc));
}
const int B_num_cols = B.size(1);
- cudaDataType computeType;
- if ( std::is_same<float, scalar_t>::value ) {
- computeType = CUDA_R_32F;
- } else if ( std::is_same<double, scalar_t>::value) {
- computeType = CUDA_R_64F;
- } else {
- TORCH_CHECK(false, "Tensor types must be either float32 or float64");
- }
csrOutput out({A.size(0), B.size(1)});
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
cusparseSpMatDescr_t matC = C.description_;
//--------------------------------------------------------------------------
+ cudaDataType computeType = at::cuda::getCudaDataType<scalar_t>();
+
+ // If a specific GPU model does not provide native support for a given data type,
+ // the routine returns CUSPARSE_STATUS_ARCH_MISMATCH error
+ cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+ TORCH_CHECK(prop->major >= 5 && !((10*prop->major + prop->minor) < 53 && computeType == CUDA_R_16F),
+ "sparse_mm: CUDA Float16 requires compute capability >= 53 (current: ", prop->major, prop->minor, ")");
+ TORCH_CHECK(!(prop->major < 8 && computeType == CUDA_R_16BF),
+ "sparse_mm: CUDA BFloat16 requires compute capability >= 80 (current: ", prop->major, prop->minor, ")");
+
// ask bufferSize1 bytes for external memory
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
handle,
const Tensor& mat1,
const Tensor& mat2) {
- static_assert(std::is_same<float, scalar_t>::value || std::is_same<double, scalar_t>::value,
- "sparse_sparse_matmul_cuda_kernel only supports float and double value types");
+ static_assert(
+ std::is_same<c10::Half, scalar_t>::value ||
+ std::is_same<c10::BFloat16, scalar_t>::value ||
+ std::is_same<float, scalar_t>::value ||
+ std::is_same<double, scalar_t>::value ||
+ std::is_same<c10::complex<float>, scalar_t>::value ||
+ std::is_same<c10::complex<double>, scalar_t>::value,
+ "sparse_sparse_matmul_cuda_kernel only supports data type of half, bfloat16, float, double and complex float, double.");
Tensor mat1_indices_ = mat1._indices().contiguous();
Tensor mat1_values = mat1._values().contiguous();
auto output = at::native::empty_like(mat1_);
output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
+#if IS_CUSPARSE11_AVAILABLE()
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
+ sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
+ });
+#else
AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
+#endif
return output;
}
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
from numbers import Number
from typing import Dict, Any
+from torch.testing import get_all_complex_dtypes, get_all_fp_dtypes
+from torch.testing._internal.common_cuda import \
+ (SM53OrLater, SM80OrLater, CUDA11OrLater)
from torch.testing._internal.common_device_type import \
- (instantiate_device_type_tests, ops, dtypes, dtypesIfCPU, onlyCPU, onlyCUDA, deviceCountAtLeast)
+ (instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
+ deviceCountAtLeast)
from torch.testing._internal.common_methods_invocations import \
(sparse_unary_ufuncs)
# TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA
@skipIfRocm
@coalescedonoff
- @dtypes(torch.double)
- @dtypesIfCPU(torch.double, torch.cdouble)
+ @dtypes(*get_all_complex_dtypes(),
+ *get_all_fp_dtypes(include_half=False, include_bfloat16=False))
+ @dtypesIfCUDA(*(get_all_complex_dtypes() if CUDA11OrLater else ()),
+ *get_all_fp_dtypes(
+ include_half=(CUDA11OrLater and SM53OrLater),
+ include_bfloat16=(CUDA11OrLater and SM80OrLater)))
+ @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
def test_sparse_matmul(self, device, dtype, coalesced):
"""
This function test `torch.sparse.mm` when both the mat1 and mat2 are sparse tensors.
r2 = torch.sparse.mm(a, b)
self.assertEqual(r1, r2)
- a.requires_grad_(True)
- b.requires_grad_(True)
+ if dtype in [torch.double, torch.cdouble]:
+ a.requires_grad_(True)
+ b.requires_grad_(True)
- # check autograd support on sparse matmul
- def fn(D1, D2):
- return torch.sparse.mm(D1, D2).to_dense()
+ # check autograd support on sparse matmul
+ def fn(D1, D2):
+ return torch.sparse.mm(D1, D2).to_dense()
- if a.is_cuda:
- # For cuda, `nondet_tol` is set with `1e-5`
- # This is because cuSparse sometimes returns approximate zero values like `~e-323`
- # TODO: Check this cuSparse issue.
- # This happens when you do chain multiplication `torch.sparse.mm` operations
- gradcheck(fn, (a, b), check_sparse_nnz=True, nondet_tol=1e-5)
- else:
- gradcheck(fn, (a, b), check_sparse_nnz=True)
- grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b)
+ if a.is_cuda:
+ # For cuda, `nondet_tol` is set with `1e-5`
+ # This is because cuSparse sometimes returns approximate zero values like `~e-323`
+ # TODO: Check this cuSparse issue.
+ # This happens when you do chain multiplication `torch.sparse.mm` operations
+ gradcheck(fn, (a, b), check_sparse_nnz=True, nondet_tol=1e-5)
+ else:
+ gradcheck(fn, (a, b), check_sparse_nnz=True)
+ grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b)
def test_error_cases():
def fn(sparse_dims, nnz, shape_a, shape_b):
CUDA9 = torch.version.cuda and torch.version.cuda.startswith('9.')
SM53OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)
SM60OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0)
+SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)
TEST_MAGMA = TEST_CUDA
if TEST_CUDA:
),
("device_functions.h", ("hip/device_functions.h", CONV_INCLUDE, API_RUNTIME)),
("driver_types.h", ("hip/driver_types.h", CONV_INCLUDE, API_RUNTIME)),
+ ("library_types.h", ("hip/library_types.h", CONV_INCLUDE, API_RUNTIME)),
("cuComplex.h", ("hip/hip_complex.h", CONV_INCLUDE, API_RUNTIME)),
("cuda_fp16.h", ("hip/hip_fp16.h", CONV_INCLUDE, API_RUNTIME)),
(
),
),
("cudaDataType_t", ("hipDataType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("cudaDataType", ("hipDataType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_R_16F", ("hipR16F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_C_16F", ("hipC16F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_R_32F", ("hipR32F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_C_32F", ("hipC32F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_R_64F", ("hipR64F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_C_64F", ("hipC64F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_R_8I", ("hipR8I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_C_8I", ("hipC8I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_R_8U", ("hipR8U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_C_8U", ("hipC8U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_R_32I", ("hipR32I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_C_32I", ("hipC32I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_R_32U", ("hipR32U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
- ("CUDA_C_32U", ("hipC32U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("cudaDataType", ("hipDataType", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_R_16F", ("HIP_R_16F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_C_16F", ("HIP_C_16F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_R_32F", ("HIP_R_32F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_C_32F", ("HIP_C_32F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_R_64F", ("HIP_R_64F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_C_64F", ("HIP_C_64F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_R_8I", ("HIP_R_8I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_C_8I", ("HIP_C_8I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_R_8U", ("HIP_R_8U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_C_8U", ("HIP_C_8U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_R_32I", ("HIP_R_32I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_C_32I", ("HIP_C_32I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_R_32U", ("HIP_R_32U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+ ("CUDA_C_32U", ("HIP_C_32U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
(
"MAJOR_VERSION",
("hipLibraryMajorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED),