0fe2c1f0021fd01f421f7df50fd3b2e9ea20bd5d
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_gemm.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
19
20 #include "mkldnn_types.h"
21 #include "mkldnn.h"
22
23 namespace mkldnn {
24 struct test_params {
25     char transA;
26     char transB;
27     int M;
28     int N;
29     int K;
30     float alpha;
31     float beta;
32     int lda;
33     int ldb;
34     int ldc;
35
36     bool expect_to_fail;
37     mkldnn_status_t expected_status;
38 };
39
40 void ref_gemm(const char *transa, const char *transb, int m, int n, int k,
41         const float alpha, const float *a, int lda, const float *b,
42         int ldb, float beta, float *c, int ldc) {
43
44     const bool tr_a = transa && (*transa == 'T' || *transa == 't');
45     const bool tr_b = transb && (*transb == 'T' || *transb == 't');
46
47     auto pa = [=] (int i, int j) { return a[j*lda + i]; };
48     auto pb = [=] (int i, int j) { return b[j*ldb + i]; };
49     auto pc = [=] (int i, int j) { return c[j*ldc + i]; };
50
51     mkldnn::impl::parallel_nd(m, n, [&](int im, int in) {
52         float c_elem = (beta == 0.) ? 0. : pc(im, in) * beta;
53         for (int ik = 0; ik < k; ik++) {
54             const float a_elem = tr_a ? pa(ik, im) : pa(im, ik);
55             const float b_elem = tr_b ? pb(in, ik) : pb(ik, in);
56             c_elem += alpha * a_elem * b_elem;
57         }
58         c[in*ldc + im] = c_elem;
59     });
60 }
61
62 void compare(int M, int N, int ldc, float *C, float *C_ref) {
63     mkldnn::impl::parallel_nd(N, ldc, [&](int i, int j) {
64         float ref = C_ref[i*ldc + j];
65         float got = C[i*ldc + j];
66         float diff = got - ref;
67         float e = (std::abs(ref) > 1e-4) ? diff / ref : diff;
68         EXPECT_NEAR(e, 0.0, 1e-4)
69             << "Row: " << j << " Column: " << i;
70     });
71 }
72
73 class sgemm_test: public ::testing::TestWithParam<test_params> {
74 protected:
75     virtual void SetUp() {
76         test_params p
77             = ::testing::TestWithParam<test_params>::GetParam();
78         catch_expected_failures([=](){Test();}, p.expect_to_fail,
79                     p.expected_status);
80     }
81     virtual void Test() {
82         mkldnn_status_t status;
83         test_params p
84             = ::testing::TestWithParam<test_params>::GetParam();
85         const bool tr_a = (p.transA == 'T' || p.transA == 't');
86         const bool tr_b = (p.transB == 'T' || p.transB == 't');
87         size_t sizeA = !tr_a ? p.lda * p.K : p.lda * p.M,
88                 sizeB = !tr_b ? p.ldb * p.N : p.ldb * p.K,
89                 sizeC = p.ldc * p.N;
90         float *A = nullptr, *B = nullptr, *C = nullptr, *C_ref = nullptr;
91         A = (float *)test_malloc(sizeA*sizeof(float));
92         B = (float *)test_malloc(sizeB*sizeof(float));
93         C = (float *)test_malloc(sizeC*sizeof(float));
94         C_ref = (float *)test_malloc(sizeC*sizeof(float));
95
96         fill_data<float>(sizeA, A);
97         fill_data<float>(sizeB, B);
98         fill_data<float>(sizeC, C);
99
100         mkldnn::impl::parallel_nd(p.N * p.ldc, [&](int i) { C_ref[i] = C[i]; });
101
102         status = mkldnn_sgemm(&p.transA, &p.transB, &p.M, &p.N, &p.K, &p.alpha, A,
103                 &p.lda, B, &p.ldb, &p.beta, C, &p.ldc);
104         if (status != mkldnn_success)
105             throw error(status, "mkldnn_sgemm returned error");
106
107         ref_gemm(&p.transA, &p.transB, p.M, p.N, p.K, p.alpha, A, p.lda,
108                 B, p.ldb, p.beta, C_ref, p.ldc);
109         compare(p.M, p.N, p.ldc, C, C_ref);
110
111         test_free((char *)A);
112         test_free((char *)B);
113         test_free((char *)C);
114         test_free((char *)C_ref);
115     }
116 };
117 TEST_P(sgemm_test, TestSGEMM) {}
118 INSTANTIATE_TEST_CASE_P(TestSGEMM, sgemm_test, ::testing::Values(
119     test_params{'n', 'n', 3, 2, 1, 1.0, 0.0, 2, 5, 8, true, mkldnn_invalid_arguments},
120     test_params{'t', 'n', 3, 2, 2, 1.0, 0.0, 1, 5, 8, true, mkldnn_invalid_arguments},
121     test_params{'n', 't', 3, 2, 1, 1.0, 0.0, 3, 1, 8, true, mkldnn_invalid_arguments},
122     test_params{'n', 'd', 3, 2, 1, 1.0, 0.0, 3, 3, 3, true, mkldnn_invalid_arguments},
123
124     test_params{'N', 'n', 30, 20, 10, 2.0, 1.0, 60, 50, 80, false},
125     test_params{'n', 'T', 30, 20, 10, 2.0, 1.0, 60, 50, 80, false},
126     test_params{'T', 'N', 30, 20, 10, 2.0, 1.0, 60, 50, 80, false},
127     test_params{'t', 't', 30, 20, 10, 2.0, 1.0, 60, 50, 80, false},
128     test_params{'n', 'n', 100, 100, 2, 1.0, 2.0, 100, 100, 100, false},
129     test_params{'n', 't', 100, 2, 100, 1.0, 2.0, 100, 100, 100, false},
130     test_params{'t', 'n', 2, 100, 100, 1.0, 2.0, 100, 100, 100, false},
131     test_params{'t', 't', 2, 100, 100, 1.0, 2.0, 100, 100, 100, false},
132     test_params{'n', 'n', 2, 2, 10000, 1.0, 2.0, 2, 10000, 2, false},
133
134     test_params{'n', 'n', 2000, 2000, 2000, 1.0, 0.0, 2000, 2000, 2000, false},
135     test_params{'n', 'n', 3000, 3000, 3000, 1.0, 0.0, 3000, 3000, 3000, false},
136     test_params{'t', 'n', 2000, 2000, 2000, 1.0, 0.0, 2000, 2000, 2000, false},
137     test_params{'t', 'n', 3000, 3000, 3000, 1.0, 0.0, 3000, 3000, 3000, false},
138     test_params{'n', 't', 2000, 2000, 2000, 1.0, 0.0, 2000, 2000, 2000, false},
139     test_params{'n', 't', 3000, 3000, 3000, 1.0, 0.0, 3000, 3000, 3000, false},
140     test_params{'t', 't', 2000, 2000, 2000, 1.0, 0.0, 2000, 2000, 2000, false},
141     test_params{'t', 't', 3000, 3000, 3000, 1.0, 0.0, 3000, 3000, 3000, false}
142 ));
143 }