Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / dnn_types.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef _DNN_TYPES_HPP
18 #define _DNN_TYPES_HPP
19
20 #include <stdlib.h>
21 #include <stddef.h>
22 #include <string.h>
23
24 #include "common.hpp"
25 #include "mkldnn_types.h"
26
27 enum dir_t {
28     DIR_UNDEF = 0,
29     FLAG_DAT = 1, FLAG_WEI = 2, FLAG_BIA = 4,
30     FLAG_FWD = 32, FLAG_BWD = 64,
31     FLAG_INF = 128,
32     FWD_D = FLAG_FWD + FLAG_DAT,
33     FWD_I = FLAG_FWD + FLAG_DAT + FLAG_INF,
34     FWD_B = FLAG_FWD + FLAG_DAT + FLAG_BIA,
35     BWD_D = FLAG_BWD + FLAG_DAT,
36     BWD_DW = FLAG_BWD + FLAG_DAT + FLAG_WEI,
37     BWD_W = FLAG_BWD + FLAG_WEI,
38     BWD_WB = FLAG_BWD + FLAG_WEI + FLAG_BIA,
39 };
40 dir_t str2dir(const char *str);
41 const char *dir2str(dir_t dir);
42
43 typedef int data_kind_t;
44 enum {
45     SRC = 0, WEI, BIA, DST, ACC,
46     DATA, MEAN, VAR, SS,
47     GWEI,
48     DAT_TOTAL };
49 const char *data_kind2str(data_kind_t kind);
50 data_kind_t fmt2data_kind(mkldnn_memory_format_t fmt);
51
52 struct attr_t {
53     enum round_mode_t {
54         NEAREST = (int)mkldnn_round_nearest,
55         DOWN = (int)mkldnn_round_down,
56     };
57
58     struct scale_t {
59         enum policy_t { NONE = 0, COMMON, PER_OC, POLICY_TOTAL };
60         static policy_t str2policy(const char *str);
61         static const char *policy2str(policy_t policy);
62
63         int str2scale(const char *str, const char **end_s);
64         void scale2str(char *buffer, char **end_b) const;
65
66         bool is_def() const { return this->policy == NONE; }
67
68         policy_t policy = NONE;
69         float scale = 1.;
70     };
71
72     struct post_ops_t {
73         enum kind_t { SUM, RELU, TANH, ELU, SQUARE, ABS, SQRT, LINEAR, BRELU,
74             SRELU, LOGISTIC, KIND_TOTAL };
75         static kind_t str2kind(const char *str);
76         static const char *kind2str(kind_t kind);
77         static mkldnn_alg_kind_t kind2mkldnn_kind(kind_t kind);
78
79         struct entry_t {
80             kind_t kind;
81             union {
82                 struct { float scale; } sum;
83                 struct {
84                     mkldnn_alg_kind_t alg;
85                     float scale, alpha, beta;
86                 } eltwise;
87             };
88         };
89
90         post_ops_t(): len(0) {}
91
92         int from_str(const char *str, const char **end_s);
93         void to_str(char *buffer, char **end_b) const;
94
95         bool is_def() const { return len == 0; }
96
97         enum { capacity = 4 };
98         int len;
99         entry_t entry[4];
100     };
101
102     round_mode_t irmode = NEAREST;
103     scale_t oscale;
104     post_ops_t post_ops;
105
106     bool is_def() const;
107 };
108
109 const size_t max_attr_len = 128;
110 int str2attr(attr_t *attr, const char *str);
111 void attr2str(const attr_t *attr, char *buffer);
112
113 mkldnn_memory_format_t get_default_format(int ndims, data_kind_t kind);
114 mkldnn_primitive_attr_t create_mkldnn_attr(const attr_t &attr, int scale_cnt,
115         int scale_mask, const float *scales);
116 inline mkldnn_primitive_attr_t create_mkldnn_attr(const attr_t &attr,
117         int scale_cnt, const float *scales)
118 { return create_mkldnn_attr(attr, scale_cnt, -1, scales); }
119
120 #endif