1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
24 #include "src/common/mkldnn_thread.hpp"
26 #include "mkldnn_common.hpp"
27 #include "mkldnn_memory.hpp"
32 inline bool is_3d(const prb_t *p) {
36 inline int init_pd(const prb_t *p, mkldnn_inner_product_desc_t &ipd,
37 mkldnn_primitive_desc_t &ippd, res_t *r) {
38 mkldnn_memory_desc_t src_d, wei_d, bia_d, dst_d;
40 int ndims = is_3d(p) ? 5 : 4;
41 mkldnn_dims_t src_dims = {p->mb, p->ic, p->ih, p->iw};
42 mkldnn_dims_t src_3d_dims = {p->mb, p->ic, p->id, p->ih, p->iw};
43 mkldnn_dims_t wei_dims = {p->oc, p->ic, p->ih, p->iw};
44 mkldnn_dims_t wei_3d_dims = {p->oc, p->ic, p->id, p->ih, p->iw};
45 mkldnn_dims_t bia_dims = {p->oc};
46 mkldnn_dims_t dst_dims = {p->mb, p->oc};
48 DNN_SAFE(mkldnn_memory_desc_init(&src_d, ndims, is_3d(p) ? src_3d_dims : src_dims,
49 p->cfg[SRC].dt, mkldnn_any), WARN);
50 DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims, is_3d(p) ? wei_3d_dims : wei_dims,
51 p->cfg[WEI].dt, mkldnn_any), WARN);
52 DNN_SAFE(mkldnn_memory_desc_init(&bia_d, 1, bia_dims, p->cfg[BIA].dt, mkldnn_any), WARN);
53 DNN_SAFE(mkldnn_memory_desc_init(&dst_d, 2, dst_dims, p->cfg[DST].dt, mkldnn_any), WARN);
56 case FWD_D: case FWD_B:
57 DNN_SAFE(mkldnn_inner_product_forward_desc_init(&ipd, mkldnn_forward,
58 &src_d, &wei_d, p->dir == FWD_D ? NULL : &bia_d, &dst_d),
62 DNN_SAFE(mkldnn_inner_product_backward_data_desc_init(&ipd, &src_d,
63 &wei_d, &dst_d), WARN);
65 case BWD_W: case BWD_WB:
66 DNN_SAFE(mkldnn_inner_product_backward_weights_desc_init(&ipd, &src_d,
67 &wei_d, p->dir == BWD_W ? NULL : &bia_d, &dst_d), WARN);
69 default: DNN_SAFE(mkldnn_invalid_arguments, CRIT);
72 DNN_SAFE(ipd.accum_data_type == p->cfg[ACC].dt
73 ? mkldnn_success : mkldnn_unimplemented, CRIT);
75 auto mkldnn_attr = create_mkldnn_attr(p->attr, p->oc, p->scales);
77 mkldnn_status_t init_status = mkldnn_success;
78 init_status = mkldnn_primitive_desc_create_v2(&ippd, &ipd, mkldnn_attr,
81 mkldnn_primitive_attr_destroy(mkldnn_attr);
83 if (init_status == mkldnn_unimplemented)
84 return r->state = UNIMPLEMENTED, OK;
86 SAFE(init_status, WARN);
88 auto q = [=](mkldnn_query_t query, int index = 0) {
89 return *mkldnn_primitive_desc_query_memory_d(
90 mkldnn_primitive_desc_query_pd(ippd, query, index));
94 ipd.diff_src_desc = q(mkldnn_query_diff_src_pd);
96 ipd.src_desc = q(mkldnn_query_src_pd);
98 if (p->dir & FLAG_WEI)
99 ipd.diff_weights_desc = q(mkldnn_query_diff_weights_pd);
101 ipd.weights_desc = q(mkldnn_query_weights_pd);
103 if (p->dir & FLAG_BIA) {
104 if (p->dir & FLAG_BWD)
105 ipd.diff_bias_desc = q(mkldnn_query_diff_weights_pd, 1);
107 ipd.bias_desc = q(mkldnn_query_weights_pd, 1);
110 if (p->dir & FLAG_BWD)
111 ipd.diff_dst_desc = q(mkldnn_query_diff_dst_pd);
113 ipd.dst_desc = q(mkldnn_query_dst_pd);
118 inline int compare_dat(const prb_t *p, data_kind_t kind, dnn_mem_t &mem_dt,
119 dnn_mem_t &mem_fp, res_t *r) {
120 size_t nelems = mem_dt.nelems();
122 const char *skind = data_kind2str(kind);
127 for (size_t i = 0; i < nelems; ++i) {
128 float dt = ((float*)mem_dt)[i];
129 float fp0 = ((float *)mem_fp)[i];
132 if (p->cfg[kind].dt != mkldnn_f32) {
133 using R = attr_t::round_mode_t;
134 switch (p->attr.irmode) {
135 case R::DOWN: fp = floorf(fp0); break;
136 case R::NEAREST: fp = nearbyintf(fp0); break;
142 float diff = fabsf(fp - dt);
143 float rel_diff = diff / (fabsf(fp) > FLT_MIN ? fabsf(fp) : 1);
146 if (fp < p->cfg[kind].min)
147 ok = dt == p->cfg[kind].min;
148 else if (fp > p->cfg[kind].max)
149 ok = dt == p->cfg[kind].max;
151 ok = (fabs(fp) > 1e-5 ? rel_diff : diff) <= p->cfg[kind].eps;
155 if (r->errors < 10 || verbose >= 10) {
156 print(0, "[%4lu][%s]"
157 "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
158 (unsigned long)i, skind, fp, fp0, dt, diff, rel_diff);
164 const double trust_nz = (double)non_zero / r->total;
165 bool no_trust = trust_nz < 0.1;
167 r->state = MISTRUSTED;
168 const char *skind = data_kind2str(kind);
169 print(0, "@@@ [%s] test-bug: trust is too low."
170 " Nonzeros in output: %.2f\n",
177 if (r->state == UNTESTED)
178 r->state = PASSED; /* optimism */
180 return r->state == FAILED ? FAIL : OK;
183 int fill_src(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *r) {
185 mem_dt.md_, mkldnn_f32, is_3d(p) ? mkldnn_ncdhw : mkldnn_nchw);
187 const auto &c = p->cfg[SRC];
188 const int range = c.f_max - c.f_min + 1;
190 mkldnn::impl::parallel_nd(p->mb, p->ic, p->id, p->ih, p->iw,
191 [&](int mb, int ic, int id, int ih, int iw) {
193 = 5 * id + 17 * ih + 13 * iw + 13 * mb + 19 * ic + 1637;
194 const bool non_base = flip_coin(gen, c.f_sparsity);
195 const float value = non_base
196 ? c.f_min + gen * c.f_step % range : c.f_base;
198 ((float *)mem_00)[src_off_f(p, mb, ic, id, ih, iw)] = value;
202 SAFE(mem_dt.reorder(mem_00), WARN);
203 SAFE(mem_fp.reorder(mem_dt), WARN);
207 int fill_wei(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *r) {
209 mem_dt.md_, mkldnn_f32, is_3d(p) ? mkldnn_goihw : mkldnn_oihw);
211 const auto &c = p->cfg[WEI];
212 const int range = c.f_max - c.f_min + 1;
214 mkldnn::impl::parallel_nd(p->oc, p->ic, p->id, p->ih, p->iw,
215 [&](int oc, int ic, int id, int ih, int iw) {
216 const int gen = 5 * id + 17 * ih + 13 * iw + 13 * oc + 19 * ic + 38;
217 const bool non_base = flip_coin(gen, c.f_sparsity);
218 const float value = non_base
219 ? c.f_min + gen * c.f_step % range : c.f_base;
221 ((float *)mem_00)[wei_off_f(p, oc, ic, id, ih, iw)] = value;
225 SAFE(mem_dt.reorder(mem_00), WARN);
226 SAFE(mem_fp.reorder(mem_dt), WARN);
230 int fill_bia(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *r) {
231 dnn_mem_t mem_00(mem_dt.md_, mkldnn_f32, mkldnn_x);
233 const auto &c = p->cfg[BIA];
234 const int range = c.f_max - c.f_min + 1;
236 const size_t sz = mem_00.nelems();
237 for (size_t i = 0; i < sz; ++i) {
238 const int gen = (int)(19 * i);
239 const bool non_base = flip_coin(gen, c.f_sparsity);
240 const float value = non_base
241 ? c.f_min + gen * c.f_step % range : c.f_base;
243 ((float *)mem_00)[i] = value;
246 SAFE(mem_dt.reorder(mem_00), WARN);
247 SAFE(mem_fp.reorder(mem_dt), WARN);
251 int fill_dst(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *r) {
252 dnn_mem_t mem_00(mem_dt.md_, mkldnn_f32, mkldnn_nc);
254 const auto &c = p->cfg[DST];
255 const int range = c.f_max - c.f_min + 1;
257 mkldnn::impl::parallel_nd(p->mb, p->oc, [&](int mb, int oc) {
258 const int gen = 17 * mb + 13 * oc + 12;
259 const bool non_base = flip_coin(gen, c.f_sparsity);
260 const float value = non_base
261 ? c.f_min + gen * c.f_step % range : c.f_base;
263 ((float *)mem_00)[dst_off_f(p, mb, oc)] = value;
266 SAFE(mem_dt.reorder(mem_00), WARN);
267 SAFE(mem_fp.reorder(mem_dt), WARN);
272 int doit(const prb_t *p, res_t *r) {
273 mkldnn_inner_product_desc_t ipd;
274 mkldnn_primitive_desc_t ippd;
275 mkldnn_primitive_t ip;
277 SAFE(init_pd(p, ipd, ippd, r), WARN);
278 if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
281 auto &src_dt_d = p->dir == BWD_D ? ipd.diff_src_desc : ipd.src_desc;
282 auto &wei_dt_d = p->dir & FLAG_WEI ? ipd.diff_weights_desc : ipd.weights_desc;
283 auto &bia_dt_d = p->dir & FLAG_BWD ? ipd.diff_bias_desc : ipd.bias_desc;
284 auto &dst_dt_d = p->dir & FLAG_BWD ? ipd.diff_dst_desc: ipd.dst_desc;
286 const auto fp = mkldnn_f32;
287 dnn_mem_t src_dt(src_dt_d, p->cfg[SRC].dt);
288 dnn_mem_t wei_dt(wei_dt_d, p->cfg[WEI].dt);
289 dnn_mem_t dst_dt(dst_dt_d, p->cfg[DST].dt);
290 dnn_mem_t bia_dt = p->dir & FLAG_BIA
291 ? dnn_mem_t(bia_dt_d, p->cfg[BIA].dt) : dnn_mem_t();
293 auto src_format = is_3d(p) ? mkldnn_ncdhw : mkldnn_nchw;
294 auto wei_format = is_3d(p) ? mkldnn_oidhw : mkldnn_oihw;
295 dnn_mem_t src_fp(src_dt_d, fp, src_format);
296 dnn_mem_t wei_fp(wei_dt_d, fp, wei_format);
297 dnn_mem_t dst_fp(dst_dt_d, fp, mkldnn_nc);
298 dnn_mem_t bia_fp = p->dir & FLAG_BIA
299 ? dnn_mem_t(bia_dt_d, fp, mkldnn_x) : dnn_mem_t();
301 SAFE(fill_src(p, src_dt, src_fp, r), WARN);
302 SAFE(fill_wei(p, wei_dt, wei_fp, r), WARN);
303 SAFE(fill_dst(p, dst_dt, dst_fp, r), WARN);
304 if (p->dir & FLAG_BIA)
305 SAFE(fill_bia(p, bia_dt, bia_fp, r), WARN);
307 if (p->dir & FLAG_FWD) {
308 mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {wei_dt.p_, 0},
309 {p->dir & FLAG_BIA ? bia_dt.p_ : NULL, 0}
311 const_mkldnn_primitive_t outputs[] = { dst_dt.p_ };
312 DNN_SAFE(mkldnn_primitive_create(&ip, ippd, inputs, outputs), WARN);
313 SAFE(execute(ip), WARN);
314 if (bench_mode & CORR) {
315 compute_ref_fwd(p, src_fp, wei_fp, bia_fp, dst_fp);
316 dnn_mem_t dst(dst_dt, fp, mkldnn_nc);
317 SAFE(compare_dat(p, DST, dst, dst_fp, r), WARN);
319 } else if (p->dir == BWD_D) {
320 mkldnn_primitive_at_t inputs[3] = { {dst_dt.p_, 0}, {wei_dt.p_, 0}, };
321 const_mkldnn_primitive_t outputs[] = { src_dt.p_ };
322 DNN_SAFE(mkldnn_primitive_create(&ip, ippd, inputs, outputs), WARN);
323 SAFE(execute(ip), WARN);
324 if (bench_mode & CORR) {
325 compute_ref_bwd_d(p, src_fp, wei_fp, dst_fp);
326 dnn_mem_t src(src_dt, fp, src_format);
327 SAFE(compare_dat(p, SRC, src, src_fp, r), WARN);
329 } else if (p->dir & FLAG_BWD && p->dir & FLAG_WEI) {
330 mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {dst_dt.p_, 0}, };
331 const_mkldnn_primitive_t outputs[] = { wei_dt.p_,
332 p->dir & FLAG_BIA ? bia_dt.p_ : NULL,
334 DNN_SAFE(mkldnn_primitive_create(&ip, ippd, inputs, outputs), WARN);
335 SAFE(execute(ip), WARN);
336 if (bench_mode & CORR) {
337 compute_ref_bwd_w(p, src_fp, wei_fp, bia_fp, dst_fp);
338 dnn_mem_t wei(wei_dt, fp, wei_format);
339 if (compare_dat(p, WEI, wei, wei_fp, r) != OK) return FAIL;
340 if (p->dir & FLAG_BIA) {
341 dnn_mem_t bia(bia_dt, fp, mkldnn_x);
342 SAFE(compare_dat(p, BIA, bia, bia_fp, r), WARN);
347 if (bench_mode & PERF) {
351 SAFE(execute(ip), WARN);
353 const bool stop = false
354 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
355 || (!fix_times_per_prb
356 && t.total_ms() >= max_ms_per_prb
357 && t.times() >= min_times_per_prb);
362 DNN_SAFE(mkldnn_primitive_desc_destroy(ippd), CRIT);
363 DNN_SAFE(mkldnn_primitive_destroy(ip), CRIT);