1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include <mkldnn_types.h>
18 #include "gtest/gtest.h"
19 #include "mkldnn_test_common.hpp"
24 template <typename T> inline T scale_shift_fwd(T s_val, T w_val, T b_val) {
25 return s_val*w_val + b_val;
28 template <typename T> inline T prelu_fwd(T s_val, T w_val) {
29 return s_val >= 0 ? s_val : w_val*s_val;
32 template <typename data_t>
33 struct depthwise_test_params {
34 engine::kind engine_kind;
36 memory::format data_format;
40 template <typename data_t>
41 void check_depthwise_fwd(const depthwise_test_params<data_t> &p,
42 const memory::desc &md, const memory &src, const memory &weights,
43 const memory &bias, bool with_bias, const memory &dst)
45 data_t *src_data = (data_t *)src.get_data_handle();
46 data_t *weights_data = (data_t *)weights.get_data_handle();
47 data_t *bias_data = with_bias ? (data_t *)bias.get_data_handle() : nullptr;
48 data_t *dst_data = (data_t *)dst.get_data_handle();
50 const memory::desc src_d = src.get_primitive_desc().desc();
51 const memory::desc weights_d = weights.get_primitive_desc().desc();
52 const memory::desc dst_d = dst.get_primitive_desc().desc();
54 ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
56 int N = md.data.ndims > 0 ? md.data.dims[0] : 1;
57 int C = md.data.ndims > 1 ? md.data.dims[1] : 1;
58 int H = md.data.ndims > 2 ? md.data.dims[2] : 1;
59 int W = md.data.ndims > 3 ? md.data.dims[3] : 1;
61 for (int n = 0; n < N; ++n) {
62 for (int c = 0; c < C; ++c) {
63 for (int h = 0; h < H; ++h) {
64 for (int w = 0; w < W; ++w) {
65 int idx = n*C*H*W + c*H*W + h*W + w;
67 data_t s_val = src_data[map_index(src_d, idx)];
68 data_t w_val = weights_data[map_index(weights_d, c)];
69 data_t b_val = with_bias ? bias_data[map_index(bias.get_primitive_desc().desc(), c)] : 0;
73 case depthwise_scale_shift:
74 ref_d = scale_shift_fwd(s_val, w_val, b_val);
77 ref_d = prelu_fwd(s_val, w_val);
80 assert(!"unknown alg_kind");
83 EXPECT_NEAR(dst_data[map_index(dst_d, idx)], ref_d, 1.e-6);
90 template <typename data_t>
91 class depthwise_test : public ::testing::TestWithParam<depthwise_test_params<data_t>> {
93 std::shared_ptr<memory> src;
94 std::shared_ptr<memory> weights;
95 std::shared_ptr<memory> bias;
96 std::shared_ptr<memory> dst;
97 std::shared_ptr<memory> workspace;
98 std::shared_ptr<memory::desc> src_desc;
99 std::shared_ptr<memory::desc> dst_desc;
100 std::shared_ptr<memory::desc> weights_desc;
101 std::shared_ptr<memory::desc> bias_desc;
102 std::shared_ptr<depthwise_forward::primitive_desc> depthwise_prim_desc;
103 depthwise_test_params<data_t> p;
104 std::shared_ptr<engine> eng;
105 memory::data_type data_type;
110 virtual void SetUp() {
111 p = ::testing::TestWithParam<depthwise_test_params<data_t>>::GetParam();
113 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
114 eng.reset(new engine(p.engine_kind, 0));
116 data_type = data_traits<data_t>::data_type;
117 ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
119 data_size = p.dims[0] * p.dims[1] * p.dims[2] * p.dims[3];
120 weights_size = p.dims[1];
126 bool with_bias = p.alg_kind == depthwise_scale_shift;
128 memory::dims dims = p.data_format == mkldnn_nc ? memory::dims({p.dims[0], p.dims[1]}) : p.dims;
130 src_desc.reset(new memory::desc(dims, data_type, p.data_format));
131 dst_desc.reset(new memory::desc(dims, data_type, p.data_format));
132 src.reset(new memory({*src_desc, *eng}));
133 dst.reset(new memory({*dst_desc, *eng}));
134 fill_data<data_t>(data_size, (data_t *)src->get_data_handle(),
135 data_t(0), data_t(1));
137 weights_desc.reset(new memory::desc({dims[1]}, data_type, memory::format::x));
138 weights.reset(new memory({*weights_desc, *eng}));
139 fill_data<data_t>(weights_size, (data_t *)weights->get_data_handle(),
140 data_t(0), data_t(1));
143 bias_desc.reset(new memory::desc({dims[1]}, data_type, memory::format::x));
144 bias.reset(new memory({*bias_desc, *eng}));
145 fill_data<data_t>(weights_size, (data_t *) bias->get_data_handle(),
146 data_t(0), data_t(1));
149 std::vector<primitive> pipeline;
150 auto depthwise_desc = with_bias
151 ? depthwise_forward::desc(prop_kind::forward_training, p.alg_kind, *src_desc, *dst_desc, *weights_desc, *bias_desc)
152 : depthwise_forward::desc(prop_kind::forward_training, p.alg_kind, *src_desc, *dst_desc, *weights_desc);
153 depthwise_prim_desc.reset(new depthwise_forward::primitive_desc(depthwise_desc, *eng));
155 auto depthwise = with_bias
156 ? depthwise_forward(*depthwise_prim_desc, *src, *weights, *bias, *dst)
157 : depthwise_forward(*depthwise_prim_desc, *src, *weights, *dst);
159 pipeline.push_back(depthwise);
160 auto s = stream(stream::kind::lazy);
161 s.submit(pipeline).wait();
163 check_depthwise_fwd(p, *src_desc, *src, *weights, *bias, with_bias, *dst);
167 using depthwise_test_float = depthwise_test<float>;
168 using depthwise_test_params_float = depthwise_test_params<float>;
170 TEST_P(depthwise_test_float, TestsDepthwise)
174 #define EXPAND(args) args
176 #define EXPAND_FORMATS(data) memory::format::data
178 #define ENGINE engine::kind::cpu
180 #define PARAMS(alg, data, mb, c, h, w) \
181 depthwise_test_params_float { ENGINE, algorithm::alg, \
182 EXPAND_FORMATS(data), {mb, c, h, w} }
184 #define PARAMS_ALL_ALG(...) \
185 EXPAND(PARAMS(depthwise_scale_shift, __VA_ARGS__)), \
186 EXPAND(PARAMS(depthwise_prelu, __VA_ARGS__))
188 #define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
189 str, depthwise_test_float, ::testing::Values(__VA_ARGS__))
191 INST_TEST_CASE(Simple_NC,
192 PARAMS_ALL_ALG(nc, 2, 8, 1, 1),
193 PARAMS_ALL_ALG(nc, 2, 16, 1, 1),
194 PARAMS_ALL_ALG(nc, 2, 16, 1, 1),
195 PARAMS_ALL_ALG(nc, 2, 16, 1, 1),
196 PARAMS_ALL_ALG(nc, 2, 16, 1, 1),
197 PARAMS_ALL_ALG(nc, 10, 10, 1, 1),
198 PARAMS_ALL_ALG(nc, 256, 64, 1, 1),
199 PARAMS_ALL_ALG(nc, 1, 1, 1, 1),
200 PARAMS_ALL_ALG(nc, 3, 5, 1, 1)
203 INST_TEST_CASE(Simple_NCHW,
204 PARAMS_ALL_ALG(nchw, 2, 8, 4, 4),
205 PARAMS_ALL_ALG(nchw, 2, 16, 4, 4),
206 PARAMS_ALL_ALG(nchw, 2, 16, 8, 8),
207 PARAMS_ALL_ALG(nchw, 2, 16, 16, 8),
208 PARAMS_ALL_ALG(nchw, 2, 16, 10, 8),
209 PARAMS_ALL_ALG(nchw, 10, 10, 10, 10),
210 PARAMS_ALL_ALG(nchw, 256, 64, 8, 16),
211 PARAMS_ALL_ALG(nchw, 1, 1, 1, 1),
212 PARAMS_ALL_ALG(nchw, 3, 5, 7, 11)
215 INST_TEST_CASE(Simple_Blocked,
216 PARAMS_ALL_ALG(nChw8c, 2, 8, 4, 4),
217 PARAMS_ALL_ALG(nChw8c, 2, 16, 4, 4),
218 PARAMS_ALL_ALG(nChw8c, 2, 16, 8, 8),
219 PARAMS_ALL_ALG(nChw8c, 2, 16, 16, 8),
220 PARAMS_ALL_ALG(nChw8c, 2, 32, 10, 8),
221 PARAMS_ALL_ALG(nChw8c, 256, 64, 8, 16)
224 INST_TEST_CASE(Simple_Blocked16,
225 PARAMS_ALL_ALG(nChw16c, 2, 16, 4, 4),
226 PARAMS_ALL_ALG(nChw16c, 2, 16, 8, 8),
227 PARAMS_ALL_ALG(nChw16c, 2, 16, 16, 8),
228 PARAMS_ALL_ALG(nChw16c, 2, 32, 10, 8),
229 PARAMS_ALL_ALG(nChw16c, 256, 64, 8, 16)