Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_inner_product.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "type_helpers.hpp"
19 #include "mkldnn_thread.hpp"
20
21 #include "gemm_inner_product.hpp"
22
23 namespace mkldnn {
24 namespace impl {
25 namespace cpu {
26
27 using namespace mkldnn::impl::status;
28 using namespace mkldnn::impl::prop_kind;
29 using namespace mkldnn::impl::data_type;
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::primitive_kind;
32
33 template <impl::data_type_t data_type>
34 void gemm_inner_product_fwd_t<data_type>::execute_forward() const {
35     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
36     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
37     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
38     auto dst = reinterpret_cast<data_t*>(this->memory());
39
40     const int MB = pd()->MB();
41     const int OC = pd()->OC();
42     const int IC = pd()->IC_total_padded();
43
44     bool wei_tr = !utils::one_of(pd()->weights_pd()->desc()->format,
45              hwio, dhwio, io);
46
47     const auto &post_ops = pd()->attr()->post_ops_;
48     const bool do_relu = post_ops.len_ == 1;
49
50     float alpha = 1.0, beta = 0.0;
51     extended_sgemm(wei_tr ? "T" : "N", "N", &OC, &MB, &IC, &alpha, weights,
52             wei_tr ? &IC : &OC, src, &IC, &beta, dst, &OC, bias);
53
54     if (do_relu) {
55         float nslope = post_ops.entry_[0].eltwise.alpha;
56         parallel_nd(MB, OC, [&](int mb, int oc) {
57             size_t dst_off = mb * OC + oc;
58             if (dst[dst_off] < 0)
59                 dst[dst_off] *= nslope;
60         });
61     }
62 }
63
64 template <impl::data_type_t data_type>
65 void gemm_inner_product_bwd_data_t<data_type>::execute_backward_data() const {
66     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
67     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
68     auto diff_src = reinterpret_cast<data_t*>(this->memory());
69
70     const int MB = pd()->MB();
71     const int OC = pd()->OC();
72     const int IC = pd()->IC_total_padded();
73
74     bool wei_tr = utils::one_of(pd()->weights_pd()->desc()->format,
75              hwio, dhwio, io);
76
77     float alpha = 1.0, beta = 0.0;
78     extended_sgemm(wei_tr ? "T" : "N", "N", &IC, &MB, &OC, &alpha, weights,
79             wei_tr ? &OC : &IC, diff_dst, &OC, &beta, diff_src, &IC);
80 }
81
82 template <impl::data_type_t data_type>
83 void gemm_inner_product_bwd_weights_t<data_type>::execute_backward_weights() const {
84     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
85     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
86     auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
87     auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
88
89     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
90     const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
91
92     diff_dst += diff_dst_d.blocking_desc().offset_padding;
93
94     const int MB = pd()->MB();
95     const int OC = pd()->OC();
96     const int IC = pd()->IC_total_padded();
97
98     bool wei_tr = utils::one_of(pd()->diff_weights_pd()->desc()->format,
99              hwio, dhwio, io);
100
101     float alpha = 1.0, beta = 0.0;
102     if (wei_tr)
103         extended_sgemm("N", "T", &OC, &IC, &MB, &alpha, diff_dst, &OC, src, &IC,
104                 &beta, diff_weights, &OC);
105     else
106         extended_sgemm("N", "T", &IC, &OC, &MB, &alpha, src, &IC, diff_dst, &OC,
107                 &beta, diff_weights, &IC);
108
109     if (diff_bias) {
110         diff_bias += diff_bias_d.blocking_desc().offset_padding;
111         constexpr int blksize = 8;
112         const int OC_blocks = OC / blksize;
113         const int rem_OC = OC % blksize;
114         parallel(0, [&](const int ithr, const int nthr) {
115             int oc_st{0}, oc_e{0};
116             balance211(OC_blocks, nthr, ithr, oc_st, oc_e);
117             oc_st = oc_st * blksize;
118             oc_e = oc_e * blksize;
119
120             PRAGMA_OMP_SIMD()
121             for (int oc = oc_st; oc < oc_e; ++oc) {
122                 diff_bias[oc] = diff_dst[oc];
123             }
124
125             for (int mb = 1; mb < MB; ++mb) {
126                 PRAGMA_OMP_SIMD()
127                 for (int oc = oc_st; oc < oc_e; ++oc) {
128                     diff_bias[oc] += diff_dst[mb * OC + oc];
129                 }
130             }
131
132             if (rem_OC != 0 && ithr == nthr-1) {
133                 for (int oc = OC_blocks * blksize; oc < OC; oc++)
134                     diff_bias[oc] = diff_dst[oc];
135                 for (int mb = 1; mb < MB; ++mb) {
136                     for (int oc = OC_blocks * blksize; oc < OC; oc++) {
137                         diff_bias[oc] += diff_dst[mb * OC + oc];
138                     }
139                 }
140             }
141         });
142     }
143 }
144
145 template struct gemm_inner_product_fwd_t<data_type::f32>;
146 template struct gemm_inner_product_bwd_data_t<data_type::f32>;
147 template struct gemm_inner_product_bwd_weights_t<data_type::f32>;
148
149 }
150 }
151 }
152
153 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s