1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "gtest/gtest.h"
18 #include "mkldnn_test_common.hpp"
24 template <typename data_t>
25 struct relu_test_params {
26 engine::kind engine_kind;
27 memory::format data_format;
28 memory::format diff_format;
29 data_t negative_slope;
32 mkldnn_status_t expected_status;
35 template <typename data_t>
36 void check_relu_fwd(data_t negative_slope, const memory::desc &md,
37 const memory &src, const memory &dst)
39 data_t *src_data = (data_t *)src.get_data_handle();
40 data_t *dst_data = (data_t *)dst.get_data_handle();
42 ASSERT_EQ(md.data.ndims, 4);
43 ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
45 size_t N = md.data.dims[0];
46 size_t C = md.data.dims[1];
47 size_t H = md.data.dims[2];
48 size_t W = md.data.dims[3];
49 for (size_t i = 0; i < N * C * H * W; ++i) {
50 data_t s = src_data[i];
51 EXPECT_NEAR(dst_data[i], s > 0 ? s : s * negative_slope, 1.e-7);
55 template <typename data_t>
56 void check_relu_bwd(data_t negative_slope, const memory::desc &md,
57 const memory &src, const memory &diff_dst, const memory &diff_src)
59 data_t *src_data = (data_t *)src.get_data_handle();
60 data_t *diff_dst_data = (data_t *)diff_dst.get_data_handle();
61 data_t *diff_src_data = (data_t *)diff_src.get_data_handle();
63 const memory::desc data_d = src.get_primitive_desc().desc();
64 const memory::desc diff_data_d = diff_src.get_primitive_desc().desc();
66 ASSERT_EQ(md.data.ndims, 4);
67 ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
69 size_t N = md.data.dims[0];
70 size_t C = md.data.dims[1];
71 size_t H = md.data.dims[2];
72 size_t W = md.data.dims[3];
73 for (size_t i = 0; i < N * C * H * W; ++i) {
74 data_t ref_s = src_data[map_index(data_d, i)];
75 data_t ref_dd = diff_dst_data[map_index(diff_data_d, i)];
76 data_t ref_ds = ref_dd * ((ref_s > 0) ? data_t{1} : negative_slope);
77 EXPECT_NEAR(diff_src_data[map_index(diff_data_d, i)], ref_ds, 1.e-7);
81 template <typename data_t>
82 class relu_test : public ::testing::TestWithParam<relu_test_params<data_t>> {
84 std::shared_ptr<memory> src;
85 std::shared_ptr<memory> diff_src;
86 std::shared_ptr<memory> dst;
87 std::shared_ptr<memory> diff_dst;
88 std::shared_ptr<memory> workspace;
89 std::shared_ptr<memory::desc> data_desc;
90 std::shared_ptr<memory::desc> diff_data_desc;
91 std::shared_ptr<relu_forward::primitive_desc> relu_prim_desc;
92 relu_test_params<data_t> p;
93 std::shared_ptr<engine> eng;
94 memory::data_type data_type;
98 virtual void SetUp() {
99 p = ::testing::TestWithParam<decltype(p)>::GetParam();
100 catch_expected_failures([=](){Test();}, p.expect_to_fail,
105 p = ::testing::TestWithParam<decltype(p)>::GetParam();
107 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
108 eng.reset(new engine(p.engine_kind, 0));
110 ASSERT_EQ(p.dims.size(), 4U);
112 data_type = data_traits<data_t>::data_type;
113 ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
115 size = p.dims[0] * p.dims[1] * p.dims[2] * p.dims[3];
122 data_desc.reset(new memory::desc(p.dims, data_type,
124 diff_data_desc.reset(new memory::desc(p.dims, data_type,
126 src.reset(new memory({*data_desc, *eng}));
127 dst.reset(new memory({*data_desc, *eng}));
129 fill_data<data_t>(size, (data_t *)src->get_data_handle(),
130 data_t(0), data_t(1));
132 auto relu_desc = relu_forward::desc(prop_kind::forward_training,
133 algorithm::eltwise_relu, *data_desc, p.negative_slope);
134 relu_prim_desc.reset(
135 new relu_forward::primitive_desc(relu_desc, *eng));
136 auto relu = relu_forward(*relu_prim_desc, *src, *dst);
138 std::vector<primitive> pipeline;
139 pipeline.push_back(relu);
140 auto s = stream(stream::kind::lazy);
141 s.submit(pipeline).wait();
143 check_relu_fwd(p.negative_slope, *data_desc,
148 diff_src.reset(new memory({*diff_data_desc, *eng}));
149 diff_dst.reset(new memory({*diff_data_desc, *eng}));
151 fill_data<data_t>(size, (data_t *)diff_dst->get_data_handle(),
152 data_t(0), data_t(1));
154 auto relu_bwd_desc = relu_backward::desc(algorithm::eltwise_relu,
155 *diff_data_desc, *data_desc, p.negative_slope);
156 auto relu_bwd_prim_desc = relu_backward::primitive_desc(
157 relu_bwd_desc, *eng, *relu_prim_desc);
158 auto relu_bwd = relu_backward(relu_bwd_prim_desc, *src, *diff_dst,
161 std::vector<primitive> pipeline;
162 pipeline.push_back(relu_bwd);
163 auto s = stream(stream::kind::lazy);
164 s.submit(pipeline).wait();
166 check_relu_bwd(p.negative_slope, *data_desc,
167 *src, *diff_dst, *diff_src);
171 using relu_test_float = relu_test<float>;
172 using relu_test_params_float = relu_test_params<float>;
174 TEST_P(relu_test_float, TestsReLU)
178 #define EXPAND_SIZES(mb, c, h, w) { mb, c, h, w }
179 #define EXPAND_FORMATS(data) memory::format::data
181 #define ENGINE engine::kind::cpu
183 #define PARAMS_EF(data, diff_data, ns, mb, c, h, w, ef, es) \
184 relu_test_params_float { ENGINE, \
185 EXPAND_FORMATS(data), EXPAND_FORMATS(diff_data), \
186 ns, EXPAND_SIZES(mb, c, h, w), ef, es}
188 #define PARAMS(data, diff_data, ns, mb, c, h, w) \
189 PARAMS_EF(data, diff_data, ns, mb, c, h, w, false, mkldnn_success)
191 #define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
192 str, relu_test_float, ::testing::Values(__VA_ARGS__))
194 INST_TEST_CASE(SimpleZeroDim,
195 PARAMS(nchw, nchw, 0.f, 0, 8, 4, 4),
196 PARAMS(nchw, nchw, 0.f, 2, 0, 4, 4),
197 PARAMS(nchw, nchw, 0.f, 2, 8, 0, 4),
198 PARAMS(nchw, nchw, 0.f, 2, 8, 4, 0)
201 INST_TEST_CASE(SimpleEF,
202 PARAMS_EF(nchw, nchw, 0.f, -1, 8, 4, 4, true, mkldnn_invalid_arguments),
203 PARAMS_EF(nchw, nchw, 0.f, 2, -1, 4, 4, true, mkldnn_invalid_arguments),
204 PARAMS_EF(nchw, nchw, 0.f, 2, 8, -1, 4, true, mkldnn_invalid_arguments),
205 PARAMS_EF(nchw, nchw, 0.f, 2, 8, 4, -1, true, mkldnn_invalid_arguments)
208 INST_TEST_CASE(SimpleZeroNegativeSlope_NCHW,
209 //PARAMS(nchw, nchw, 0.f, 1, 8, 10000, 10000), // is a tensor of 3 Gb data ok? YES (330 s runtime, slow)
210 //PARAMS(nchw, nchw, 0.f, 1, 12, 10000, 10000), // is a tensor of >4 Gb data ok? worked once (release mode)
211 PARAMS(nchw, nchw, 0.f, 2, 8, 4, 4),
212 PARAMS(nchw, nchw, 0.f, 2, 16, 4, 4),
213 PARAMS(nchw, nchw, 0.f, 2, 16, 8, 8),
214 PARAMS(nchw, nchw, 0.f, 2, 16, 16, 8),
215 PARAMS(nchw, nchw, 0.f, 2, 16, 10, 8),
216 PARAMS(nchw, nchw, 0.f, 10, 10, 10, 10),
217 PARAMS(nchw, nchw, 0.f, 256, 64, 8, 16),
218 PARAMS(nchw, nchw, 0.f, 1, 1, 1, 1),
219 PARAMS(nchw, nchw, 0.f, 3, 5, 7, 11)
222 INST_TEST_CASE(Simple_NCHW,
223 PARAMS(nchw, nchw, 0.1f, 2, 8, 4, 4),
224 PARAMS(nchw, nchw, 0.1f, 2, 16, 4, 4),
225 PARAMS(nchw, nchw, 0.1f, 2, 16, 8, 8),
226 PARAMS(nchw, nchw, 0.1f, 2, 16, 16, 8),
227 PARAMS(nchw, nchw, 0.1f, 2, 16, 10, 8),
228 PARAMS(nchw, nchw, 0.1f, 10, 10, 10, 10),
229 PARAMS(nchw, nchw, 0.1f, 256, 64, 8, 16),
230 PARAMS(nchw, nchw, 0.1f, 1, 1, 1, 1),
231 PARAMS(nchw, nchw, 0.1f, 3, 5, 7, 11)
234 INST_TEST_CASE(Simple,
235 PARAMS(nchw, nChw8c, 0.1f, 2, 8, 4, 4),
236 PARAMS(nChw8c, nchw, 0.1f, 2, 16, 4, 4),
237 PARAMS(nchw, nchw, 0.1f, 2, 16, 8, 8),
238 PARAMS(nChw8c, nChw8c, 0.1f, 2, 16, 16, 8),
239 PARAMS(nhwc, nchw, 0.1f, 2, 16, 10, 8),
240 PARAMS(nchw, nhwc, 0.1f, 10, 10, 10, 10)
243 INST_TEST_CASE(AlexNet_NCHW,
244 PARAMS(nchw, nchw, 0.f, 2, 96, 55, 55),
245 PARAMS(nchw, nchw, 0.f, 2, 256, 27, 27),
246 PARAMS(nchw, nchw, 0.f, 2, 384, 13, 13)