1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "type_helpers.hpp"
19 #include "mkldnn_thread.hpp"
21 #include "gemm_inner_product.hpp"
27 using namespace mkldnn::impl::status;
28 using namespace mkldnn::impl::prop_kind;
29 using namespace mkldnn::impl::data_type;
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::primitive_kind;
33 template <impl::data_type_t data_type>
34 void gemm_inner_product_fwd_t<data_type>::execute_forward() const {
35 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
36 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
37 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
38 auto dst = reinterpret_cast<data_t*>(this->memory());
40 const int MB = pd()->MB();
41 const int OC = pd()->OC();
42 const int IC = pd()->IC_total_padded();
44 bool wei_tr = !utils::one_of(pd()->weights_pd()->desc()->format,
47 const float *scales = pd()->attr()->output_scales_.scales_;
49 float alpha = 1.0, beta = 0.0;
50 extended_sgemm(wei_tr ? "T" : "N", "N", &OC, &MB, &IC, &alpha, weights,
51 wei_tr ? &IC : &OC, src, &IC, &beta, dst, &OC,
52 postops_in_ip_ ? nullptr : bias);
55 parallel(0, (size_t)OC * MB, [&](int ithr, int nthr) {
56 size_t start = 0, end = 0;
57 balance211((size_t)OC * MB, nthr, ithr, start, end);
58 (*pp_kernel_)(dst, dst, (char *)bias, scales, start, end);
63 template <impl::data_type_t data_type>
64 void gemm_inner_product_bwd_data_t<data_type>::execute_backward_data() const {
65 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
66 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
67 auto diff_src = reinterpret_cast<data_t*>(this->memory());
69 const int MB = pd()->MB();
70 const int OC = pd()->OC();
71 const int IC = pd()->IC_total_padded();
73 bool wei_tr = utils::one_of(pd()->weights_pd()->desc()->format,
76 float alpha = 1.0, beta = 0.0;
77 extended_sgemm(wei_tr ? "T" : "N", "N", &IC, &MB, &OC, &alpha, weights,
78 wei_tr ? &OC : &IC, diff_dst, &OC, &beta, diff_src, &IC);
81 template <impl::data_type_t data_type>
82 void gemm_inner_product_bwd_weights_t<data_type>::execute_backward_weights() const {
83 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
84 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
85 auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
86 auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
88 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
89 const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
91 diff_dst += diff_dst_d.blocking_desc().offset_padding;
93 const int MB = pd()->MB();
94 const int OC = pd()->OC();
95 const int IC = pd()->IC_total_padded();
97 bool wei_tr = utils::one_of(pd()->diff_weights_pd()->desc()->format,
100 float alpha = 1.0, beta = 0.0;
102 extended_sgemm("N", "T", &OC, &IC, &MB, &alpha, diff_dst, &OC, src, &IC,
103 &beta, diff_weights, &OC);
105 extended_sgemm("N", "T", &IC, &OC, &MB, &alpha, src, &IC, diff_dst, &OC,
106 &beta, diff_weights, &IC);
109 diff_bias += diff_bias_d.blocking_desc().offset_padding;
110 constexpr int blksize = 8;
111 const int OC_blocks = OC / blksize;
112 const int rem_OC = OC % blksize;
113 parallel(0, (size_t)OC_blocks, [&](const int ithr, const int nthr) {
114 int oc_st{0}, oc_e{0};
115 balance211(OC_blocks, nthr, ithr, oc_st, oc_e);
116 oc_st = oc_st * blksize;
117 oc_e = oc_e * blksize;
120 for (int oc = oc_st; oc < oc_e; ++oc) {
121 diff_bias[oc] = diff_dst[oc];
124 for (int mb = 1; mb < MB; ++mb) {
126 for (int oc = oc_st; oc < oc_e; ++oc) {
127 diff_bias[oc] += diff_dst[mb * OC + oc];
131 if (rem_OC != 0 && ithr == nthr-1) {
132 for (int oc = OC_blocks * blksize; oc < OC; oc++)
133 diff_bias[oc] = diff_dst[oc];
134 for (int mb = 1; mb < MB; ++mb) {
135 for (int oc = OC_blocks * blksize; oc < OC; oc++) {
136 diff_bias[oc] += diff_dst[mb * OC + oc];
144 template struct gemm_inner_product_fwd_t<data_type::f32>;
145 template struct gemm_inner_product_bwd_data_t<data_type::f32>;
146 template struct gemm_inner_product_bwd_weights_t<data_type::f32>;
152 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s