1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "type_helpers.hpp"
19 #include "mkldnn_thread.hpp"
20 #include "mkldnn_traits.hpp"
21 #include "math_utils.hpp"
23 #include "ref_inner_product.hpp"
32 template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type,
34 void ref_inner_product_fwd_t<src_type, wei_type, dst_type, acc_type>
35 ::execute_forward() const {
36 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
37 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
38 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
39 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
41 const memory_desc_wrapper src_d(pd()->src_pd());
42 const memory_desc_wrapper dst_d(pd()->dst_pd());
43 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
44 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
46 const int MB = pd()->MB();
47 const int OC = pd()->OC();
48 const int IC = pd()->IC();
50 const bool src_has_spatial = utils::one_of(src_d.ndims(), 4, 5);
52 const bool is_3d = src_d.ndims() == 5;
54 const auto &post_ops = pd()->attr()->post_ops_;
55 const bool do_relu = post_ops.len_ == 1;
56 const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f;
58 auto ker_has_spatial = [=](int mb, int oc) {
60 const int KD = pd()->KD();
61 const int KH = pd()->KH();
62 const int KW = pd()->KW();
63 for (int ic = 0; ic < IC; ++ic) {
64 for (int kd = 0; kd < KD; ++kd) {
65 for (int kh = 0; kh < KH; ++kh) {
66 for (int kw = 0; kw < KW; ++kw) {
68 d += (acc_data_t)src[src_d.off(mb, ic, kd, kh, kw)]
69 * weights[weights_d.off(oc, ic, kd, kh, kw)];
71 d += (acc_data_t)src[src_d.off(mb, ic, kh, kw)]
72 * weights[weights_d.off(oc, ic, kh, kw)];
80 auto ker_no_spatial = [=](int mb, int oc) {
82 for (int ic = 0; ic < IC; ++ic) {
83 d += (acc_data_t)src[src_d.off(mb, ic)]
84 * weights[weights_d.off(oc, ic)];
89 parallel_nd(MB, OC, [&](int mb, int oc) {
91 ? get_bias(bias, bias_d.off(oc), pd()->desc()->bias_desc.data_type)
94 a += ker_has_spatial(mb, oc);
96 a += ker_no_spatial(mb, oc);
97 if (do_relu && a < (acc_data_t)0)
99 dst[dst_d.off(mb, oc)] = saturate<dst_data_t>(a);
102 using namespace data_type;
103 template struct ref_inner_product_fwd_t<f32>;
104 template struct ref_inner_product_fwd_t<s16, s16, s32, s32>;
105 template struct ref_inner_product_fwd_t<u8, s8, f32, s32>;
106 template struct ref_inner_product_fwd_t<u8, s8, s32, s32>;
107 template struct ref_inner_product_fwd_t<u8, s8, s8, s32>;
108 template struct ref_inner_product_fwd_t<u8, s8, u8, s32>;
110 template <data_type_t diff_src_type, data_type_t wei_type,
111 data_type_t diff_dst_type, data_type_t acc_type>
112 void ref_inner_product_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
113 acc_type>::execute_backward_data() const {
114 auto diff_dst = reinterpret_cast<const diff_dst_data_t *>(
115 this->input_memory(0));
116 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
117 auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
119 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
120 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
121 const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
123 const int MB = pd()->MB();
124 const int OC = pd()->OC();
125 const int IC = pd()->IC();
127 const bool diff_src_has_spatial = utils::one_of(diff_src_d.ndims(), 4, 5);
129 const bool is_3d = diff_src_d.ndims() == 5;
131 parallel_nd(MB, IC, [&](int mb, int ic) {
132 if (diff_src_has_spatial) {
133 const int KD = pd()->KD();
134 const int KH = pd()->KH();
135 const int KW = pd()->KW();
136 for (int kd = 0; kd < KD; ++kd)
137 for (int kh = 0; kh < KH; ++kh)
138 for (int kw = 0; kw < KW; ++kw) {
139 acc_data_t ds = acc_data_t(0);
140 for (int oc = 0; oc < OC; ++oc) {
142 ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)]
143 * weights[weights_d.off(oc, ic, kd, kh, kw)]);
145 ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)]
146 * weights[weights_d.off(oc, ic, kh, kw)]);
148 if (is_3d) diff_src[diff_src_d.off(mb, ic, kd, kh, kw)] =
150 else diff_src[diff_src_d.off(mb, ic, kh, kw)] =
154 acc_data_t ds = acc_data_t(0);
155 for (int oc = 0; oc < OC; ++oc) {
156 ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] *
157 weights[weights_d.off(oc, ic)]);
159 diff_src[diff_src_d.off(mb, ic)] = (diff_src_data_t)ds;
164 template struct ref_inner_product_bwd_data_t<f32, f32, f32, f32>;
165 template struct ref_inner_product_bwd_data_t<s32, s16, s16, s32>;
167 template <impl::data_type_t data_type>
168 void ref_inner_product_bwd_weights_t<data_type>::execute_backward_weights() const {
169 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
170 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
171 auto diff_weights = reinterpret_cast<data_t*>(this->memory(0));
172 auto diff_bias = reinterpret_cast<data_t*>(this->memory(1));
174 const memory_desc_wrapper src_d(pd()->src_pd());
175 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
176 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
177 const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
179 const int MB = pd()->MB();
180 const int OC = pd()->OC();
181 const int IC = pd()->IC();
183 const bool src_has_spatial = utils::one_of(src_d.ndims(), 4 ,5);
185 const bool is_3d = src_d.ndims() == 5;
187 parallel_nd(OC, IC, [&](int oc, int ic) {
188 if (src_has_spatial) {
189 const int KD = pd()->KD();
190 const int KH = pd()->KH();
191 const int KW = pd()->KW();
192 for (int kd = 0; kd < KD; ++kd) {
193 for (int kh = 0; kh < KH; ++kh) {
194 for (int kw = 0; kw < KW; ++kw) {
197 diff_weights_d.off(oc, ic, kd, kh, kw)]
199 diff_weights_d.off(oc, ic, kh, kw)];
201 for (int mb = 0; mb < MB; ++mb) {
203 *dw += diff_dst[diff_dst_d.off(mb, oc)] *
204 src[src_d.off(mb, ic, kd, kh, kw)];
206 *dw += diff_dst[diff_dst_d.off(mb, oc)] *
207 src[src_d.off(mb, ic, kh, kw)];
213 data_t *dw = &diff_weights[diff_weights_d.off(oc, ic)];
215 for (int mb = 0; mb < MB; ++mb) {
216 *dw += diff_dst[diff_dst_d.off(mb, oc)] *
217 src[src_d.off(mb, ic)];
223 diff_bias += diff_bias_d.blocking_desc().offset_padding;
225 parallel_nd(OC, [&](int oc) {
226 data_t *db = &diff_bias[oc];
228 for (int mb = 0; mb < MB; ++mb)
229 *db += diff_dst[diff_dst_d.off(mb, oc)];
234 template struct ref_inner_product_bwd_weights_t<data_type::f32>;
240 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s