Add GPU support for float16 batched matmul (#18436)
authorBen Barsdell <benbarsdell@gmail.com>
Thu, 10 May 2018 18:06:01 +0000 (11:06 -0700)
committerJonathan Hseu <vomjom@vomjom.net>
Thu, 10 May 2018 18:06:01 +0000 (11:06 -0700)
* Add GPU support for float16 batched matmul

- Uses cublasGemmBatchedEx introduced in CUDA 9.1.
- Includes support for Tensor Op math.
- Falls back to a loop over non-batched gemm calls on older CUDA
  versions or GPU architectures.

* Refactor GPU batched gemm into one internal func

tensorflow/core/kernels/batch_matmul_op_impl.h
tensorflow/core/kernels/batch_matmul_op_real.cc
tensorflow/stream_executor/blas.h
tensorflow/stream_executor/cuda/cuda_blas.cc
tensorflow/stream_executor/cuda/cuda_blas.h
tensorflow/stream_executor/stream.cc
tensorflow/stream_executor/stream.h

index a1c03f9..475bda8 100644 (file)
@@ -329,6 +329,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
       c_ptrs.push_back(&c_device_memory.back());
     }
 
+    typedef Scalar Coefficient;
+
     // Cublas does
     // C = A x B
     // where A, B and C are assumed to be in column major.
@@ -352,9 +354,9 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
         bool blas_launch_status =
             stream
                 ->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m,
-                               static_cast<Scalar>(1.0), *(a_ptrs[0]),
+                               static_cast<Coefficient>(1.0), *(a_ptrs[0]),
                                adj_x ? m : k, *(b_ptrs[0]), 1,
-                               static_cast<Scalar>(0.0), c_ptrs[0], 1)
+                               static_cast<Coefficient>(0.0), c_ptrs[0], 1)
                 .ok();
         if (!blas_launch_status) {
           context->SetStatus(errors::Internal(
@@ -366,9 +368,9 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
         bool blas_launch_status =
             stream
                 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
-                               static_cast<Scalar>(1.0), *(b_ptrs[0]),
+                               static_cast<Coefficient>(1.0), *(b_ptrs[0]),
                                adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
-                               static_cast<Scalar>(0.0), c_ptrs[0], n)
+                               static_cast<Coefficient>(0.0), c_ptrs[0], n)
                 .ok();
         if (!blas_launch_status) {
           context->SetStatus(errors::Internal(
@@ -383,8 +385,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
           stream
               ->ThenBlasGemmBatchedWithScratch(
                   blas_transpose_b, blas_transpose_a, n, m, k,
-                  static_cast<Scalar>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
-                  adj_x ? m : k, static_cast<Scalar>(0.0), c_ptrs, n,
+                  static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
+                  adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
                   batch_size, &scratch_allocator)
               .ok();
       if (!blas_launch_status) {
@@ -398,6 +400,98 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
   }
 };
 
+template <>
+struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
+  static void Launch(OpKernelContext* context, const Tensor& in_x,
+                     const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
+    typedef Eigen::half Scalar;
+    constexpr perftools::gputools::blas::Transpose kTranspose =
+        is_complex<Scalar>::value
+            ? perftools::gputools::blas::Transpose::kConjugateTranspose
+            : perftools::gputools::blas::Transpose::kTranspose;
+    perftools::gputools::blas::Transpose trans[] = {
+        perftools::gputools::blas::Transpose::kNoTranspose, kTranspose};
+    const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
+    const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
+    const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
+    const uint64 batch_size = in_x.dim_size(0);
+    auto blas_transpose_a = trans[adj_x];
+    auto blas_transpose_b = trans[adj_y];
+
+    auto* stream = context->op_device_context()->stream();
+    OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+    typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
+    std::vector<DeviceMemoryType> a_device_memory;
+    std::vector<DeviceMemoryType> b_device_memory;
+    std::vector<DeviceMemoryType> c_device_memory;
+    std::vector<DeviceMemoryType*> a_ptrs;
+    std::vector<DeviceMemoryType*> b_ptrs;
+    std::vector<DeviceMemoryType*> c_ptrs;
+    a_device_memory.reserve(batch_size);
+    b_device_memory.reserve(batch_size);
+    c_device_memory.reserve(batch_size);
+    a_ptrs.reserve(batch_size);
+    b_ptrs.reserve(batch_size);
+    c_ptrs.reserve(batch_size);
+    auto* a_base_ptr = in_x.template flat<Scalar>().data();
+    auto* b_base_ptr = in_y.template flat<Scalar>().data();
+    auto* c_base_ptr = out->template flat<Scalar>().data();
+    for (int64 i = 0; i < batch_size; ++i) {
+      a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
+      b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
+      c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
+      a_ptrs.push_back(&a_device_memory.back());
+      b_ptrs.push_back(&b_device_memory.back());
+      c_ptrs.push_back(&c_device_memory.back());
+    }
+
+    typedef float Coefficient;
+
+    // Cublas does
+    // C = A x B
+    // where A, B and C are assumed to be in column major.
+    // We want the output to be in row-major, so we can compute
+    // C' = B' x A', where ' stands for transpose (not adjoint).
+    // TODO(yangzihao): Choose the best of the three strategies using autotune.
+    if (batch_size == 1) {
+      // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
+      // overhead of the scratch allocator and the batch interface.
+      // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
+      bool blas_launch_status =
+          stream
+              ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
+                             static_cast<Coefficient>(1.0), *(b_ptrs[0]),
+                             adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
+                             static_cast<Coefficient>(0.0), c_ptrs[0], n)
+              .ok();
+      if (!blas_launch_status) {
+        context->SetStatus(errors::Internal(
+            "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
+            ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
+            ", k=", k));
+      }
+    } else {
+      CublasScratchAllocator scratch_allocator(context);
+      bool blas_launch_status =
+          stream
+              ->ThenBlasGemmBatchedWithScratch(
+                  blas_transpose_b, blas_transpose_a, n, m, k,
+                  static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
+                  adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
+                  batch_size, &scratch_allocator)
+              .ok();
+      if (!blas_launch_status) {
+        context->SetStatus(
+            errors::Internal("Blas xGEMMBatched launch failed : a.shape=",
+                             in_x.shape().DebugString(), ", b.shape=",
+                             in_y.shape().DebugString(), ", m=", m, ", n=", n,
+                             ", k=", k, ", batch_size=", batch_size));
+      }
+    }
+  }
+};
+
 #endif  // GOOGLE_CUDA
 
 #ifdef TENSORFLOW_USE_SYCL
index 7e1e2aa..2bb22bb 100644 (file)
@@ -15,6 +15,10 @@ limitations under the License.
 
 #include "tensorflow/core/kernels/batch_matmul_op_impl.h"
 
+#if GOOGLE_CUDA
+#include "cuda/include/cuda.h"
+#endif  // GOOGLE_CUDA
+
 namespace tensorflow {
 
 #if !defined(INTEL_MKL)
index be0b0bf..ea87744 100644 (file)
@@ -1086,6 +1086,13 @@ class BlasSupport {
   virtual bool DoBlasGemmBatched(
       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
       uint64 n, uint64 k, float alpha,
+      const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+      const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
+      float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
+      int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
+  virtual bool DoBlasGemmBatched(
+      Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+      uint64 n, uint64 k, float alpha,
       const port::ArraySlice<DeviceMemory<float> *> &a, int lda,
       const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,
       const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
@@ -1948,6 +1955,13 @@ class BlasSupport {
   bool DoBlasGemmBatched(                                                      \
       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
       uint64 m, uint64 n, uint64 k, float alpha,                               \
+      const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,         \
+      const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,         \
+      float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,      \
+      int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
+  bool DoBlasGemmBatched(                                                      \
+      Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
+      uint64 m, uint64 n, uint64 k, float alpha,                               \
       const port::ArraySlice<DeviceMemory<float> *> &a, int lda,               \
       const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,   \
       const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,               \
index 3c1353a..38e33d4 100644 (file)
@@ -292,6 +292,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasGetMathMode)
 STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode)
 #endif
 
+#if CUDA_VERSION >= 9010
+PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGemmBatchedEx)
+#endif
+
 }  // namespace wrap
 
 static string ToString(cublasStatus_t status) {
@@ -2342,13 +2346,23 @@ bool CUDABlas::DoBlasGemmWithAlgorithm(
       computation_type, algorithm, output_profile_result);
 }
 
-template <typename T, typename FuncT>
+template <typename T>
+struct HalfAsFloat {
+  typedef T type;
+};
+
+template <>
+struct HalfAsFloat<Eigen::half> {
+  typedef float type;
+};
+
+template <typename T, typename Scalar, typename FuncT>
 port::Status CUDABlas::DoBlasGemmBatchedInternal(
     FuncT cublas_func, Stream *stream, blas::Transpose transa,
-    blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
+    blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha,
     const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
     const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
-    T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
+    Scalar beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
   std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
   for (int i = 0; i < batch_count; ++i) {
@@ -2357,7 +2371,7 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
     c_raw_ptrs.push_back(static_cast<T *>(c_ptrs_to_wrappers[i]->opaque()));
   }
 
-  typedef typename CUDAComplexT<T>::type CUDA_T;
+  typedef typename HalfAsFloat<typename CUDAComplexT<T>::type>::type CUDA_T;
 
   const size_t size = batch_count * sizeof(CUDA_T *);
 
@@ -2409,18 +2423,84 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
                         "CUDABlas::DoBlasGemmBatched");
   }
 
-  bool ok = DoBlasInternal(
-      cublas_func, stream, true /* = pointer_mode_host */,
-      CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
-      CUDAComplex(&alpha), const_cast<const CUDA_T **>(CUDAMemory(a)), lda,
-      const_cast<const CUDA_T **>(CUDAMemory(b)), ldb, CUDAComplex(&beta),
-      const_cast<CUDA_T **>(CUDAMemory(c)), ldc, batch_count);
+  cudaDataType_t data_type = CUDADataType<T>::type;
 
-  if (ok) {
+#if CUDA_VERSION >= 9010
+  int cc_major, cc_minor;
+  if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
+          &cc_major, &cc_minor) &&
+      cc_major >= 5) {
+    bool use_tensor_ops = TensorOpMathEnabled() && data_type == CUDA_R_16F;
+    cublasGemmAlgo_t algo =
+        (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
+    cudaDataType_t compute_type =
+        (data_type == CUDA_R_16F ? CUDA_R_32F : data_type);
+    const void **a_void_ptrs = reinterpret_cast<const void **>(
+        const_cast<const CUDA_T **>(CUDAMemory(a)));
+    const void **b_void_ptrs = reinterpret_cast<const void **>(
+        const_cast<const CUDA_T **>(CUDAMemory(b)));
+    void **c_void_ptrs =
+        reinterpret_cast<void **>(const_cast<CUDA_T **>(CUDAMemory(c)));
+    bool ok;
+    ok = DoBlasInternalImpl(
+        wrap::cublasGemmBatchedEx, stream, true /* = pointer_mode_host */,
+        true /* = err_on_failure */, use_tensor_ops, CUDABlasTranspose(transa),
+        CUDABlasTranspose(transb), m, n, k, &alpha, a_void_ptrs, data_type, lda,
+        b_void_ptrs, data_type, ldb, &beta, c_void_ptrs, data_type, ldc,
+        batch_count, compute_type, algo);
+    if (ok) {
+      return port::Status::OK();
+    }
+    return port::Status(port::error::INTERNAL,
+                        "failed BLAS call, see log for details");
+  }
+#endif
+  // either CUDA_VERSION < 9.1 or SM < 5.0
+  if (data_type != CUDA_R_16F) {
+    bool ok = DoBlasInternal(
+        cublas_func, stream, true /* = pointer_mode_host */,
+        CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+        CUDAComplex(&alpha), const_cast<const CUDA_T **>(CUDAMemory(a)), lda,
+        const_cast<const CUDA_T **>(CUDAMemory(b)), ldb, CUDAComplex(&beta),
+        const_cast<CUDA_T **>(CUDAMemory(c)), ldc, batch_count);
+    if (ok) {
+      return port::Status::OK();
+    }
+    return port::Status(port::error::INTERNAL,
+                        "failed BLAS call, see log for details");
+  } else {
+    // Fall back to a loop for fp16
+    for (int b = 0; b < batch_count; ++b) {
+      const DeviceMemory<T> &a_matrix = *a_ptrs_to_wrappers[b];
+      const DeviceMemory<T> &b_matrix = *b_ptrs_to_wrappers[b];
+      DeviceMemory<T> *c_matrix = c_ptrs_to_wrappers[b];
+      bool ok = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a_matrix,
+                           lda, b_matrix, ldb, beta, c_matrix, ldc);
+      if (!ok) {
+        return port::Status(port::error::INTERNAL,
+                            "failed BLAS call, see log for details");
+      }
+    }
     return port::Status::OK();
   }
-  return port::Status(port::error::INTERNAL,
-                      "failed BLAS call, see log for details");
+}
+
+bool CUDABlas::DoBlasGemmBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, float alpha,
+    const port::ArraySlice<DeviceMemory<Eigen::half> *> &a_array, int lda,
+    const port::ArraySlice<DeviceMemory<Eigen::half> *> &b_array, int ldb,
+    float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c_array,
+    int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+  // Note: The func passed here (cublasSgemmBatched) is not actually called,
+  // due to special handling of fp16 inside DoBlasGemmBatchedInternal.
+  port::Status status = DoBlasGemmBatchedInternal(
+      wrap::cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array,
+      lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
+  if (!status.ok()) {
+    LOG(ERROR) << status;
+  }
+  return status.ok();
 }
 
 bool CUDABlas::DoBlasGemmBatched(
index 12dc5e4..42b3fde 100644 (file)
@@ -107,12 +107,12 @@ class CUDABlas : public blas::BlasSupport {
 
   // A helper function to implement DoBlasGemmBatched interfaces for generic
   // types.
-  template <typename T, typename FuncT>
+  template <typename T, typename Scalar, typename FuncT>
   port::Status DoBlasGemmBatchedInternal(
       FuncT cublas_func, Stream *stream, blas::Transpose transa,
-      blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
+      blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha,
       const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
-      const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
+      const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, Scalar beta,
       const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
       int batch_count, ScratchAllocator *scratch_allocator);
 
index 093f0c9..330320c 100644 (file)
@@ -4482,6 +4482,40 @@ Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
 
 Stream &Stream::ThenBlasGemmBatched(
     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+    uint64 k, float alpha,
+    const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+    const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
+    const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
+    int batch_count) {
+  return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
+                                        b, ldb, beta, c, ldc, batch_count,
+                                        /*scratch_allocator=*/nullptr);
+}
+
+Stream &Stream::ThenBlasGemmBatchedWithScratch(
+    blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+    uint64 k, float alpha,
+    const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+    const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
+    const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
+    int batch_count, ScratchAllocator *scratch_allocator) {
+  VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+            PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+            PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
+
+  ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+               const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
+               const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
+               float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &,
+               int, int, ScratchAllocator *>
+      impl;
+  return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
+              k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
+              scratch_allocator);
+}
+
+Stream &Stream::ThenBlasGemmBatched(
+    blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
index 3d1b011..99d27b5 100644 (file)
@@ -1474,6 +1474,13 @@ class Stream {
       blas::ProfileResult *output_profile_result);
 
   // See BlasSupport::DoBlasGemmBatched.
+  Stream &ThenBlasGemmBatched(
+      blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+      uint64 k, float alpha,
+      const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+      const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
+      float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
+      int ldc, int batch_count);
   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
                               uint64 m, uint64 n, uint64 k, float alpha,
                               const port::ArraySlice<DeviceMemory<float> *> &a,
@@ -1508,6 +1515,13 @@ class Stream {
       int batch_count);
   Stream &ThenBlasGemmBatchedWithScratch(
       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+      uint64 k, float alpha,
+      const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+      const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
+      float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
+      int ldc, int batch_count, ScratchAllocator *scratch_allocator);
+  Stream &ThenBlasGemmBatchedWithScratch(
+      blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
       uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
       int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
       float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,