Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / f32 / jit_avx_gemm_f32.cpp
 * limitations under the License.
 *******************************************************************************/
 
-#include <math.h>
+#include <cmath>
+#include <mutex>
 
 #include "mkldnn_thread.hpp"
 #include "utils.hpp"
-#include "gemm_utils.hpp"
+
+#include "ref_gemm_f32.hpp"
+#include "gemm_utils_f32.hpp"
 #include "jit_avx_gemm_f32.hpp"
 
-#define CACHE_LINE_SIZE 64
+#include "jit_generator.hpp"
 
 namespace mkldnn {
 namespace impl {
 namespace cpu {
 
-using namespace mkldnn::impl::memory_format;
-using namespace mkldnn::impl::utils;
+#define CACHE_LINE_SIZE 64
 
-using namespace Xbyak;
 #define STACKSIZE get_size_of_abi_save_regs()
 #if _WIN32
 #define STACK_K_CAPACITY 128
@@ -42,22 +43,25 @@ using namespace Xbyak;
 #define BASE_SHIFT 2
 #define SECOND_FETCH 14
 
-struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
+namespace avx_gemm_f32 {
+using namespace gemm_utils;
+
+struct xbyak_gemm : public jit_generator {
     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm)
 
-    xbyak_gemm(char transa, char transb, float beta, bool hasBias = false,
+    xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false,
             void *code_ptr = nullptr,
             size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
         : jit_generator(code_ptr, code_size)
     {
+        using namespace Xbyak;
+
         const bool is_avx2 = mayiuse(avx2);
         assert(IMPLICATION(!is_avx2, mayiuse(avx)));
 
         const int UNROLL_M = is_avx2 ? 16 : 8;
         const int UNROLL_N = 6;
 
-        bool isTransA = (transa == 'T' || transa == 't');
-        bool isTransB = (transb == 'T' || transb == 't');
         bool isBeta0 = (beta == 0.0);
         bool isBetaN = (!isBeta0 && beta != 1.0);
 
@@ -2275,38 +2279,60 @@ struct jit_avx_gemm_f32::xbyak_gemm : public jit_generator {
 
         L(main999);
         // Restore original stack
-        mov(rax, ORIG_SP);
-        mov(rsp, rax);
+        mov(rsp, ORIG_SP);
 
         vzeroupper();
         postamble();
 
-        ker_ = reinterpret_cast<decltype(ker_)>(
-                const_cast<uint8_t *>(this->getCode()));
+        ker_ = this->getCode<ker_t>();
     }
 
-    void operator()(long long int m, long long int n, long long int k,
-            const float *alpha, const float *a, long long int lda,
-            const float *b, long long int ldb, const float *beta, float *c,
-            long long int ldc, const float *bias, float *ws)
+    typedef void (*ker_t)(dim_t m, dim_t n, dim_t k,
+            const float *alpha, const float *a, dim_t lda,
+            const float *b, dim_t ldb, const float *beta, float *c,
+            dim_t ldc, const float *bias, float *ws);
+
+    void operator()(dim_t  m, dim_t n, dim_t k,
+            const float *alpha, const float *a, dim_t lda,
+            const float *b, dim_t ldb, const float *beta, float *c,
+            dim_t ldc, const float *bias, float *ws) const
     {
-        (*ker_)(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
+        ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
     }
 
 private:
-    void (*ker_)(long long int m, long long int n, long long int k,
-            const float *alpha, const float *a, long long int lda,
-            const float *b, long long int ldb, const float *beta, float *c,
-            long long int ldc, const float *bias, float *ws);
+    ker_t ker_;
 };
 
-typedef void (*ker)(long long int, long long int, long long int, float *,
-        float *, long long int, float *, long long int, float *, float *,
-        long long int, float *);
-void jit_avx_gemm_f32::sgemm_nocopy_driver(const char *transa,
+const xbyak_gemm *get_xbyak_gemm(
+        bool isTransA, bool isTransB, float beta, bool hasBias) {
+    auto beta_idx = [](float beta) {
+        return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2);
+    };
+
+    // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)]
+    static xbyak_gemm *kernel_table[2][2][2][3];
+    static std::once_flag initialized;
+    std::call_once(initialized, [=]{
+            for (bool isTransA: {false, true})
+            for (bool isTransB: {false, true})
+            for (bool hasBias: {false, true})
+            for (float beta: {0.0f, 1.0f, 2.0f}) {
+                // nocopy sgemm with bias for beta != 0.0 is not supported
+                if (hasBias && beta != 0.0)
+                    continue;
+                kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] =
+                    new xbyak_gemm(isTransA, isTransB, beta, hasBias);
+            }
+    });
+
+    return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)];
+}
+
+void sgemm_nocopy_driver(const char *transa,
         const char *transb, int m, int n, int k, const float *alpha,
-        const float *a, int lda, const float *b, int ldb, const float *beta,
-        float *c, int ldc, const float *bias, float *ws)
+        const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta,
+        float *c, dim_t ldc, const float *bias, float *ws)
 {
     bool isTransA = (*transa == 'T' || *transa == 't');
     bool isTransB = (*transb == 'T' || *transb == 't');
@@ -2333,6 +2359,15 @@ void jit_avx_gemm_f32::sgemm_nocopy_driver(const char *transa,
         return;
     }
 
+    assert(IMPLICATION(bias != nullptr, *beta == 0.0));
+
+    // XXX: this happens on every thread...
+    bool hasBias = (bias != nullptr);
+    auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias);
+    auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false);
+    auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false);
+    assert(ker_bn && ker_b1 && ker_b0);
+
     int BM = 4032;
     int BN = isTransA ? 96 : 48;
     int BK = isTransB ? 96 : 256;
@@ -2367,14 +2402,14 @@ void jit_avx_gemm_f32::sgemm_nocopy_driver(const char *transa,
                 }
 
                 if (!isTransA) {
-                    curA = a + Bm + (size_t)Bk * lda;
+                    curA = a + Bm + Bk * lda;
                 } else {
-                    curA = a + Bk + (size_t)Bm * lda;
+                    curA = a + Bk + Bm * lda;
                 }
                 if (!isTransB) {
-                    curB = b + Bk + (size_t)Bn * ldb;
+                    curB = b + Bk + Bn * ldb;
                 } else {
-                    curB = b + Bn + (size_t)Bk * ldb;
+                    curB = b + Bn + Bk * ldb;
                 }
                 curC = c + Bm + (size_t)Bn * ldc;
                 if (bias != nullptr) {
@@ -2386,51 +2421,54 @@ void jit_avx_gemm_f32::sgemm_nocopy_driver(const char *transa,
                 }
                 if (Bk == 0) {
                     if (*beta == 0.0 && bias == nullptr)
-                        (*ker_b0_)((long long int)sizeM, (long long int)sizeN,
-                                (long long int)sizeK, alpha, curA,
-                                (long long int)lda, curB, (long long int)ldb,
-                                beta, curC, (long long int)ldc, curBias, ws);
+                        (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+                                alpha, curA, lda, curB, ldb, beta, curC, ldc,
+                                curBias, ws);
                     else
-                        (*ker_bn_)((long long int)sizeM, (long long int)sizeN,
-                                (long long int)sizeK, alpha, curA,
-                                (long long int)lda, curB, (long long int)ldb,
-                                beta, curC, (long long int)ldc, curBias, ws);
+                        (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+                                alpha, curA, lda, curB, ldb, beta, curC, ldc,
+                                curBias, ws);
                 } else {
-                    (*ker_b1_)((long long int)sizeM, (long long int)sizeN,
-                            (long long int)sizeK, alpha, curA,
-                            (long long int)lda, curB, (long long int)ldb, beta,
-                            curC, (long long int)ldc, curBias, ws);
+                    (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+                            alpha, curA, lda, curB, ldb, beta, curC, ldc,
+                            curBias, ws);
                 }
             }
         }
     }
-    return;
 }
-void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
+
+}
+
+mkldnn_status_t jit_avx_gemm_f32(
+        const char *transa, const char *transb,
         const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
         const float *A, const int *p_lda, const float *B, const int *p_ldb,
         const float *p_beta, float *C, const int *p_ldc, const float *bias)
 {
-    if (beta_ == 0. || beta_ == 1.)
-        assert(*p_beta == beta_);
-    assert((one_of(*transa, 'T', 't') == one_of(transa_, 'T', 't')));
+    using namespace mkldnn::impl::utils;
+    using namespace avx_gemm_f32;
+    using namespace gemm_utils;
+
+    if (*p_beta != 0 && bias)
+        return ref_gemm(transa, transb, p_m, p_n, p_k,
+                p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias);
+
+    int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
 
-    int nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
     int m = *p_m;
     int n = *p_n;
     int k = *p_k;
-    int lda = *p_lda;
-    int ldb = *p_ldb;
-    int ldc = *p_ldc;
+    dim_t lda = *p_lda;
+    dim_t ldb = *p_ldb;
+    dim_t ldc = *p_ldc;
     float beta = *p_beta;
     int MB, NB, KB;
 
     int nthr_m, nthr_n, nthr_k, nthr_mn;
 
-    assert(nthr <= nthrs_);
-
     // Determine threading partitioning
-    gemm_utils::calc_nthr_nocopy_avx(
+    calc_nthr_nocopy_avx(
             m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
     assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
 
@@ -2460,14 +2498,14 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
                 * sizeof(float), PAGE_4K);
     }
 
-    const size_t ws_elems_per_thr = k * 16 + 64;
+    const size_t ws_elems_per_thr = (size_t)k * 16 + 64;
     const size_t ws_size_per_thr
-            = utils::rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
+            = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
     if (k > STACK_K_CAPACITY) {
         ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
     }
 
-    parallel(nthr, [&](const int ithr, const int nthr) {
+    parallel_nd(nthr, [&](const int ithr) {
         int ithr_m, ithr_n, ithr_k, ithr_mn;
         int m_from, m_to, myM;
         int n_from, n_to, myN;
@@ -2477,7 +2515,9 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
         float *myC = C, myBeta;
         float *ws = ws_buffers ?
                 ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
-        int ld = ldc;
+        dim_t ld = ldc;
+
+        int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k);
 
         if (ithr < nthr_m * nthr_n * nthr_k) {
 
@@ -2529,10 +2569,10 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
                     myC = &(C[m_from + n_from * ldc]);
                     myBeta = beta;
                     ld = ldc;
-                    if (hasBias_)
+                    if (bias)
                         myBias = &(bias[m_from]);
                 } else {
-                    myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
+                    myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
                     myBeta = 0.0;
                     ld = MB;
                     myBias = nullptr;
@@ -2541,40 +2581,40 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
                 sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
                         lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
 
-                if (nthr_k > 1)
+                if (nthr_k > 1 && !sum_later)
                     ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
             }
 
-            if (nthr_k > 1) {
+            if (nthr_k > 1 && !sum_later) {
 
                 // sum matrices partitioned along K dimension
                 int n1, n2;
 
-                gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
+                partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
 
                 if (ithr_k > 0) {
 
-                    myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
-                    myC = myC + n1 * MB;
+                    myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
+                        + (dim_t)n1 * MB;
                     /* need to wait until main thread finishes */
                     while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
                     };
 
                     /* my cache is hot */
-                    gemm_utils::sum_two_matrices(myM, n2, myC, MB,
+                    sum_two_matrices(myM, n2, myC, MB,
                             &C[m_from + (n_from + n1) * ldc], ldc);
                 }
 
                 for (int ik = 1; ik < nthr_k; ++ik) {
                     if (ik != ithr_k) {
 
-                        myC = c_buffers + MB * NB * (cbase + ik - 1);
-                        myC = myC + n1 * MB;
+                        myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
+                            + (dim_t)n1 * MB;
 
                         while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
                         };
 
-                        gemm_utils::sum_two_matrices(myM, n2, myC, MB,
+                        sum_two_matrices(myM, n2, myC, MB,
                                 &C[m_from + (n_from + n1) * ldc], ldc);
                     }
                 }
@@ -2582,42 +2622,80 @@ void jit_avx_gemm_f32::sgemm(const char *transa, const char *transb,
         }
     });
 
+    // handle C summation later
+    if (nthr_k > 1 && ompstatus[0] == 0) {
+
+        parallel_nd(nthr, [&](const int ithr) {
+            int ithr_m, ithr_n, ithr_k, ithr_mn;
+            int m_from, m_to, myM;
+            int n_from, n_to, myN;
+            int cbase;
+            float *myC = C;
+
+            if (ithr < nthr_m * nthr_n * nthr_k) {
+
+                ithr_mn = ithr % nthr_mn;
+                ithr_m = ithr_mn % nthr_m;
+                ithr_n = ithr_mn / nthr_m;
+                ithr_k = ithr / nthr_mn;
+
+                /* swap ithr_k for performance improvement */
+                if (ithr_k == 0)
+                    ithr_k = nthr_k - 1;
+                else if (ithr_k == nthr_k - 1)
+                    ithr_k = 0;
+
+                m_from = MB * (ithr_m);
+                m_to = MB * (ithr_m + 1);
+                if (m_to > m)
+                    m_to = m;
+                myM = m_to - m_from;
+
+                n_from = NB * (ithr_n);
+                n_to = NB * (ithr_n + 1);
+                if (n_to > n)
+                    n_to = n;
+                myN = n_to - n_from;
+
+                cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+
+                if (nthr_k > 1) {
+                    // sum matrices partitioned along K dimension
+                    int n1, n2;
+
+                    partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
+
+                    if (ithr_k > 0) {
+
+                        myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
+                            + (dim_t)n1 * MB;
+
+                        /* my cache is hot */
+                        sum_two_matrices(myM, n2, myC, MB,
+                                         &C[m_from + (n_from + n1) * ldc], ldc);
+                    }
+
+                    for (int ik = 1; ik < nthr_k; ++ik) {
+                        if (ik != ithr_k) {
+
+                            myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
+                                + (dim_t)n1 * MB;
+
+                            sum_two_matrices(myM, n2, myC, MB,
+                                             &C[m_from + (n_from + n1) * ldc], ldc);
+                        }
+                    }
+                }
+            }
+        });
+    }
+
+
     free(c_buffers);
     free(ompstatus_);
     free(ws_buffers);
-}
-
-jit_avx_gemm_f32::jit_avx_gemm_f32(
-        char transa, char transb, float beta, bool hasBias)
-{
-    transa_ = transa;
-    transb_ = transb;
-    beta_ = beta;
-    hasBias_ = hasBias;
-    if (hasBias) {
-        assert(beta == 0.0);
-    }
-    ker_bn_ = new xbyak_gemm(transa, transb, beta, hasBias);
-    if (beta != 1.0) {
-        ker_b1_ = new xbyak_gemm(transa, transb, 1.0);
-    } else {
-        ker_b1_ = ker_bn_;
-    }
-    if (beta != 0.0 || (beta == 0.0 && hasBias)) {
-        ker_b0_ = new xbyak_gemm(transa, transb, 0.0);
-    } else {
-        ker_b0_ = ker_bn_;
-    }
-    nthrs_ = mkldnn_get_max_threads();
-}
 
-jit_avx_gemm_f32::~jit_avx_gemm_f32()
-{
-    delete ker_bn_;
-    if (beta_ != 1.0)
-        delete ker_b1_;
-    if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_))
-        delete ker_b0_;
+    return mkldnn_success;
 }
 
 }