1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef _DNN_TYPES_HPP
18 #define _DNN_TYPES_HPP
25 #include "mkldnn_types.h"
29 FLAG_DAT = 1, FLAG_WEI = 2, FLAG_BIA = 4,
30 FLAG_FWD = 32, FLAG_BWD = 64,
32 FWD_D = FLAG_FWD + FLAG_DAT,
33 FWD_I = FLAG_FWD + FLAG_DAT + FLAG_INF,
34 FWD_B = FLAG_FWD + FLAG_DAT + FLAG_BIA,
35 BWD_D = FLAG_BWD + FLAG_DAT,
36 BWD_DW = FLAG_BWD + FLAG_DAT + FLAG_WEI,
37 BWD_W = FLAG_BWD + FLAG_WEI,
38 BWD_WB = FLAG_BWD + FLAG_WEI + FLAG_BIA,
40 dir_t str2dir(const char *str);
41 const char *dir2str(dir_t dir);
43 typedef int data_kind_t;
45 SRC = 0, WEI, BIA, DST, ACC,
49 const char *data_kind2str(data_kind_t kind);
50 data_kind_t fmt2data_kind(mkldnn_memory_format_t fmt);
54 NEAREST = (int)mkldnn_round_nearest,
55 DOWN = (int)mkldnn_round_down,
59 enum policy_t { NONE = 0, COMMON, PER_OC, POLICY_TOTAL };
60 static policy_t str2policy(const char *str);
61 static const char *policy2str(policy_t policy);
63 int str2scale(const char *str, const char **end_s);
64 void scale2str(char *buffer, char **end_b) const;
66 bool is_def() const { return this->policy == NONE; }
68 policy_t policy = NONE;
73 enum kind_t { SUM, RELU, TANH, ELU, SQUARE, ABS, SQRT, LINEAR, BRELU,
74 SRELU, LOGISTIC, KIND_TOTAL };
75 static kind_t str2kind(const char *str);
76 static const char *kind2str(kind_t kind);
77 static mkldnn_alg_kind_t kind2mkldnn_kind(kind_t kind);
82 struct { float scale; } sum;
84 mkldnn_alg_kind_t alg;
85 float scale, alpha, beta;
90 post_ops_t(): len(0) {}
92 int from_str(const char *str, const char **end_s);
93 void to_str(char *buffer, char **end_b) const;
95 bool is_def() const { return len == 0; }
97 enum { capacity = 4 };
102 round_mode_t irmode = NEAREST;
109 const size_t max_attr_len = 128;
110 int str2attr(attr_t *attr, const char *str);
111 void attr2str(const attr_t *attr, char *buffer);
113 mkldnn_memory_format_t get_default_format(int ndims, data_kind_t kind);
114 mkldnn_primitive_attr_t create_mkldnn_attr(const attr_t &attr, int scale_cnt,
115 int scale_mask, const float *scales);
116 inline mkldnn_primitive_attr_t create_mkldnn_attr(const attr_t &attr,
117 int scale_cnt, const float *scales)
118 { return create_mkldnn_attr(attr, scale_cnt, -1, scales); }