1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef PRIMITIVE_DESC_HPP
18 #define PRIMITIVE_DESC_HPP
22 #include "c_types_map.hpp"
23 #include "memory_tracking.hpp"
25 #include "type_helpers.hpp"
26 #include "primitive_attr.hpp"
27 #include "verbose.hpp"
29 struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible {
30 using memory_pd_t = mkldnn::impl::memory_pd_t;
32 mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
33 const mkldnn::impl::primitive_attr_t *attr,
34 mkldnn::impl::primitive_kind_t kind)
35 : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; }
37 mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
38 mkldnn::impl::primitive_kind_t kind)
39 : engine_(engine), kind_(kind) { info_[0] = '\0'; }
41 virtual mkldnn_primitive_desc *clone() const = 0;
42 virtual ~mkldnn_primitive_desc() {}
44 const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; }
45 mkldnn::impl::engine_t *engine() const { return engine_; }
46 mkldnn::impl::primitive_kind_t kind() const { return kind_; }
48 virtual void init_info() {}
49 const char *info() const { return info_; }
51 mkldnn::impl::memory_tracking::registry_t &scratchpad_registry()
52 { return scratchpad_registry_; }
53 const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const
54 { return scratchpad_registry_; }
56 virtual const mkldnn::impl::op_desc_t *op_desc() const = 0;
58 # define DECLARE_PD_STUB(stub) \
59 virtual const memory_pd_t *stub(int idx = 0) const { return nullptr; }
61 DECLARE_PD_STUB(input_pd); DECLARE_PD_STUB(output_pd);
62 DECLARE_PD_STUB(src_pd); DECLARE_PD_STUB(diff_src_pd);
63 DECLARE_PD_STUB(dst_pd); DECLARE_PD_STUB(diff_dst_pd);
64 DECLARE_PD_STUB(weights_pd); DECLARE_PD_STUB(diff_weights_pd);
65 DECLARE_PD_STUB(workspace_pd);
66 # undef DECLARE_PD_STUB
68 virtual int n_inputs() const { return 0; }
69 virtual int n_outputs() const { return 0; }
71 virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx,
74 virtual mkldnn::impl::status_t create_primitive(
75 mkldnn::impl::primitive_t **primitive,
76 const mkldnn::impl::primitive_at_t *inputs,
77 const mkldnn::impl::primitive_t **outputs) const = 0;
79 virtual const char *name() const { return "mkldnn_primitive_desc"; }
83 template<typename pd_t>
84 static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd,
85 const mkldnn::impl::op_desc_t *adesc,
86 const mkldnn::impl::primitive_attr_t *attr,
87 mkldnn::impl::engine_t *engine,
88 const mkldnn::impl::primitive_desc_t *hint_fwd) {
89 using namespace mkldnn::impl;
90 using namespace mkldnn::impl::status;
91 using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type;
92 if (adesc->kind != pd_t::base_pkind) return invalid_arguments;
93 assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true);
95 reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd);
96 auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint);
97 if (_pd == nullptr) return out_of_memory;
98 if (_pd->init() != success) { delete _pd; return unimplemented; }
105 mkldnn::impl::engine_t *engine_;
106 mkldnn::impl::primitive_attr_t attr_;
107 mkldnn::impl::primitive_kind_t kind_;
109 char info_[MKLDNN_VERBOSE_BUF_LEN];
111 mkldnn::impl::memory_tracking::registry_t scratchpad_registry_;
114 #define DECLARE_COMMON_PD_t(impl_name, ...) \
115 virtual pd_t *clone() const override { return new pd_t(*this); } \
116 virtual status_t create_primitive(primitive_t **primitive, \
117 const primitive_at_t *inputs, \
118 const primitive_t **outputs) const override { \
119 double ms = get_msec(); \
120 primitive_t::input_vector ins(inputs, inputs + this->n_inputs()); \
121 primitive_t::output_vector outs(outputs, outputs + this->n_outputs()); \
122 auto ret = safe_ptr_assign<primitive_t>(*primitive, \
123 new (__VA_ARGS__)(this, ins, outs)); \
124 ms = get_msec() - ms; \
125 if (mkldnn_verbose()->level >= 2) { \
126 printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
131 virtual const char *name() const override { return impl_name; }
132 #define DECLARE_COMMON_PD_T(impl_name, ...) \
133 DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__)
137 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s