Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / reorder / reorder.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stdlib.h>
18
19 #include "dnn_types.hpp"
20 #include "mkldnn_common.hpp"
21 #include "mkldnn_memory.hpp"
22
23 #include "reorder.hpp"
24
25 namespace reorder {
26
27 int get_scale_mask(const mkldnn_memory_desc_t &md, const attr_t &attr) {
28     using P = attr_t::scale_t::policy_t;
29     const auto policy = attr.oscale.policy;
30
31     const bool is_data = fmt2data_kind(md.format) == DATA;
32     const bool is_gwei = fmt2data_kind(md.format) == GWEI;
33
34     int scale_mask = 0;
35
36     switch (policy) {
37     case P::PER_OC:
38         if (md.ndims < 2) SAFE_V(FAIL);
39         scale_mask = is_data
40             ? 1 << 1
41             : (is_gwei ? (1 << 0) + (1 << 1) : 1 << 0);
42         break;
43     case P::COMMON:
44     case P::NONE: scale_mask = 0; break;
45     default: SAFE_V(FAIL);
46     }
47
48     return scale_mask;
49 }
50
51 int scales_count(int *count, int *mask, const dnn_mem_t &memory,
52         const attr_t &attr) {
53     const mkldnn_memory_desc_t &md = memory.md_;
54     const int scale_mask = get_scale_mask(md, attr);
55     if (mask) *mask = scale_mask;
56
57     int uniq_scales = 1;
58     for(int d = 0; d < md.ndims; ++d) {
59         if (scale_mask & (1 << d))
60             uniq_scales *= md.dims[d];
61     }
62     *count = uniq_scales;
63     return OK;
64 }
65
66 int fill_scales(const prb_t *p, float *scales, int count) {
67     const float scale_value = p->attr.oscale.scale;
68
69     for (int i = 0; i < count; ++i)
70         scales[i] = scale_value;
71
72     if (count != 1) scales[count - 1] = scale_value + 1.1;
73
74     return OK;
75 }
76
77 inline float saturate(float value, float min, float max) {
78     return MAX2(min, MIN2(max, value));
79 }
80
81 int fill_memory(const prb_t *p, dnn_mem_t &mem, const float *scales,
82         const attr_t &attr) {
83     const dt_conf_t c_src = p->conf_in;
84     const int range = c_src->range;
85     const int max = c_src->min + range - 1;
86     int scale_mask = get_scale_mask(mem.md_, attr);
87
88     const size_t nelems = mem.nelems();
89
90     for (size_t idx = 0; idx < nelems; ++idx) {
91         const size_t mask_idx = mem.get_scale_idx(idx, scale_mask);
92         const float scale = scales[mask_idx];
93
94         const float gen[7] = {
95             (float)max, /* saturate to max of output data type */
96             (float)c_src->min, /* saturate to min of output data type */
97             (float)1.6 / scale, /* rounding check */
98             (float)0.2 / scale, /* saturate to 0 */
99             (float)1.0,
100             (float)2.0,
101             (float)scale,
102         };
103
104         float value = saturate(gen[idx % 7], c_src->min, max);
105         mem.set_elem(idx, value);
106     }
107
108     return OK;
109 }
110
111 /* TODO: Complete */
112 int reorder(const prb_t *p, dnn_mem_t &dst, const dnn_mem_t &src,
113         const float *scales) {
114     auto dst_dt = dst.dt();
115
116     size_t nelems = src.nelems();
117
118     /* calculate min max for data_type */
119     /* TODO: add dst range support */
120 //    const auto c_dst = p->conf_out;
121 //    const float dst_conf_min = c_dst.min;
122 //    const float dst_conf_max = dst_conf_min + c_dst.range - 1;
123
124     auto dst_width = dst.sizeof_dt() * 8;
125
126     const float dst_dt_min = dst_dt == mkldnn_u8
127         ? 0.f : -(float)(1l << (dst_width - 1));
128     const float dst_dt_max = dst_dt == mkldnn_u8
129         ? 255.f : (float)((1l << (dst_width - 1)) - 1);
130
131     /* TODO: add dst range support */
132 //    const float dst_max = MIN2(dst_conf_max, dst_dt_max);
133 //    const float dst_min = MAX2(dst_conf_min, dst_dt_min);
134     const float dst_max = dst_dt_max;
135     const float dst_min = dst_dt_min;
136
137     const int scale_mask = get_scale_mask(src.md_, p->attr);
138
139     for (size_t idx = 0; idx < nelems; ++idx) {
140         float src_ = src.get_elem(idx);
141         const size_t scale_idx = dst.get_scale_idx(idx, scale_mask);
142
143         const float scale = scales[scale_idx];
144
145         float dst_ = saturate(src_ * scale, dst_min, dst_max);
146
147         /* parse round mode and round value*/
148         if (dst_dt != mkldnn_f32) {
149             switch (p->attr.irmode) {
150                 case attr_t::NEAREST: dst_ = rint(dst_); break;
151                 case attr_t::DOWN: dst_ = floorf(dst_); break;
152                 default: assert(!"unknown round_mode");
153             }
154             dst_ = saturate(dst_, dst_min, dst_max);
155         }
156
157         dst.set_elem(idx, dst_);
158     }
159
160     return OK;
161 }
162
163 int compare(const prb_t *p, dnn_mem_t &mem_expected, dnn_mem_t &mem_computed,
164         const float *scales, int count, res_t *r){
165     size_t nelems = mem_expected.nelems();
166     assert(nelems == mem_computed.nelems());
167
168     r->errors = 0;
169     r->total = nelems;
170
171     /* TODO: range support */
172     const auto dt = mem_expected.dt();
173     const size_t width = mem_expected.sizeof_dt()*8;
174
175     const float dt_min = dt == mkldnn_u8
176         ? 0.f : -(float)(1l << (width - 1));
177     const float dt_max = dt == mkldnn_u8
178         ? 255.f : (float)((1l << (width - 1)) - 1);
179
180     size_t inf_p = 0, inf_n = 0, zeros = 0, reg = 0;
181
182     for (size_t i = 0; i < nelems; ++i) {
183         const float expected = mem_expected.get_elem(i);
184         const float computed = mem_computed.get_elem(i);
185         const float diff = fabsf(computed - expected);
186
187         if (expected == dt_max) inf_p++;
188         else if (expected == dt_min) inf_n++;
189         else if (expected == 0.0) zeros++;
190         else
191             reg++;
192
193         if (r->errors < 10 && diff != 0.0) {
194             printf("idx: %zu exp: %f com:%f\n", i, expected, computed);
195             r->errors++;
196         }
197     }
198
199     if (r->errors)
200         r->state = FAILED;
201
202     if (r->state == UNTESTED)
203         r->state = PASSED; /* optimism */
204
205     float max_scale = scales[0];
206     for (int i = 1; i < count; ++i) {
207         if (scales[i] > max_scale) max_scale = scales[i];
208     }
209
210     dt_conf_t c_src = p->conf_in;
211     dt_conf_t c_dst = p->conf_out;
212     const int c_src_max = c_src->min + c_src->range - 1;
213     const int c_dst_max = c_dst->min + c_dst->range - 1;
214
215     bool check_inf_p = (dt != mkldnn_f32 && dt != mkldnn_s32)
216         && (c_src_max * max_scale > c_dst_max) ? true : false;
217     bool check_inf_n = (dt != mkldnn_f32 && dt != mkldnn_s32)
218         && (c_src->min * max_scale < c_dst->min) ? true : false;
219     bool check_zeros = (dt != mkldnn_f32)
220         && (dt_min != 0 && dt_max != 0) ? true : false;
221
222     bool mistrusted = reg == 0
223         || (check_inf_p && inf_p == 0)
224         || (check_inf_n && inf_n == 0)
225         || (check_zeros && zeros == 0);
226     if (mistrusted) r->state = MISTRUSTED;
227
228     return r->state == FAILED ? FAIL : OK;
229 }
230
231 int check_reorder(const prb_t *p, res_t *res) {
232 /*                                       ___________________
233  *                                      |                   |
234  *                                      | performance timer |
235  *                                      |___________________|
236  *                                                |
237  *   _______________           ______________     V     ________________
238  *  |               | MKL-DNN |              | MKL-DNN |                |
239  *  | dt_in fmt_ref |-------->| dt_in fmt_in |-------->| dt_out fmt_out |
240  *  |_______________|         |______________|    ^    |________________|
241  *           |                                    |            |
242  *  benchdnn |<-------------------------------- scales         | MKL-DNN
243  *   ________V_______                                   _______V________
244  *  |                |                                 |                |
245  *  | dt_out fmt_ref |         <= compare =>           | dt_out fmt_ref |
246  *  |________________|                                 |________________|
247  *
248  * Steps:
249  * 1. create memory
250  * 2. fill scales
251  * 3. fill input memory
252  * 4. execute mkl-dnn: reorder->q10n->reorder
253  * 5. execute benchdnn: q10n
254  * 6. compare results
255  * 7. performance measurment
256  * 8. clean up
257  */
258
259     const reorder_conf_t &r = p->reorder;
260     const int ndims = (int)r.dims.size();
261     const ptrdiff_t *dims = &r.dims[0];
262
263     mkldnn_memory_format_t fmt_ref;
264     const bool is_data = fmt2data_kind(r.fmt_in) == DATA;
265     const bool is_gwei = fmt2data_kind(r.fmt_in) == GWEI;
266
267     switch (ndims) {
268     case 1: assert(is_data); fmt_ref = mkldnn_x; break;
269     case 2: fmt_ref = is_data ? mkldnn_nc : mkldnn_oi; break;
270     case 3: assert(is_data); fmt_ref = mkldnn_tnc; break;
271     case 4: fmt_ref = is_data ? mkldnn_nchw : mkldnn_oihw; break;
272     case 5:
273             fmt_ref = is_data
274                 ? mkldnn_ncdhw
275                 : (is_gwei ? mkldnn_goihw : mkldnn_oidhw);
276             break;
277     case 6: assert(!is_data);
278             fmt_ref = is_gwei ? mkldnn_goidhw : mkldnn_ldigo;
279             break;
280     default: assert(!"bad ndims"); return FAIL;
281     }
282
283     /* Step 1: create memory */
284     dnn_mem_t mem_dt_in_fmt_ref(ndims, dims, p->conf_in->dt, fmt_ref);
285     dnn_mem_t mem_dt_in_fmt_in(ndims, dims, p->conf_in->dt, r.fmt_in);
286     dnn_mem_t mem_dt_out_fmt_out(ndims, dims, p->conf_out->dt, r.fmt_out);
287     dnn_mem_t mem_dt_out_fmt_ref(ndims, dims, p->conf_out->dt, fmt_ref);
288     dnn_mem_t mem_test_dt_out_fmt_ref(ndims, dims, p->conf_out->dt, fmt_ref);
289
290     /* Step 2: fill scales */
291     int count = 0, mask = 0;
292     SAFE(scales_count(&count, &mask, mem_dt_out_fmt_out, p->attr), WARN);
293     float *scales = (float *)zmalloc(sizeof(float) * count, 64);
294     SAFE(scales != NULL ? OK : FAIL, CRIT);
295     SAFE(fill_scales(p, scales, count), WARN);
296     /* Step 3: fill input memory */
297     SAFE(fill_memory(p, mem_dt_in_fmt_ref, scales, p->attr), WARN);
298
299     /* Step 4: execute mkl-dnn */
300     SAFE(mem_dt_in_fmt_in.reorder(mem_dt_in_fmt_ref), WARN);
301
302     auto mkldnn_attr = create_mkldnn_attr(p->attr, count, mask, scales);
303
304     mkldnn_primitive_desc_t check_rpd;
305     mkldnn_status_t init_status = mkldnn_reorder_primitive_desc_create_v2(
306             &check_rpd, mem_dt_in_fmt_in.mpd_, mem_dt_out_fmt_out.mpd_,
307             mkldnn_attr);
308     if (init_status == mkldnn_unimplemented) {
309         res->state = UNIMPLEMENTED;
310         goto cleanup;
311     }
312     mkldnn_primitive_desc_destroy(check_rpd);
313     SAFE(init_status, WARN);
314
315     SAFE(mem_dt_out_fmt_out.reorder(mem_dt_in_fmt_in, mkldnn_attr), WARN);
316
317     /* Step 5: check corrrectness */
318     if (bench_mode & CORR) {
319         /* Step 5a: reorder output from mkldnn to ref format using mkldnn */
320         SAFE(mem_dt_out_fmt_ref.reorder(mem_dt_out_fmt_out), WARN);
321
322         /* Step 5b: execute benchdnn reorder */
323         SAFE(reorder(p, mem_test_dt_out_fmt_ref, mem_dt_in_fmt_ref, scales), WARN);
324
325         /* Step 5c: compare benchdnn and mkldnn output */
326         SAFE(compare(p, mem_test_dt_out_fmt_ref, mem_dt_out_fmt_ref,
327                     scales, count, res), WARN);
328     }
329
330     /* Step 6: performance measurement */
331     if (bench_mode & PERF) {
332         mkldnn_primitive_desc_t perf_r_pd;
333         mkldnn_primitive_t perf_r;
334
335         DNN_SAFE(mkldnn_reorder_primitive_desc_create_v2(&perf_r_pd,
336                 mem_dt_in_fmt_in.mpd_, mem_dt_out_fmt_out.mpd_,
337                 mkldnn_attr), WARN);
338         mkldnn_primitive_at_t i = {mem_dt_in_fmt_in.p_, 0};
339         const_mkldnn_primitive_t o = mem_dt_out_fmt_out.p_;
340         DNN_SAFE(mkldnn_primitive_create(&perf_r, perf_r_pd, &i, &o), WARN);
341         DNN_SAFE_V(mkldnn_primitive_desc_destroy(perf_r_pd));
342
343         auto &t = res->timer;
344         t.reset();
345         while (true) {
346             SAFE(execute(perf_r), WARN);
347             t.stamp();
348             const bool stop = false
349                 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
350                 || (!fix_times_per_prb
351                         && t.total_ms() >= max_ms_per_prb
352                         && t.times() >= min_times_per_prb);
353             if (stop) break;
354         }
355
356         DNN_SAFE_V(mkldnn_primitive_destroy(perf_r));
357     }
358
359     /* Step 7: clean up */
360 cleanup:
361     mkldnn_primitive_attr_destroy(mkldnn_attr);
362     zfree(scales);
363
364     return OK;
365 }
366
367 int doit(const prb_t *p, res_t *r) {
368     return check_reorder(p, r);
369 }
370
371 }