Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / dnn_types.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <stdlib.h>
19 #include <stddef.h>
20 #include <string.h>
21 #include <math.h>
22
23 #include "mkldnn.h"
24
25 #include "common.hpp"
26 #include "dnn_types.hpp"
27 #include "mkldnn_common.hpp"
28 #include "mkldnn_debug.hpp"
29
30 dir_t str2dir(const char *str) {
31 #define CASE(x) if (!strcasecmp(STRINGIFY(x), str)) return x
32     CASE(FWD_D);
33     CASE(FWD_I);
34     CASE(FWD_B);
35     CASE(BWD_D);
36     CASE(BWD_W);
37     CASE(BWD_WB);
38     CASE(BWD_DW);
39 #undef CASE
40     assert(!"unknown dir");
41     return DIR_UNDEF;
42 }
43
44 const char *dir2str(dir_t dir) {
45 #define CASE(x) if (dir == x) return STRINGIFY(x)
46     CASE(FWD_D);
47     CASE(FWD_I);
48     CASE(FWD_B);
49     CASE(BWD_D);
50     CASE(BWD_DW);
51     CASE(BWD_W);
52     CASE(BWD_WB);
53 #undef CASE
54     assert(!"unknown dir");
55     return "DIR_UNDEF";
56 }
57
58 const char *data_kind2str(data_kind_t kind) {
59     switch (kind) {
60     case SRC: return "SRC";
61     case WEI: return "WEI";
62     case BIA: return "BIA";
63     case DST: return "DST";
64     case ACC: return "ACC";
65     case DATA: return "DATA";
66     case MEAN: return "MEAN";
67     case VAR: return "VAR";
68     case SS: return "SS";
69     case GWEI: return "GWEI";
70     }
71     assert(!"incorrect data kind");
72     return "incorrect data kind";
73 }
74
75 data_kind_t fmt2data_kind(mkldnn_memory_format_t fmt) {
76     switch (fmt) {
77     case mkldnn_x:
78     case mkldnn_nc:
79     case mkldnn_tnc:
80     case mkldnn_ntc:
81
82     case mkldnn_ncw:
83     case mkldnn_nwc:
84     case mkldnn_nCw16c:
85
86     case mkldnn_nchw:
87     case mkldnn_nhwc:
88     case mkldnn_chwn:
89     case mkldnn_nChw8c:
90     case mkldnn_nChw16c:
91
92     case mkldnn_ncdhw:
93     case mkldnn_ndhwc:
94     case mkldnn_nCdhw16c:
95         return DATA;
96
97     case mkldnn_goiw:
98     case mkldnn_gOIw16i16o:
99     case mkldnn_gOIw16o16i:
100     case mkldnn_gOiw16o:
101     case mkldnn_gOwi16o:
102     case mkldnn_gOIw8i16o2i:
103     case mkldnn_goihw:
104     case mkldnn_hwigo:
105     case mkldnn_giohw:
106     case mkldnn_hwigo_s8s8:
107     case mkldnn_gOIhw8i8o:
108     case mkldnn_gOIhw16i16o:
109     case mkldnn_gOIhw4i16o4i:
110     case mkldnn_gOIhw4i16o4i_s8s8:
111     case mkldnn_gOIhw8i16o2i:
112     case mkldnn_gOIdhw8i16o2i:
113     case mkldnn_gOIhw8o16i2o:
114     case mkldnn_gOIhw8o8i:
115     case mkldnn_gOIhw16o16i:
116     case mkldnn_gIOhw16o16i:
117     case mkldnn_gOihw8o:
118     case mkldnn_gOihw16o:
119     case mkldnn_gOhwi8o:
120     case mkldnn_gOhwi16o:
121     case mkldnn_Goihw8g:
122     case mkldnn_Goihw16g:
123     case mkldnn_Goihw16g_s8s8:
124     case mkldnn_gOhIw16o4i:
125     case mkldnn_goidhw:
126     case mkldnn_gOIdhw16i16o:
127     case mkldnn_gOIdhw16o16i:
128     case mkldnn_gOidhw16o:
129     case mkldnn_gOdhwi16o:
130         return GWEI;
131
132     default: return WEI;
133     }
134 }
135
136 attr_t::scale_t::policy_t attr_t::scale_t::str2policy(const char *str) {
137 #define CASE(_plc) if (!strcasecmp(STRINGIFY(_plc), str)) return _plc
138     CASE(NONE);
139     CASE(COMMON);
140     CASE(PER_OC);
141 #undef CASE
142     assert(!"unknown attr::scale::policy");
143     return NONE;
144 }
145
146 const char *attr_t::scale_t::policy2str(attr_t::scale_t::policy_t policy) {
147     if (policy == NONE) return "none";
148     if (policy == COMMON) return "common";
149     if (policy == PER_OC) return "per_oc";
150     assert(!"unknown attr::scale::policy");
151     return "unknown attr::scale::policy";
152 }
153
154 int attr_t::scale_t::str2scale(const char *str, const char **end_s) {
155     *this = attr_t::scale_t();
156
157     if (str == NULL) return FAIL;
158
159     const char *s_;
160     const char * &s = end_s ? *end_s : s_;
161     s = str;
162
163     for (policy_t p = NONE; true; p = (policy_t)((int)p + 1)) {
164         if (p == POLICY_TOTAL) return FAIL;
165
166         const char *ps = policy2str(p);
167         if (!strncasecmp(ps, s, strlen(ps))) {
168             this->policy = p;
169             s += strlen(ps);
170             break;
171         }
172     }
173
174     if (*s != ':') return OK;
175     s++;
176
177     char *end;
178     this->scale = strtof(s, &end);
179     if (this->scale < 0 || end == s) return FAIL;
180
181     s = end;
182     assert(*s == '\0' || *s == ';');
183
184     return OK;
185 }
186
187 void attr_t::scale_t::scale2str(char *buffer, char **end_b) const {
188     assert(buffer);
189     buffer += sprintf(buffer, "%s:%g", policy2str(this->policy), this->scale);
190     if (end_b) *end_b = buffer;
191 }
192
193 attr_t::post_ops_t::kind_t attr_t::post_ops_t::str2kind(const char *str) {
194 #define CASE(_knd) if (!strcasecmp(STRINGIFY(_knd), str)) return _knd
195     CASE(SUM);
196     CASE(RELU);
197     CASE(TANH);
198     CASE(ELU);
199     CASE(SQUARE);
200     CASE(ABS);
201     CASE(SQRT);
202     CASE(LINEAR);
203     CASE(BRELU);
204     CASE(SRELU);
205     CASE(LOGISTIC);
206 #undef CASE
207     assert(!"unknown attr::post_ops::kind");
208     return KIND_TOTAL;
209 }
210
211 const char *attr_t::post_ops_t::kind2str(attr_t::post_ops_t::kind_t kind) {
212 #define CASE(_knd, str) if (kind == _knd) return str
213     CASE(SUM, "sum");
214     CASE(RELU, "relu");
215     CASE(TANH, "tanh");
216     CASE(ELU, "elu");
217     CASE(SQUARE, "square");
218     CASE(ABS, "abs");
219     CASE(SQRT, "sqrt");
220     CASE(LINEAR, "linear");
221     CASE(BRELU, "brelu");
222     CASE(SRELU, "srelu");
223     CASE(LOGISTIC, "logistic");
224 #undef CASE
225     assert(!"unknown attr::post_ops::kind");
226     return "unknown attr::post_ops::kind";
227 }
228
229 mkldnn_alg_kind_t attr_t::post_ops_t::kind2mkldnn_kind(
230         attr_t::post_ops_t::kind_t kind) {
231 #define CASE(_knd, _mknd) if (kind == _knd) return _mknd
232     CASE(RELU, mkldnn_eltwise_relu);
233     CASE(TANH, mkldnn_eltwise_tanh);
234     CASE(ELU, mkldnn_eltwise_elu);
235     CASE(SQUARE, mkldnn_eltwise_square);
236     CASE(ABS, mkldnn_eltwise_abs);
237     CASE(SQRT, mkldnn_eltwise_sqrt);
238     CASE(LINEAR, mkldnn_eltwise_linear);
239     CASE(BRELU, mkldnn_eltwise_bounded_relu);
240     CASE(SRELU, mkldnn_eltwise_soft_relu);
241     CASE(LOGISTIC, mkldnn_eltwise_logistic);
242 #undef CASE
243     assert(!"unknown attr::post_ops::kind");
244     return mkldnn_alg_kind_undef;
245 }
246
247 int attr_t::post_ops_t::from_str(const char *str, const char **end_s) {
248     *this = post_ops_t();
249
250     if (str == NULL || *str != '\'') return FAIL;
251
252     const char *s_;
253     const char * &s = end_s ? *end_s : s_;
254     s = str;
255
256     ++s;
257     for (;;) {
258         if (*s == '\'') { ++s; return OK; }
259         if (len == capacity) return FAIL;
260
261         for (kind_t k = SUM; true; k = (kind_t)((int)k + 1)) {
262             if (k == KIND_TOTAL) return FAIL;
263
264             const char *ks = kind2str(k);
265             if (!strncasecmp(ks, s, strlen(ks))) {
266                 auto &e = entry[len];
267
268                 e.kind = k;
269                 s += strlen(ks);
270                 if (k == SUM) {
271                     if (*s == ':') {
272                         char *end;
273                         e.sum.scale = strtof(++s, &end);
274                         if (e.sum.scale <= 0 || end == s) return FAIL;
275                         s = end;
276                     } else {
277                         e.sum.scale = 1.f;
278                     }
279                 } else {
280                     e.eltwise.alg = kind2mkldnn_kind(k);
281                     e.eltwise.scale = 1.f;
282                     e.eltwise.alpha = e.eltwise.beta = 0.f;
283
284                     for (int i = 0; i < 3; ++i) {
285                         // :alpha:beta:scale
286                         float &val = i == 0 ? e.eltwise.alpha
287                             : i == 1 ? e.eltwise.beta : e.eltwise.scale;
288                         if (*s == ':') {
289                             char *end;
290                             val = strtof(++s, &end);
291                             if (end == s) return FAIL;
292                             s = end;
293                         } else {
294                             break;
295                         }
296                     }
297
298                     if (e.eltwise.scale <= 0) return FAIL;
299                 }
300
301                 break;
302             }
303         }
304         ++len;
305
306         if (*s == ';') ++s;
307     }
308
309     return FAIL; /* unreachable */
310 }
311
312 void attr_t::post_ops_t::to_str(char *buffer, char **end_b) const {
313     assert(buffer);
314
315     buffer += sprintf(buffer, "'");
316     for (int idx = 0; idx < len; ++idx) {
317         buffer += sprintf(buffer, "%s", idx > 0 ? ";" : "");
318         const auto &e = entry[idx];
319
320         switch (e.kind) {
321         case SUM:
322             buffer += sprintf(buffer, "%s:%g", kind2str(e.kind), e.sum.scale);
323             break;
324         case RELU:
325         case TANH:
326         case ELU:
327         case SQUARE:
328         case ABS:
329         case SQRT:
330         case LINEAR:
331         case BRELU:
332         case SRELU:
333         case LOGISTIC:
334             buffer += sprintf(buffer, "%s:%g", kind2str(e.kind), e.eltwise.alpha);
335             if (e.eltwise.beta != 0.f || e.eltwise.scale != 1.f)
336                 buffer += sprintf(buffer, ":%g:%g", e.eltwise.beta, e.eltwise.scale);
337             break;
338         default:
339             assert(!"unknown kind");
340             buffer += sprintf(buffer, "unknown_kind");
341         }
342     }
343     buffer += sprintf(buffer, "'");
344     if (end_b) *end_b = buffer;
345 }
346
347 bool attr_t::is_def() const {
348     return true
349         && irmode == round_mode_t::NEAREST
350         && oscale.is_def()
351         && post_ops.is_def();
352 }
353
354 int str2attr(attr_t *attr, const char *str) {
355     if (attr == NULL || str == NULL) return FAIL;
356     *attr = attr_t();
357
358     const char *s = str;
359
360     while (*s != '\0') {
361         int rc = FAIL;
362         const char *param;
363
364         param = "irmode=";
365         if (!strncasecmp(param, s, strlen(param))) {
366             s += strlen(param);
367             attr->irmode = (attr_t::round_mode_t)str2rmode(s);
368             s += strlen(rmode2str((mkldnn_round_mode_t)attr->irmode));
369             rc = OK;
370         }
371
372         param = "oscale=";
373         if (!strncasecmp(param, s, strlen(param))) {
374             s += strlen(param);
375             rc = attr->oscale.str2scale(s, &s);
376             if (rc != OK) return rc;
377         }
378
379         param = "post_ops=";
380         if (!strncasecmp(param, s, strlen(param))) {
381             s += strlen(param);
382             rc = attr->post_ops.from_str(s, &s);
383             if (rc != OK) return rc;
384         }
385
386         if (rc != OK) return FAIL;
387         if (*s == ';') ++s;
388     }
389
390     return OK;
391 }
392
393 void attr2str(const attr_t *attr, char *buffer) {
394     buffer += sprintf(buffer, "irmode=%s",
395             rmode2str((mkldnn_round_mode_t)attr->irmode));
396     buffer += sprintf(buffer, ";oscale=");
397     attr->oscale.scale2str(buffer, &buffer);
398     buffer += sprintf(buffer, ";post_ops=");
399     attr->post_ops.to_str(buffer, &buffer);
400 }
401
402 mkldnn_primitive_attr_t create_mkldnn_attr(const attr_t &attr, int scale_cnt,
403         int scale_mask, const float *scales) {
404     mkldnn_primitive_attr_t mkldnn_attr = NULL;
405     DNN_SAFE_V(mkldnn_primitive_attr_create(&mkldnn_attr));
406
407     if (attr.irmode != attr_t::round_mode_t::NEAREST)
408         DNN_SAFE_V(mkldnn_primitive_attr_set_int_output_round_mode(mkldnn_attr,
409                     (mkldnn_round_mode_t)attr.irmode));
410
411     if (!attr.oscale.is_def()) {
412         using P = attr_t::scale_t::policy_t;
413         int count = attr.oscale.policy == P::COMMON ? 1 : scale_cnt;
414         if (scale_mask == -1)
415             scale_mask = attr.oscale.policy == P::PER_OC ? 1 << 1 : 0;
416
417         float *gen_scs = NULL;
418         if (scales == NULL) {
419             gen_scs = (float *)zmalloc(count * sizeof(float), 64);
420             SAFE_V(gen_scs != NULL ? OK : FAIL);
421             for (int i = 0; i < count; ++i)
422                 gen_scs[i] = attr.oscale.scale;
423             scales = gen_scs;
424         }
425
426         DNN_SAFE_V(mkldnn_primitive_attr_set_output_scales(mkldnn_attr, count,
427                     scale_mask, scales));
428
429         if (gen_scs)
430             zfree(gen_scs);
431     }
432
433     if (!attr.post_ops.is_def()) {
434         mkldnn_post_ops_t ops;
435         DNN_SAFE_V(mkldnn_post_ops_create(&ops));
436         for (int idx = 0; idx < attr.post_ops.len; ++idx) {
437             const auto &e = attr.post_ops.entry[idx];
438             switch (attr.post_ops.entry[idx].kind) {
439             case attr_t::post_ops_t::SUM:
440                 DNN_SAFE_V(mkldnn_post_ops_append_sum(ops, e.sum.scale));
441                 break;
442             case attr_t::post_ops_t::RELU:
443             case attr_t::post_ops_t::TANH:
444             case attr_t::post_ops_t::ELU:
445             case attr_t::post_ops_t::SQUARE:
446             case attr_t::post_ops_t::ABS:
447             case attr_t::post_ops_t::SQRT:
448             case attr_t::post_ops_t::LINEAR:
449             case attr_t::post_ops_t::BRELU:
450             case attr_t::post_ops_t::SRELU:
451             case attr_t::post_ops_t::LOGISTIC:
452                 DNN_SAFE_V(mkldnn_post_ops_append_eltwise(ops, e.eltwise.scale,
453                             e.eltwise.alg, e.eltwise.alpha, e.eltwise.beta));
454                 break;
455             default:
456                 assert(!"unknown attr::post_ops::kind");
457             }
458         }
459         DNN_SAFE_V(mkldnn_primitive_attr_set_post_ops(mkldnn_attr, ops));
460
461         const_mkldnn_post_ops_t c_ops;
462         DNN_SAFE_V(mkldnn_primitive_attr_get_post_ops(mkldnn_attr, &c_ops));
463         SAFE_V(mkldnn_post_ops_len(c_ops) == attr.post_ops.len ? OK : FAIL);
464
465         DNN_SAFE_V(mkldnn_post_ops_destroy(ops));
466     }
467
468     return mkldnn_attr;
469 }
470
471 mkldnn_memory_format_t get_default_format(int ndims, data_kind_t kind) {
472     switch(kind) {
473     case DATA: return (ndims == 5)
474         ? mkldnn_ncdhw
475         : (ndims == 4)
476         ? mkldnn_nchw
477         : mkldnn_ncw;
478     case GWEI: return (ndims == 6)
479         ? mkldnn_goidhw
480         : (ndims == 5)
481         ? mkldnn_goihw
482         : mkldnn_goiw;
483     case WEI: return (ndims == 5)
484         ? mkldnn_oidhw
485         : (ndims == 4)
486         ? mkldnn_oihw
487         : mkldnn_oiw;
488     default:
489         assert(!"unknown kind");
490     }
491     return mkldnn_format_undef;
492 }