From 7e2b074219fad6d2b09b379423e83b2295b29df2 Mon Sep 17 00:00:00 2001 From: rohithkrn Date: Mon, 10 Dec 2018 17:25:46 -0800 Subject: [PATCH] Integrate rocBLAS fp16 api into Caffe2 (#14882) Summary: This PR integrates rocBLAS half and mixed precision APIs in to Caffe2. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14882 Differential Revision: D13407840 Pulled By: bddppq fbshipit-source-id: 75cb0d74da066776fa66575f1d255e879d36121e --- caffe2/core/common_gpu.h | 9 +- caffe2/operators/fully_connected_op_gpu.cc | 2 - caffe2/python/operator_test/fc_operator_test.py | 11 +- caffe2/python/operator_test/matmul_op_test.py | 6 +- caffe2/python/operator_test/momentum_sgd_test.py | 10 +- caffe2/sgd/fp16_momentum_sgd_op.cu | 2 +- caffe2/utils/math_gpu.cu | 173 ++++++++++++++++++----- cmake/Dependencies.cmake | 1 - tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py | 2 +- 9 files changed, 159 insertions(+), 57 deletions(-) diff --git a/caffe2/core/common_gpu.h b/caffe2/core/common_gpu.h index 1af674e..db87887 100644 --- a/caffe2/core/common_gpu.h +++ b/caffe2/core/common_gpu.h @@ -69,7 +69,7 @@ // CAFFE_HAS_CUDA_FP16 manually. #ifndef CAFFE_HAS_CUDA_FP16 -#if CUDA_VERSION >= 7050 +#if CUDA_VERSION >= 7050 || defined(__HIP_PLATFORM_HCC__) #define CAFFE_HAS_CUDA_FP16 #endif // CUDA_VERSION >= 7050 #endif // CAFFE_HAS_CUDA_FP16 @@ -78,6 +78,13 @@ #include #endif +// cuda major revision number below which fp16 compute is not supoorted +#ifndef __HIP_PLATFORM_HCC__ +constexpr int kFp16CUDADevicePropMajor = 6; +#else +constexpr int kFp16CUDADevicePropMajor = 3; +#endif + // Re-enable strict aliasing diagnostic if it was disabled. #if CUDA_VERSION >= 9000 #ifdef __GNUC__ diff --git a/caffe2/operators/fully_connected_op_gpu.cc b/caffe2/operators/fully_connected_op_gpu.cc index 3f82283..4762692 100644 --- a/caffe2/operators/fully_connected_op_gpu.cc +++ b/caffe2/operators/fully_connected_op_gpu.cc @@ -6,8 +6,6 @@ namespace caffe2 { namespace { -constexpr int kFp16CUDADevicePropMajor = 6; - template bool RunFullyConnectedOpOnCUDADevice( const bool float16_compute, diff --git a/caffe2/python/operator_test/fc_operator_test.py b/caffe2/python/operator_test/fc_operator_test.py index d42e00c..466453c 100644 --- a/caffe2/python/operator_test/fc_operator_test.py +++ b/caffe2/python/operator_test/fc_operator_test.py @@ -16,9 +16,9 @@ import unittest class TestFcOperator(serial.SerializedTestCase): def _run_test(self, n, m, k, transposed, multi_dim, dtype, engine, gc, dc): if dtype == np.float16: - # fp16 only supported with CUDA - assume(gc.device_type == caffe2_pb2.CUDA) - dc = [d for d in dc if d.device_type == caffe2_pb2.CUDA] + # fp16 only supported with CUDA/HIP + assume(core.IsGPUDeviceType(gc.device_type)) + dc = [d for d in dc if core.IsGPUDeviceType(d.device_type)] if engine == 'TENSORCORE': # TensorCore only makes sense with CUDA @@ -54,18 +54,21 @@ class TestFcOperator(serial.SerializedTestCase): engine=engine, ) - if dtype == np.float16 and gc.device_type == caffe2_pb2.CUDA: + if dtype == np.float16 and core.IsGPUDeviceType(gc.device_type): a = caffe2_pb2.Argument() a.i = 1 a.name = "float16_compute" op.arg.extend([a]) # Check against numpy reference + # ReferenceChecks is flaky on rocm with threshold of 1e-4 for fp16. Relaxing to 1e-3. + threshold = 1e-3 if (gc.device_type == caffe2_pb2.HIP and dtype == np.float16) else 1e-4 self.assertReferenceChecks( device_option=gc, op=op, inputs=[X, W, b], reference=fc_tranposed_op if transposed else fc_op, + threshold=threshold ) # Check over multiple devices self.assertDeviceChecks(dc, op, [X, W, b], [0]) diff --git a/caffe2/python/operator_test/matmul_op_test.py b/caffe2/python/operator_test/matmul_op_test.py index 1872a12..64e0e51 100644 --- a/caffe2/python/operator_test/matmul_op_test.py +++ b/caffe2/python/operator_test/matmul_op_test.py @@ -140,9 +140,9 @@ class TestBatchMatMul(serial.SerializedTestCase): ) def test_batch_matmul(self, C, M, K, N, trans_a, trans_b, dtype, gc, dc): if dtype == np.float16: - # fp16 is only supported with CUDA - assume(gc.device_type == caffe2_pb2.CUDA) - dc = [d for d in dc if d.device_type == caffe2_pb2.CUDA] + # fp16 is only supported with CUDA/HIP + assume(core.IsGPUDeviceType(gc.device_type)) + dc = [d for d in dc if core.IsGPUDeviceType(d.device_type)] batch_dims = np.random.randint( low=1, diff --git a/caffe2/python/operator_test/momentum_sgd_test.py b/caffe2/python/operator_test/momentum_sgd_test.py index 39e358f..27dcb78 100644 --- a/caffe2/python/operator_test/momentum_sgd_test.py +++ b/caffe2/python/operator_test/momentum_sgd_test.py @@ -7,8 +7,7 @@ from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial -import hypothesis -from hypothesis import given +from hypothesis import given, assume import hypothesis.strategies as st import numpy as np import unittest @@ -95,7 +94,7 @@ class TestMomentumSGD(serial.SerializedTestCase): ) # Verify that the generated indices are unique - hypothesis.assume( + assume( np.array_equal( np.unique(indices.flatten()), np.sort(indices.flatten()))) @@ -139,9 +138,10 @@ class TestMomentumSGD(serial.SerializedTestCase): [grad, m, lr, w, indices], sparse) - @given(n=st.integers(4, 8), nesterov=st.booleans(), **hu.gcs_gpu_only) - @unittest.skipIf(not workspace.has_gpu_support, "No gpu support.") + @unittest.skipIf(not workspace.has_gpu_support and not workspace.has_hip_support, "No gpu support.") + @given(n=st.integers(4, 8), nesterov=st.booleans(), **hu.gcs) def test_fp16momentum_sgd(self, n, nesterov, gc, dc): + assume(core.IsGPUDeviceType(gc.device_type)) gpuvers = workspace.GetDeviceProperties(0)["major"] if gpuvers < 6: print("No FP16 support because major version {} < 6".format(gpuvers)) diff --git a/caffe2/sgd/fp16_momentum_sgd_op.cu b/caffe2/sgd/fp16_momentum_sgd_op.cu index 4da36da..b7ac0a7 100644 --- a/caffe2/sgd/fp16_momentum_sgd_op.cu +++ b/caffe2/sgd/fp16_momentum_sgd_op.cu @@ -198,7 +198,7 @@ void fp16_momentum_sgd_update( at::Half* param, CUDAContext* context) { const cudaDeviceProp& prop = GetDeviceProperty(0); - if (prop.major >= 6) { + if (prop.major >= kFp16CUDADevicePropMajor) { if (!fp32_update) { FP16MomentumSGDKernel<<< CAFFE_GET_BLOCKS(N / 2), diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu index 12abf42..dc7cb22 100644 --- a/caffe2/utils/math_gpu.cu +++ b/caffe2/utils/math_gpu.cu @@ -35,6 +35,12 @@ #define FIXED_DIVISOR_DIV_MOD(d, n, q, r) (d.DivMod(n, q, r)) #endif // __HIP_PLATFORM_HCC__ +#ifdef __HIP_PLATFORM_HCC__ +using CUBLAS_HALF_TYPE = rocblas_half; +#else // __HIP_PLATFORM_HCC +using CUBLAS_HALF_TYPE = __half; +#endif // __HIP_PLATFORM_HCC + #include "caffe2/utils/math_utils.h" #if THRUST_VERSION >= 100800 @@ -743,9 +749,6 @@ CAFFE2_CUDA_EXPORT void Gemm( at::Half* C, CUDAContext* context, TensorProto::DataType math_type) { -#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16 - CAFFE_THROW("HIP currently does not support FP16 yet."); -#else // Note that cublas follows fortran order, so the order is different from // the cblas convention. const int lda = (trans_A == CblasNoTrans) ? K : M; @@ -757,6 +760,39 @@ CAFFE2_CUDA_EXPORT void Gemm( if (math_type == TensorProto_DataType_FLOAT) { CUBLAS_ENFORCE(cublasSetPointerMode( context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); +#ifdef __HIP_PLATFORM_HCC__ + // rocblas doesn't support cublasSgemmEx type API yet. + // It has more general rocblas_gemm_ex API which is more close to cublasGemmEx + // rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, whereas + // cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C + ROCBLAS_ENFORCE(rocblas_gemm_ex( + context->rocblashandle(), + cu_trans_B, + cu_trans_A, + N, + M, + K, + &alpha, + B, + rocblas_datatype_f16_r, + ldb, + A, + rocblas_datatype_f16_r, + lda, + &beta, + C, + rocblas_datatype_f16_r, + N, + C, // D + rocblas_datatype_f16_r, // D type + N, // ldd + rocblas_datatype_f32_r, // compute type + rocblas_gemm_algo_standard, // rocblas_gemm_algo + 0, // solution index, reserved for future use + 0, // flags, reserved for future use + NULL, // size of workspace + NULL)); // workspace +#else CUBLAS_ENFORCE(cublasSgemmEx( context->cublas_handle(), cu_trans_B, @@ -775,6 +811,7 @@ CAFFE2_CUDA_EXPORT void Gemm( C, CUDA_R_16F, N)); +#endif // __HIP_PLATFORM_HCC__ } else if (math_type == TensorProto_DataType_FLOAT16) { // convert alpha, beta from float -> __half const __half alpha_fp16 = at::Half(alpha); @@ -789,19 +826,18 @@ CAFFE2_CUDA_EXPORT void Gemm( N, M, K, - &alpha_fp16, - (const __half*)B, + reinterpret_cast(&alpha_fp16), + reinterpret_cast(B), ldb, - (const __half*)A, + reinterpret_cast(A), lda, - &beta_fp16, - (__half*)C, + reinterpret_cast(&beta_fp16), + reinterpret_cast(C), N)); } else { // fail CAFFE_THROW("Unsupported math type"); } -#endif } template <> @@ -968,9 +1004,6 @@ CAFFE2_CUDA_EXPORT void GemmBatched( at::Half** C, CUDAContext* context, TensorProto::DataType math_type) { -#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16 - CAFFE_THROW("HIP currently does not support FP16 yet."); -#else #if __CUDACC_VER_MAJOR__ < 9 // loop over matrices in the batch for (int i = 0; i < batch_size; ++i) { @@ -1083,7 +1116,6 @@ CAFFE2_CUDA_EXPORT void GemmBatched( CAFFE_THROW("Unsupported math type"); } #endif -#endif } template <> @@ -1104,10 +1136,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched( const int C_stride, CUDAContext* context, TensorProto::DataType math_type) { -#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16 - CAFFE_THROW("HIP currently does not support FP16 yet."); -#else -#if __CUDACC_VER_MAJOR__ < 8 +#if __CUDACC_VER_MAJOR__ < 8 && !defined(__HIP_PLATFORM_HCC__) // loop over matrices in the batch for (int i = 0; i < batch_size; ++i) { Gemm( @@ -1127,7 +1156,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched( const cublasOperation_t cu_trans_B = (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; if (math_type == TensorProto_DataType_FLOAT) { -#if CUDA_VERSION < 9010 +#if CUDA_VERSION < 9010 && !defined(__HIP_PLATFORM_HCC__) // loop over matrices in the batch for (int i = 0; i < batch_size; ++i) { Gemm( @@ -1139,6 +1168,42 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched( #else CUBLAS_ENFORCE(cublasSetPointerMode( context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); +#ifdef __HIP_PLATFORM_HCC__ + // D[i*stride_d] = alpha*op(A[i*stride_a])*op(B[i*stride_b]) + beta*C[i*stride_c], + // for i in [0,batch_count-1] + ROCBLAS_ENFORCE(rocblas_gemm_strided_batched_ex( + context->rocblashandle(), + cu_trans_B, + cu_trans_A, + N, + M, + K, + &alpha, + B, + rocblas_datatype_f16_r, + ldb, + B_stride, + A, + rocblas_datatype_f16_r, + lda, + A_stride, + &beta, + C, + rocblas_datatype_f16_r, + ldc, + C_stride, + C, // D + rocblas_datatype_f16_r, // D type + ldc, // ldd + C_stride, // D stride + batch_size, + rocblas_datatype_f32_r, // compute type + rocblas_gemm_algo_standard, // rocblas_gemm_algo + 0, // solution index, reserved for future use + 0, // flags, reserved for future use + NULL, // size of workspace + NULL)); // workspace +#else CUBLAS_ENFORCE(cublasGemmStridedBatchedEx( context->cublas_handle(), cu_trans_B, @@ -1163,6 +1228,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched( batch_size, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +#endif // __HIP_PLATFORM_HCC__ #endif } else if (math_type == TensorProto_DataType_FLOAT16) { // Convert alpha, beta from float -> __half @@ -1177,15 +1243,15 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched( N, M, K, - &alpha_fp16, - (const __half*)B, + reinterpret_cast(&alpha_fp16), + reinterpret_cast(B), ldb, B_stride, - (const __half*)A, + reinterpret_cast(A), lda, A_stride, - &beta_fp16, - (__half*)C, + reinterpret_cast(&beta_fp16), + reinterpret_cast(C), ldc, C_stride, batch_size)); @@ -1193,7 +1259,6 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched( CAFFE_THROW("Unsupported math type"); } #endif -#endif } #if CUDA_VERSION >= 9000 @@ -1479,9 +1544,6 @@ CAFFE2_CUDA_EXPORT void Gemv( at::Half* y, CUDAContext* context, TensorProto::DataType math_type) { -#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16 - CAFFE_THROW("HIP currently does not support FP16 yet."); -#else const cublasOperation_t cu_trans_A = (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; @@ -1494,6 +1556,39 @@ CAFFE2_CUDA_EXPORT void Gemv( if (math_type == TensorProto_DataType_FLOAT) { CUBLAS_ENFORCE(cublasSetPointerMode( context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); +#ifdef __HIP_PLATFORM_HCC__ + // rocblas doesn't support cublasSgemmEx type API yet. + // It has more general rocblas_gemm_ex API which is more close to cublasGemmEx + // rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C, whereas + // cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C + ROCBLAS_ENFORCE(rocblas_gemm_ex( + context->rocblashandle(), + cu_trans_A, + rocblas_operation_none, + m, + 1, + k, + &alpha, + A, + rocblas_datatype_f16_r, + lda, + x, + rocblas_datatype_f16_r, + k, + &beta, + y, + rocblas_datatype_f16_r, + ldc, + y, // D + rocblas_datatype_f16_r, // D type + ldc, // ldd + rocblas_datatype_f32_r, // compute type + rocblas_gemm_algo_standard, // rocblas_gemm_algo + 0, // solution index, reserved for future use + 0, // flags, reserved for future use + NULL, // size of workspace + NULL)); // workspace +#else CUBLAS_ENFORCE(cublasSgemmEx( context->cublas_handle(), cu_trans_A, @@ -1512,6 +1607,7 @@ CAFFE2_CUDA_EXPORT void Gemv( y, CUDA_R_16F, ldc)); +#endif // __HIP_PLATFORM_HCC__ } else if (math_type == TensorProto_DataType_FLOAT16) { const __half alpha_fp16 = at::Half(alpha); const __half beta_fp16 = at::Half(beta); @@ -1524,19 +1620,18 @@ CAFFE2_CUDA_EXPORT void Gemv( m, 1, k, - &alpha_fp16, - (const __half*)A, + reinterpret_cast(&alpha_fp16), + reinterpret_cast(A), lda, - (const __half*)x, + reinterpret_cast(x), k, - &beta_fp16, - (__half*)y, + reinterpret_cast(&beta_fp16), + reinterpret_cast(y), ldc)); } else { // fail CAFFE_THROW("Unsupported math type"); } -#endif } namespace { @@ -1727,8 +1822,8 @@ CAFFE2_CUDA_EXPORT void Dot( const at::Half* b, at::Half* y, CUDAContext* context) { -#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16 - CAFFE_THROW("HIP currently does not support FP16 yet."); +#if defined(__HIP_PLATFORM_HCC__) + CAFFE_THROW("HIP currently does not support FP16 completely yet."); #else // execute with 32-bit math CUBLAS_ENFORCE(cublasSetPointerMode( @@ -2358,8 +2453,8 @@ CAFFE2_CUDA_EXPORT void Axpy( const at::Half* X, at::Half* Y, CUDAContext* context) { -#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16 - CAFFE_THROW("HIP currently does not support FP16 yet."); +#if defined(__HIP_PLATFORM_HCC__) + CAFFE_THROW("HIP currently does not support FP16 completely yet."); #else CUBLAS_ENFORCE( cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); @@ -2397,8 +2492,8 @@ CAFFE2_CUDA_EXPORT void Axpy( const at::Half* X, at::Half* Y, CUDAContext* context) { -#if defined(__HIP_PLATFORM_HCC__) && !ROCBLAS_FP16 - CAFFE_THROW("HIP currently does not support FP16 yet."); +#if defined(__HIP_PLATFORM_HCC__) + CAFFE_THROW("HIP currently does not support FP16 completely yet."); #else CUBLAS_ENFORCE(cublasSetPointerMode( context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 07cab96..1b84bf1 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -746,7 +746,6 @@ if(USE_ROCM) list(APPEND HIP_CXX_FLAGS -Wno-unused-command-line-argument) list(APPEND HIP_CXX_FLAGS -Wno-duplicate-decl-specifier) list(APPEND HIP_CXX_FLAGS -DCAFFE2_USE_MIOPEN) - list(APPEND HIP_CXX_FLAGS -DROCBLAS_FP16=0) set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) # Ask hcc to generate device code during compilation so we can use diff --git a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py index c530c50..22aa972 100644 --- a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py +++ b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py @@ -1729,7 +1729,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict([ ("cublasCgemmStridedBatched", ("rocblas_cgemm_strided_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), ("cublasCgemm3mStridedBatched", ("rocblas_cgemm_3m_strided_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), ("cublasZgemmStridedBatched", ("rocblas_zgemm_strided_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), - ("cublasHgemmStridedBatched", ("rocblas_hgemm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasHgemmStridedBatched", ("rocblas_hgemm_strided_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), ("cublasSgemm", ("rocblas_sgemm", CONV_MATH_FUNC, API_BLAS)), ("cublasDgemm", ("rocblas_dgemm", CONV_MATH_FUNC, API_BLAS)), ("cublasCgemm", ("rocblas_cgemm", CONV_MATH_FUNC, API_BLAS)), -- 2.7.4