updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_inner_product.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "type_helpers.hpp"
19 #include "mkldnn_thread.hpp"
20
21 #include "gemm_inner_product.hpp"
22
23 namespace mkldnn {
24 namespace impl {
25 namespace cpu {
26
27 using namespace mkldnn::impl::status;
28 using namespace mkldnn::impl::prop_kind;
29 using namespace mkldnn::impl::data_type;
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::primitive_kind;
32
33 template <impl::data_type_t data_type>
34 void gemm_inner_product_fwd_t<data_type>::execute_forward() const {
35     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
36     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
37     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
38     auto dst = reinterpret_cast<data_t*>(this->memory());
39
40     const int MB = pd()->MB();
41     const int OC = pd()->OC();
42     const int IC = pd()->IC_total_padded();
43
44     bool wei_tr = !utils::one_of(pd()->weights_pd()->desc()->format,
45              hwio, dhwio, io);
46
47     const float *scales = pd()->attr()->output_scales_.scales_;
48
49     float alpha = 1.0, beta = 0.0;
50     extended_sgemm(wei_tr ? "T" : "N", "N", &OC, &MB, &IC, &alpha, weights,
51             wei_tr ? &IC : &OC, src, &IC, &beta, dst, &OC,
52             postops_in_ip_ ? nullptr : bias);
53
54     if (postops_in_ip_) {
55         parallel(0, (size_t)OC * MB, [&](int ithr, int nthr) {
56             size_t start = 0, end = 0;
57             balance211((size_t)OC * MB, nthr, ithr, start, end);
58             (*pp_kernel_)(dst, dst, (char *)bias, scales, start, end);
59         });
60     }
61 }
62
63 template <impl::data_type_t data_type>
64 void gemm_inner_product_bwd_data_t<data_type>::execute_backward_data() const {
65     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
66     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
67     auto diff_src = reinterpret_cast<data_t*>(this->memory());
68
69     const int MB = pd()->MB();
70     const int OC = pd()->OC();
71     const int IC = pd()->IC_total_padded();
72
73     bool wei_tr = utils::one_of(pd()->weights_pd()->desc()->format,
74              hwio, dhwio, io);
75
76     float alpha = 1.0, beta = 0.0;
77     extended_sgemm(wei_tr ? "T" : "N", "N", &IC, &MB, &OC, &alpha, weights,
78             wei_tr ? &OC : &IC, diff_dst, &OC, &beta, diff_src, &IC);
79 }
80
81 template <impl::data_type_t data_type>
82 void gemm_inner_product_bwd_weights_t<data_type>::execute_backward_weights() const {
83     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
84     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
85     auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
86     auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
87
88     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
89     const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
90
91     diff_dst += diff_dst_d.blocking_desc().offset_padding;
92
93     const int MB = pd()->MB();
94     const int OC = pd()->OC();
95     const int IC = pd()->IC_total_padded();
96
97     bool wei_tr = utils::one_of(pd()->diff_weights_pd()->desc()->format,
98              hwio, dhwio, io);
99
100     float alpha = 1.0, beta = 0.0;
101     if (wei_tr)
102         extended_sgemm("N", "T", &OC, &IC, &MB, &alpha, diff_dst, &OC, src, &IC,
103                 &beta, diff_weights, &OC);
104     else
105         extended_sgemm("N", "T", &IC, &OC, &MB, &alpha, src, &IC, diff_dst, &OC,
106                 &beta, diff_weights, &IC);
107
108     if (diff_bias) {
109         diff_bias += diff_bias_d.blocking_desc().offset_padding;
110         constexpr int blksize = 8;
111         const int OC_blocks = OC / blksize;
112         const int rem_OC = OC % blksize;
113         parallel(0, (size_t)OC_blocks, [&](const int ithr, const int nthr) {
114             int oc_st{0}, oc_e{0};
115             balance211(OC_blocks, nthr, ithr, oc_st, oc_e);
116             oc_st = oc_st * blksize;
117             oc_e = oc_e * blksize;
118
119             PRAGMA_OMP_SIMD()
120             for (int oc = oc_st; oc < oc_e; ++oc) {
121                 diff_bias[oc] = diff_dst[oc];
122             }
123
124             for (int mb = 1; mb < MB; ++mb) {
125                 PRAGMA_OMP_SIMD()
126                 for (int oc = oc_st; oc < oc_e; ++oc) {
127                     diff_bias[oc] += diff_dst[mb * OC + oc];
128                 }
129             }
130
131             if (rem_OC != 0 && ithr == nthr-1) {
132                 for (int oc = OC_blocks * blksize; oc < OC; oc++)
133                     diff_bias[oc] = diff_dst[oc];
134                 for (int mb = 1; mb < MB; ++mb) {
135                     for (int oc = OC_blocks * blksize; oc < OC; oc++) {
136                         diff_bias[oc] += diff_dst[mb * OC + oc];
137                     }
138                 }
139             }
140         });
141     }
142 }
143
144 template struct gemm_inner_product_fwd_t<data_type::f32>;
145 template struct gemm_inner_product_bwd_data_t<data_type::f32>;
146 template struct gemm_inner_product_bwd_weights_t<data_type::f32>;
147
148 }
149 }
150 }
151
152 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s