1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
26 #include "mkldnn_common.hpp"
27 #include "mkldnn_memory.hpp"
30 #include "shuffle/shuffle.hpp"
34 inline float saturate(float value, float min, float max) {
35 return MAX2(min, MIN2(max, value));
38 int fill_memory(const prb_t *p, dnn_mem_t &mem) {
41 case mkldnn_u8: c_src = conf_u8; break;
42 case mkldnn_s8: c_src = conf_s8; break;
43 case mkldnn_s32: c_src = conf_s32; break;
44 default: c_src = conf_f32; break;
46 const int range = c_src.range;
47 const int max = c_src.min + range - 1;
49 const size_t nelems = mem.nelems();
51 for (size_t idx = 0; idx < nelems; ++idx) {
52 float value = saturate((float)(idx % c_src.range), c_src.min, max);
53 mem.set_elem(idx, value);
59 static int compare(const prb_t *p, const dnn_mem_t &fp_mem,
60 const dnn_mem_t &dt_mem, res_t *r) {
61 size_t nelems = fp_mem.nelems();
62 assert(nelems == dt_mem.nelems());
65 for (size_t i = 0; i < nelems; ++i) {
66 const float fp = fp_mem.get_elem(i);
67 const float dt = dt_mem.get_elem(i);
68 const float diff = fabsf(fp - dt);
69 if (r->errors < 10 && diff != 0.0) {
70 printf("idx: %zu fp: %f dt:%f\n", i, fp, dt);
78 if (r->state == UNTESTED)
79 r->state = PASSED; /* optimism */
81 return r->state == FAILED ? FAIL : OK;
84 static int init_pd(const prb_t *p, mkldnn_shuffle_desc_t &sd,
85 mkldnn_primitive_desc_t &spd, res_t *r) {
87 mkldnn_memory_desc_t data_d;
88 mkldnn_dims_t data_dims;
89 const int ndims = (int)p->dims.size();
91 for (int i = 0; i < ndims; ++i) data_dims[i] = p->dims[i];
92 DNN_SAFE(mkldnn_memory_desc_init(&data_d, ndims, data_dims, p->dt, p->fmt),
95 mkldnn_status_t init_status = mkldnn_success;
96 mkldnn_primitive_desc_t hint_fwd_pd = NULL;
97 if (p->dir == FWD_D) {
98 auto prop = mkldnn_forward_training;
99 DNN_SAFE(mkldnn_shuffle_forward_desc_init(&sd, prop,
100 &data_d, p->a, p->g), WARN);
101 } else if (p->dir == BWD_D) {
102 DNN_SAFE(mkldnn_shuffle_backward_desc_init(&sd, &data_d, p->a,
104 mkldnn_shuffle_desc_t sd_fwd;
105 DNN_SAFE(mkldnn_shuffle_forward_desc_init(&sd_fwd,
106 mkldnn_forward_training, &data_d, p->a, p->g), WARN);
107 DNN_SAFE(mkldnn_primitive_desc_create(&hint_fwd_pd, &sd_fwd, engine,
110 init_status = mkldnn_primitive_desc_create(&spd, &sd, engine, hint_fwd_pd);
111 mkldnn_primitive_desc_destroy(hint_fwd_pd);
113 if (init_status == mkldnn_unimplemented)
114 return r->state = UNIMPLEMENTED, OK;
116 SAFE(init_status, WARN);
118 const char *impl_str = query_impl_info(spd);
119 print(5, "mkldnn implementation: %s\n", impl_str);
124 int doit(const prb_t *p, res_t *r) {
129 mkldnn_shuffle_desc_t sd;
130 mkldnn_primitive_desc_t spd;
131 mkldnn_primitive_t s{};
133 SAFE(init_pd(p, sd, spd, r), WARN);
134 if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
137 const auto fp = p->dt;
138 auto &src_dt_d = sd.data_desc;
140 const int ndims = (int)p->dims.size();
141 const auto src_format = (ndims == 1)
145 : get_default_format(ndims, fmt2data_kind(p->fmt));
147 dnn_mem_t src_fp(src_dt_d, fp, src_format), src_dt(src_dt_d);
148 dnn_mem_t dst_fp(src_dt_d, fp, src_format), dst_dt(src_dt_d);
150 SAFE(fill_memory(p, src_fp), WARN);
152 mkldnn_primitive_at_t inputs[1];
153 const_mkldnn_primitive_t outputs[1];
154 SAFE(src_dt.reorder(src_fp), WARN);
155 inputs[0] = {src_dt.p_, 0};
156 outputs[0] = dst_dt.p_;
157 DNN_SAFE(mkldnn_primitive_create(&s, spd, inputs, outputs), WARN);
158 DNN_SAFE_V(mkldnn_primitive_desc_destroy(spd));
159 SAFE(execute(s), WARN);
160 if (bench_mode & CORR) {
161 compute_shuffle(p, src_fp, dst_fp);
162 dnn_mem_t data(dst_dt, fp, src_format);
163 SAFE(compare(p, dst_fp, data, r), WARN);
166 if (bench_mode & PERF) {
170 SAFE(execute(s), WARN);
172 const bool stop = false
173 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
174 || (!fix_times_per_prb
175 && t.total_ms() >= max_ms_per_prb
176 && t.times() >= min_times_per_prb);
181 DNN_SAFE_V(mkldnn_primitive_destroy(s));