Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / shuffle / shuffle.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stdlib.h>
18 #include <stddef.h>
19 #include <stdio.h>
20 #include <float.h>
21 #include <math.h>
22 #include <time.h>
23
24 #include "mkldnn.h"
25
26 #include "mkldnn_common.hpp"
27 #include "mkldnn_memory.hpp"
28 #include "norm.hpp"
29
30 #include "shuffle/shuffle.hpp"
31
32 namespace shuffle {
33
34 inline float saturate(float value, float min, float max) {
35     return MAX2(min, MIN2(max, value));
36 }
37
38 int fill_memory(const prb_t *p, dnn_mem_t &mem) {
39     dt_conf_t c_src;
40     switch (p->dt) {
41         case mkldnn_u8: c_src = conf_u8; break;
42         case mkldnn_s8: c_src = conf_s8; break;
43         case mkldnn_s32: c_src = conf_s32; break;
44         default: c_src = conf_f32; break;
45     }
46     const int range = c_src.range;
47     const int max = c_src.min + range - 1;
48
49     const size_t nelems = mem.nelems();
50
51     for (size_t idx = 0; idx < nelems; ++idx) {
52         float value = saturate((float)(idx % c_src.range), c_src.min, max);
53         mem.set_elem(idx, value);
54     }
55
56     return OK;
57 }
58
59 static int compare(const prb_t *p, const dnn_mem_t &fp_mem,
60         const dnn_mem_t &dt_mem, res_t *r) {
61     size_t nelems = fp_mem.nelems();
62     assert(nelems == dt_mem.nelems());
63     r->errors = 0;
64
65     for (size_t i = 0; i < nelems; ++i) {
66         const float fp = fp_mem.get_elem(i);
67         const float dt = dt_mem.get_elem(i);
68         const float diff = fabsf(fp - dt);
69         if (r->errors < 10 && diff != 0.0) {
70             printf("idx: %zu fp: %f dt:%f\n", i, fp, dt);
71             r->errors++;
72         }
73     }
74
75     if (r->errors)
76         r->state = FAILED;
77
78     if (r->state == UNTESTED)
79         r->state = PASSED; /* optimism */
80
81     return r->state == FAILED ? FAIL : OK;
82 }
83
84 static int init_pd(const prb_t *p, mkldnn_shuffle_desc_t &sd,
85         mkldnn_primitive_desc_t &spd, res_t *r) {
86
87     mkldnn_memory_desc_t data_d;
88     mkldnn_dims_t data_dims;
89     const int ndims = (int)p->dims.size();
90
91     for (int i = 0; i < ndims; ++i) data_dims[i] = p->dims[i];
92     DNN_SAFE(mkldnn_memory_desc_init(&data_d, ndims, data_dims, p->dt, p->fmt),
93            WARN);
94
95     mkldnn_status_t init_status = mkldnn_success;
96     mkldnn_primitive_desc_t hint_fwd_pd = NULL;
97     if (p->dir == FWD_D) {
98         auto prop = mkldnn_forward_training;
99         DNN_SAFE(mkldnn_shuffle_forward_desc_init(&sd, prop,
100                     &data_d, p->a, p->g), WARN);
101     } else if (p->dir == BWD_D) {
102         DNN_SAFE(mkldnn_shuffle_backward_desc_init(&sd, &data_d, p->a,
103                     p->g), WARN);
104         mkldnn_shuffle_desc_t sd_fwd;
105         DNN_SAFE(mkldnn_shuffle_forward_desc_init(&sd_fwd,
106                     mkldnn_forward_training, &data_d, p->a, p->g), WARN);
107         DNN_SAFE(mkldnn_primitive_desc_create(&hint_fwd_pd, &sd_fwd, engine,
108                     NULL), WARN);
109     }
110     init_status = mkldnn_primitive_desc_create(&spd, &sd, engine, hint_fwd_pd);
111     mkldnn_primitive_desc_destroy(hint_fwd_pd);
112
113     if (init_status == mkldnn_unimplemented)
114         return r->state = UNIMPLEMENTED, OK;
115     else
116         SAFE(init_status, WARN);
117
118     const char *impl_str = query_impl_info(spd);
119     print(5, "mkldnn implementation: %s\n", impl_str);
120
121     return OK;
122 }
123
124 int doit(const prb_t *p, res_t *r) {
125
126     res_t res_zero{};
127     *r = res_zero;
128
129     mkldnn_shuffle_desc_t sd;
130     mkldnn_primitive_desc_t spd;
131     mkldnn_primitive_t s{};
132
133     SAFE(init_pd(p, sd, spd, r), WARN);
134     if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
135         return OK;
136
137     const auto fp = p->dt;
138     auto &src_dt_d = sd.data_desc;
139
140     const int ndims = (int)p->dims.size();
141     const auto src_format = (ndims == 1)
142            ? mkldnn_x
143            : (ndims == 2)
144            ? mkldnn_nc
145            : get_default_format(ndims, fmt2data_kind(p->fmt));
146
147     dnn_mem_t src_fp(src_dt_d, fp, src_format), src_dt(src_dt_d);
148     dnn_mem_t dst_fp(src_dt_d, fp, src_format), dst_dt(src_dt_d);
149
150     SAFE(fill_memory(p, src_fp), WARN);
151
152     mkldnn_primitive_at_t inputs[1];
153     const_mkldnn_primitive_t outputs[1];
154     SAFE(src_dt.reorder(src_fp), WARN);
155     inputs[0] = {src_dt.p_, 0};
156     outputs[0] = dst_dt.p_;
157     DNN_SAFE(mkldnn_primitive_create(&s, spd, inputs, outputs), WARN);
158     DNN_SAFE_V(mkldnn_primitive_desc_destroy(spd));
159     SAFE(execute(s), WARN);
160     if (bench_mode & CORR) {
161         compute_shuffle(p, src_fp, dst_fp);
162         dnn_mem_t data(dst_dt, fp, src_format);
163         SAFE(compare(p, dst_fp, data, r), WARN);
164     }
165
166     if (bench_mode & PERF) {
167         auto &t = r->timer;
168         t.reset();
169         while (true) {
170             SAFE(execute(s), WARN);
171             t.stamp();
172             const bool stop = false
173                 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
174                 || (!fix_times_per_prb
175                         && t.total_ms() >= max_ms_per_prb
176                         && t.times() >= min_times_per_prb);
177             if (stop) break;
178         }
179     }
180
181     DNN_SAFE_V(mkldnn_primitive_destroy(s));
182     return OK;
183 }
184
185 }