Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_gemm_common.hpp
index fa8e683..f468d3e 100644 (file)
 
 namespace mkldnn {
 
-struct test_params {
+struct test_igemm_params {
     char offsetc;
+    bool zero_oa;
+    bool zero_ob;
+    bool zero_oc;
+};
+
+struct test_params {
     char transA;
     char transB;
     int M;
@@ -46,6 +52,7 @@ struct test_params {
     int ldb;
     int ldc;
 
+    test_igemm_params igemm_params;
     bool expect_to_fail;
     mkldnn_status_t expected_status;
 };
@@ -77,9 +84,9 @@ void ref_gemm(const char *transa, const char *transb, int m, int n, int k,
 template <typename b_dt>
 void ref_gemm_s8x8s32(const char *transa, const char *transb,
         const char *offsetc, int m, int n, int k, const float alpha,
-        int8_t *A, int lda, const int8_t *ao, b_dt *B, int ldb,
-        const int8_t *bo, const float beta, int32_t *C, int ldc,
-        const int32_t *co) {
+        int8_t *A, int lda, const int8_t *oa, b_dt *B, int ldb,
+        const int8_t *ob, const float beta, int32_t *C, int ldc,
+        const int32_t *oc) {
 
     bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
     bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
@@ -104,14 +111,14 @@ void ref_gemm_s8x8s32(const char *transa, const char *transb,
     const int a_cols = AisN ? k : m;
     mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) {
         da_setter(i, j,
-            static_cast<double>(ia_accessor(i, j)) + static_cast<double>(ao[0]));
+            static_cast<double>(ia_accessor(i, j)) + static_cast<double>(oa[0]));
     });
 
     const int b_rows = BisN ? k : n;
     const int b_cols = BisN ? n : k;
     mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) {
         db_setter(i, j,
-            static_cast<double>(ib_accessor(i, j)) + static_cast<double>(bo[0]));
+            static_cast<double>(ib_accessor(i, j)) + static_cast<double>(ob[0]));
     });
 
     ref_gemm(transa, transb, m, n, k, 1.0, dA, lda, dB, ldb, 0.0, dC, ldc);
@@ -120,7 +127,7 @@ void ref_gemm_s8x8s32(const char *transa, const char *transb,
     auto f2d = [=] (float v) { return static_cast<double>(v); };
 
     mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) {
-        double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]);
+        double coffset = OCisR ? i2d(oc[j]) : OCisC ? i2d(oc[i]) : i2d(oc[0]);
         double val = ((beta == 0.0f) ? 0.0 : f2d(beta) * i2d(C[i + j * ldc]))
             + f2d(alpha) * dC[i + j * ldc] + coffset;
         C[i + j * ldc] =
@@ -132,20 +139,31 @@ void ref_gemm_s8x8s32(const char *transa, const char *transb,
     test_free((char *)dC);
 }
 
-template <typename T>
-void compare(int M, int N, int ldc, T *C, T *C_ref, int K = 1) {
-    mkldnn::impl::parallel_nd(N, ldc, [&](int i, int j) {
-        T ref = C_ref[i*ldc + j];
-        T got = C[i*ldc + j];
-        T diff = got - ref;
-        if (data_traits<T>::data_type == memory::data_type::f32) {
-            T e = (std::abs(ref) > 1e-4) ? diff / ref : diff;
-            EXPECT_NEAR(e, 0.0, 1e-4)
-                << "Row: " << j << " Column: " << i;
+template <typename b_dt, typename c_dt>
+void compare(int m, int n, const c_dt *c, const c_dt *c_ref, int ldc,
+        float alpha = 1.0f, float beta = 0.0f, int k = 1) {
+    using data_type = memory::data_type;
+    mkldnn::impl::parallel_nd(n, ldc, [&](int i, int j) {
+        c_dt ref = c_ref[i*ldc + j];
+        c_dt got = c[i*ldc + j];
+        c_dt diff = got - ref;
+
+        if (data_traits<b_dt>::data_type == data_type::f32) {
+            c_dt e = (std::abs(ref) > 1e-4) ? diff / ref : diff;
+            EXPECT_NEAR(e, 0.0, 1e-4) << "Row: " << j << " Col: " << i;
         } else {
-            T eps = K / 1000 + 1;
-            EXPECT_NEAR(diff, 0, eps)
-                << "Row: " << j << " Column: " << i;
+            // igemm
+            if (alpha == 1.0f) {
+                EXPECT_NEAR(diff, 0, 1) << "Row: " << j << " Col: " << i;
+            } else {
+                if (data_traits<b_dt>::data_type == data_type::u8) {
+                    c_dt eps = k / 1000 + 1;
+                    EXPECT_NEAR(diff, 0, eps) << "Row: " << j << " Col: " << i;
+                } else if (data_traits<b_dt>::data_type == data_type::s8) {
+                    c_dt eps = k / 500 + 1;
+                    EXPECT_NEAR(diff, 0, eps) << "Row: " << j << " Col: " << i;
+                }
+            }
         }
     });
 }
@@ -165,15 +183,23 @@ inline T* get_matrix_buffer(size_t n) {
 }
 
 template <typename a_dt, typename b_dt, typename c_dt>
-inline void fill_matrix(size_t sizeA, size_t sizeB, size_t sizeC, size_t sizeco,
-        a_dt *A, b_dt *B, c_dt *C, a_dt *ao, a_dt *bo, c_dt *co) {
+inline void fill_matrix(const test_params &p, size_t sizeA, size_t sizeB,
+        size_t sizeC, size_t sizeco, a_dt *A, b_dt *B, c_dt *C, a_dt *oa,
+        a_dt *ob, c_dt *oc) {
     fill_data<a_dt>(sizeA, A);
     fill_data<b_dt>(sizeB, B);
     fill_data<c_dt>(sizeC, C);
-    if (ao != nullptr && bo != nullptr && co != nullptr) {
-        fill_data<a_dt>(1, ao);
-        fill_data<a_dt>(1, bo);
-        fill_data<c_dt>(sizeco, co);
+    if (oa != nullptr && ob != nullptr && oc != nullptr) {
+        if (p.igemm_params.zero_oa) (*oa) = 0;
+        else fill_data<a_dt>(1, oa);
+
+        if (p.igemm_params.zero_ob) (*ob) = 0;
+        else fill_data<a_dt>(1, ob);
+
+        if (p.igemm_params.zero_oc) {
+            for (size_t i = 0; i < sizeco; i++)
+                oc[i] = 0;
+        } else fill_data<c_dt>(sizeco, oc);
     }
 }
 
@@ -190,37 +216,37 @@ void run_test_gemm<int8_t, uint8_t, int32_t>(const test_params &p) {
     int32_t *C = get_matrix_buffer<int32_t>(sizeC);
     int32_t *C_ref = get_matrix_buffer<int32_t>(sizeC);
 
-    bool OCisR = (p.offsetc == 'R' || p.offsetc == 'r');
-    bool OCisC = (p.offsetc == 'C' || p.offsetc == 'c');
+    bool OCisR = (p.igemm_params.offsetc == 'R' || p.igemm_params.offsetc == 'r');
+    bool OCisC = (p.igemm_params.offsetc == 'C' || p.igemm_params.offsetc == 'c');
     size_t sizeco = OCisR ? p.N : OCisC ? p.M : 1;
 
-    int8_t ao, bo;
-    int32_t *co = get_matrix_buffer<int32_t>(sizeco);
+    int8_t oa, ob;
+    int32_t *oc = get_matrix_buffer<int32_t>(sizeco);
 
-    fill_matrix<int8_t, uint8_t, int32_t>(sizeA, sizeB, sizeC, sizeco, A, B, C,
-        &ao, &bo, co);
+    fill_matrix<int8_t, uint8_t, int32_t>(p, sizeA, sizeB, sizeC, sizeco,
+        A, B, C, &oa, &ob, oc);
 
     mkldnn::impl::parallel_nd(p.ldc * p.N,
         [&](int i) { C_ref[i] = static_cast<int32_t>(C[i]); });
 
-    auto status = mkldnn_gemm_s8u8s32(&p.transA, &p.transB, &p.offsetc,
-        &p.M, &p.N, &p.K, &p.alpha, A, &p.lda, &ao, B, &p.ldb, &bo,
-        &p.beta, C, &p.ldc, co);
+    auto status = mkldnn_gemm_s8u8s32(&p.transA, &p.transB, &p.igemm_params.offsetc,
+        &p.M, &p.N, &p.K, &p.alpha, A, &p.lda, &oa, B, &p.ldb, &ob,
+        &p.beta, C, &p.ldc, oc);
 
     if (status != mkldnn_success)
         throw error(status, "mkldnn_gemm_s8u8s32 returned error");
 
-    ref_gemm_s8x8s32<uint8_t>(&p.transA, &p.transB, &p.offsetc, p.M, p.N,
-        p.K, p.alpha, A, p.lda, &ao, B, p.ldb, &bo, p.beta, C_ref,
-        p.ldc, co);
+    ref_gemm_s8x8s32<uint8_t>(&p.transA, &p.transB, &p.igemm_params.offsetc, p.M, p.N,
+        p.K, p.alpha, A, p.lda, &oa, B, p.ldb, &ob, p.beta, C_ref,
+        p.ldc, oc);
 
-    compare(p.M, p.N, p.ldc, C, C_ref, p.K);
+    compare<uint8_t, int32_t>(p.M, p.N, C, C_ref, p.ldc, p.alpha, p.beta, p.K);
 
     test_free((char *)A);
     test_free((char *)B);
     test_free((char *)C);
     test_free((char *)C_ref);
-    test_free((char *)co);
+    test_free((char *)oc);
 }
 
 template <>
@@ -233,37 +259,37 @@ void run_test_gemm<int8_t, int8_t, int32_t>(const test_params &p) {
     int32_t *C = get_matrix_buffer<int32_t>(sizeC);
     int32_t *C_ref = get_matrix_buffer<int32_t>(sizeC);
 
-    bool OCisR = (p.offsetc == 'R' || p.offsetc == 'r');
-    bool OCisC = (p.offsetc == 'C' || p.offsetc == 'c');
+    bool OCisR = (p.igemm_params.offsetc == 'R' || p.igemm_params.offsetc == 'r');
+    bool OCisC = (p.igemm_params.offsetc == 'C' || p.igemm_params.offsetc == 'c');
     size_t sizeco = OCisR ? p.N : OCisC ? p.M : 1;
 
-    int8_t ao, bo;
-    int32_t* co = get_matrix_buffer<int32_t>(sizeco);
+    int8_t oa, ob;
+    int32_t* oc = get_matrix_buffer<int32_t>(sizeco);
 
-    fill_matrix<int8_t, int8_t, int32_t>(sizeA, sizeB, sizeC, sizeco, A, B, C,
-        &ao, &bo, co);
+    fill_matrix<int8_t, int8_t, int32_t>(p, sizeA, sizeB, sizeC, sizeco, A, B, C,
+        &oa, &ob, oc);
 
     mkldnn::impl::parallel_nd(p.ldc * p.N,
         [&](int i) { C_ref[i] = static_cast<int32_t>(C[i]); });
 
-    auto status = mkldnn_gemm_s8s8s32(&p.transA, &p.transB, &p.offsetc,
-        &p.M, &p.N, &p.K, &p.alpha, A, &p.lda, &ao, B, &p.ldb, &bo,
-        &p.beta, C, &p.ldc, co);
+    auto status = mkldnn_gemm_s8s8s32(&p.transA, &p.transB, &p.igemm_params.offsetc,
+        &p.M, &p.N, &p.K, &p.alpha, A, &p.lda, &oa, B, &p.ldb, &ob,
+        &p.beta, C, &p.ldc, oc);
 
     if (status != mkldnn_success)
         throw error(status, "mkldnn_gemm_s8s8s32 returned error");
 
-    ref_gemm_s8x8s32<int8_t>(&p.transA, &p.transB, &p.offsetc, p.M, p.N,
-        p.K, p.alpha, A, p.lda, &ao, B, p.ldb, &bo, p.beta, C_ref,
-        p.ldc, co);
+    ref_gemm_s8x8s32<int8_t>(&p.transA, &p.transB, &p.igemm_params.offsetc, p.M, p.N,
+        p.K, p.alpha, A, p.lda, &oa, B, p.ldb, &ob, p.beta, C_ref,
+        p.ldc, oc);
 
-    compare(p.M, p.N, p.ldc, C, C_ref, p.K);
+    compare<int8_t, int32_t>(p.M, p.N, C, C_ref, p.ldc, p.alpha, p.beta, p.K);
 
     test_free((char *)A);
     test_free((char *)B);
     test_free((char *)C);
     test_free((char *)C_ref);
-    test_free((char *)co);
+    test_free((char *)oc);
 }
 
 template <>
@@ -276,7 +302,7 @@ void run_test_gemm<float, float, float>(const test_params &p) {
     float *C = get_matrix_buffer<float>(sizeC);
     float *C_ref = get_matrix_buffer<float>(sizeC);
 
-    fill_matrix<float, float, float>(sizeA, sizeB, sizeC, 0, A, B, C,
+    fill_matrix<float, float, float>(p, sizeA, sizeB, sizeC, 0, A, B, C,
         nullptr, nullptr, nullptr);
 
     mkldnn::impl::parallel_nd(p.N * p.ldc, [&](int i) { C_ref[i] = C[i]; });
@@ -286,7 +312,7 @@ void run_test_gemm<float, float, float>(const test_params &p) {
     if (status == mkldnn_success) {
         ref_gemm(&p.transA, &p.transB, p.M, p.N, p.K, p.alpha, A, p.lda, B, p.ldb,
             p.beta, C_ref, p.ldc);
-        compare(p.M, p.N, p.ldc, C, C_ref);
+        compare<float, float>(p.M, p.N, C, C_ref, p.ldc);
     }
 
     test_free((char *)A);