Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / conv / conv_aux.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stdlib.h>
18 #include <string.h>
19 #include <stdio.h>
20 #include <float.h>
21 #include <math.h>
22
23 #include "mkldnn.h"
24
25 #include "mkldnn_common.hpp"
26 #include "mkldnn_debug.hpp"
27 #include "conv/conv.hpp"
28
29 namespace conv {
30
31 alg_t str2alg(const char *str) {
32 #define CASE(_alg) if (!strcasecmp(STRINGIFY(_alg), str)) return _alg
33     CASE(AUTO);
34     CASE(DIRECT);
35     CASE(WINO);
36 #undef CASE
37     assert(!"unknown algorithm");
38     return DIRECT;
39 }
40
41 const char *alg2str(alg_t alg) {
42     if (alg == AUTO) return "auto";
43     if (alg == DIRECT) return "direct";
44     if (alg == WINO) return "wino";
45     assert(!"unknown algorithm");
46     return "unknown algorithm";
47 }
48
49 alg_t alg_kind2alg(mkldnn_alg_kind_t alg) {
50     if (alg == mkldnn_convolution_auto) return AUTO;
51     if (alg == mkldnn_convolution_direct) return DIRECT;
52     if (alg == mkldnn_convolution_winograd) return WINO;
53     assert(!"unknown algorithm");
54     return DIRECT;
55 }
56
57 int str2desc(desc_t *desc, const char *str, bool is_deconv) {
58     desc_t d{0};
59
60     /* canonical form:
61      * dYgXmbXicXihXiwXocXohXowXkhXkwXshXswXphXpwXdhXdwXnS
62      *
63      * where: Y = {fb, fd, bd, bw, bb}, X is number, S - string
64      * note: symbol `_` is ignored
65      *
66      * implicit rules:
67      *  - default values:
68      *      mb = 2, g = 1, d = fd, sh = sw = 1, dh = dw = 0, S="wip"
69      *  - if H is undefined => H = W
70      *  - if W is undefined => W = H
71      *  - if `output` is undefined => compute output
72      *  - if padding is undefined => compute trivial padding
73      */
74
75     d.g = 1; d.mb = 2; d.sd = d.sh = d.sw = 1; d.dd = d.dh = d.dw = 0;
76     d.has_groups = false, d.name = "\"wip\"";
77     d.pw = -1; d.ph = -1; d.pd = -1;
78
79     const char *s = str;
80     assert(s);
81
82 #   define CASE_NN(p, c) do { \
83         if (!strncmp(p, s, strlen(p))) { \
84             ok = 1; s += strlen(p); \
85             char *end_s; d. c = strtol(s, &end_s, 10); s += (end_s - s); \
86             if (!strncmp(p, "g", 1)) d.has_groups = true; \
87             /* printf("@@@debug: %s: %d\n", p, d. c); */ \
88         } \
89     } while (0)
90 #   define CASE_N(c) CASE_NN(#c, c)
91     while (*s) {
92         int ok = 0;
93         CASE_N(g); CASE_N(mb);
94         CASE_N(ic); CASE_N(id); CASE_N(ih); CASE_N(iw);
95         CASE_N(oc); CASE_N(od); CASE_N(oh); CASE_N(ow);
96         CASE_N(kd); CASE_N(kh); CASE_N(kw);
97         CASE_N(sd); CASE_N(sh); CASE_N(sw);
98         CASE_N(pd); CASE_N(ph); CASE_N(pw);
99         CASE_N(dd); CASE_N(dh); CASE_N(dw);
100         if (*s == 'n') { d.name = s + 1; break; }
101         if (*s == '_') ++s;
102         if (!ok) return FAIL;
103     }
104 #   undef CASE_NN
105 #   undef CASE_N
106
107     if (d.ic == 0 || d.oc == 0) return FAIL;
108     if (d.sd <= 0 || d.sh <= 0 || d.sw <= 0) return FAIL;
109
110     auto compute_out = [](bool is_deconv, int i, int k, int s, int p, int d) {
111         if (is_deconv)
112             return (i - 1) * s + (k - 1) * (d + 1) + 2 * p + 1;
113         else
114             return (i - ((k - 1) * (d + 1) + 1) + 2 * p) / s + 1;
115     };
116     auto compute_pad = [](bool is_deconv, int o, int i, int k, int s, int d) {
117         if (is_deconv)
118             return ((i - 1) * s - o + ((k - 1) * (d + 1) + 1)) / 2;
119         else
120             return ((o - 1) * s - i + ((k - 1) * (d + 1) + 1)) / 2;
121     };
122
123     const bool no_d = (d.id | d.kd | d.od | d.dd) == 0 && d.sd == 1 && d.pd < 1;
124     const bool no_h = (d.ih | d.kh | d.oh | d.dh) == 0 && d.sh == 1 && d.ph < 1;
125     const bool no_w = (d.iw | d.kw | d.ow | d.dw) == 0 && d.sw == 1 && d.pw < 1;
126     if (!no_h) {
127         if (!d.ih || !d.kh) return FAIL;
128         if (!d.oh) {
129             d.ph = 0;
130             d.oh = compute_out(is_deconv, d.ih, d.kh, d.sh, d.ph, d.dh);
131         } else if (d.ph < 0)
132             d.ph = compute_pad(is_deconv, d.oh, d.ih, d.kh, d.sh, d.dh);
133     }
134
135     if (!no_w) {
136         if (!d.iw || !d.kw) return FAIL;
137         if (!d.ow) {
138             d.pw = 0;
139             d.ow = compute_out(is_deconv, d.iw, d.kw, d.sw, d.pw, d.dw);
140         } else if (d.pw < 0)
141             d.pw = compute_pad(is_deconv, d.ow, d.iw, d.kw, d.sw, d.dw);
142     }
143
144     if (!no_d && d.id) {
145         if (!d.id || !d.kd) return FAIL;
146         if (!d.od) {
147             d.pd = 0;
148             d.od = compute_out(is_deconv, d.id, d.kd, d.sd, d.pd, d.dd);
149         } else if (d.pd < 0)
150             d.pd = compute_pad(is_deconv, d.od, d.id, d.kd, d.sd, d.dd);
151     }
152     if (no_w && no_h && d.id) {
153         d.iw = d.ih = d.id;
154         d.kw = d.kh = d.kd;
155         d.ow = d.oh = d.od;
156         d.pw = d.ph = d.pd;
157         d.sw = d.sh = d.sd;
158         d.dw = d.dh = d.dd;
159     } else if (no_w) {
160         d.iw = d.ih;
161         d.kw = d.kh;
162         d.ow = d.oh;
163         d.pw = d.ph;
164         d.sw = d.sh;
165         d.dw = d.dh;
166     } else if (no_h) {
167         d.ih = 1;
168         d.kh = 1;
169         d.oh = 1;
170         d.ph = 0;
171         d.sh = 1;
172         d.dh = 0;
173     }
174     if (d.id<1) {d.id = 1; d.kd = 1; d.od = 1; d.sd = 1; d.pd = 0; d.dd = 0;}
175
176     *desc = d;
177
178     return OK;
179 }
180
181 void desc2str(const desc_t *d, char *buffer, bool canonical) {
182     int rem_len = max_desc_len;
183 #   define DPRINT(...) do { \
184         int l = snprintf(buffer, rem_len, __VA_ARGS__); \
185         buffer += l; rem_len -= l; \
186     } while(0)
187
188     if (canonical || d->has_groups) DPRINT("g%d", d->g);
189     if (canonical || d->mb != 2) DPRINT("mb%d", d->mb);
190
191     const bool half_form = (d->ih == d->iw && d->kh == d->kw && d->oh == d->ow
192         && d->sh == d->sw && d->ph == d->pw && d->dh == d->dw) && d->id == 1;
193
194     if (!canonical && half_form) {
195         DPRINT("ic%dih%doc%doh%dkh%d", d->ic, d->ih, d->oc, d->oh, d->kh);
196         if (d->sh != 1) DPRINT("sh%d", d->sh);
197         if (d->ph != 0) DPRINT("ph%d", d->ph);
198         if (d->dh != 0) DPRINT("dh%d", d->dh);
199     } else {
200         if( d->id == 1 )
201         {
202             DPRINT("ic%dih%diw%doc%doh%dow%dkh%dkw%d",
203                 d->ic, d->ih, d->iw, d->oc, d->oh, d->ow, d->kh, d->kw);
204             if (canonical || d->sh != 1 || d->sw != 1)
205                 DPRINT("sh%dsw%d", d->sh, d->sw);
206             if (canonical || d->ph != 0 || d->pw != 0)
207                 DPRINT("ph%dpw%d", d->ph, d->pw);
208             if (canonical || d->dh != 0 || d->dw != 0)
209                 DPRINT("dh%ddw%d", d->dh, d->dw);
210         } else {
211             DPRINT("ic%did%dih%diw%doc%dod%doh%dow%dkd%dkh%dkw%d",
212                 d->ic, d->id, d->ih, d->iw, d->oc, d->od, d->oh, d->ow,
213                 d->kd, d->kh, d->kw);
214             if (canonical || d->sh != 1 || d->sw != 1 || d->sd != 1)
215                 DPRINT("sd%dsh%dsw%d", d->sd, d->sh, d->sw);
216             if (canonical || d->ph != 0 || d->pw != 0 || d->pd != 0)
217                 DPRINT("pd%dph%dpw%d", d->pd, d->ph, d->pw);
218             if (canonical || d->dh != 0 || d->dw != 0 || d->dd != 0)
219                 DPRINT("dd%ddh%ddw%d", d->dd, d->dh, d->dw);
220         }
221     }
222
223     DPRINT("n%s", d->name);
224
225 #   undef DPRINT
226 }
227
228 void prb_t::count_ops() {
229     if (ops > 0) return;
230
231     int od_t = is_deconv ? this->id : this->od;
232     int oh_t = is_deconv ? this->ih : this->oh;
233     int ow_t = is_deconv ? this->iw : this->ow;
234     int id_t = is_deconv ? this->od : this->id;
235     int ih_t = is_deconv ? this->oh : this->ih;
236     int iw_t = is_deconv ? this->ow : this->iw;
237     double sp_ops = 0;
238     for (int od = 0; od < od_t; ++od) {
239     for (int oh = 0; oh < oh_t; ++oh) {
240     for (int ow = 0; ow < ow_t; ++ow) {
241         for (int kd = 0; kd < this->kd; ++kd) {
242             const int id = od * this->sd - this->pd + kd * (this->dd + 1);
243             if (id < 0 || id >= id_t) continue;
244             for (int kh = 0; kh < this->kh; ++kh) {
245                 const int ih = oh * this->sh - this->ph + kh * (this->dh + 1);
246                 if (ih < 0 || ih >= ih_t) continue;
247                 for (int kw = 0; kw < this->kw; ++kw) {
248                     const int iw = ow * this->sw - this->pw + kw * (this->dw + 1);
249                     if (iw < 0 || iw >= iw_t) continue;
250                     sp_ops += 1;
251                 }
252             }
253         }
254     }
255     }
256     }
257
258     ops = 2 * this->mb * this->oc * this->ic / this->g * sp_ops;
259 }
260
261 void prb_t::generate_oscales() {
262     if (attr.oscale.policy != attr_t::scale_t::policy_t::PER_OC) return;
263
264     scales = (float *)zmalloc(sizeof(float) * oc, 64);
265     SAFE_V(scales != NULL ? OK : FAIL);
266
267     const float K = 32;
268     /* scale in [1/K .. K], with starting point at oscale.scale */
269     float s[2] = {attr.oscale.scale, attr.oscale.scale/2};
270     for (int i = 0; i < oc; ++i) {
271         int si = i % 2; // 0 -> left, 1 -> right
272         scales[i] = s[si];
273         if (si == 0) {
274             s[si] /= 2.;
275             if (s[si] < 1./K) s[si] *= K*K; // turn around to become ~K
276         } else {
277             s[si] *= 2.;
278             if (s[si] > K) s[si] /= K*K; // turn around to become ~K
279         }
280     }
281 }
282
283 void prb2str(const prb_t *p, char *buffer, bool canonical) {
284     char desc_buf[max_desc_len], attr_buf[max_attr_len];
285     char dir_str[32] = {0}, cfg_str[32] = {0}, alg_str[32] = {0};
286     desc2str(p, desc_buf, canonical);
287     snprintf(dir_str, sizeof(dir_str), "--dir=%s ", dir2str(p->dir));
288     snprintf(cfg_str, sizeof(cfg_str), "--cfg=%s ", cfg2str(p->cfg));
289     snprintf(alg_str, sizeof(alg_str), "--alg=%s ", alg2str(p->alg));
290     bool is_attr_def = p->attr.is_def();
291     if (!is_attr_def) {
292         int len = snprintf(attr_buf, max_attr_len, "--attr=\"");
293         SAFE_V(len >= 0 ? OK : FAIL);
294         attr2str(&p->attr, attr_buf + len);
295         len = (int)strnlen(attr_buf, max_attr_len);
296         snprintf(attr_buf + len, max_attr_len - len, "\" ");
297     }
298     snprintf(buffer, max_prb_len, "%s%s%s%s%s",
299             p->dir == FWD_B ? "" : dir_str,
300             p->cfg == conf_f32 ? "" : cfg_str,
301             p->alg == DIRECT ? "" : alg_str,
302             is_attr_def ? "" : attr_buf,
303             desc_buf);
304 }
305
306 }