c_ptrs.push_back(&c_device_memory.back());
}
+ typedef Scalar Coefficient;
+
// Cublas does
// C = A x B
// where A, B and C are assumed to be in column major.
bool blas_launch_status =
stream
->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m,
- static_cast<Scalar>(1.0), *(a_ptrs[0]),
+ static_cast<Coefficient>(1.0), *(a_ptrs[0]),
adj_x ? m : k, *(b_ptrs[0]), 1,
- static_cast<Scalar>(0.0), c_ptrs[0], 1)
+ static_cast<Coefficient>(0.0), c_ptrs[0], 1)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
bool blas_launch_status =
stream
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
- static_cast<Scalar>(1.0), *(b_ptrs[0]),
+ static_cast<Coefficient>(1.0), *(b_ptrs[0]),
adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
- static_cast<Scalar>(0.0), c_ptrs[0], n)
+ static_cast<Coefficient>(0.0), c_ptrs[0], n)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
stream
->ThenBlasGemmBatchedWithScratch(
blas_transpose_b, blas_transpose_a, n, m, k,
- static_cast<Scalar>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
- adj_x ? m : k, static_cast<Scalar>(0.0), c_ptrs, n,
+ static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
+ adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
batch_size, &scratch_allocator)
.ok();
if (!blas_launch_status) {
}
};
+template <>
+struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
+ static void Launch(OpKernelContext* context, const Tensor& in_x,
+ const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
+ typedef Eigen::half Scalar;
+ constexpr perftools::gputools::blas::Transpose kTranspose =
+ is_complex<Scalar>::value
+ ? perftools::gputools::blas::Transpose::kConjugateTranspose
+ : perftools::gputools::blas::Transpose::kTranspose;
+ perftools::gputools::blas::Transpose trans[] = {
+ perftools::gputools::blas::Transpose::kNoTranspose, kTranspose};
+ const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
+ const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
+ const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
+ const uint64 batch_size = in_x.dim_size(0);
+ auto blas_transpose_a = trans[adj_x];
+ auto blas_transpose_b = trans[adj_y];
+
+ auto* stream = context->op_device_context()->stream();
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+ typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
+ std::vector<DeviceMemoryType> a_device_memory;
+ std::vector<DeviceMemoryType> b_device_memory;
+ std::vector<DeviceMemoryType> c_device_memory;
+ std::vector<DeviceMemoryType*> a_ptrs;
+ std::vector<DeviceMemoryType*> b_ptrs;
+ std::vector<DeviceMemoryType*> c_ptrs;
+ a_device_memory.reserve(batch_size);
+ b_device_memory.reserve(batch_size);
+ c_device_memory.reserve(batch_size);
+ a_ptrs.reserve(batch_size);
+ b_ptrs.reserve(batch_size);
+ c_ptrs.reserve(batch_size);
+ auto* a_base_ptr = in_x.template flat<Scalar>().data();
+ auto* b_base_ptr = in_y.template flat<Scalar>().data();
+ auto* c_base_ptr = out->template flat<Scalar>().data();
+ for (int64 i = 0; i < batch_size; ++i) {
+ a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
+ b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
+ c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
+ a_ptrs.push_back(&a_device_memory.back());
+ b_ptrs.push_back(&b_device_memory.back());
+ c_ptrs.push_back(&c_device_memory.back());
+ }
+
+ typedef float Coefficient;
+
+ // Cublas does
+ // C = A x B
+ // where A, B and C are assumed to be in column major.
+ // We want the output to be in row-major, so we can compute
+ // C' = B' x A', where ' stands for transpose (not adjoint).
+ // TODO(yangzihao): Choose the best of the three strategies using autotune.
+ if (batch_size == 1) {
+ // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
+ // overhead of the scratch allocator and the batch interface.
+ // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
+ static_cast<Coefficient>(1.0), *(b_ptrs[0]),
+ adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
+ static_cast<Coefficient>(0.0), c_ptrs[0], n)
+ .ok();
+ if (!blas_launch_status) {
+ context->SetStatus(errors::Internal(
+ "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
+ ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
+ ", k=", k));
+ }
+ } else {
+ CublasScratchAllocator scratch_allocator(context);
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemmBatchedWithScratch(
+ blas_transpose_b, blas_transpose_a, n, m, k,
+ static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
+ adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
+ batch_size, &scratch_allocator)
+ .ok();
+ if (!blas_launch_status) {
+ context->SetStatus(
+ errors::Internal("Blas xGEMMBatched launch failed : a.shape=",
+ in_x.shape().DebugString(), ", b.shape=",
+ in_y.shape().DebugString(), ", m=", m, ", n=", n,
+ ", k=", k, ", batch_size=", batch_size));
+ }
+ }
+ }
+};
+
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
+#if GOOGLE_CUDA
+#include "cuda/include/cuda.h"
+#endif // GOOGLE_CUDA
+
namespace tensorflow {
#if !defined(INTEL_MKL)
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
+ float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
+ virtual bool DoBlasGemmBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha,
const port::ArraySlice<DeviceMemory<float> *> &a, int lda,
const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,
const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, float alpha, \
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, \
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, \
+ float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, \
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
+ bool DoBlasGemmBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, \
const port::ArraySlice<DeviceMemory<float> *> &a, int lda, \
const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, \
const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, \
STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode)
#endif
+#if CUDA_VERSION >= 9010
+PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGemmBatchedEx)
+#endif
+
} // namespace wrap
static string ToString(cublasStatus_t status) {
computation_type, algorithm, output_profile_result);
}
-template <typename T, typename FuncT>
+template <typename T>
+struct HalfAsFloat {
+ typedef T type;
+};
+
+template <>
+struct HalfAsFloat<Eigen::half> {
+ typedef float type;
+};
+
+template <typename T, typename Scalar, typename FuncT>
port::Status CUDABlas::DoBlasGemmBatchedInternal(
FuncT cublas_func, Stream *stream, blas::Transpose transa,
- blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha,
const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
- T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
+ Scalar beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
for (int i = 0; i < batch_count; ++i) {
c_raw_ptrs.push_back(static_cast<T *>(c_ptrs_to_wrappers[i]->opaque()));
}
- typedef typename CUDAComplexT<T>::type CUDA_T;
+ typedef typename HalfAsFloat<typename CUDAComplexT<T>::type>::type CUDA_T;
const size_t size = batch_count * sizeof(CUDA_T *);
"CUDABlas::DoBlasGemmBatched");
}
- bool ok = DoBlasInternal(
- cublas_func, stream, true /* = pointer_mode_host */,
- CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
- CUDAComplex(&alpha), const_cast<const CUDA_T **>(CUDAMemory(a)), lda,
- const_cast<const CUDA_T **>(CUDAMemory(b)), ldb, CUDAComplex(&beta),
- const_cast<CUDA_T **>(CUDAMemory(c)), ldc, batch_count);
+ cudaDataType_t data_type = CUDADataType<T>::type;
- if (ok) {
+#if CUDA_VERSION >= 9010
+ int cc_major, cc_minor;
+ if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
+ &cc_major, &cc_minor) &&
+ cc_major >= 5) {
+ bool use_tensor_ops = TensorOpMathEnabled() && data_type == CUDA_R_16F;
+ cublasGemmAlgo_t algo =
+ (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
+ cudaDataType_t compute_type =
+ (data_type == CUDA_R_16F ? CUDA_R_32F : data_type);
+ const void **a_void_ptrs = reinterpret_cast<const void **>(
+ const_cast<const CUDA_T **>(CUDAMemory(a)));
+ const void **b_void_ptrs = reinterpret_cast<const void **>(
+ const_cast<const CUDA_T **>(CUDAMemory(b)));
+ void **c_void_ptrs =
+ reinterpret_cast<void **>(const_cast<CUDA_T **>(CUDAMemory(c)));
+ bool ok;
+ ok = DoBlasInternalImpl(
+ wrap::cublasGemmBatchedEx, stream, true /* = pointer_mode_host */,
+ true /* = err_on_failure */, use_tensor_ops, CUDABlasTranspose(transa),
+ CUDABlasTranspose(transb), m, n, k, &alpha, a_void_ptrs, data_type, lda,
+ b_void_ptrs, data_type, ldb, &beta, c_void_ptrs, data_type, ldc,
+ batch_count, compute_type, algo);
+ if (ok) {
+ return port::Status::OK();
+ }
+ return port::Status(port::error::INTERNAL,
+ "failed BLAS call, see log for details");
+ }
+#endif
+ // either CUDA_VERSION < 9.1 or SM < 5.0
+ if (data_type != CUDA_R_16F) {
+ bool ok = DoBlasInternal(
+ cublas_func, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ CUDAComplex(&alpha), const_cast<const CUDA_T **>(CUDAMemory(a)), lda,
+ const_cast<const CUDA_T **>(CUDAMemory(b)), ldb, CUDAComplex(&beta),
+ const_cast<CUDA_T **>(CUDAMemory(c)), ldc, batch_count);
+ if (ok) {
+ return port::Status::OK();
+ }
+ return port::Status(port::error::INTERNAL,
+ "failed BLAS call, see log for details");
+ } else {
+ // Fall back to a loop for fp16
+ for (int b = 0; b < batch_count; ++b) {
+ const DeviceMemory<T> &a_matrix = *a_ptrs_to_wrappers[b];
+ const DeviceMemory<T> &b_matrix = *b_ptrs_to_wrappers[b];
+ DeviceMemory<T> *c_matrix = c_ptrs_to_wrappers[b];
+ bool ok = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a_matrix,
+ lda, b_matrix, ldb, beta, c_matrix, ldc);
+ if (!ok) {
+ return port::Status(port::error::INTERNAL,
+ "failed BLAS call, see log for details");
+ }
+ }
return port::Status::OK();
}
- return port::Status(port::error::INTERNAL,
- "failed BLAS call, see log for details");
+}
+
+bool CUDABlas::DoBlasGemmBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a_array, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b_array, int ldb,
+ float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c_array,
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+ // Note: The func passed here (cublasSgemmBatched) is not actually called,
+ // due to special handling of fp16 inside DoBlasGemmBatchedInternal.
+ port::Status status = DoBlasGemmBatchedInternal(
+ wrap::cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array,
+ lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
+ if (!status.ok()) {
+ LOG(ERROR) << status;
+ }
+ return status.ok();
}
bool CUDABlas::DoBlasGemmBatched(
// A helper function to implement DoBlasGemmBatched interfaces for generic
// types.
- template <typename T, typename FuncT>
+ template <typename T, typename Scalar, typename FuncT>
port::Status DoBlasGemmBatchedInternal(
FuncT cublas_func, Stream *stream, blas::Transpose transa,
- blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha,
const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
- const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
+ const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, Scalar beta,
const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
int batch_count, ScratchAllocator *scratch_allocator);
Stream &Stream::ThenBlasGemmBatched(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
+ int batch_count) {
+ return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
+ b, ldb, beta, c, ldc, batch_count,
+ /*scratch_allocator=*/nullptr);
+}
+
+Stream &Stream::ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
+ float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &,
+ int, int, ScratchAllocator *>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
+ k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
+ scratch_allocator);
+}
+
+Stream &Stream::ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
blas::ProfileResult *output_profile_result);
// See BlasSupport::DoBlasGemmBatched.
+ Stream &ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
+ float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
+ int ldc, int batch_count);
Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, float alpha,
const port::ArraySlice<DeviceMemory<float> *> &a,
int batch_count);
Stream &ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
+ float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator);
+ Stream &ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,