Improve bmm() performance on CPU when input tensor is non-contiguous (#19338)
authorMingfei Ma <mingfei.ma@intel.com>
Thu, 18 Apr 2019 13:31:24 +0000 (06:31 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 18 Apr 2019 13:34:17 +0000 (06:34 -0700)
Summary:
This PR aims to improve Transformer performance on CPU, `bmm()` is one of the major bottlenecks now.

Current logic of `bmm()` on CPU only uses MKL batch gemm when the inputs `A` and `B` are contiguous or transposed. So when `A` or `B` is a slice of a larger tensor, it falls to a slower path.

`A` and `B` are both 3D tensors. MKL is able to handle the batch matrix multiplication on occasion that `A.stride(1) == 1 || A.stride(2) == 1` and `B.stride(1) == || B.stride(2) == 1`.

From [fairseq](https://github.com/pytorch/fairseq) implementation of Transformer, multi-head attention has two places to call bmm(), [here](https://github.com/pytorch/fairseq/blob/master/fairseq/modules/multihead_attention.py#L167) and [here](https://github.com/pytorch/fairseq/blob/master/fairseq/modules/multihead_attention.py#L197), `q`, `k`, `v` are all slices from larger tensor. So the `bmm()` falls to slow path at the moment.

Results on Xeon 6148 (20*2 cores 2.5GHz) indicate this PR improves Transformer training performance by **48%** (seconds per iteration reduced from **5.48** to **3.70**), the inference performance should also be boosted.

Before:
```
| epoch 001:   0%| | 27/25337 [02:27<38:31:26,  5.48s/it, loss=16.871, nll_loss=16.862, ppl=119099.70, wps=865, ups=0, wpb=4715.778, bsz=129.481, num_updates=27, lr=4.05e-06, gnorm=9.133,
```
After:
```
| epoch 001:   0%| | 97/25337 [05:58<25:55:49,  3.70s/it, loss=14.736, nll_loss=14.571, ppl=24339.38, wps=1280, ups=0, wpb=4735.299, bsz=131.134, num_updates=97, lr=1.455e-05, gnorm=3.908,
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19338

Differential Revision: D14986346

Pulled By: soumith

fbshipit-source-id: 827106245af908b8a4fda69ed0288d322b028f08

aten/src/ATen/native/LinearAlgebra.cpp
aten/src/ATen/native/mkl/LinearAlgebra.cpp

index ee8b743..b40120c 100644 (file)
@@ -297,8 +297,8 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor&
   }
 
   auto batch_items_contiguous_or_transposed = [&](const Tensor& t) {
-    return (t.stride(2) == 1 && t.stride(1) == t.size(2))
-            || (t.stride(1) == 1 && t.stride(2) == t.size(1));
+    return (t.stride(2) == 1 && t.stride(1) >= t.size(2))
+            || (t.stride(1) == 1 && t.stride(2) >= t.size(1));
   };
 
   if (contraction_size * res_rows * res_cols < 400) {
index 809bd82..a6ecdcd 100644 (file)
@@ -34,10 +34,8 @@ namespace at { namespace native {
 
 static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
   const int batch_size, const int M, const int N, const int K, const float alpha,
-  const float** A, const float** B, const float beta, float** C) {
-  const int lda = (trans_A == CblasNoTrans) ? K : M;
-  const int ldb = (trans_B == CblasNoTrans) ? N : K;
-  const int ldc = N;
+  const float** A, const int lda, const float** B, const int ldb, const float beta,
+  float** C, const int ldc) {
 
   cblas_sgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
     A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
@@ -45,10 +43,8 @@ static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANS
 
 static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
   const int batch_size, const int M, const int N, const int K, const double alpha,
-  const double** A, const double** B, const double beta, double** C) {
-  const int lda = (trans_A == CblasNoTrans) ? K : M;
-  const int ldb = (trans_B == CblasNoTrans) ? N : K;
-  const int ldc = N;
+  const double** A, const int lda, const double** B, const int ldb, const double beta,
+  double** C, const int ldc) {
 
   cblas_dgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
     A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
@@ -57,7 +53,7 @@ static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANS
 template <typename scalar_t>
 static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, const Tensor& mat2, Scalar beta_, Scalar alpha_) {
   auto is_transposed = [&](const Tensor& t) {
-    return t.stride(0) == 1 && t.stride(1) == t.size(0);
+    return t.stride(0) == 1 && t.stride(1) >= t.size(0);
   };
   const CBLAS_TRANSPOSE trans_A = is_transposed(mat1[0]) ? CblasTrans : CblasNoTrans;
   const CBLAS_TRANSPOSE trans_B = is_transposed(mat2[0]) ? CblasTrans : CblasNoTrans;
@@ -69,6 +65,10 @@ static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, c
   scalar_t alpha = alpha_.to<scalar_t>();
   scalar_t beta = beta_.to<scalar_t>();
 
+  const int lda = is_transposed(mat1[0]) ? mat1[0].stride(1) : mat1[0].stride(0);
+  const int ldb = is_transposed(mat2[0]) ? mat2[0].stride(1) : mat2[0].stride(0);
+  const int ldc = res[0].stride(0);
+
   std::vector<const scalar_t*> A(batch_size);
   std::vector<const scalar_t*> B(batch_size);
   std::vector<scalar_t*> C(batch_size);
@@ -78,7 +78,7 @@ static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, c
     C[batch] = res[batch].data<scalar_t>();
   }
 
-  gemm_batched(trans_A, trans_B, batch_size, M, N, K, alpha, A.data(), B.data(), beta, C.data());
+  gemm_batched(trans_A, trans_B, batch_size, M, N, K, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc);
 }
 
 Tensor& _baddbmm_mkl_(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {