1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "dnn_types.hpp"
20 #include "mkldnn_common.hpp"
21 #include "mkldnn_memory.hpp"
23 #include "reorder.hpp"
27 int get_scale_mask(const mkldnn_memory_desc_t &md, const attr_t &attr) {
28 using P = attr_t::scale_t::policy_t;
29 const auto policy = attr.oscale.policy;
31 const bool is_data = fmt2data_kind(md.format) == DATA;
32 const bool is_gwei = fmt2data_kind(md.format) == GWEI;
38 if (md.ndims < 2) SAFE_V(FAIL);
41 : (is_gwei ? (1 << 0) + (1 << 1) : 1 << 0);
44 case P::NONE: scale_mask = 0; break;
45 default: SAFE_V(FAIL);
51 int scales_count(int *count, int *mask, const dnn_mem_t &memory,
53 const mkldnn_memory_desc_t &md = memory.md_;
54 const int scale_mask = get_scale_mask(md, attr);
55 if (mask) *mask = scale_mask;
58 for(int d = 0; d < md.ndims; ++d) {
59 if (scale_mask & (1 << d))
60 uniq_scales *= md.dims[d];
66 int fill_scales(const prb_t *p, float *scales, int count) {
67 const float scale_value = p->attr.oscale.scale;
69 for (int i = 0; i < count; ++i)
70 scales[i] = scale_value;
72 if (count != 1) scales[count - 1] = scale_value + 1.1;
77 inline float saturate(float value, float min, float max) {
78 return MAX2(min, MIN2(max, value));
81 int fill_memory(const prb_t *p, dnn_mem_t &mem, const float *scales,
83 const dt_conf_t c_src = p->conf_in;
84 const int range = c_src->range;
85 const int max = c_src->min + range - 1;
86 int scale_mask = get_scale_mask(mem.md_, attr);
88 const size_t nelems = mem.nelems();
90 for (size_t idx = 0; idx < nelems; ++idx) {
91 const size_t mask_idx = mem.get_scale_idx(idx, scale_mask);
92 const float scale = scales[mask_idx];
94 const float gen[7] = {
95 (float)max, /* saturate to max of output data type */
96 (float)c_src->min, /* saturate to min of output data type */
97 (float)1.6 / scale, /* rounding check */
98 (float)0.2 / scale, /* saturate to 0 */
104 float value = saturate(gen[idx % 7], c_src->min, max);
105 mem.set_elem(idx, value);
112 int reorder(const prb_t *p, dnn_mem_t &dst, const dnn_mem_t &src,
113 const float *scales) {
114 auto dst_dt = dst.dt();
116 size_t nelems = src.nelems();
118 /* calculate min max for data_type */
119 /* TODO: add dst range support */
120 // const auto c_dst = p->conf_out;
121 // const float dst_conf_min = c_dst.min;
122 // const float dst_conf_max = dst_conf_min + c_dst.range - 1;
124 auto dst_width = dst.sizeof_dt() * 8;
126 const float dst_dt_min = dst_dt == mkldnn_u8
127 ? 0.f : -(float)(1l << (dst_width - 1));
128 const float dst_dt_max = dst_dt == mkldnn_u8
129 ? 255.f : (float)((1l << (dst_width - 1)) - 1);
131 /* TODO: add dst range support */
132 // const float dst_max = MIN2(dst_conf_max, dst_dt_max);
133 // const float dst_min = MAX2(dst_conf_min, dst_dt_min);
134 const float dst_max = dst_dt_max;
135 const float dst_min = dst_dt_min;
137 const int scale_mask = get_scale_mask(src.md_, p->attr);
139 for (size_t idx = 0; idx < nelems; ++idx) {
140 float src_ = src.get_elem(idx);
141 const size_t scale_idx = dst.get_scale_idx(idx, scale_mask);
143 const float scale = scales[scale_idx];
145 float dst_ = saturate(src_ * scale, dst_min, dst_max);
147 /* parse round mode and round value*/
148 if (dst_dt != mkldnn_f32) {
149 switch (p->attr.irmode) {
150 case attr_t::NEAREST: dst_ = rint(dst_); break;
151 case attr_t::DOWN: dst_ = floorf(dst_); break;
152 default: assert(!"unknown round_mode");
154 dst_ = saturate(dst_, dst_min, dst_max);
157 dst.set_elem(idx, dst_);
163 int compare(const prb_t *p, dnn_mem_t &mem_expected, dnn_mem_t &mem_computed,
164 const float *scales, int count, res_t *r){
165 size_t nelems = mem_expected.nelems();
166 assert(nelems == mem_computed.nelems());
171 /* TODO: range support */
172 const auto dt = mem_expected.dt();
173 const size_t width = mem_expected.sizeof_dt()*8;
175 const float dt_min = dt == mkldnn_u8
176 ? 0.f : -(float)(1l << (width - 1));
177 const float dt_max = dt == mkldnn_u8
178 ? 255.f : (float)((1l << (width - 1)) - 1);
180 size_t inf_p = 0, inf_n = 0, zeros = 0, reg = 0;
182 for (size_t i = 0; i < nelems; ++i) {
183 const float expected = mem_expected.get_elem(i);
184 const float computed = mem_computed.get_elem(i);
185 const float diff = fabsf(computed - expected);
187 if (expected == dt_max) inf_p++;
188 else if (expected == dt_min) inf_n++;
189 else if (expected == 0.0) zeros++;
193 if (r->errors < 10 && diff != 0.0) {
194 printf("idx: %zu exp: %f com:%f\n", i, expected, computed);
202 if (r->state == UNTESTED)
203 r->state = PASSED; /* optimism */
205 float max_scale = scales[0];
206 for (int i = 1; i < count; ++i) {
207 if (scales[i] > max_scale) max_scale = scales[i];
210 dt_conf_t c_src = p->conf_in;
211 dt_conf_t c_dst = p->conf_out;
212 const int c_src_max = c_src->min + c_src->range - 1;
213 const int c_dst_max = c_dst->min + c_dst->range - 1;
215 bool check_inf_p = (dt != mkldnn_f32 && dt != mkldnn_s32)
216 && (c_src_max * max_scale > c_dst_max) ? true : false;
217 bool check_inf_n = (dt != mkldnn_f32 && dt != mkldnn_s32)
218 && (c_src->min * max_scale < c_dst->min) ? true : false;
219 bool check_zeros = (dt != mkldnn_f32)
220 && (dt_min != 0 && dt_max != 0) ? true : false;
222 bool mistrusted = reg == 0
223 || (check_inf_p && inf_p == 0)
224 || (check_inf_n && inf_n == 0)
225 || (check_zeros && zeros == 0);
226 if (mistrusted) r->state = MISTRUSTED;
228 return r->state == FAILED ? FAIL : OK;
231 int check_reorder(const prb_t *p, res_t *res) {
232 /* ___________________
234 * | performance timer |
235 * |___________________|
237 * _______________ ______________ V ________________
238 * | | MKL-DNN | | MKL-DNN | |
239 * | dt_in fmt_ref |-------->| dt_in fmt_in |-------->| dt_out fmt_out |
240 * |_______________| |______________| ^ |________________|
242 * benchdnn |<-------------------------------- scales | MKL-DNN
243 * ________V_______ _______V________
245 * | dt_out fmt_ref | <= compare => | dt_out fmt_ref |
246 * |________________| |________________|
251 * 3. fill input memory
252 * 4. execute mkl-dnn: reorder->q10n->reorder
253 * 5. execute benchdnn: q10n
255 * 7. performance measurment
259 const reorder_conf_t &r = p->reorder;
260 const int ndims = (int)r.dims.size();
261 const ptrdiff_t *dims = &r.dims[0];
263 mkldnn_memory_format_t fmt_ref;
264 const bool is_data = fmt2data_kind(r.fmt_in) == DATA;
265 const bool is_gwei = fmt2data_kind(r.fmt_in) == GWEI;
268 case 1: assert(is_data); fmt_ref = mkldnn_x; break;
269 case 2: fmt_ref = is_data ? mkldnn_nc : mkldnn_oi; break;
270 case 3: assert(is_data); fmt_ref = mkldnn_tnc; break;
271 case 4: fmt_ref = is_data ? mkldnn_nchw : mkldnn_oihw; break;
275 : (is_gwei ? mkldnn_goihw : mkldnn_oidhw);
277 case 6: assert(!is_data);
278 fmt_ref = is_gwei ? mkldnn_goidhw : mkldnn_ldigo;
280 default: assert(!"bad ndims"); return FAIL;
283 /* Step 1: create memory */
284 dnn_mem_t mem_dt_in_fmt_ref(ndims, dims, p->conf_in->dt, fmt_ref);
285 dnn_mem_t mem_dt_in_fmt_in(ndims, dims, p->conf_in->dt, r.fmt_in);
286 dnn_mem_t mem_dt_out_fmt_out(ndims, dims, p->conf_out->dt, r.fmt_out);
287 dnn_mem_t mem_dt_out_fmt_ref(ndims, dims, p->conf_out->dt, fmt_ref);
288 dnn_mem_t mem_test_dt_out_fmt_ref(ndims, dims, p->conf_out->dt, fmt_ref);
290 /* Step 2: fill scales */
291 int count = 0, mask = 0;
292 SAFE(scales_count(&count, &mask, mem_dt_out_fmt_out, p->attr), WARN);
293 float *scales = (float *)zmalloc(sizeof(float) * count, 64);
294 SAFE(scales != NULL ? OK : FAIL, CRIT);
295 SAFE(fill_scales(p, scales, count), WARN);
296 /* Step 3: fill input memory */
297 SAFE(fill_memory(p, mem_dt_in_fmt_ref, scales, p->attr), WARN);
299 /* Step 4: execute mkl-dnn */
300 SAFE(mem_dt_in_fmt_in.reorder(mem_dt_in_fmt_ref), WARN);
302 auto mkldnn_attr = create_mkldnn_attr(p->attr, count, mask, scales);
304 mkldnn_primitive_desc_t check_rpd;
305 mkldnn_status_t init_status = mkldnn_reorder_primitive_desc_create_v2(
306 &check_rpd, mem_dt_in_fmt_in.mpd_, mem_dt_out_fmt_out.mpd_,
308 if (init_status == mkldnn_unimplemented) {
309 res->state = UNIMPLEMENTED;
312 mkldnn_primitive_desc_destroy(check_rpd);
313 SAFE(init_status, WARN);
315 SAFE(mem_dt_out_fmt_out.reorder(mem_dt_in_fmt_in, mkldnn_attr), WARN);
317 /* Step 5: check corrrectness */
318 if (bench_mode & CORR) {
319 /* Step 5a: reorder output from mkldnn to ref format using mkldnn */
320 SAFE(mem_dt_out_fmt_ref.reorder(mem_dt_out_fmt_out), WARN);
322 /* Step 5b: execute benchdnn reorder */
323 SAFE(reorder(p, mem_test_dt_out_fmt_ref, mem_dt_in_fmt_ref, scales), WARN);
325 /* Step 5c: compare benchdnn and mkldnn output */
326 SAFE(compare(p, mem_test_dt_out_fmt_ref, mem_dt_out_fmt_ref,
327 scales, count, res), WARN);
330 /* Step 6: performance measurement */
331 if (bench_mode & PERF) {
332 mkldnn_primitive_desc_t perf_r_pd;
333 mkldnn_primitive_t perf_r;
335 DNN_SAFE(mkldnn_reorder_primitive_desc_create_v2(&perf_r_pd,
336 mem_dt_in_fmt_in.mpd_, mem_dt_out_fmt_out.mpd_,
338 mkldnn_primitive_at_t i = {mem_dt_in_fmt_in.p_, 0};
339 const_mkldnn_primitive_t o = mem_dt_out_fmt_out.p_;
340 DNN_SAFE(mkldnn_primitive_create(&perf_r, perf_r_pd, &i, &o), WARN);
341 DNN_SAFE_V(mkldnn_primitive_desc_destroy(perf_r_pd));
343 auto &t = res->timer;
346 SAFE(execute(perf_r), WARN);
348 const bool stop = false
349 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
350 || (!fix_times_per_prb
351 && t.total_ms() >= max_ms_per_prb
352 && t.times() >= min_times_per_prb);
356 DNN_SAFE_V(mkldnn_primitive_destroy(perf_r));
359 /* Step 7: clean up */
361 mkldnn_primitive_attr_destroy(mkldnn_attr);
367 int doit(const prb_t *p, res_t *r) {
368 return check_reorder(p, r);