Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_inner_product.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_REF_INNER_PRODUCT_HPP
18 #define CPU_REF_INNER_PRODUCT_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "cpu_inner_product_pd.hpp"
24 #include "cpu_engine.hpp"
25 #include "type_helpers.hpp"
26 #include "utils.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 template <impl::data_type_t src_type, impl::data_type_t wei_type = src_type,
33          impl::data_type_t dst_type = src_type,
34          impl::data_type_t acc_type = dst_type>
35 struct ref_inner_product_fwd_t: public cpu_primitive_t {
36     struct pd_t: public cpu_inner_product_fwd_pd_t {
37         pd_t(engine_t *engine, const inner_product_desc_t *adesc,
38                 const primitive_attr_t *attr,
39                 const inner_product_fwd_pd_t *hint_fwd_pd)
40             : cpu_inner_product_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
41
42         DECLARE_COMMON_PD_T("ref:any", ref_inner_product_fwd_t);
43
44         virtual status_t init() override {
45             using namespace prop_kind;
46             using namespace data_type;
47             assert(engine()->kind() == engine_kind::cpu);
48             bool ok = true
49                 && this->set_default_params() == status::success
50                 && utils::one_of(desc()->prop_kind, forward_training,
51                         forward_inference)
52                 && desc()->src_desc.data_type == src_type
53                 && desc()->weights_desc.data_type == wei_type
54                 && desc()->accum_data_type == acc_type
55                 && desc()->dst_desc.data_type == dst_type
56                 && IMPLICATION(with_bias(),
57                             utils::one_of(desc()->bias_desc.data_type,
58                                 f32, s32, s8, u8))
59                 && attr()->output_scales_.has_default_values()
60                 && attr()->post_ops_.len_ <= 1
61                 && IMPLICATION(attr()->post_ops_.len_ == 1,
62                         attr()->post_ops_.entry_[0].is_relu(true, false));
63             return ok ? status::success : status::unimplemented;
64         }
65     };
66
67     ref_inner_product_fwd_t(const pd_t *apd, const input_vector &inputs,
68             const output_vector &outputs)
69         : cpu_primitive_t(apd, inputs, outputs) {}
70
71     typedef typename prec_traits<src_type>::type src_data_t;
72     typedef typename prec_traits<wei_type>::type wei_data_t;
73     typedef typename prec_traits<dst_type>::type dst_data_t;
74     typedef typename prec_traits<acc_type>::type acc_data_t;
75
76     virtual void execute(event_t *e) const {
77         switch (pd()->desc()->prop_kind) {
78         case prop_kind::forward_training:
79         case prop_kind::forward_inference:
80             execute_forward();
81             break;
82         default:
83             assert(!"invalid prop_kind");
84         }
85         e->set_state(event_t::ready);
86     }
87
88 private:
89     void execute_forward() const;
90     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
91 };
92
93 template <impl::data_type_t diff_src_type, impl::data_type_t wei_type,
94          impl::data_type_t diff_dst_type,
95          impl::data_type_t acc_type = diff_src_type>
96 struct ref_inner_product_bwd_data_t: public cpu_primitive_t {
97     struct pd_t: public cpu_inner_product_bwd_data_pd_t {
98         pd_t(engine_t *engine, const inner_product_desc_t *adesc,
99                 const primitive_attr_t *attr,
100                 const inner_product_fwd_pd_t *hint_fwd_pd)
101             : cpu_inner_product_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
102         {}
103
104         DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_data_t);
105
106         virtual status_t init() override {
107             using namespace prop_kind;
108             assert(engine()->kind() == engine_kind::cpu);
109             bool ok = true
110                 && this->set_default_params() == status::success
111                 && utils::one_of(this->desc()->prop_kind, backward,
112                         backward_data)
113                 && desc()->diff_src_desc.data_type == diff_src_type
114                 && desc()->weights_desc.data_type == wei_type
115                 && desc()->accum_data_type == acc_type
116                 && desc()->diff_dst_desc.data_type == diff_dst_type
117                 && attr()->has_default_values();
118             return ok ? status::success : status::unimplemented;
119         }
120     };
121
122     ref_inner_product_bwd_data_t(const pd_t *apd, const input_vector &inputs,
123             const output_vector &outputs)
124         : cpu_primitive_t(apd, inputs, outputs) {}
125
126     typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
127     typedef typename prec_traits<wei_type>::type wei_data_t;
128     typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
129     typedef typename prec_traits<acc_type>::type acc_data_t;
130
131     virtual void execute(event_t *e) const {
132         switch (pd()->desc()->prop_kind) {
133         case prop_kind::backward:
134         case prop_kind::backward_data:
135             execute_backward_data();
136             break;
137         default:
138             assert(!"invalid prop_kind");
139         }
140         e->set_state(event_t::ready);
141     }
142
143 private:
144     void execute_backward_data() const;
145     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
146 };
147
148 template <impl::data_type_t data_type>
149 struct ref_inner_product_bwd_weights_t: public cpu_primitive_t {
150     struct pd_t: public cpu_inner_product_bwd_weights_pd_t {
151         pd_t(engine_t *engine, const inner_product_desc_t *adesc,
152                 const primitive_attr_t *attr,
153                 const inner_product_fwd_pd_t *hint_fwd_pd)
154             : cpu_inner_product_bwd_weights_pd_t(engine, adesc, attr,
155                     hint_fwd_pd) {}
156
157         DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_weights_t);
158
159         virtual status_t init() override {
160             using namespace prop_kind;
161             assert(engine()->kind() == engine_kind::cpu);
162             bool ok = true
163                 && this->set_default_params() == status::success
164                 && utils::one_of(this->desc()->prop_kind, backward,
165                         backward_weights)
166                 && utils::everyone_is(data_type,
167                         this->desc()->src_desc.data_type,
168                         this->desc()->diff_dst_desc.data_type,
169                         this->desc()->diff_weights_desc.data_type)
170                 && IMPLICATION(this->with_bias(),
171                         data_type == this->desc()->diff_bias_desc.data_type)
172                 && attr()->has_default_values();
173             return ok ? status::success : status::unimplemented;
174         }
175     };
176
177     ref_inner_product_bwd_weights_t(const pd_t *apd, const input_vector &inputs,
178             const output_vector &outputs)
179         : cpu_primitive_t(apd, inputs, outputs) {}
180     typedef typename prec_traits<data_type>::type data_t;
181
182     virtual void execute(event_t *e) const {
183         switch (pd()->desc()->prop_kind) {
184         case prop_kind::backward:
185         case prop_kind::backward_weights:
186             execute_backward_weights();
187             break;
188         default:
189             assert(!"invalid prop_kind");
190         }
191         e->set_state(event_t::ready);
192     }
193
194 private:
195     void execute_backward_weights() const;
196     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
197 };
198
199 }
200 }
201 }
202
203 #endif
204
205 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s