1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "type_helpers.hpp"
19 #include "mkldnn_thread.hpp"
21 #include "gemm_inner_product.hpp"
27 using namespace mkldnn::impl::status;
28 using namespace mkldnn::impl::prop_kind;
29 using namespace mkldnn::impl::data_type;
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::primitive_kind;
33 template <impl::data_type_t data_type>
34 void gemm_inner_product_fwd_t<data_type>::execute_forward() const {
35 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
36 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
37 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
38 auto dst = reinterpret_cast<data_t*>(this->memory());
40 const int MB = pd()->MB();
41 const int OC = pd()->OC();
42 const int IC = pd()->IC_total_padded();
44 bool wei_tr = !utils::one_of(pd()->weights_pd()->desc()->format,
47 const auto &post_ops = pd()->attr()->post_ops_;
48 const bool do_relu = post_ops.len_ == 1;
50 float alpha = 1.0, beta = 0.0;
51 extended_sgemm(wei_tr ? "T" : "N", "N", &OC, &MB, &IC, &alpha, weights,
52 wei_tr ? &IC : &OC, src, &IC, &beta, dst, &OC, bias);
55 float nslope = post_ops.entry_[0].eltwise.alpha;
56 parallel_nd(MB, OC, [&](int mb, int oc) {
57 size_t dst_off = mb * OC + oc;
59 dst[dst_off] *= nslope;
64 template <impl::data_type_t data_type>
65 void gemm_inner_product_bwd_data_t<data_type>::execute_backward_data() const {
66 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
67 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
68 auto diff_src = reinterpret_cast<data_t*>(this->memory());
70 const int MB = pd()->MB();
71 const int OC = pd()->OC();
72 const int IC = pd()->IC_total_padded();
74 bool wei_tr = utils::one_of(pd()->weights_pd()->desc()->format,
77 float alpha = 1.0, beta = 0.0;
78 extended_sgemm(wei_tr ? "T" : "N", "N", &IC, &MB, &OC, &alpha, weights,
79 wei_tr ? &OC : &IC, diff_dst, &OC, &beta, diff_src, &IC);
82 template <impl::data_type_t data_type>
83 void gemm_inner_product_bwd_weights_t<data_type>::execute_backward_weights() const {
84 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
85 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
86 auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
87 auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
89 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
90 const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
92 diff_dst += diff_dst_d.blocking_desc().offset_padding;
94 const int MB = pd()->MB();
95 const int OC = pd()->OC();
96 const int IC = pd()->IC_total_padded();
98 bool wei_tr = utils::one_of(pd()->diff_weights_pd()->desc()->format,
101 float alpha = 1.0, beta = 0.0;
103 extended_sgemm("N", "T", &OC, &IC, &MB, &alpha, diff_dst, &OC, src, &IC,
104 &beta, diff_weights, &OC);
106 extended_sgemm("N", "T", &IC, &OC, &MB, &alpha, src, &IC, diff_dst, &OC,
107 &beta, diff_weights, &IC);
110 diff_bias += diff_bias_d.blocking_desc().offset_padding;
111 constexpr int blksize = 8;
112 const int OC_blocks = OC / blksize;
113 const int rem_OC = OC % blksize;
114 parallel(0, [&](const int ithr, const int nthr) {
115 int oc_st{0}, oc_e{0};
116 balance211(OC_blocks, nthr, ithr, oc_st, oc_e);
117 oc_st = oc_st * blksize;
118 oc_e = oc_e * blksize;
121 for (int oc = oc_st; oc < oc_e; ++oc) {
122 diff_bias[oc] = diff_dst[oc];
125 for (int mb = 1; mb < MB; ++mb) {
127 for (int oc = oc_st; oc < oc_e; ++oc) {
128 diff_bias[oc] += diff_dst[mb * OC + oc];
132 if (rem_OC != 0 && ithr == nthr-1) {
133 for (int oc = OC_blocks * blksize; oc < OC; oc++)
134 diff_bias[oc] = diff_dst[oc];
135 for (int mb = 1; mb < MB; ++mb) {
136 for (int oc = OC_blocks * blksize; oc < OC; oc++) {
137 diff_bias[oc] += diff_dst[mb * OC + oc];
145 template struct gemm_inner_product_fwd_t<data_type::f32>;
146 template struct gemm_inner_product_bwd_data_t<data_type::f32>;
147 template struct gemm_inner_product_bwd_weights_t<data_type::f32>;
153 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s