Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / f32 / ref_gemm_f32.cpp
 * limitations under the License.
 *******************************************************************************/
 
+#include "mkldnn_types.h"
+
 #include "mkldnn_thread.hpp"
 #include "nstl.hpp"
 #include "utils.hpp"
 
-#include "../jit_generator.hpp"
+#include "jit_generator.hpp"
 
-#include "gemm_utils.hpp"
+#include "gemm_utils_f32.hpp"
+#include "ref_gemm_f32.hpp"
 
 namespace mkldnn {
 namespace impl {
@@ -29,13 +32,14 @@ namespace cpu {
 using namespace mkldnn::impl::utils;
 using namespace gemm_utils;
 
+namespace {
 
 template <typename data_t>
-static void copy_A(
-        bool isTransA, int K, const data_t *A, const int lda, data_t *ws) {
+void copy_A(
+        bool isTransA, int K, const data_t *A, const dim_t lda, data_t *ws) {
     for (int k = 0; k < K; k++) {
         PRAGMA_OMP_SIMD()
-        for (int i = 0; i < gemm_utils::unroll_factor<data_t>::m; i++) {
+        for (int i = 0; i < unroll_factor<data_t>::m; i++) {
             ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda];
         }
         ws += unroll_factor<data_t>::m;
@@ -43,8 +47,8 @@ static void copy_A(
 }
 
 template <typename data_t, bool isTransA, bool isTransB>
-static void kernel_mxn(int K, const data_t *A, const int lda,
-        const data_t *B, const int ldb, data_t *C, const int ldc,
+void kernel_mxn(int K, const data_t *A, const dim_t lda,
+        const data_t *B, const dim_t ldb, data_t *C, const dim_t ldc,
         const data_t alpha, const data_t beta) {
     data_t c[unroll_factor<data_t>::m * unroll_factor<data_t>::n] =
         { static_cast<data_t>(0.) };
@@ -70,9 +74,9 @@ static void kernel_mxn(int K, const data_t *A, const int lda,
 }
 
 template <typename data_t, bool isTransA, bool isTransB>
-static void block_ker(const int M, const int N, const int K,
-        const data_t *A, const int lda, const data_t *B, const int ldb,
-        data_t *C, const int ldc, const data_t alpha, const data_t beta,
+void block_ker(const int M, const int N, const int K,
+        const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb,
+        data_t *C, const dim_t ldc, const data_t alpha, const data_t beta,
         data_t *ws, bool do_copy) {
     int Nu = rnd_dn(N, unroll_factor<data_t>::n);
     int Mu = rnd_dn(M, unroll_factor<data_t>::m);
@@ -124,8 +128,9 @@ static void block_ker(const int M, const int N, const int K,
 
 template <typename data_t, bool isTransA, bool isTransB>
 void gemm_ithr(const int M, const int N, const int K, const data_t alpha,
-        const data_t *A, const int lda, const data_t *B, const int ldb,
-        const data_t beta, data_t *C, const int ldc, bool do_copy, data_t *ws) {
+        const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb,
+        const data_t beta, data_t *C, const dim_t ldc, bool do_copy,
+        data_t *ws) {
     constexpr int BM = gemm_traits<data_t, isTransA, isTransB>::BM;
     constexpr int BN = gemm_traits<data_t, isTransA, isTransB>::BN;
     constexpr int BK = gemm_traits<data_t, isTransA, isTransB>::BK;
@@ -138,12 +143,12 @@ void gemm_ithr(const int M, const int N, const int K, const data_t alpha,
         return;
 
     if ((K <= 0) || (alpha == static_cast<data_t>(0))) {
-        ptrdiff_t MN = (ptrdiff_t)N * M;
+        dim_t MN = N * M;
         if (beta == static_cast<data_t>(0.)) {
-            for (ptrdiff_t j = 0; j < MN; j++)
+            for (dim_t j = 0; j < MN; j++)
                 C[j] = static_cast<data_t>(0.);
         } else if (beta != static_cast<data_t>(1.)) {
-            for (ptrdiff_t j = 0; j < MN; j++)
+            for (dim_t j = 0; j < MN; j++)
                 C[j] *= beta;
         }
         return;
@@ -171,21 +176,26 @@ void gemm_ithr(const int M, const int N, const int K, const data_t alpha,
     }
 }
 
+}
+
 template <typename data_t>
-void ref_gemm(const char *transa_, const char *transb_, const int *M_,
+mkldnn_status_t ref_gemm(
+        const char *transa_, const char *transb_, const int *M_,
         const int *N_, const int *K_, const data_t *alpha_, const data_t *A,
         const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_,
         data_t *C, const int *ldc_, const data_t *bias) {
+
     bool isTransA = (*transa_ == 'T' || *transa_ == 't');
     bool isTransB = (*transb_ == 'T' || *transb_ == 't');
-    const int M = *M_, N = *N_, K = *K_, lda = *lda_, ldb = *ldb_, ldc = *ldc_;
+    const int M = *M_, N = *N_, K = *K_;
+    const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_;
     const data_t alpha = *alpha_, beta = *beta_;
 
     int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
     int nthr_m, nthr_n, nthr_k;
     int MB, NB, KB;
     // thread balancing over M, N, K & size of blocking dimensions
-    gemm_utils::calc_nthr_nocopy_avx(
+    calc_nthr_nocopy_avx(
             M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
     assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
 
@@ -205,14 +215,23 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_,
     const int nthr = nthr_mn * nthr_k;
     const size_t ws_elems_per_thr = K * unroll_factor<data_t>::m;
     const size_t ws_size_per_thr
-            = utils::rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K);
+            = rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K);
     if (do_copy) {
         ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K);
         if (!ws_buffers)
             do_copy = false;
     }
 
-    parallel(nthr, [&](const int ithr, const int nthr) {
+    auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N,
+                             int ithr) {
+        from = NB * (ithr);
+        to = NB * (ithr + 1);
+        if (to > N)
+            to = N;
+        myN = to - from;
+    };
+
+    parallel_nd(nthr, [&](const int ithr) {
         int ithr_mn = ithr % nthr_mn;
         int ithr_m = ithr_mn % nthr_m;
         int ithr_n = ithr_mn / nthr_m;
@@ -226,27 +245,20 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_,
 
         int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0,
                 k_from = 0, k_to = 0, myK = 0;
-        auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N,
-                int ithr) {
-            from = NB * (ithr);
-            to = NB * (ithr + 1);
-            if (to > N)
-                to = N;
-            myN = to - from;
-        };
+
         get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
         get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
         get_thr_block(k_from, k_to, myK, KB, K, ithr_k);
 
         if (myM > 0 && myN > 0) {
             data_t myBeta, *myC;
-            int ld;
+            dim_t ld;
             if (ithr_k == 0) {
                 myC = &(C[m_from + n_from * ldc]);
                 myBeta = beta;
                 ld = ldc;
             } else {
-                myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
+                myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
                 myBeta = 0.0f;
                 ld = MB;
             }
@@ -275,23 +287,36 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_,
                 }
             }
         }
+    });
 
-        if (nthr_k > 1) {
-            assert(mkldnn_thr_syncable());
-            mkldnn_thr_barrier();
+    if (nthr_k > 1) {
+        parallel_nd(nthr, [&](const int ithr) {
+            int ithr_mn = ithr % nthr_mn;
+            int ithr_m = ithr_mn % nthr_m;
+            int ithr_k = ithr / nthr_mn;
+            int ithr_n = ithr_mn / nthr_m;
+
+            int n_from = 0, n_to = 0, myN = 0;
+            int m_from = 0, m_to = 0, myM = 0;
+
+            int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+
+            get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
+            get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
 
             // sum matrices partitioned along K dimension
             int offset = 0, block = 0;
             gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset,
                     &block);
             for (int ik = 1; ik < nthr_k; ++ik) {
-                data_t *myC = c_buffers + MB * (NB * (cbase + ik - 1) + offset);
+                data_t *myC = c_buffers
+                            + MB * ((dim_t)NB * (cbase + ik - 1) + offset);
 
                 gemm_utils::sum_two_matrices(myM, block, myC, MB,
                         &C[m_from + (n_from + offset) * ldc], ldc);
             }
-        }
-    });
+        });
+    }
 
     if (bias) {
         parallel_nd(N, M, [&](int i, int j) {
@@ -301,14 +326,18 @@ void ref_gemm(const char *transa_, const char *transb_, const int *M_,
 
     free(ws_buffers);
     free(c_buffers);
+
+    return mkldnn_success;
 }
 
-template void ref_gemm<float>(const char *transa_, const char *transb_,
+template mkldnn_status_t ref_gemm<float>(
+        const char *transa_, const char *transb_,
         const int *M_, const int *N_, const int *K_, const float *alpha_,
         const float *A, const int *lda_, const float *B, const int *ldb_,
         const float *beta_, float *C, const int *ldc_, const float *bias);
 
-template void ref_gemm<double>(const char *transa_, const char *transb_,
+template mkldnn_status_t ref_gemm<double>(
+        const char *transa_, const char *transb_,
         const int *M_, const int *N_, const int *K_, const double *alpha_,
         const double *A, const int *lda_, const double *B, const int *ldb_,
         const double *beta_, double *C, const int *ldc_, const double *bias);