Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_lrn.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19
20 #include "c_types_map.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "type_helpers.hpp"
23
24 #include "ref_lrn.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 static inline float fast_negative_powf(float omega, float beta) {
31     float Y;
32 /*
33  * Y = omega^(-3/4) =
34  * = 1.0f / sqrtf(omega) * sqrtf(1.0f / sqrtf(omega))
35  * = sqrtf(1.0f / sqrtf(omega)) * 1.0f / sqrtf(omega)
36  * = sqrtf(1.0f / sqrtf(omega)) / sqrtf(omega)
37  * = sqrtf(1.0f / sqrtf(omega) / omega)
38  * = sqrtf(1.0f / (sqrtf(omega) * omega))
39  */
40     if (beta == 0.75f) {
41         Y = sqrtf(1.0f / (sqrtf(omega) * omega));
42     } else {
43         Y = 1.0f / powf(omega, beta);
44     }
45     return Y;
46 };
47
48 template <impl::data_type_t data_type>
49 template <mkldnn_memory_format_t fmt>
50 void ref_lrn_fwd_t<data_type>::execute_forward() const {
51     using namespace alg_kind;
52     using namespace memory_format;
53
54     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
55     auto dst = reinterpret_cast<data_t*>(this->memory(0));
56     auto ws = reinterpret_cast<data_t*>(this->memory(1));
57
58     const memory_desc_wrapper data_d(pd()->src_pd());
59     const memory_desc_wrapper ws_d(pd()->workspace_pd());
60     MAYBE_UNUSED(ws_d);
61
62     const int C = pd()->C();
63     const int H = pd()->H();
64     const int W = pd()->W();
65     const size_t stride_mb = data_d.blocking_desc().strides[0][0];
66     const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels;
67     constexpr int blksize = fmt == nChw16c ? 16 : 8;
68
69     auto data_off = [&](int mb, int c, int h, int w) -> size_t {
70         switch (fmt) {
71         case nChw16c:
72         case nChw8c: return mb * stride_mb + c / blksize * H * W * blksize
73                      + h * W * blksize + w * blksize + c % blksize;
74         case nchw: return mb * stride_mb + c * H * W + h * W + w;
75         case nhwc: return mb * stride_mb + h * W * C + w * C + c;
76         default: return data_d.off(mb, c, h, w);
77         }
78     };
79
80     auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) {
81         const float alpha = static_cast<float>(pd()->desc()->lrn_alpha);
82         const float beta = static_cast<float>(pd()->desc()->lrn_beta);
83         const float k = static_cast<float>(pd()->desc()->lrn_k);
84
85         const int size = pd()->desc()->local_size;
86         const int half_size = (size - 1) / 2;
87
88         float sum = 0;
89         if (across_channels) {
90             const int c_st = nstl::max(oc - half_size + 0, 0);
91             const int c_en = nstl::min(oc + half_size + 1, C);
92
93             for (int c = c_st; c < c_en; ++c) {
94                 const float s = src[data_off(mb, c, oh, ow)];
95                 sum += s * s;
96             }
97         } else {
98             int h_st = nstl::max(oh - half_size + 0, 0);
99             int h_en = nstl::min(oh + half_size + 1, H);
100             int w_st = nstl::max(ow - half_size + 0, 0);
101             int w_en = nstl::min(ow + half_size + 1, W);
102             for (int h = h_st; h < h_en; ++h) {
103                 for (int w = w_st; w < w_en; ++w) {
104                     const float s = src[data_off(mb, oc, h, w)];
105                     sum += s * s;
106                 }
107             }
108         }
109         const int summands = across_channels ? size : size * size;
110         sum = k + alpha * sum / summands;
111         size_t off = data_off(mb, oc, oh, ow);
112         if (ws)
113             ws[off] = static_cast<data_t>(sum);
114         d[0] = static_cast<data_t>(src[off] * fast_negative_powf(sum, beta));
115     };
116
117     const int MB = pd()->MB();
118     if (fmt == nChw16c || fmt == nChw8c) {
119         parallel_nd(MB, utils::div_up(C, blksize), H, W,
120             [&](int mb, int c_blk, int h, int w) {
121             int c = c_blk * blksize;
122             const size_t off = mb * stride_mb + c * H * W
123                 + (h * W + w) * blksize;
124             PRAGMA_OMP_SIMD()
125             for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc)
126                 ker(&dst[off + cc], mb, c + cc, h, w);
127         });
128     } else if (fmt == nhwc) {
129         parallel_nd(MB, H, W, C,
130             [&](int mb, int h, int w, int c) {
131             const size_t off = mb * stride_mb + h * W * C + w * C + c;
132             ker(&dst[off], mb, c, h, w);
133         });
134     } else {
135         parallel_nd(MB, C, H, W,
136             [&](int mb, int c, int h, int w) {
137             const size_t off = data_off(mb, c, h, w);
138             ker(&dst[off], mb, c, h, w);
139         });
140     }
141 }
142
143 template <impl::data_type_t data_type>
144 template <mkldnn_memory_format_t fmt>
145 void ref_lrn_bwd_t<data_type>::execute_backward() const {
146     using namespace alg_kind;
147     using namespace memory_format;
148
149     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
150     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
151     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
152
153     const memory_desc_wrapper data_d(pd()->src_pd());
154     const memory_desc_wrapper diff_data_d(pd()->diff_dst_pd());
155     MAYBE_UNUSED(diff_data_d);
156
157     const int MB = pd()->MB();
158     const int C = pd()->C();
159     const int H = pd()->H();
160     const int W = pd()->W();
161     const size_t stride_mb = data_d.blocking_desc().strides[0][0];
162     constexpr int blksize = fmt == nChw16c ? 16 : 8;
163
164     const float alpha = static_cast<float>(pd()->desc()->lrn_alpha);
165     const float beta = static_cast<float>(pd()->desc()->lrn_beta);
166     const float k = static_cast<float>(pd()->desc()->lrn_k);
167     const int kernel_size = pd()->desc()->local_size;
168     const int half_ksize = (kernel_size - 1) / 2;
169
170     auto data_off = [&](int mb, int c, int h, int w) -> size_t {
171         switch (fmt) {
172         case nChw16c:
173         case nChw8c: return mb * stride_mb + c/blksize * H * W * blksize
174                      + h * W * blksize + w * blksize + c%blksize;
175         case nchw: return mb * stride_mb + c * H * W + h * W + w;
176         case nhwc: return mb * stride_mb + h * W * C + w * C + c;
177         default: return data_d.off(mb, c, h, w);
178         }
179     };
180
181     auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) {
182         const int c_st = nstl::max(oc - half_ksize + 0, 0);
183         const int c_en = nstl::min(oc + half_ksize + 1, C);
184
185         float A = 0, B = 0, omega_mid = 0;
186         for (int c = c_st; c < c_en; c++) {
187             float sum = 0.0;
188             const int i_st = nstl::max(c - half_ksize, 0);
189             const int i_en = nstl::min(c + kernel_size - half_ksize, C);
190
191             for (int i = i_st; i < i_en; ++i) {
192                 const float value = src[data_off(mb, i, oh, ow)];
193                 sum += value * value;
194             }
195             const float omega = static_cast<float>(k + sum * alpha / kernel_size);
196             if (c == oc) omega_mid = omega;
197             float t = src[data_off(mb, c, oh, ow)]
198                    * fast_negative_powf(omega, beta);
199             B += 1.0f / omega * t * diff_dst[data_off(mb, c, oh, ow)];
200         }
201
202         const size_t off = data_off(mb, oc, oh, ow);
203         A = fast_negative_powf(omega_mid, beta) * diff_dst[off];
204         B *= src[off];
205         B *= (2.0f * alpha * beta) / kernel_size;
206         *d = static_cast<data_t>(A - B); // final cast down to data_t
207     };
208
209     if (fmt == nChw16c || fmt == nChw8c) {
210         parallel_nd(MB, utils::div_up(C, blksize), H, W,
211             [&](int mb, int c_blk, int h, int w) {
212             int c = c_blk * blksize;
213             const size_t off = mb * stride_mb + c * H * W +
214                 (h * W + w) * blksize;
215             PRAGMA_OMP_SIMD()
216             for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc)
217                 ker(&diff_src[off + cc], mb, c + cc, h, w);
218         });
219     } else if (fmt == nhwc) {
220         parallel_nd(MB, H, W, C,
221             [&](int mb, int h, int w, int c) {
222             const size_t off = mb * stride_mb + h * W * C + w * C + c;
223             ker(&diff_src[off], mb, c, h, w);
224         });
225     } else {
226         parallel_nd(MB, C, H, W,
227             [&](int mb, int c, int h, int w) {
228             const size_t off = data_off(mb, c, h, w);
229             ker(&diff_src[off], mb, c, h, w);
230         });
231     }
232 }
233
234 template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nChw16c>() const;
235 template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nChw8c>() const;
236 template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nchw>() const;
237 template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::nhwc>() const;
238 template void ref_lrn_fwd_t<data_type::f32>::execute_forward<memory_format::any>() const;
239 template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nChw16c>() const;
240 template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nChw8c>() const;
241 template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nchw>() const;
242 template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::nhwc>() const;
243 template void ref_lrn_bwd_t<data_type::f32>::execute_backward<memory_format::any>() const;
244
245 }
246 }
247 }
248
249 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s