Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / primitive_desc.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef PRIMITIVE_DESC_HPP
18 #define PRIMITIVE_DESC_HPP
19
20 #include "mkldnn.h"
21
22 #include "c_types_map.hpp"
23 #include "memory_tracking.hpp"
24 #include "nstl.hpp"
25 #include "type_helpers.hpp"
26 #include "primitive_attr.hpp"
27 #include "verbose.hpp"
28
29 struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible {
30     using memory_pd_t = mkldnn::impl::memory_pd_t;
31
32     mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
33             const mkldnn::impl::primitive_attr_t *attr,
34             mkldnn::impl::primitive_kind_t kind)
35         : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; }
36
37     mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
38             mkldnn::impl::primitive_kind_t kind)
39         : engine_(engine), kind_(kind) { info_[0] = '\0'; }
40
41     virtual mkldnn_primitive_desc *clone() const = 0;
42     virtual ~mkldnn_primitive_desc() {}
43
44     const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; }
45     mkldnn::impl::engine_t *engine() const { return engine_; }
46     mkldnn::impl::primitive_kind_t kind() const { return kind_; }
47
48     virtual void init_info() {}
49     const char *info() const { return info_; }
50
51     mkldnn::impl::memory_tracking::registry_t &scratchpad_registry()
52     { return scratchpad_registry_; }
53     const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const
54     { return scratchpad_registry_; }
55
56     virtual const mkldnn::impl::op_desc_t *op_desc() const = 0;
57
58 #   define DECLARE_PD_STUB(stub) \
59     virtual const memory_pd_t *stub(int idx = 0) const { return nullptr; }
60
61     DECLARE_PD_STUB(input_pd); DECLARE_PD_STUB(output_pd);
62     DECLARE_PD_STUB(src_pd); DECLARE_PD_STUB(diff_src_pd);
63     DECLARE_PD_STUB(dst_pd); DECLARE_PD_STUB(diff_dst_pd);
64     DECLARE_PD_STUB(weights_pd); DECLARE_PD_STUB(diff_weights_pd);
65     DECLARE_PD_STUB(workspace_pd);
66 #   undef DECLARE_PD_STUB
67
68     virtual int n_inputs() const { return 0; }
69     virtual int n_outputs() const { return 0; }
70
71     virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx,
72             void *result) const;
73
74     virtual mkldnn::impl::status_t create_primitive(
75             mkldnn::impl::primitive_t **primitive,
76             const mkldnn::impl::primitive_at_t *inputs,
77             const mkldnn::impl::primitive_t **outputs) const = 0;
78
79     virtual const char *name() const { return "mkldnn_primitive_desc"; }
80
81     /* static magic */
82
83     template<typename pd_t>
84     static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd,
85             const mkldnn::impl::op_desc_t *adesc,
86             const mkldnn::impl::primitive_attr_t *attr,
87             mkldnn::impl::engine_t *engine,
88             const mkldnn::impl::primitive_desc_t *hint_fwd) {
89         using namespace mkldnn::impl;
90         using namespace mkldnn::impl::status;
91         using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type;
92         if (adesc->kind != pd_t::base_pkind) return invalid_arguments;
93         assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true);
94         auto hint =
95             reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd);
96         auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint);
97         if (_pd == nullptr) return out_of_memory;
98         if (_pd->init() != success) { delete _pd; return unimplemented; }
99         _pd->init_info();
100         *pd = _pd;
101         return success;
102     }
103
104 protected:
105     mkldnn::impl::engine_t *engine_;
106     mkldnn::impl::primitive_attr_t attr_;
107     mkldnn::impl::primitive_kind_t kind_;
108
109     char info_[MKLDNN_VERBOSE_BUF_LEN];
110
111     mkldnn::impl::memory_tracking::registry_t scratchpad_registry_;
112 };
113
114 #define DECLARE_COMMON_PD_t(impl_name, ...) \
115     virtual pd_t *clone() const override { return new pd_t(*this); } \
116     virtual status_t create_primitive(primitive_t **primitive, \
117             const primitive_at_t *inputs, \
118             const primitive_t **outputs) const override { \
119         double ms = get_msec(); \
120         primitive_t::input_vector ins(inputs, inputs + this->n_inputs()); \
121         primitive_t::output_vector outs(outputs, outputs + this->n_outputs()); \
122         auto ret = safe_ptr_assign<primitive_t>(*primitive, \
123                 new (__VA_ARGS__)(this, ins, outs)); \
124         ms = get_msec() - ms; \
125         if (mkldnn_verbose()->level >= 2) { \
126             printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
127             fflush(0); \
128         } \
129         return ret; \
130     } \
131     virtual const char *name() const override { return impl_name; }
132 #define DECLARE_COMMON_PD_T(impl_name, ...) \
133     DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__)
134
135 #endif
136
137 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s