1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef _CONV_COMMON_HPP
18 #define _CONV_COMMON_HPP
25 #include "dnn_types.hpp"
26 #include "mkldnn_common.hpp"
27 #include "mkldnn_memory.hpp"
30 /* some extra control parameters which shouldn't be placed in prb_t */
31 extern const char *skip_impl; /* NULL or "" means do not skip anything */
32 extern bool allow_unimpl; /* true means do not treat unimplemented as error */
33 extern const char *perf_template; /* performance output template */
38 enum alg_t { DIRECT, WINO, AUTO };
39 alg_t str2alg(const char *str);
40 const char *alg2str(alg_t alg);
41 alg_t alg_kind2alg(mkldnn_alg_kind_t alg);
55 const size_t max_desc_len = 196;
56 int str2desc(desc_t *desc, const char *str, bool is_deconv);
57 void desc2str(const desc_t *d, char *buffer, bool canonical = false);
59 /** configuration structure, that controls initial data filling + error check
61 * dt defines convolution precision
63 * for each type (SRC, WEI, BIA, and DST) the values are filled as follows:
64 * if (rand() > f_sparsity) then:
65 * v <-- f_base // it is guaranteed each kernel window
66 * // has at least one non-zero element
68 * v <-- f_min + rand() * f_step % (f_max - f_min)
71 * on final check the resulting values should be in [min .. max] range, the
72 * relative difference should not exceed eps
74 typedef struct dt_conf_t {
75 mkldnn_data_type_t dt;
76 double min, max; /* representative */
77 int f_min, f_max; /* fill range */
78 int f_base; /* fill base, use 0 */
79 int f_step; /* fill step, use 1 */
80 double f_sparsity; /* amount of non-zeros, default 0.25 */
81 double eps; /* acceptable error */
82 } _dt_conf_t[DAT_TOTAL];
84 extern const _dt_conf_t conf_f32;
85 extern const _dt_conf_t conf_f32_full;
86 extern const _dt_conf_t conf_f32_wino;
87 extern const _dt_conf_t conf_s16s16s32s32;
88 extern const _dt_conf_t conf_s32s16s16s32;
89 extern const _dt_conf_t conf_s16s32s16s32;
90 extern const _dt_conf_t conf_u8s8s32s32;
91 extern const _dt_conf_t conf_u8s8s8s32;
92 extern const _dt_conf_t conf_u8s8u8s32;
93 extern const _dt_conf_t conf_s8s8s32s32;
94 extern const _dt_conf_t conf_s8s8s8s32;
95 extern const _dt_conf_t conf_s8s8u8s32;
96 extern const _dt_conf_t conf_u8s8f32s32_wino;
97 extern const _dt_conf_t conf_u8s8s32s32_wino;
98 extern const _dt_conf_t conf_u8s8s8s32_wino;
99 extern const _dt_conf_t conf_u8s8u8s32_wino;
101 const dt_conf_t *str2cfg(const char *str);
102 const char *cfg2str(const dt_conf_t *cfg);
103 const dt_conf_t *auto_cfg(const alg_t alg, const dt_conf_t *cfg);
105 struct prb_t: public desc_t {
106 prb_t(const desc_t &desc, dir_t dir, const dt_conf_t *cfg, alg_t alg,
107 const attr_t &attr, int mb = 0, bool is_deconv = false)
108 : desc_t(desc), dir(dir), cfg(cfg), alg(alg), attr(attr)
109 , ops(0), scales(NULL), is_deconv(is_deconv) {
110 if (mb) this->mb = mb;
114 ~prb_t() { if (scales) zfree(scales); }
117 const dt_conf_t *cfg;
126 void generate_oscales();
129 prb_t(const prb_t &) = delete;
130 prb_t &operator=(const prb_t &) = delete;
132 const size_t max_prb_len = max_attr_len + max_desc_len + 196;
133 void prb2str(const prb_t *p, char *buffer, bool canonical = false);
135 /* some extra control parameters which shouldn't be placed in prb_t */
136 extern const char *skip_impl; /* NULL or "" means do not skip anything */
137 extern bool allow_unimpl; /* true means do not treat unimplemented as error */
138 extern const char *perf_template; /* performance output template */
140 inline size_t src_off_f(const prb_t *p, int mb, int g, int ic,
141 int id, int ih, int iw)
143 return ((((size_t)mb * p->ic + g * p->ic/p->g + ic)
144 * p->id + id) * p->ih + ih) * p->iw + iw;
147 inline void inv_src_off_f(const prb_t *p, size_t off, int &mb, int &g, int &ic,
148 int &id, int &ih, int &iw) {
149 iw = off % p->iw; off /= p->iw;
150 ih = off % p->ih; off /= p->ih;
151 id = off % p->id; off /= p->id;
152 ic = off % (p->ic / p->g); off /= (p->ic / p->g);
153 g = off % p->g; off /= p->g;
154 mb = off % p->mb; off /= p->mb;
158 inline size_t wei_off_f(const prb_t *p, int g, int oc, int ic,
159 int kd, int kh, int kw)
161 return (((((size_t)g * p->oc / p->g + oc) * p->ic / p->g + ic)
162 * p->kd + kd) * p->kh + kh) * p->kw + kw;
165 inline void inv_wei_off_f(const prb_t *p, size_t off, int &g, int &oc, int &ic,
166 int &kd, int &kh, int &kw) {
167 kw = off % p->kw; off /= p->kw;
168 kh = off % p->kh; off /= p->kh;
169 kd = off % p->kd; off /= p->kd;
170 ic = off % (p->ic / p->g); off /= (p->ic / p->g);
171 oc = off % (p->oc / p->g); off /= (p->oc / p->g);
172 g = off % p->g; off /= p->g;
176 inline size_t bia_off_f(const prb_t *p, int g, int oc) {
177 return (size_t)g * p->oc / p->g + oc;
180 inline void inv_bia_off_f(const prb_t *p, size_t off, int &g, int &oc) {
181 oc = off % (p->oc / p->g); off /= (p->oc / p->g);
182 g = off % p->g; off /= p->g;
186 inline size_t dst_off_f(const prb_t *p, int mb, int g, int oc,
187 int od, int oh, int ow)
189 return ((((size_t)mb * p->oc + g * p->oc/p->g + oc) * p->od + od)
190 * p->oh + oh) * p->ow + ow;
193 inline void inv_dst_off_f(const prb_t *p, size_t off, int &mb, int &g, int &oc,
194 int &od, int &oh, int &ow) {
195 ow = off % p->ow; off /= p->ow;
196 oh = off % p->oh; off /= p->oh;
197 od = off % p->od; off /= p->od;
198 oc = off % (p->oc / p->g); off /= (p->oc / p->g);
199 g = off % p->g; off /= p->g;
200 mb = off % p->mb; off /= p->mb;
204 float oscale(const prb_t *p, int oc);
206 void compute_ref_fwd(const prb_t *p, dnn_mem_t &src_m, dnn_mem_t &wei_m,
207 dnn_mem_t &bia_m, dnn_mem_t &dst_m);
208 void compute_ref_bwd_d(const prb_t *p, dnn_mem_t &diff_src_m, dnn_mem_t &wei_m,
209 dnn_mem_t &bia_m, dnn_mem_t &diff_dst_m);
210 void compute_ref_bwd_w(const prb_t *p, dnn_mem_t &src_m, dnn_mem_t &diff_wei_m,
211 dnn_mem_t &diff_bia_m, dnn_mem_t &diff_dst_m);
213 void compute_ref_direct_fwd(const prb_t *p, dnn_mem_t &src_m, dnn_mem_t &wei_m,
214 dnn_mem_t &bia_m, dnn_mem_t &dst_m);
215 void compute_ref_direct_bwd_d(const prb_t *p, dnn_mem_t &diff_src_m, dnn_mem_t &wei_m,
216 dnn_mem_t &bia_m, dnn_mem_t &diff_dst_m);
217 void compute_ref_direct_bwd_w(const prb_t *p, dnn_mem_t &src_m, dnn_mem_t &diff_wei_m,
218 dnn_mem_t &diff_bia_m, dnn_mem_t &diff_dst_m);
220 void compute_wino_ref_fwd(const prb_t *p, dnn_mem_t &src_m, dnn_mem_t &wei_m,
221 dnn_mem_t &bia_m, dnn_mem_t &dst_m);
222 void compute_wino_ref_bwd_d(const prb_t *p, dnn_mem_t &idiff_src_m,
223 dnn_mem_t &wei_m, dnn_mem_t &bia_m, dnn_mem_t &diff_dst_m);
224 void compute_wino_ref_bwd_w(const prb_t *p, dnn_mem_t &src_m,
225 dnn_mem_t &diff_wei_m, dnn_mem_t &diff_bia_m, dnn_mem_t &diff_dst_m);
227 void perf_report(const prb_t *p, const res_t *r, const char *pstr);
229 int compare_src(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
230 res_t *r, bool final_compare = false);
231 int compare_wei(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
232 res_t *r, bool final_compare = false);
233 int compare_bia(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
234 res_t *r, bool final_compare = false);
235 int compare_dst(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
236 res_t *r, bool final_compare = false);
237 int fill_src(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
239 int fill_wei(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
241 int fill_bia(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
243 int fill_dst(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
245 double get_trust_nz_level(const prb_t *p, data_kind_t kind, bool final_compare);
247 void compute_ref_bwd_bias(const prb_t *p, dnn_mem_t &diff_bia_m, dnn_mem_t &diff_dst_m);
248 void compute_bias_fwd(const prb_t *p, dnn_mem_t &bia_m, dnn_mem_t &dst_m);
249 void compute_ref_bwd_weights(const prb_t *p, dnn_mem_t &src_m,dnn_mem_t &diff_wei_m, dnn_mem_t &diff_dst_m);