Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_softmax_forward.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "gtest/gtest.h"
18 #include "mkldnn_test_common.hpp"
19
20 #include "mkldnn.hpp"
21
22 namespace mkldnn {
23
24 template <typename data_t>
25 void check_softmax_fwd(prop_kind aprop_kind, memory &src, memory &dst, int axis)
26 {
27     data_t *dst_ptr = (data_t *)dst.get_data_handle();
28
29     const memory::desc dst_pd = dst.get_primitive_desc().desc();
30
31     ASSERT_EQ(dst_pd.data.data_type,
32             memory::data_type::f32); // TODO: type assert
33
34     float result = 0.0f;
35     // Worst case error bound on naive summation
36     // algorithm is on the order of n*machine_precision.
37     // See e.g. N. J. Higham. Accuracy and stability of numerical algorithms.
38     //     SIAM Publications, Philadelphia, 2nd edition, 2002.
39     // So below tests will use error bound dependent
40     // on the number of elements in reduction.
41     const float eps = std::numeric_limits<float>::epsilon();
42
43     int MB = dst_pd.data.dims[0];
44     int C = dst_pd.data.dims[1];
45     if (MB*C == 0) return;
46
47     if (dst_pd.data.ndims == 2) {
48         if (axis == 1) {
49             for (int n = 0; n < MB; ++n) {
50                 result = 0.0f;
51
52                 for (int c = 0; c < C; ++c) {
53                     result += dst_ptr[map_index(dst_pd, n * C + c)];
54                 }
55                 EXPECT_NEAR(result, 1.0, eps*C);
56             }
57         }
58         else if (axis == 0) {
59             for (int c = 0; c < C; ++c) {
60                 result = 0.0f;
61
62                 for (int n = 0; n < MB; ++n) {
63                     result += dst_ptr[map_index(dst_pd, n * C + c)];
64                 }
65                 EXPECT_NEAR(result, 1.0, eps*MB);
66             }
67         }
68     } else {
69         int H = dst_pd.data.dims[2];
70         int W = dst_pd.data.dims[3];
71         if (H*W == 0) return;
72
73         auto off = [=](int n, int c, int h, int w)
74         {
75             return ((size_t)n * W * H * C + (size_t)c * W * H + (size_t)h * W + w);
76         };
77
78         if (axis == 0) {
79             for (int c = 0; c < C; ++c) {
80                 for (int h = 0; h < H; ++h) {
81                     for (int w = 0; w < W; ++w) {
82                         result = 0.0f;
83
84                         for (int n = 0; n < MB; ++n) {
85                             result += dst_ptr[map_index(dst_pd, off(n, c, h, w))];
86                         }
87                         EXPECT_NEAR(result, 1.0, eps*MB);
88                     }
89                 }
90             }
91         } else if (axis == 1) {
92             for (int n = 0; n < MB; ++n) {
93                 for (int h = 0; h < H; ++h) {
94                     for (int w = 0; w < W; ++w) {
95                         result = 0.0f;
96
97                         for (int c = 0; c < C; ++c) {
98                             result += dst_ptr[map_index(dst_pd, off(n, c, h, w))];
99                         }
100                         EXPECT_NEAR(result, 1.0, eps*C);
101                     }
102                 }
103             }
104         } else if (axis == 2) {
105             for (int n = 0; n < MB; ++n) {
106                 for (int c = 0; c < C; ++c) {
107                     for (int w = 0; w < W; ++w) {
108                         result = 0.0f;
109
110                         for (int h = 0; h < H; ++h) {
111                             result += dst_ptr[map_index(dst_pd, off(n, c, h, w))];
112                         }
113                         EXPECT_NEAR(result, 1.0, eps*H);
114                     }
115                 }
116             }
117         } else if (axis == 3) {
118             for (int n = 0; n < MB; ++n) {
119                 for (int c = 0; c < C; ++c) {
120                     for (int h = 0; h < H; ++h) {
121                         result = 0.0f;
122
123                         for (int w = 0; w < W; ++w) {
124                             result += dst_ptr[map_index(dst_pd, off(n, c, h, w))];
125                         }
126                         EXPECT_NEAR(result, 1.0, eps*W);
127                     }
128                 }
129             }
130         }
131     }
132 }
133
134 template <typename data_t>
135 struct softmax_test_params {
136     prop_kind aprop_kind;
137     engine::kind engine_kind;
138     memory::format memory_format;
139     memory::dims dims;
140     int axis;
141     bool expect_to_fail;
142     mkldnn_status_t expected_status;
143 };
144
145 template <typename data_t>
146 class softmax_test : public ::testing::TestWithParam<softmax_test_params<data_t>> {
147     softmax_test_params<data_t> p;
148 protected:
149     virtual void SetUp() {
150         p = ::testing::TestWithParam<softmax_test_params<data_t>>::GetParam();
151         catch_expected_failures([=](){Test();}, p.expect_to_fail,
152                     p.expected_status);
153     }
154
155     void Test() {
156         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
157         ASSERT_TRUE(p.aprop_kind == prop_kind::forward_training
158                     || p.aprop_kind == prop_kind::forward_scoring
159                     || p.aprop_kind == prop_kind::forward_inference);
160         auto eng = engine(p.engine_kind, 0);
161
162         memory::data_type prec = data_traits<data_t>::data_type;
163
164         auto mem_desc = memory::desc(p.dims, prec, p.memory_format);
165         auto mem_prim_desc = memory::primitive_desc(mem_desc, eng);
166
167         auto src = memory(mem_prim_desc);
168         auto dst = memory(mem_prim_desc);
169
170         auto softmax_desc = softmax_forward::desc(p.aprop_kind, mem_desc,
171                     p.axis);
172         auto softmax_prim_desc
173             = softmax_forward::primitive_desc(softmax_desc, eng);
174         auto softmax = softmax_forward(softmax_prim_desc, src, dst);
175
176         auto test_with_given_fill = [&](data_t mean, data_t var) {
177             fill_data<data_t>(mem_prim_desc.get_size() / sizeof(data_t),
178                     (data_t *)src.get_data_handle(), mean, var);
179
180             stream(stream::kind::lazy).submit({softmax}).wait();
181             check_softmax_fwd<data_t>(p.aprop_kind, src, dst, p.axis);
182         };
183
184         test_with_given_fill(-50, 50);
185         test_with_given_fill(-200, 1);
186         test_with_given_fill(   0, 1);
187         test_with_given_fill( 200, 1);
188     }
189 };
190
191 using softmax_forward_test_float = softmax_test<float>;
192 using softmax_fwd_test_params_float = softmax_test_params<float>;
193
194 TEST_P(softmax_forward_test_float, TestsSoftmax) { }
195 INSTANTIATE_TEST_CASE_P(TestSoftmaxForward, softmax_forward_test_float,
196         ::testing::Values(
197             softmax_fwd_test_params_float{prop_kind::forward_scoring,
198             engine::kind::cpu, memory::format::nchw, {2, -2, 128, 256}, 0,
199             true, mkldnn_invalid_arguments},
200             softmax_fwd_test_params_float{prop_kind::forward_scoring,
201             engine::kind::cpu, memory::format::nchw, {2, 2, 128, 256}, 5,
202             true, mkldnn_invalid_arguments},
203             softmax_fwd_test_params_float{prop_kind::forward_scoring,
204             engine::kind::cpu, memory::format::nchw, {2, 0, 5, 5}, 0},
205             softmax_fwd_test_params_float{prop_kind::forward_scoring,
206             engine::kind::cpu, memory::format::nchw, {2, 0, 5, 5}, 1},
207             softmax_fwd_test_params_float{prop_kind::forward_scoring,
208             engine::kind::cpu, memory::format::nchw, {2, 19, 128, 256}, 0},
209             softmax_fwd_test_params_float{prop_kind::forward_scoring,
210             engine::kind::cpu, memory::format::nchw, {2, 19, 128, 256}, 1},
211             softmax_fwd_test_params_float{prop_kind::forward_scoring,
212             engine::kind::cpu, memory::format::nchw, {2, 19, 128, 256}, 2},
213             softmax_fwd_test_params_float{prop_kind::forward_scoring,
214             engine::kind::cpu, memory::format::nchw, {1, 8, 1024, 16}, 2},
215             softmax_fwd_test_params_float{prop_kind::forward_scoring,
216             engine::kind::cpu, memory::format::nchw, {2, 19, 128, 256}, 3},
217             softmax_fwd_test_params_float{prop_kind::forward_scoring,
218             engine::kind::cpu, memory::format::nc, {2, 1000}, 0},
219             softmax_fwd_test_params_float{prop_kind::forward_scoring,
220             engine::kind::cpu, memory::format::nc, {2, 1000}, 1},
221             softmax_fwd_test_params_float{prop_kind::forward_scoring,
222             engine::kind::cpu, memory::format::nc, {1, 256}, 1},
223             softmax_fwd_test_params_float{prop_kind::forward_scoring,
224             engine::kind::cpu, memory::format::nc, {1, 13}, 1}));
225 }