1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
26 #include "dnn_types.hpp"
27 #include "mkldnn_common.hpp"
28 #include "mkldnn_debug.hpp"
30 dir_t str2dir(const char *str) {
31 #define CASE(x) if (!strcasecmp(STRINGIFY(x), str)) return x
40 assert(!"unknown dir");
44 const char *dir2str(dir_t dir) {
45 #define CASE(x) if (dir == x) return STRINGIFY(x)
54 assert(!"unknown dir");
58 const char *data_kind2str(data_kind_t kind) {
60 case SRC: return "SRC";
61 case WEI: return "WEI";
62 case BIA: return "BIA";
63 case DST: return "DST";
64 case ACC: return "ACC";
65 case DATA: return "DATA";
66 case MEAN: return "MEAN";
67 case VAR: return "VAR";
69 case GWEI: return "GWEI";
71 assert(!"incorrect data kind");
72 return "incorrect data kind";
75 data_kind_t fmt2data_kind(mkldnn_memory_format_t fmt) {
98 case mkldnn_gOIw16i16o:
99 case mkldnn_gOIw16o16i:
102 case mkldnn_gOIw8i16o2i:
106 case mkldnn_hwigo_s8s8:
107 case mkldnn_gOIhw8i8o:
108 case mkldnn_gOIhw16i16o:
109 case mkldnn_gOIhw4i16o4i:
110 case mkldnn_gOIhw4i16o4i_s8s8:
111 case mkldnn_gOIhw8i16o2i:
112 case mkldnn_gOIdhw8i16o2i:
113 case mkldnn_gOIhw8o16i2o:
114 case mkldnn_gOIhw8o8i:
115 case mkldnn_gOIhw16o16i:
116 case mkldnn_gIOhw16o16i:
118 case mkldnn_gOihw16o:
120 case mkldnn_gOhwi16o:
122 case mkldnn_Goihw16g:
123 case mkldnn_Goihw16g_s8s8:
124 case mkldnn_gOhIw16o4i:
126 case mkldnn_gOIdhw16i16o:
127 case mkldnn_gOIdhw16o16i:
128 case mkldnn_gOidhw16o:
129 case mkldnn_gOdhwi16o:
136 attr_t::scale_t::policy_t attr_t::scale_t::str2policy(const char *str) {
137 #define CASE(_plc) if (!strcasecmp(STRINGIFY(_plc), str)) return _plc
142 assert(!"unknown attr::scale::policy");
146 const char *attr_t::scale_t::policy2str(attr_t::scale_t::policy_t policy) {
147 if (policy == NONE) return "none";
148 if (policy == COMMON) return "common";
149 if (policy == PER_OC) return "per_oc";
150 assert(!"unknown attr::scale::policy");
151 return "unknown attr::scale::policy";
154 int attr_t::scale_t::str2scale(const char *str, const char **end_s) {
155 *this = attr_t::scale_t();
157 if (str == NULL) return FAIL;
160 const char * &s = end_s ? *end_s : s_;
163 for (policy_t p = NONE; true; p = (policy_t)((int)p + 1)) {
164 if (p == POLICY_TOTAL) return FAIL;
166 const char *ps = policy2str(p);
167 if (!strncasecmp(ps, s, strlen(ps))) {
174 if (*s != ':') return OK;
178 this->scale = strtof(s, &end);
179 if (this->scale < 0 || end == s) return FAIL;
182 assert(*s == '\0' || *s == ';');
187 void attr_t::scale_t::scale2str(char *buffer, char **end_b) const {
189 buffer += sprintf(buffer, "%s:%g", policy2str(this->policy), this->scale);
190 if (end_b) *end_b = buffer;
193 attr_t::post_ops_t::kind_t attr_t::post_ops_t::str2kind(const char *str) {
194 #define CASE(_knd) if (!strcasecmp(STRINGIFY(_knd), str)) return _knd
207 assert(!"unknown attr::post_ops::kind");
211 const char *attr_t::post_ops_t::kind2str(attr_t::post_ops_t::kind_t kind) {
212 #define CASE(_knd, str) if (kind == _knd) return str
217 CASE(SQUARE, "square");
220 CASE(LINEAR, "linear");
221 CASE(BRELU, "brelu");
222 CASE(SRELU, "srelu");
223 CASE(LOGISTIC, "logistic");
225 assert(!"unknown attr::post_ops::kind");
226 return "unknown attr::post_ops::kind";
229 mkldnn_alg_kind_t attr_t::post_ops_t::kind2mkldnn_kind(
230 attr_t::post_ops_t::kind_t kind) {
231 #define CASE(_knd, _mknd) if (kind == _knd) return _mknd
232 CASE(RELU, mkldnn_eltwise_relu);
233 CASE(TANH, mkldnn_eltwise_tanh);
234 CASE(ELU, mkldnn_eltwise_elu);
235 CASE(SQUARE, mkldnn_eltwise_square);
236 CASE(ABS, mkldnn_eltwise_abs);
237 CASE(SQRT, mkldnn_eltwise_sqrt);
238 CASE(LINEAR, mkldnn_eltwise_linear);
239 CASE(BRELU, mkldnn_eltwise_bounded_relu);
240 CASE(SRELU, mkldnn_eltwise_soft_relu);
241 CASE(LOGISTIC, mkldnn_eltwise_logistic);
243 assert(!"unknown attr::post_ops::kind");
244 return mkldnn_alg_kind_undef;
247 int attr_t::post_ops_t::from_str(const char *str, const char **end_s) {
248 *this = post_ops_t();
250 if (str == NULL || *str != '\'') return FAIL;
253 const char * &s = end_s ? *end_s : s_;
258 if (*s == '\'') { ++s; return OK; }
259 if (len == capacity) return FAIL;
261 for (kind_t k = SUM; true; k = (kind_t)((int)k + 1)) {
262 if (k == KIND_TOTAL) return FAIL;
264 const char *ks = kind2str(k);
265 if (!strncasecmp(ks, s, strlen(ks))) {
266 auto &e = entry[len];
273 e.sum.scale = strtof(++s, &end);
274 if (e.sum.scale <= 0 || end == s) return FAIL;
280 e.eltwise.alg = kind2mkldnn_kind(k);
281 e.eltwise.scale = 1.f;
282 e.eltwise.alpha = e.eltwise.beta = 0.f;
284 for (int i = 0; i < 3; ++i) {
286 float &val = i == 0 ? e.eltwise.alpha
287 : i == 1 ? e.eltwise.beta : e.eltwise.scale;
290 val = strtof(++s, &end);
291 if (end == s) return FAIL;
298 if (e.eltwise.scale <= 0) return FAIL;
309 return FAIL; /* unreachable */
312 void attr_t::post_ops_t::to_str(char *buffer, char **end_b) const {
315 buffer += sprintf(buffer, "'");
316 for (int idx = 0; idx < len; ++idx) {
317 buffer += sprintf(buffer, "%s", idx > 0 ? ";" : "");
318 const auto &e = entry[idx];
322 buffer += sprintf(buffer, "%s:%g", kind2str(e.kind), e.sum.scale);
334 buffer += sprintf(buffer, "%s:%g", kind2str(e.kind), e.eltwise.alpha);
335 if (e.eltwise.beta != 0.f || e.eltwise.scale != 1.f)
336 buffer += sprintf(buffer, ":%g:%g", e.eltwise.beta, e.eltwise.scale);
339 assert(!"unknown kind");
340 buffer += sprintf(buffer, "unknown_kind");
343 buffer += sprintf(buffer, "'");
344 if (end_b) *end_b = buffer;
347 bool attr_t::is_def() const {
349 && irmode == round_mode_t::NEAREST
351 && post_ops.is_def();
354 int str2attr(attr_t *attr, const char *str) {
355 if (attr == NULL || str == NULL) return FAIL;
365 if (!strncasecmp(param, s, strlen(param))) {
367 attr->irmode = (attr_t::round_mode_t)str2rmode(s);
368 s += strlen(rmode2str((mkldnn_round_mode_t)attr->irmode));
373 if (!strncasecmp(param, s, strlen(param))) {
375 rc = attr->oscale.str2scale(s, &s);
376 if (rc != OK) return rc;
380 if (!strncasecmp(param, s, strlen(param))) {
382 rc = attr->post_ops.from_str(s, &s);
383 if (rc != OK) return rc;
386 if (rc != OK) return FAIL;
393 void attr2str(const attr_t *attr, char *buffer) {
394 buffer += sprintf(buffer, "irmode=%s",
395 rmode2str((mkldnn_round_mode_t)attr->irmode));
396 buffer += sprintf(buffer, ";oscale=");
397 attr->oscale.scale2str(buffer, &buffer);
398 buffer += sprintf(buffer, ";post_ops=");
399 attr->post_ops.to_str(buffer, &buffer);
402 mkldnn_primitive_attr_t create_mkldnn_attr(const attr_t &attr, int scale_cnt,
403 int scale_mask, const float *scales) {
404 mkldnn_primitive_attr_t mkldnn_attr = NULL;
405 DNN_SAFE_V(mkldnn_primitive_attr_create(&mkldnn_attr));
407 if (attr.irmode != attr_t::round_mode_t::NEAREST)
408 DNN_SAFE_V(mkldnn_primitive_attr_set_int_output_round_mode(mkldnn_attr,
409 (mkldnn_round_mode_t)attr.irmode));
411 if (!attr.oscale.is_def()) {
412 using P = attr_t::scale_t::policy_t;
413 int count = attr.oscale.policy == P::COMMON ? 1 : scale_cnt;
414 if (scale_mask == -1)
415 scale_mask = attr.oscale.policy == P::PER_OC ? 1 << 1 : 0;
417 float *gen_scs = NULL;
418 if (scales == NULL) {
419 gen_scs = (float *)zmalloc(count * sizeof(float), 64);
420 SAFE_V(gen_scs != NULL ? OK : FAIL);
421 for (int i = 0; i < count; ++i)
422 gen_scs[i] = attr.oscale.scale;
426 DNN_SAFE_V(mkldnn_primitive_attr_set_output_scales(mkldnn_attr, count,
427 scale_mask, scales));
433 if (!attr.post_ops.is_def()) {
434 mkldnn_post_ops_t ops;
435 DNN_SAFE_V(mkldnn_post_ops_create(&ops));
436 for (int idx = 0; idx < attr.post_ops.len; ++idx) {
437 const auto &e = attr.post_ops.entry[idx];
438 switch (attr.post_ops.entry[idx].kind) {
439 case attr_t::post_ops_t::SUM:
440 DNN_SAFE_V(mkldnn_post_ops_append_sum(ops, e.sum.scale));
442 case attr_t::post_ops_t::RELU:
443 case attr_t::post_ops_t::TANH:
444 case attr_t::post_ops_t::ELU:
445 case attr_t::post_ops_t::SQUARE:
446 case attr_t::post_ops_t::ABS:
447 case attr_t::post_ops_t::SQRT:
448 case attr_t::post_ops_t::LINEAR:
449 case attr_t::post_ops_t::BRELU:
450 case attr_t::post_ops_t::SRELU:
451 case attr_t::post_ops_t::LOGISTIC:
452 DNN_SAFE_V(mkldnn_post_ops_append_eltwise(ops, e.eltwise.scale,
453 e.eltwise.alg, e.eltwise.alpha, e.eltwise.beta));
456 assert(!"unknown attr::post_ops::kind");
459 DNN_SAFE_V(mkldnn_primitive_attr_set_post_ops(mkldnn_attr, ops));
461 const_mkldnn_post_ops_t c_ops;
462 DNN_SAFE_V(mkldnn_primitive_attr_get_post_ops(mkldnn_attr, &c_ops));
463 SAFE_V(mkldnn_post_ops_len(c_ops) == attr.post_ops.len ? OK : FAIL);
465 DNN_SAFE_V(mkldnn_post_ops_destroy(ops));
471 mkldnn_memory_format_t get_default_format(int ndims, data_kind_t kind) {
473 case DATA: return (ndims == 5)
478 case GWEI: return (ndims == 6)
483 case WEI: return (ndims == 5)
489 assert(!"unknown kind");
491 return mkldnn_format_undef;