1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
20 #include "mkldnn_types.h"
37 mkldnn_status_t expected_status;
40 void ref_gemm(const char *transa, const char *transb, int m, int n, int k,
41 const float alpha, const float *a, int lda, const float *b,
42 int ldb, float beta, float *c, int ldc) {
44 const bool tr_a = transa && (*transa == 'T' || *transa == 't');
45 const bool tr_b = transb && (*transb == 'T' || *transb == 't');
47 auto pa = [=] (int i, int j) { return a[j*lda + i]; };
48 auto pb = [=] (int i, int j) { return b[j*ldb + i]; };
49 auto pc = [=] (int i, int j) { return c[j*ldc + i]; };
51 mkldnn::impl::parallel_nd(m, n, [&](int im, int in) {
52 float c_elem = (beta == 0.) ? 0. : pc(im, in) * beta;
53 for (int ik = 0; ik < k; ik++) {
54 const float a_elem = tr_a ? pa(ik, im) : pa(im, ik);
55 const float b_elem = tr_b ? pb(in, ik) : pb(ik, in);
56 c_elem += alpha * a_elem * b_elem;
58 c[in*ldc + im] = c_elem;
62 void compare(int M, int N, int ldc, float *C, float *C_ref) {
63 mkldnn::impl::parallel_nd(N, ldc, [&](int i, int j) {
64 float ref = C_ref[i*ldc + j];
65 float got = C[i*ldc + j];
66 float diff = got - ref;
67 float e = (std::abs(ref) > 1e-4) ? diff / ref : diff;
68 EXPECT_NEAR(e, 0.0, 1e-4)
69 << "Row: " << j << " Column: " << i;
73 class sgemm_test: public ::testing::TestWithParam<test_params> {
75 virtual void SetUp() {
77 = ::testing::TestWithParam<test_params>::GetParam();
78 catch_expected_failures([=](){Test();}, p.expect_to_fail,
82 mkldnn_status_t status;
84 = ::testing::TestWithParam<test_params>::GetParam();
85 const bool tr_a = (p.transA == 'T' || p.transA == 't');
86 const bool tr_b = (p.transB == 'T' || p.transB == 't');
87 size_t sizeA = !tr_a ? p.lda * p.K : p.lda * p.M,
88 sizeB = !tr_b ? p.ldb * p.N : p.ldb * p.K,
90 float *A = nullptr, *B = nullptr, *C = nullptr, *C_ref = nullptr;
91 A = (float *)test_malloc(sizeA*sizeof(float));
92 B = (float *)test_malloc(sizeB*sizeof(float));
93 C = (float *)test_malloc(sizeC*sizeof(float));
94 C_ref = (float *)test_malloc(sizeC*sizeof(float));
96 fill_data<float>(sizeA, A);
97 fill_data<float>(sizeB, B);
98 fill_data<float>(sizeC, C);
100 mkldnn::impl::parallel_nd(p.N * p.ldc, [&](int i) { C_ref[i] = C[i]; });
102 status = mkldnn_sgemm(&p.transA, &p.transB, &p.M, &p.N, &p.K, &p.alpha, A,
103 &p.lda, B, &p.ldb, &p.beta, C, &p.ldc);
104 if (status != mkldnn_success)
105 throw error(status, "mkldnn_sgemm returned error");
107 ref_gemm(&p.transA, &p.transB, p.M, p.N, p.K, p.alpha, A, p.lda,
108 B, p.ldb, p.beta, C_ref, p.ldc);
109 compare(p.M, p.N, p.ldc, C, C_ref);
111 test_free((char *)A);
112 test_free((char *)B);
113 test_free((char *)C);
114 test_free((char *)C_ref);
117 TEST_P(sgemm_test, TestSGEMM) {}
118 INSTANTIATE_TEST_CASE_P(TestSGEMM, sgemm_test, ::testing::Values(
119 test_params{'n', 'n', 3, 2, 1, 1.0, 0.0, 2, 5, 8, true, mkldnn_invalid_arguments},
120 test_params{'t', 'n', 3, 2, 2, 1.0, 0.0, 1, 5, 8, true, mkldnn_invalid_arguments},
121 test_params{'n', 't', 3, 2, 1, 1.0, 0.0, 3, 1, 8, true, mkldnn_invalid_arguments},
122 test_params{'n', 'd', 3, 2, 1, 1.0, 0.0, 3, 3, 3, true, mkldnn_invalid_arguments},
124 test_params{'N', 'n', 30, 20, 10, 2.0, 1.0, 60, 50, 80, false},
125 test_params{'n', 'T', 30, 20, 10, 2.0, 1.0, 60, 50, 80, false},
126 test_params{'T', 'N', 30, 20, 10, 2.0, 1.0, 60, 50, 80, false},
127 test_params{'t', 't', 30, 20, 10, 2.0, 1.0, 60, 50, 80, false},
128 test_params{'n', 'n', 100, 100, 2, 1.0, 2.0, 100, 100, 100, false},
129 test_params{'n', 't', 100, 2, 100, 1.0, 2.0, 100, 100, 100, false},
130 test_params{'t', 'n', 2, 100, 100, 1.0, 2.0, 100, 100, 100, false},
131 test_params{'t', 't', 2, 100, 100, 1.0, 2.0, 100, 100, 100, false},
132 test_params{'n', 'n', 2, 2, 10000, 1.0, 2.0, 2, 10000, 2, false},
134 test_params{'n', 'n', 2000, 2000, 2000, 1.0, 0.0, 2000, 2000, 2000, false},
135 test_params{'n', 'n', 3000, 3000, 3000, 1.0, 0.0, 3000, 3000, 3000, false},
136 test_params{'t', 'n', 2000, 2000, 2000, 1.0, 0.0, 2000, 2000, 2000, false},
137 test_params{'t', 'n', 3000, 3000, 3000, 1.0, 0.0, 3000, 3000, 3000, false},
138 test_params{'n', 't', 2000, 2000, 2000, 1.0, 0.0, 2000, 2000, 2000, false},
139 test_params{'n', 't', 3000, 3000, 3000, 1.0, 0.0, 3000, 3000, 3000, false},
140 test_params{'t', 't', 2000, 2000, 2000, 1.0, 0.0, 2000, 2000, 2000, false},
141 test_params{'t', 't', 3000, 3000, 3000, 1.0, 0.0, 3000, 3000, 3000, false}