From f08f24cd559b5824a1874a0e76d339875e43f366 Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Thu, 10 May 2018 11:06:01 -0700 Subject: [PATCH] Add GPU support for float16 batched matmul (#18436) * Add GPU support for float16 batched matmul - Uses cublasGemmBatchedEx introduced in CUDA 9.1. - Includes support for Tensor Op math. - Falls back to a loop over non-batched gemm calls on older CUDA versions or GPU architectures. * Refactor GPU batched gemm into one internal func --- tensorflow/core/kernels/batch_matmul_op_impl.h | 106 ++++++++++++++++++++++-- tensorflow/core/kernels/batch_matmul_op_real.cc | 4 + tensorflow/stream_executor/blas.h | 14 ++++ tensorflow/stream_executor/cuda/cuda_blas.cc | 106 +++++++++++++++++++++--- tensorflow/stream_executor/cuda/cuda_blas.h | 6 +- tensorflow/stream_executor/stream.cc | 34 ++++++++ tensorflow/stream_executor/stream.h | 14 ++++ 7 files changed, 262 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index a1c03f9..475bda8 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -329,6 +329,8 @@ struct LaunchBatchMatMul { c_ptrs.push_back(&c_device_memory.back()); } + typedef Scalar Coefficient; + // Cublas does // C = A x B // where A, B and C are assumed to be in column major. @@ -352,9 +354,9 @@ struct LaunchBatchMatMul { bool blas_launch_status = stream ->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m, - static_cast(1.0), *(a_ptrs[0]), + static_cast(1.0), *(a_ptrs[0]), adj_x ? m : k, *(b_ptrs[0]), 1, - static_cast(0.0), c_ptrs[0], 1) + static_cast(0.0), c_ptrs[0], 1) .ok(); if (!blas_launch_status) { context->SetStatus(errors::Internal( @@ -366,9 +368,9 @@ struct LaunchBatchMatMul { bool blas_launch_status = stream ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, - static_cast(1.0), *(b_ptrs[0]), + static_cast(1.0), *(b_ptrs[0]), adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k, - static_cast(0.0), c_ptrs[0], n) + static_cast(0.0), c_ptrs[0], n) .ok(); if (!blas_launch_status) { context->SetStatus(errors::Internal( @@ -383,8 +385,8 @@ struct LaunchBatchMatMul { stream ->ThenBlasGemmBatchedWithScratch( blas_transpose_b, blas_transpose_a, n, m, k, - static_cast(1.0), b_ptrs, adj_y ? k : n, a_ptrs, - adj_x ? m : k, static_cast(0.0), c_ptrs, n, + static_cast(1.0), b_ptrs, adj_y ? k : n, a_ptrs, + adj_x ? m : k, static_cast(0.0), c_ptrs, n, batch_size, &scratch_allocator) .ok(); if (!blas_launch_status) { @@ -398,6 +400,98 @@ struct LaunchBatchMatMul { } }; +template <> +struct LaunchBatchMatMul { + static void Launch(OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) { + typedef Eigen::half Scalar; + constexpr perftools::gputools::blas::Transpose kTranspose = + is_complex::value + ? perftools::gputools::blas::Transpose::kConjugateTranspose + : perftools::gputools::blas::Transpose::kTranspose; + perftools::gputools::blas::Transpose trans[] = { + perftools::gputools::blas::Transpose::kNoTranspose, kTranspose}; + const uint64 m = in_x.dim_size(adj_x ? 2 : 1); + const uint64 k = in_x.dim_size(adj_x ? 1 : 2); + const uint64 n = in_y.dim_size(adj_y ? 1 : 2); + const uint64 batch_size = in_x.dim_size(0); + auto blas_transpose_a = trans[adj_x]; + auto blas_transpose_b = trans[adj_y]; + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + + typedef perftools::gputools::DeviceMemory DeviceMemoryType; + std::vector a_device_memory; + std::vector b_device_memory; + std::vector c_device_memory; + std::vector a_ptrs; + std::vector b_ptrs; + std::vector c_ptrs; + a_device_memory.reserve(batch_size); + b_device_memory.reserve(batch_size); + c_device_memory.reserve(batch_size); + a_ptrs.reserve(batch_size); + b_ptrs.reserve(batch_size); + c_ptrs.reserve(batch_size); + auto* a_base_ptr = in_x.template flat().data(); + auto* b_base_ptr = in_y.template flat().data(); + auto* c_base_ptr = out->template flat().data(); + for (int64 i = 0; i < batch_size; ++i) { + a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); + b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); + c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); + a_ptrs.push_back(&a_device_memory.back()); + b_ptrs.push_back(&b_device_memory.back()); + c_ptrs.push_back(&c_device_memory.back()); + } + + typedef float Coefficient; + + // Cublas does + // C = A x B + // where A, B and C are assumed to be in column major. + // We want the output to be in row-major, so we can compute + // C' = B' x A', where ' stands for transpose (not adjoint). + // TODO(yangzihao): Choose the best of the three strategies using autotune. + if (batch_size == 1) { + // This is a regular matrix*matrix or matrix*vector multiply. Avoid the + // overhead of the scratch allocator and the batch interface. + // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS + bool blas_launch_status = + stream + ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), *(b_ptrs[0]), + adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k, + static_cast(0.0), c_ptrs[0], n) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal( + "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(), + ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, + ", k=", k)); + } + } else { + CublasScratchAllocator scratch_allocator(context); + bool blas_launch_status = + stream + ->ThenBlasGemmBatchedWithScratch( + blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), b_ptrs, adj_y ? k : n, a_ptrs, + adj_x ? m : k, static_cast(0.0), c_ptrs, n, + batch_size, &scratch_allocator) + .ok(); + if (!blas_launch_status) { + context->SetStatus( + errors::Internal("Blas xGEMMBatched launch failed : a.shape=", + in_x.shape().DebugString(), ", b.shape=", + in_y.shape().DebugString(), ", m=", m, ", n=", n, + ", k=", k, ", batch_size=", batch_size)); + } + } + } +}; + #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc index 7e1e2aa..2bb22bb 100644 --- a/tensorflow/core/kernels/batch_matmul_op_real.cc +++ b/tensorflow/core/kernels/batch_matmul_op_real.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/core/kernels/batch_matmul_op_impl.h" +#if GOOGLE_CUDA +#include "cuda/include/cuda.h" +#endif // GOOGLE_CUDA + namespace tensorflow { #if !defined(INTEL_MKL) diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index be0b0bf..ea87744 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -1086,6 +1086,13 @@ class BlasSupport { virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, + const port::ArraySlice *> &a, int lda, + const port::ArraySlice *> &b, int ldb, + float beta, const port::ArraySlice *> &c, + int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; + virtual bool DoBlasGemmBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const port::ArraySlice *> &a, int lda, const port::ArraySlice *> &b, int ldb, float beta, const port::ArraySlice *> &c, int ldc, @@ -1948,6 +1955,13 @@ class BlasSupport { bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, float alpha, \ + const port::ArraySlice *> &a, int lda, \ + const port::ArraySlice *> &b, int ldb, \ + float beta, const port::ArraySlice *> &c, \ + int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ + bool DoBlasGemmBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, \ const port::ArraySlice *> &a, int lda, \ const port::ArraySlice *> &b, int ldb, float beta, \ const port::ArraySlice *> &c, int ldc, \ diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 3c1353a..38e33d4 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -292,6 +292,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasGetMathMode) STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode) #endif +#if CUDA_VERSION >= 9010 +PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGemmBatchedEx) +#endif + } // namespace wrap static string ToString(cublasStatus_t status) { @@ -2342,13 +2346,23 @@ bool CUDABlas::DoBlasGemmWithAlgorithm( computation_type, algorithm, output_profile_result); } -template +template +struct HalfAsFloat { + typedef T type; +}; + +template <> +struct HalfAsFloat { + typedef float type; +}; + +template port::Status CUDABlas::DoBlasGemmBatchedInternal( FuncT cublas_func, Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha, const port::ArraySlice *> &a_ptrs_to_wrappers, int lda, const port::ArraySlice *> &b_ptrs_to_wrappers, int ldb, - T beta, const port::ArraySlice *> &c_ptrs_to_wrappers, + Scalar beta, const port::ArraySlice *> &c_ptrs_to_wrappers, int ldc, int batch_count, ScratchAllocator *scratch_allocator) { std::vector a_raw_ptrs, b_raw_ptrs, c_raw_ptrs; for (int i = 0; i < batch_count; ++i) { @@ -2357,7 +2371,7 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal( c_raw_ptrs.push_back(static_cast(c_ptrs_to_wrappers[i]->opaque())); } - typedef typename CUDAComplexT::type CUDA_T; + typedef typename HalfAsFloat::type>::type CUDA_T; const size_t size = batch_count * sizeof(CUDA_T *); @@ -2409,18 +2423,84 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal( "CUDABlas::DoBlasGemmBatched"); } - bool ok = DoBlasInternal( - cublas_func, stream, true /* = pointer_mode_host */, - CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, - CUDAComplex(&alpha), const_cast(CUDAMemory(a)), lda, - const_cast(CUDAMemory(b)), ldb, CUDAComplex(&beta), - const_cast(CUDAMemory(c)), ldc, batch_count); + cudaDataType_t data_type = CUDADataType::type; - if (ok) { +#if CUDA_VERSION >= 9010 + int cc_major, cc_minor; + if (stream->parent()->GetDeviceDescription().cuda_compute_capability( + &cc_major, &cc_minor) && + cc_major >= 5) { + bool use_tensor_ops = TensorOpMathEnabled() && data_type == CUDA_R_16F; + cublasGemmAlgo_t algo = + (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); + cudaDataType_t compute_type = + (data_type == CUDA_R_16F ? CUDA_R_32F : data_type); + const void **a_void_ptrs = reinterpret_cast( + const_cast(CUDAMemory(a))); + const void **b_void_ptrs = reinterpret_cast( + const_cast(CUDAMemory(b))); + void **c_void_ptrs = + reinterpret_cast(const_cast(CUDAMemory(c))); + bool ok; + ok = DoBlasInternalImpl( + wrap::cublasGemmBatchedEx, stream, true /* = pointer_mode_host */, + true /* = err_on_failure */, use_tensor_ops, CUDABlasTranspose(transa), + CUDABlasTranspose(transb), m, n, k, &alpha, a_void_ptrs, data_type, lda, + b_void_ptrs, data_type, ldb, &beta, c_void_ptrs, data_type, ldc, + batch_count, compute_type, algo); + if (ok) { + return port::Status::OK(); + } + return port::Status(port::error::INTERNAL, + "failed BLAS call, see log for details"); + } +#endif + // either CUDA_VERSION < 9.1 or SM < 5.0 + if (data_type != CUDA_R_16F) { + bool ok = DoBlasInternal( + cublas_func, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), const_cast(CUDAMemory(a)), lda, + const_cast(CUDAMemory(b)), ldb, CUDAComplex(&beta), + const_cast(CUDAMemory(c)), ldc, batch_count); + if (ok) { + return port::Status::OK(); + } + return port::Status(port::error::INTERNAL, + "failed BLAS call, see log for details"); + } else { + // Fall back to a loop for fp16 + for (int b = 0; b < batch_count; ++b) { + const DeviceMemory &a_matrix = *a_ptrs_to_wrappers[b]; + const DeviceMemory &b_matrix = *b_ptrs_to_wrappers[b]; + DeviceMemory *c_matrix = c_ptrs_to_wrappers[b]; + bool ok = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a_matrix, + lda, b_matrix, ldb, beta, c_matrix, ldc); + if (!ok) { + return port::Status(port::error::INTERNAL, + "failed BLAS call, see log for details"); + } + } return port::Status::OK(); } - return port::Status(port::error::INTERNAL, - "failed BLAS call, see log for details"); +} + +bool CUDABlas::DoBlasGemmBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, + const port::ArraySlice *> &a_array, int lda, + const port::ArraySlice *> &b_array, int ldb, + float beta, const port::ArraySlice *> &c_array, + int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + // Note: The func passed here (cublasSgemmBatched) is not actually called, + // due to special handling of fp16 inside DoBlasGemmBatchedInternal. + port::Status status = DoBlasGemmBatchedInternal( + wrap::cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, + lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); + if (!status.ok()) { + LOG(ERROR) << status; + } + return status.ok(); } bool CUDABlas::DoBlasGemmBatched( diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h index 12dc5e4..42b3fde 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -107,12 +107,12 @@ class CUDABlas : public blas::BlasSupport { // A helper function to implement DoBlasGemmBatched interfaces for generic // types. - template + template port::Status DoBlasGemmBatchedInternal( FuncT cublas_func, Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha, const port::ArraySlice *> &a_array, int lda, - const port::ArraySlice *> &b_array, int ldb, T beta, + const port::ArraySlice *> &b_array, int ldb, Scalar beta, const port::ArraySlice *> &c_array, int ldc, int batch_count, ScratchAllocator *scratch_allocator); diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 093f0c9..330320c 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -4482,6 +4482,40 @@ Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, Stream &Stream::ThenBlasGemmBatched( blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, + const port::ArraySlice *> &a, int lda, + const port::ArraySlice *> &b, int ldb, float beta, + const port::ArraySlice *> &c, int ldc, + int batch_count) { + return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, + b, ldb, beta, c, ldc, batch_count, + /*scratch_allocator=*/nullptr); +} + +Stream &Stream::ThenBlasGemmBatchedWithScratch( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, + const port::ArraySlice *> &a, int lda, + const port::ArraySlice *> &b, int ldb, float beta, + const port::ArraySlice *> &c, int ldc, + int batch_count, ScratchAllocator *scratch_allocator) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), + PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); + + ThenBlasImpl *> &, int, + const port::ArraySlice *> &, int, + float, const port::ArraySlice *> &, + int, int, ScratchAllocator *> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, + k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, + scratch_allocator); +} + +Stream &Stream::ThenBlasGemmBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const port::ArraySlice *> &a, int lda, const port::ArraySlice *> &b, int ldb, float beta, const port::ArraySlice *> &c, int ldc, diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 3d1b011..99d27b5 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1474,6 +1474,13 @@ class Stream { blas::ProfileResult *output_profile_result); // See BlasSupport::DoBlasGemmBatched. + Stream &ThenBlasGemmBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, + const port::ArraySlice *> &a, int lda, + const port::ArraySlice *> &b, int ldb, + float beta, const port::ArraySlice *> &c, + int ldc, int batch_count); Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const port::ArraySlice *> &a, @@ -1508,6 +1515,13 @@ class Stream { int batch_count); Stream &ThenBlasGemmBatchedWithScratch( blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, + const port::ArraySlice *> &a, int lda, + const port::ArraySlice *> &b, int ldb, + float beta, const port::ArraySlice *> &c, + int ldc, int batch_count, ScratchAllocator *scratch_allocator); + Stream &ThenBlasGemmBatchedWithScratch( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const port::ArraySlice *> &a, int lda, const port::ArraySlice *> &b, int ldb, float beta, const port::ArraySlice *> &c, int ldc, -- 2.7.4