Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_inner_product.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_REF_INNER_PRODUCT_HPP
18 #define CPU_REF_INNER_PRODUCT_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "cpu_inner_product_pd.hpp"
24 #include "cpu_engine.hpp"
25 #include "type_helpers.hpp"
26 #include "utils.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 template <impl::data_type_t src_type, impl::data_type_t wei_type = src_type,
33          impl::data_type_t dst_type = src_type,
34          impl::data_type_t acc_type = dst_type>
35 struct ref_inner_product_fwd_t: public cpu_primitive_t {
36     struct pd_t: public cpu_inner_product_fwd_pd_t {
37         pd_t(engine_t *engine, const inner_product_desc_t *adesc,
38                 const primitive_attr_t *attr,
39                 const inner_product_fwd_pd_t *hint_fwd_pd)
40             : cpu_inner_product_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
41
42         DECLARE_COMMON_PD_T("ref:any", ref_inner_product_fwd_t);
43
44         virtual status_t init() override {
45             using namespace prop_kind;
46             assert(engine()->kind() == engine_kind::cpu);
47             bool ok = true
48                 && this->set_default_params() == status::success
49                 && utils::one_of(desc()->prop_kind, forward_training,
50                         forward_inference)
51                 && desc()->src_desc.data_type == src_type
52                 && desc()->weights_desc.data_type == wei_type
53                 && desc()->accum_data_type == acc_type
54                 && desc()->dst_desc.data_type == dst_type
55                 && utils::implication(this->with_bias(),
56                         desc()->bias_desc.data_type == dst_type)
57                 && attr()->has_default_values();
58             return ok ? status::success : status::unimplemented;
59         }
60     };
61
62     ref_inner_product_fwd_t(const pd_t *pd, const input_vector &inputs,
63             const output_vector &outputs)
64         : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {}
65
66     typedef typename prec_traits<src_type>::type src_data_t;
67     typedef typename prec_traits<wei_type>::type wei_data_t;
68     typedef typename prec_traits<dst_type>::type dst_data_t;
69     typedef typename prec_traits<acc_type>::type acc_data_t;
70
71     virtual void execute(event_t *e) {
72         switch (conf_.desc()->prop_kind) {
73         case prop_kind::forward_training:
74         case prop_kind::forward_inference:
75             execute_forward();
76             break;
77         default:
78             assert(!"invalid prop_kind");
79         }
80         e->set_state(event_t::ready);
81     }
82
83 private:
84     void execute_forward();
85     pd_t conf_;
86 };
87
88 template <impl::data_type_t diff_src_type, impl::data_type_t wei_type,
89          impl::data_type_t diff_dst_type,
90          impl::data_type_t acc_type = diff_src_type>
91 struct ref_inner_product_bwd_data_t: public cpu_primitive_t {
92     struct pd_t: public cpu_inner_product_bwd_data_pd_t {
93         pd_t(engine_t *engine, const inner_product_desc_t *adesc,
94                 const primitive_attr_t *attr,
95                 const inner_product_fwd_pd_t *hint_fwd_pd)
96             : cpu_inner_product_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
97         {}
98
99         DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_data_t);
100
101         virtual status_t init() override {
102             using namespace prop_kind;
103             assert(engine()->kind() == engine_kind::cpu);
104             bool ok = true
105                 && this->set_default_params() == status::success
106                 && utils::one_of(this->desc()->prop_kind, backward,
107                         backward_data)
108                 && desc()->diff_src_desc.data_type == diff_src_type
109                 && desc()->weights_desc.data_type == wei_type
110                 && desc()->accum_data_type == acc_type
111                 && desc()->diff_dst_desc.data_type == diff_dst_type
112                 && attr()->has_default_values();
113             return ok ? status::success : status::unimplemented;
114         }
115     };
116
117     ref_inner_product_bwd_data_t(const pd_t *pd, const input_vector &inputs,
118             const output_vector &outputs)
119         : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {}
120
121     typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
122     typedef typename prec_traits<wei_type>::type wei_data_t;
123     typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
124     typedef typename prec_traits<acc_type>::type acc_data_t;
125
126     virtual void execute(event_t *e) {
127         switch (conf_.desc()->prop_kind) {
128         case prop_kind::backward:
129         case prop_kind::backward_data:
130             execute_backward_data();
131             break;
132         default:
133             assert(!"invalid prop_kind");
134         }
135         e->set_state(event_t::ready);
136     }
137
138 private:
139     void execute_backward_data();
140     pd_t conf_;
141 };
142
143 template <impl::data_type_t data_type>
144 struct ref_inner_product_bwd_weights_t: public cpu_primitive_t {
145     struct pd_t: public cpu_inner_product_bwd_weights_pd_t {
146         pd_t(engine_t *engine, const inner_product_desc_t *adesc,
147                 const primitive_attr_t *attr,
148                 const inner_product_fwd_pd_t *hint_fwd_pd)
149             : cpu_inner_product_bwd_weights_pd_t(engine, adesc, attr,
150                     hint_fwd_pd) {}
151
152         DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_weights_t);
153
154         virtual status_t init() override {
155             using namespace prop_kind;
156             assert(engine()->kind() == engine_kind::cpu);
157             bool ok = true
158                 && this->set_default_params() == status::success
159                 && utils::one_of(this->desc()->prop_kind, backward,
160                         backward_weights)
161                 && utils::everyone_is(data_type,
162                         this->desc()->src_desc.data_type,
163                         this->desc()->diff_dst_desc.data_type,
164                         this->desc()->diff_weights_desc.data_type)
165                 && utils::implication(this->with_bias(),
166                         data_type == this->desc()->diff_bias_desc.data_type)
167                 && attr()->has_default_values();
168             return ok ? status::success : status::unimplemented;
169         }
170     };
171
172     ref_inner_product_bwd_weights_t(const pd_t *pd, const input_vector &inputs,
173             const output_vector &outputs)
174         : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {}
175     typedef typename prec_traits<data_type>::type data_t;
176
177     virtual void execute(event_t *e) {
178         switch (conf_.desc()->prop_kind) {
179         case prop_kind::backward:
180         case prop_kind::backward_weights:
181             execute_backward_weights();
182             break;
183         default:
184             assert(!"invalid prop_kind");
185         }
186         e->set_state(event_t::ready);
187     }
188
189 private:
190     void execute_backward_weights();
191     pd_t conf_;
192 };
193
194 }
195 }
196 }
197
198 #endif
199
200 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s