#define FIXED_DIVISOR_DIV_MOD(d, n, q, r) (d.DivMod(n, q, r))
#endif // __HIP_PLATFORM_HCC__
+#ifdef __HIP_PLATFORM_HCC__
+using CUBLAS_HALF_TYPE = rocblas_half;
+#else // __HIP_PLATFORM_HCC
+using CUBLAS_HALF_TYPE = __half;
+#endif // __HIP_PLATFORM_HCC
+
#include "caffe2/utils/math_utils.h"
#if THRUST_VERSION >= 100800
at::Half* C,
CUDAContext* context,
TensorProto::DataType math_type) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
-#else
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
const int lda = (trans_A == CblasNoTrans) ? K : M;
if (math_type == TensorProto_DataType_FLOAT) {
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
+#ifdef __HIP_PLATFORM_HCC__
+ // rocblas doesn't support cublasSgemmEx type API yet.
+ // It has more general rocblas_gemm_ex API which is more close to cublasGemmEx
+ // rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, whereas
+ // cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
+ ROCBLAS_ENFORCE(rocblas_gemm_ex(
+ context->rocblashandle(),
+ cu_trans_B,
+ cu_trans_A,
+ N,
+ M,
+ K,
+ &alpha,
+ B,
+ rocblas_datatype_f16_r,
+ ldb,
+ A,
+ rocblas_datatype_f16_r,
+ lda,
+ &beta,
+ C,
+ rocblas_datatype_f16_r,
+ N,
+ C, // D
+ rocblas_datatype_f16_r, // D type
+ N, // ldd
+ rocblas_datatype_f32_r, // compute type
+ rocblas_gemm_algo_standard, // rocblas_gemm_algo
+ 0, // solution index, reserved for future use
+ 0, // flags, reserved for future use
+ NULL, // size of workspace
+ NULL)); // workspace
+#else
CUBLAS_ENFORCE(cublasSgemmEx(
context->cublas_handle(),
cu_trans_B,
C,
CUDA_R_16F,
N));
+#endif // __HIP_PLATFORM_HCC__
} else if (math_type == TensorProto_DataType_FLOAT16) {
// convert alpha, beta from float -> __half
const __half alpha_fp16 = at::Half(alpha);
N,
M,
K,
- &alpha_fp16,
- (const __half*)B,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&alpha_fp16),
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(B),
ldb,
- (const __half*)A,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(A),
lda,
- &beta_fp16,
- (__half*)C,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&beta_fp16),
+ reinterpret_cast<CUBLAS_HALF_TYPE*>(C),
N));
} else {
// fail
CAFFE_THROW("Unsupported math type");
}
-#endif
}
template <>
at::Half** C,
CUDAContext* context,
TensorProto::DataType math_type) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
-#else
#if __CUDACC_VER_MAJOR__ < 9
// loop over matrices in the batch
for (int i = 0; i < batch_size; ++i) {
CAFFE_THROW("Unsupported math type");
}
#endif
-#endif
}
template <>
const int C_stride,
CUDAContext* context,
TensorProto::DataType math_type) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
-#else
-#if __CUDACC_VER_MAJOR__ < 8
+#if __CUDACC_VER_MAJOR__ < 8 && !defined(__HIP_PLATFORM_HCC__)
// loop over matrices in the batch
for (int i = 0; i < batch_size; ++i) {
Gemm<at::Half, CUDAContext>(
const cublasOperation_t cu_trans_B =
(trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
if (math_type == TensorProto_DataType_FLOAT) {
-#if CUDA_VERSION < 9010
+#if CUDA_VERSION < 9010 && !defined(__HIP_PLATFORM_HCC__)
// loop over matrices in the batch
for (int i = 0; i < batch_size; ++i) {
Gemm<at::Half, CUDAContext>(
#else
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
+#ifdef __HIP_PLATFORM_HCC__
+ // D[i*stride_d] = alpha*op(A[i*stride_a])*op(B[i*stride_b]) + beta*C[i*stride_c],
+ // for i in [0,batch_count-1]
+ ROCBLAS_ENFORCE(rocblas_gemm_strided_batched_ex(
+ context->rocblashandle(),
+ cu_trans_B,
+ cu_trans_A,
+ N,
+ M,
+ K,
+ &alpha,
+ B,
+ rocblas_datatype_f16_r,
+ ldb,
+ B_stride,
+ A,
+ rocblas_datatype_f16_r,
+ lda,
+ A_stride,
+ &beta,
+ C,
+ rocblas_datatype_f16_r,
+ ldc,
+ C_stride,
+ C, // D
+ rocblas_datatype_f16_r, // D type
+ ldc, // ldd
+ C_stride, // D stride
+ batch_size,
+ rocblas_datatype_f32_r, // compute type
+ rocblas_gemm_algo_standard, // rocblas_gemm_algo
+ 0, // solution index, reserved for future use
+ 0, // flags, reserved for future use
+ NULL, // size of workspace
+ NULL)); // workspace
+#else
CUBLAS_ENFORCE(cublasGemmStridedBatchedEx(
context->cublas_handle(),
cu_trans_B,
batch_size,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+#endif // __HIP_PLATFORM_HCC__
#endif
} else if (math_type == TensorProto_DataType_FLOAT16) {
// Convert alpha, beta from float -> __half
N,
M,
K,
- &alpha_fp16,
- (const __half*)B,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&alpha_fp16),
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(B),
ldb,
B_stride,
- (const __half*)A,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(A),
lda,
A_stride,
- &beta_fp16,
- (__half*)C,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&beta_fp16),
+ reinterpret_cast<CUBLAS_HALF_TYPE*>(C),
ldc,
C_stride,
batch_size));
CAFFE_THROW("Unsupported math type");
}
#endif
-#endif
}
#if CUDA_VERSION >= 9000
at::Half* y,
CUDAContext* context,
TensorProto::DataType math_type) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
-#else
const cublasOperation_t cu_trans_A =
(trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
if (math_type == TensorProto_DataType_FLOAT) {
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
+#ifdef __HIP_PLATFORM_HCC__
+ // rocblas doesn't support cublasSgemmEx type API yet.
+ // It has more general rocblas_gemm_ex API which is more close to cublasGemmEx
+ // rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, whereas
+ // cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
+ ROCBLAS_ENFORCE(rocblas_gemm_ex(
+ context->rocblashandle(),
+ cu_trans_A,
+ rocblas_operation_none,
+ m,
+ 1,
+ k,
+ &alpha,
+ A,
+ rocblas_datatype_f16_r,
+ lda,
+ x,
+ rocblas_datatype_f16_r,
+ k,
+ &beta,
+ y,
+ rocblas_datatype_f16_r,
+ ldc,
+ y, // D
+ rocblas_datatype_f16_r, // D type
+ ldc, // ldd
+ rocblas_datatype_f32_r, // compute type
+ rocblas_gemm_algo_standard, // rocblas_gemm_algo
+ 0, // solution index, reserved for future use
+ 0, // flags, reserved for future use
+ NULL, // size of workspace
+ NULL)); // workspace
+#else
CUBLAS_ENFORCE(cublasSgemmEx(
context->cublas_handle(),
cu_trans_A,
y,
CUDA_R_16F,
ldc));
+#endif // __HIP_PLATFORM_HCC__
} else if (math_type == TensorProto_DataType_FLOAT16) {
const __half alpha_fp16 = at::Half(alpha);
const __half beta_fp16 = at::Half(beta);
m,
1,
k,
- &alpha_fp16,
- (const __half*)A,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&alpha_fp16),
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(A),
lda,
- (const __half*)x,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(x),
k,
- &beta_fp16,
- (__half*)y,
+ reinterpret_cast<const CUBLAS_HALF_TYPE*>(&beta_fp16),
+ reinterpret_cast<CUBLAS_HALF_TYPE*>(y),
ldc));
} else {
// fail
CAFFE_THROW("Unsupported math type");
}
-#endif
}
namespace {
const at::Half* b,
at::Half* y,
CUDAContext* context) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
+#if defined(__HIP_PLATFORM_HCC__)
+ CAFFE_THROW("HIP currently does not support FP16 completely yet.");
#else
// execute with 32-bit math
CUBLAS_ENFORCE(cublasSetPointerMode(
const at::Half* X,
at::Half* Y,
CUDAContext* context) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
+#if defined(__HIP_PLATFORM_HCC__)
+ CAFFE_THROW("HIP currently does not support FP16 completely yet.");
#else
CUBLAS_ENFORCE(
cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
const at::Half* X,
at::Half* Y,
CUDAContext* context) {
-#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16
- CAFFE_THROW("HIP currently does not support FP16 yet.");
+#if defined(__HIP_PLATFORM_HCC__)
+ CAFFE_THROW("HIP currently does not support FP16 completely yet.");
#else
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE));