1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
25 #include "mkldnn_common.hpp"
26 #include "mkldnn_debug.hpp"
27 #include "conv/conv.hpp"
31 alg_t str2alg(const char *str) {
32 #define CASE(_alg) if (!strcasecmp(STRINGIFY(_alg), str)) return _alg
37 assert(!"unknown algorithm");
41 const char *alg2str(alg_t alg) {
42 if (alg == AUTO) return "auto";
43 if (alg == DIRECT) return "direct";
44 if (alg == WINO) return "wino";
45 assert(!"unknown algorithm");
46 return "unknown algorithm";
49 alg_t alg_kind2alg(mkldnn_alg_kind_t alg) {
50 if (alg == mkldnn_convolution_auto) return AUTO;
51 if (alg == mkldnn_convolution_direct) return DIRECT;
52 if (alg == mkldnn_convolution_winograd) return WINO;
53 assert(!"unknown algorithm");
57 int str2desc(desc_t *desc, const char *str, bool is_deconv) {
61 * dYgXmbXicXihXiwXocXohXowXkhXkwXshXswXphXpwXdhXdwXnS
63 * where: Y = {fb, fd, bd, bw, bb}, X is number, S - string
64 * note: symbol `_` is ignored
68 * mb = 2, g = 1, d = fd, sh = sw = 1, dh = dw = 0, S="wip"
69 * - if H is undefined => H = W
70 * - if W is undefined => W = H
71 * - if `output` is undefined => compute output
72 * - if padding is undefined => compute trivial padding
75 d.g = 1; d.mb = 2; d.sd = d.sh = d.sw = 1; d.dd = d.dh = d.dw = 0;
76 d.has_groups = false, d.name = "\"wip\"";
77 d.pw = -1; d.ph = -1; d.pd = -1;
82 # define CASE_NN(p, c) do { \
83 if (!strncmp(p, s, strlen(p))) { \
84 ok = 1; s += strlen(p); \
85 char *end_s; d. c = strtol(s, &end_s, 10); s += (end_s - s); \
86 if (!strncmp(p, "g", 1)) d.has_groups = true; \
87 /* printf("@@@debug: %s: %d\n", p, d. c); */ \
90 # define CASE_N(c) CASE_NN(#c, c)
93 CASE_N(g); CASE_N(mb);
94 CASE_N(ic); CASE_N(id); CASE_N(ih); CASE_N(iw);
95 CASE_N(oc); CASE_N(od); CASE_N(oh); CASE_N(ow);
96 CASE_N(kd); CASE_N(kh); CASE_N(kw);
97 CASE_N(sd); CASE_N(sh); CASE_N(sw);
98 CASE_N(pd); CASE_N(ph); CASE_N(pw);
99 CASE_N(dd); CASE_N(dh); CASE_N(dw);
100 if (*s == 'n') { d.name = s + 1; break; }
102 if (!ok) return FAIL;
107 if (d.ic == 0 || d.oc == 0) return FAIL;
108 if (d.sd <= 0 || d.sh <= 0 || d.sw <= 0) return FAIL;
110 auto compute_out = [](bool is_deconv, int i, int k, int s, int p, int d) {
112 return (i - 1) * s + (k - 1) * (d + 1) + 2 * p + 1;
114 return (i - ((k - 1) * (d + 1) + 1) + 2 * p) / s + 1;
116 auto compute_pad = [](bool is_deconv, int o, int i, int k, int s, int d) {
118 return ((i - 1) * s - o + ((k - 1) * (d + 1) + 1)) / 2;
120 return ((o - 1) * s - i + ((k - 1) * (d + 1) + 1)) / 2;
123 const bool no_d = (d.id | d.kd | d.od | d.dd) == 0 && d.sd == 1 && d.pd < 1;
124 const bool no_h = (d.ih | d.kh | d.oh | d.dh) == 0 && d.sh == 1 && d.ph < 1;
125 const bool no_w = (d.iw | d.kw | d.ow | d.dw) == 0 && d.sw == 1 && d.pw < 1;
127 if (!d.ih || !d.kh) return FAIL;
130 d.oh = compute_out(is_deconv, d.ih, d.kh, d.sh, d.ph, d.dh);
132 d.ph = compute_pad(is_deconv, d.oh, d.ih, d.kh, d.sh, d.dh);
136 if (!d.iw || !d.kw) return FAIL;
139 d.ow = compute_out(is_deconv, d.iw, d.kw, d.sw, d.pw, d.dw);
141 d.pw = compute_pad(is_deconv, d.ow, d.iw, d.kw, d.sw, d.dw);
145 if (!d.id || !d.kd) return FAIL;
148 d.od = compute_out(is_deconv, d.id, d.kd, d.sd, d.pd, d.dd);
150 d.pd = compute_pad(is_deconv, d.od, d.id, d.kd, d.sd, d.dd);
152 if (no_w && no_h && d.id) {
174 if (d.id<1) {d.id = 1; d.kd = 1; d.od = 1; d.sd = 1; d.pd = 0; d.dd = 0;}
181 void desc2str(const desc_t *d, char *buffer, bool canonical) {
182 int rem_len = max_desc_len;
183 # define DPRINT(...) do { \
184 int l = snprintf(buffer, rem_len, __VA_ARGS__); \
185 buffer += l; rem_len -= l; \
188 if (canonical || d->has_groups) DPRINT("g%d", d->g);
189 if (canonical || d->mb != 2) DPRINT("mb%d", d->mb);
191 const bool half_form = (d->ih == d->iw && d->kh == d->kw && d->oh == d->ow
192 && d->sh == d->sw && d->ph == d->pw && d->dh == d->dw) && d->id == 1;
194 if (!canonical && half_form) {
195 DPRINT("ic%dih%doc%doh%dkh%d", d->ic, d->ih, d->oc, d->oh, d->kh);
196 if (d->sh != 1) DPRINT("sh%d", d->sh);
197 if (d->ph != 0) DPRINT("ph%d", d->ph);
198 if (d->dh != 0) DPRINT("dh%d", d->dh);
202 DPRINT("ic%dih%diw%doc%doh%dow%dkh%dkw%d",
203 d->ic, d->ih, d->iw, d->oc, d->oh, d->ow, d->kh, d->kw);
204 if (canonical || d->sh != 1 || d->sw != 1)
205 DPRINT("sh%dsw%d", d->sh, d->sw);
206 if (canonical || d->ph != 0 || d->pw != 0)
207 DPRINT("ph%dpw%d", d->ph, d->pw);
208 if (canonical || d->dh != 0 || d->dw != 0)
209 DPRINT("dh%ddw%d", d->dh, d->dw);
211 DPRINT("ic%did%dih%diw%doc%dod%doh%dow%dkd%dkh%dkw%d",
212 d->ic, d->id, d->ih, d->iw, d->oc, d->od, d->oh, d->ow,
213 d->kd, d->kh, d->kw);
214 if (canonical || d->sh != 1 || d->sw != 1 || d->sd != 1)
215 DPRINT("sd%dsh%dsw%d", d->sd, d->sh, d->sw);
216 if (canonical || d->ph != 0 || d->pw != 0 || d->pd != 0)
217 DPRINT("pd%dph%dpw%d", d->pd, d->ph, d->pw);
218 if (canonical || d->dh != 0 || d->dw != 0 || d->dd != 0)
219 DPRINT("dd%ddh%ddw%d", d->dd, d->dh, d->dw);
223 DPRINT("n%s", d->name);
228 void prb_t::count_ops() {
231 int od_t = is_deconv ? this->id : this->od;
232 int oh_t = is_deconv ? this->ih : this->oh;
233 int ow_t = is_deconv ? this->iw : this->ow;
234 int id_t = is_deconv ? this->od : this->id;
235 int ih_t = is_deconv ? this->oh : this->ih;
236 int iw_t = is_deconv ? this->ow : this->iw;
238 for (int od = 0; od < od_t; ++od) {
239 for (int oh = 0; oh < oh_t; ++oh) {
240 for (int ow = 0; ow < ow_t; ++ow) {
241 for (int kd = 0; kd < this->kd; ++kd) {
242 const int id = od * this->sd - this->pd + kd * (this->dd + 1);
243 if (id < 0 || id >= id_t) continue;
244 for (int kh = 0; kh < this->kh; ++kh) {
245 const int ih = oh * this->sh - this->ph + kh * (this->dh + 1);
246 if (ih < 0 || ih >= ih_t) continue;
247 for (int kw = 0; kw < this->kw; ++kw) {
248 const int iw = ow * this->sw - this->pw + kw * (this->dw + 1);
249 if (iw < 0 || iw >= iw_t) continue;
258 ops = 2 * this->mb * this->oc * this->ic / this->g * sp_ops;
261 void prb_t::generate_oscales() {
262 if (attr.oscale.policy != attr_t::scale_t::policy_t::PER_OC) return;
264 scales = (float *)zmalloc(sizeof(float) * oc, 64);
265 SAFE_V(scales != NULL ? OK : FAIL);
268 /* scale in [1/K .. K], with starting point at oscale.scale */
269 float s[2] = {attr.oscale.scale, attr.oscale.scale/2};
270 for (int i = 0; i < oc; ++i) {
271 int si = i % 2; // 0 -> left, 1 -> right
275 if (s[si] < 1./K) s[si] *= K*K; // turn around to become ~K
278 if (s[si] > K) s[si] /= K*K; // turn around to become ~K
283 void prb2str(const prb_t *p, char *buffer, bool canonical) {
284 char desc_buf[max_desc_len], attr_buf[max_attr_len];
285 char dir_str[32] = {0}, cfg_str[32] = {0}, alg_str[32] = {0};
286 desc2str(p, desc_buf, canonical);
287 snprintf(dir_str, sizeof(dir_str), "--dir=%s ", dir2str(p->dir));
288 snprintf(cfg_str, sizeof(cfg_str), "--cfg=%s ", cfg2str(p->cfg));
289 snprintf(alg_str, sizeof(alg_str), "--alg=%s ", alg2str(p->alg));
290 bool is_attr_def = p->attr.is_def();
292 int len = snprintf(attr_buf, max_attr_len, "--attr=\"");
293 SAFE_V(len >= 0 ? OK : FAIL);
294 attr2str(&p->attr, attr_buf + len);
295 len = (int)strnlen(attr_buf, max_attr_len);
296 snprintf(attr_buf + len, max_attr_len - len, "\" ");
298 snprintf(buffer, max_prb_len, "%s%s%s%s%s",
299 p->dir == FWD_B ? "" : dir_str,
300 p->cfg == conf_f32 ? "" : cfg_str,
301 p->alg == DIRECT ? "" : alg_str,
302 is_attr_def ? "" : attr_buf,