Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / gemm.cpp
index 146e688..ac619b1 100644 (file)
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
-#include <mutex>
 
 #include "mkldnn.h"
 
-#include "verbose.hpp"
+#include "mkldnn_traits.hpp"
+#include "nstl.hpp"
+
+#include "jit_generator.hpp"
 
-#include "jit_avx_gemm_f32.hpp"
-#include "jit_avx512_common_gemm_f32.hpp"
 #include "gemm.hpp"
-#include "../jit_generator.hpp"
-#include "nstl.hpp"
+
+#include "f32/jit_avx512_common_gemm_f32.hpp"
+#include "f32/jit_avx_gemm_f32.hpp"
+#include "f32/ref_gemm_f32.hpp"
+
+#include "s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp"
+#include "s8x8s32/jit_avx512_core_gemm_s8s8s32.hpp"
+#include "s8x8s32/ref_gemm_s8x8s32.hpp"
+
 #include "os_blas.hpp"
-#include "math_utils.hpp"
-#include "mkldnn_traits.hpp"
 
 /* USE_MKL      USE_CBLAS       effect
  * -------      ---------       ------
 namespace mkldnn {
 namespace impl {
 namespace cpu {
-using namespace mkldnn::impl::status;
+
 mkldnn_status_t check_gemm_input(const char *transa, const char *transb,
         const int *M, const int *N, const int *K, const int *lda,
         const int *ldb, const int *ldc, const float *alpha, const float *beta,
         const bool with_bias) {
     if (utils::any_null(transa, transb, M, N, K, lda, ldb, ldc, alpha, beta))
-        return invalid_arguments;
+        return mkldnn_invalid_arguments;
     if (with_bias && *beta != 0)
-        return unimplemented;
+        return mkldnn_unimplemented;
     bool consistency = true
         && utils::one_of(*transa, 'T', 't', 'N', 'n')
         && utils::one_of(*transb, 'T', 't', 'N', 'n')
@@ -55,7 +60,8 @@ mkldnn_status_t check_gemm_input(const char *transa, const char *transb,
         && *N >= 0
         && *K >= 0;
 
-    if (!consistency) return invalid_arguments;
+    if (!consistency)
+        return mkldnn_invalid_arguments;
     bool isTransA = utils::one_of(*transa, 'T', 't');
     bool isTransB = utils::one_of(*transb, 'T', 't');
     int nrowA = isTransA ? *K : *M;
@@ -64,136 +70,65 @@ mkldnn_status_t check_gemm_input(const char *transa, const char *transb,
         && *lda >= nstl::max(1, nrowA)
         && *ldb >= nstl::max(1, nrowB)
         && *ldc >= nstl::max(1, *M);
-    if (!consistency) return invalid_arguments;
+    if (!consistency)
+        return mkldnn_invalid_arguments;
 
-    return success;
+    return mkldnn_success;
 }
 
 mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc,
         const char *transa, const char *transb, const int *M, const int *N,
         const int *K, const int *lda, const int *ldb, const int *ldc,
         const float *alpha, const float *beta, const bool with_bias) {
-
-    if (offsetc == nullptr) return invalid_arguments;
+    if (offsetc == nullptr)
+        return mkldnn_invalid_arguments;
     if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r'))
-        return invalid_arguments;
+        return mkldnn_invalid_arguments;
 
     return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha,
         beta, with_bias);
 }
 
-struct gemm_impl_t {
-    gemm_impl_t(char transa, char transb, bool zero_beta, bool with_bias) {
-        //jit kernel has three codepaths: beta is 0, 1 or arbitrary
-        //we will generate kernel for 0 and arbitrary beta
-        float zero = 0.0f, arbitrary_float = 2.0f;
-        if (mayiuse(avx512_common)) {
-            isa_ = avx512_common;
-            ker_ = (void *)new jit_avx512_common_gemm_f32(
-                    transa, transb, zero_beta ? zero : arbitrary_float,
-                    with_bias);
-        }
-        else if (mayiuse(avx)) {
-            isa_ = avx;
-            ker_ = (void *)new jit_avx_gemm_f32(
-                    transa, transb, zero_beta ? zero : arbitrary_float,
-                    with_bias);
-        }
-    }
-
-    mkldnn_status_t call(const char *transa, const char *transb, const int *M,
-            const int *N, const int *K, const float *alpha, const float *A,
-            const int *lda, const float *B, const int *ldb, const float *beta,
-            float *C, const int *ldc, const float *bias = nullptr) {
-        switch (isa_) {
-            case avx:
-                ((jit_avx_gemm_f32*)ker_)->sgemm(transa, transb, M, N, K,
-                    alpha, A, lda, B, ldb, beta, C, ldc, bias);
-                break;
-            case avx512_common:
-                ((jit_avx512_common_gemm_f32*)ker_)->sgemm(transa, transb,
-                    M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
-                break;
-            default:
-                ref_gemm(transa, transb, M, N, K, alpha, A, lda, B, ldb, beta,
-                        C, ldc, bias);
-                break;
-        }
-        return mkldnn_success;
-    }
-
-    void *ker_;
-    cpu_isa_t isa_;
-};
-//Gemm implementations for: zero/nonzero beta, transA, transB
-static gemm_impl_t *gemm_impl[2][2][2];
-//Gemm with bias implementations for: transA, transB
-//Gemm with bias for beta!=0. is not supported
-static gemm_impl_t *gemm_bias_impl[2][2];
-
-void initialize() {
-    for (int i = 0; i < 2; ++i) {
-        gemm_impl[i][0][0] = new gemm_impl_t('n', 'n', (bool)i, false);
-        gemm_impl[i][0][1] = new gemm_impl_t('n', 't', (bool)i, false);
-        gemm_impl[i][1][0] = new gemm_impl_t('t', 'n', (bool)i, false);
-        gemm_impl[i][1][1] = new gemm_impl_t('t', 't', (bool)i, false);
-    }
-    gemm_bias_impl[0][0] = new gemm_impl_t('n', 'n', true, true);
-    gemm_bias_impl[0][1] = new gemm_impl_t('n', 't', true, true);
-    gemm_bias_impl[1][0] = new gemm_impl_t('t', 'n', true, true);
-    gemm_bias_impl[1][1] = new gemm_impl_t('t', 't', true, true);
-}
-
 mkldnn_status_t extended_sgemm(const char *transa, const char *transb,
         const int *M, const int *N, const int *K, const float *alpha,
         const float *A, const int *lda, const float *B, const int *ldb,
         const float *beta, float *C, const int *ldc,
         const float *bias, const bool force_jit_gemm) {
-    //Check input
     mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K,
             lda, ldb, ldc, alpha, beta, bias != nullptr);
     if (status != mkldnn_success)
         return status;
-    if (*M == 0 || *N == 0 || *K == 0)
-        return mkldnn_success;
-    int trA = *transa == 't' || *transa == 'T';
-    int trB = *transb == 't' || *transb == 'T';
+
 #ifdef USE_CBLAS
     if (!force_jit_gemm) {
-        //Call cblas
+        bool trA = *transa == 't' || *transa == 'T';
+        bool trB = *transb == 't' || *transb == 'T';
         CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans;
         CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans;
         cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB,
                 *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
-        //Add bias if necessary (bias is applied to columns of C)
+
         if (bias) {
+            // Add bias if necessary (bias is applied to columns of C)
             cblas_int incx = 1, incy = 1;
             parallel_nd(*N, [&](int n) {
-                cblas_saxpy(*M, 1.0, bias, incx, C + n*(*ldc), incy);
+                ptrdiff_t offset = (ptrdiff_t)n * (*ldc);
+                cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy);
             });
         }
         return mkldnn_success;
     }
 #endif
-    //Generate jit kernel and call sgemm with bias
-    volatile static int initialized = 0;
-    if (!initialized) {
-        static std::mutex mtx;
-        std::lock_guard<std::mutex> lock(mtx);
-        if (!initialized) {
-            mkldnn::impl::cpu::initialize();
-            initialized = 1;
-        }
-    }
-    if (bias)
-        gemm_bias_impl[trA][trB]->call(
-                transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc,
-                bias);
-    else
-        gemm_impl[*beta == 0.f][trA][trB]->call(
-                transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
 
-    return mkldnn_success;
+    if (mayiuse(avx512_common))
+        return jit_avx512_common_gemm_f32(transa, transb,
+                M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
+    else if (mayiuse(avx))
+        return jit_avx_gemm_f32(transa, transb,
+                M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
+    else
+        return ref_gemm<float>(transa, transb,
+                M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
 }
 
 template <typename b_dt>
@@ -202,22 +137,20 @@ mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
         const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
         const b_dt *B, const int *LDB, const int8_t *bo, const float *beta,
         int32_t *C, const int *LDC, const int32_t *co) {
-
     mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb,
         M, N, K, LDA, LDB, LDC, alpha, beta, false);
-
     if (status != mkldnn_success)
         return status;
 
     if (*M == 0 || *N == 0 || *K == 0)
         return mkldnn_success;
 
-    bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
-    bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
-    bool AisN = (*transa == 'N' || *transa == 'n');
-    bool BisN = (*transb == 'N' || *transb == 'n');
+#if USE_MKL_IGEMM
+        bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
+        bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
+        bool AisN = (*transa == 'N' || *transa == 'n');
+        bool BisN = (*transb == 'N' || *transb == 'n');
 
-#if defined(USE_MKL) && defined(USE_CBLAS)
     if (data_traits<b_dt>::data_type == data_type::u8) {
         CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans;
         CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans;
@@ -228,64 +161,58 @@ mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
             ? CblasColOffset
             : CblasFixOffset;
         cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc,
-            *M, *N, *K, *alpha, A, *LDA, *ao, (b_dt*)B, *LDB, *bo, *beta, C, *LDC, co);
+                *M, *N, *K, *alpha, A, *LDA, *ao, (uint8_t *)B, *LDB, *bo,
+                *beta, C, *LDC, co);
         return mkldnn_success;
+    } else {
+        assert(data_traits<b_dt>::data_type == data_type::s8);
+        // TODO CBLAS implementation of gemm_s8s8s32 goes here.
+        // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo
+        if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni))
+                && *ao == 0 && *bo == 0) {
+            return jit_avx512_core_gemm_s8s8s32(transa, transb, offsetc, M,
+                    N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta,
+                    C, LDC, co);
+        } else {
+            return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
+                    alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
+        }
     }
-#endif
-    int m = *M, n = *N, k = *K, lda = *LDA, ldb = *LDB, ldc = *LDC;
-    size_t sizeA = AisN ? lda * k : lda * m;
-    size_t sizeB = BisN ? ldb * n : ldb * k;
-    size_t sizeC = ldc * n;
-
-    double *dA = (double *)malloc(sizeA * sizeof(double), PAGE_4K);
-    double *dB = (double *)malloc(sizeB * sizeof(double), PAGE_4K);
-    double *dC = (double *)malloc(sizeC * sizeof(double), PAGE_4K);
-
-    if (utils::any_null(dA, dB, dC)) {
-        free(dA);
-        free(dB);
-        free(dC);
-        return mkldnn_out_of_memory;
+#else
+    cpu_isa_t isa = isa_any;
+    if (mayiuse(avx512_core_vnni)) {
+        isa = avx512_core_vnni;
+    } else if (mayiuse(avx512_core)) {
+        isa = avx512_core;
     }
 
-    auto da_setter = [=] (int i, int j, double v) { dA[j * lda + i] = v; };
-    auto db_setter = [=] (int i, int j, double v) { dB[j * ldb + i] = v; };
-
-    auto ia_accessor = [=] (int i, int j) { return A[j * lda + i]; };
-    auto ib_accessor = [=] (int i, int j) { return B[j * ldb + i]; };
-
-    const int a_rows = AisN ? m : k;
-    const int a_cols = AisN ? k : m;
-    mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) {
-        da_setter(i, j,
-            static_cast<double>(ia_accessor(i, j)) + static_cast<double>(ao[0]));
-    });
-
-    const int b_rows = BisN ? k : n;
-    const int b_cols = BisN ? n : k;
-    mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) {
-        db_setter(i, j,
-            static_cast<double>(ib_accessor(i, j)) + static_cast<double>(bo[0]));
-    });
-    double one = 1.0, zero = 0.0;
-    ref_gemm<double>(transa, transb, M, N, K, &one, dA, LDA, dB, LDB, &zero,
-        dC, LDC, nullptr);
-
-    auto i2d = [=] (int32_t v) { return static_cast<double>(v); };
-    auto f2d = [=] (float v)   { return static_cast<double>(v); };
-
-    mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) {
-        double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]);
-        double val = ((*beta == 0.0f) ? 0.0 : f2d(*beta) * i2d(C[i + j * ldc]))
-            + f2d(*alpha) * dC[i + j * ldc] + coffset;
-        C[i + j * ldc] = math::out_round<int32_t>(math::saturate<int32_t>(val));
-    });
-
-    free(dA);
-    free(dB);
-    free(dC);
-    return mkldnn_success;
+    if (data_traits<b_dt>::data_type == data_type::u8) {
+        switch (isa) {
+        case avx512_core:
+        case avx512_core_vnni:
+            return jit_avx512_core_gemm_s8u8s32(transa, transb, offsetc, M,
+                    N, K, alpha, A, LDA, ao, (uint8_t *)B, LDB, bo, beta,
+                    C, LDC, co);
+        default:
+            return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
+                    alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
+        }
+    } else {
+        assert(data_traits<b_dt>::data_type == data_type::s8);
+        // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo
+        if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni))
+                && *ao == 0 && *bo == 0) {
+            return jit_avx512_core_gemm_s8s8s32(transa, transb, offsetc, M,
+                    N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta,
+                    C, LDC, co);
+        } else {
+            return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
+                    alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
+        }
+    }
+#endif
 }
+
 }
 }
 }
@@ -305,18 +232,18 @@ mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb,
         const char *offsetc, const int *M, const int *N, const int *K,
         const float *alpha, const int8_t *A, const int *lda, const int8_t *ao,
         const uint8_t *B, const int *ldb, const int8_t *bo, const float *beta,
-        int32_t *c, const int *ldc, const int32_t *co) {
+        int32_t *C, const int *ldc, const int32_t *co) {
     return gemm_s8x8s32(
         transa, transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo,
-        beta, c, ldc, co);
+        beta, C, ldc, co);
 }
 
 mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb,
         const char *offsetc, const int *M, const int *N, const int *K,
         const float *alpha, const int8_t *A, const int *lda, const int8_t *ao,
         const int8_t *B, const int *ldb, const int8_t *bo, const float *beta,
-        int32_t *c, const int *ldc, const int32_t *co) {
+        int32_t *C, const int *ldc, const int32_t *co) {
     return gemm_s8x8s32(
         transa, transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo,
-        beta, c, ldc, co);
+        beta, C, ldc, co);
 }