Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_lrn.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "jit_generator.hpp"
19 #include "jit_uni_lrn.hpp"
20 #include "type_helpers.hpp"
21 #include "utils.hpp"
22
23 namespace mkldnn {
24 namespace impl {
25 namespace cpu {
26
27 template <cpu_isa_t isa>
28 jit_uni_lrn_fwd_t<isa>::jit_uni_lrn_fwd_t(
29     const pd_t *apd,
30     const input_vector &inputs, const output_vector &outputs)
31     : cpu_primitive_t(apd, inputs, outputs), ker_(nullptr)
32     , ker_first_(nullptr), ker_last_(nullptr)
33 {
34     using namespace alg_kind;
35
36     const int C = pd()->C();
37     const int H = pd()->H();
38     const int W = pd()->W();
39     const int ls = pd()->desc()->local_size;
40     float A = pd()->desc()->lrn_alpha / ls;
41     float K = pd()->desc()->lrn_k;
42
43     auto pk = pd()->desc()->prop_kind;
44     auto ak = pd()->desc()->alg_kind;
45     auto dfmt = pd()->src_pd()->desc()->format;
46
47     if (dfmt == nChw8c && ls == 5 && ak == lrn_across_channels) {
48         ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
49                 nchw8c_across(H, W, 0), A, K, pk);
50         ker_first_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
51                 nchw8c_across(H, W, -1), A, K, pk);
52         ker_last_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
53                 nchw8c_across(H, W, +1), A, K, pk);
54     } else if (dfmt == nChw8c && ak == lrn_within_channel) {
55         /* within channel, local_size (x) local_size */
56         A /= ls; /* XXX: why? */
57         ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
58                 nchw8c_within(H, W, ls), A, K, pk);
59     } else if (dfmt == nchw && ls == 5 && ak == lrn_across_channels) {
60         ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
61                 nchw_across(C, H*W, 0), A, K, pk);
62         int remind = (H*W) % VECTOR_LENGTH;
63         if (remind != 0) {
64             ker_last_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
65                         nchw_across(C, H*W, remind), A, K, pk);
66         }
67     } else if (true /* XXX: why */) {
68         ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(nhwc_across(C), A, K, pk);
69     }
70 }
71
72 template <cpu_isa_t isa>
73 jit_uni_lrn_fwd_t<isa>::~jit_uni_lrn_fwd_t()
74 { delete ker_; delete ker_first_; delete ker_last_; }
75
76 template <cpu_isa_t isa>
77 void jit_uni_lrn_fwd_t<isa>::execute_forward() const {
78     using namespace alg_kind;
79
80     auto src = reinterpret_cast<const data_t*>(this->input_memory(0));
81     auto dst = reinterpret_cast<data_t*>(this->memory(0));
82     auto ws = reinterpret_cast<data_t*>(this->memory(1));
83
84     const int N = pd()->MB();
85     const int C = pd()->C();
86     const int HW = pd()->H() * pd()->W();
87     const int ls = pd()->desc()->local_size;
88
89     auto ak = pd()->desc()->alg_kind;
90     auto dfmt = pd()->src_pd()->desc()->format;
91
92     if (dfmt == nChw8c && ls == 5 && ak == lrn_across_channels) {
93         parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
94             jit_args_fwd_t args;
95             args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH];
96             args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH];
97             args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH];
98             if (c8 == 0)
99                 (*ker_first_)(&args);
100             else if (c8 == C / VECTOR_LENGTH - 1)
101                 (*ker_last_)(&args);
102             else
103                 (*ker_)(&args);
104         });
105     }
106     else if (dfmt == nChw8c && ak == lrn_within_channel) {
107         parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
108             jit_args_fwd_t args;
109             args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH];
110             args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH];
111             args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH];
112             (*ker_)(&args);
113         });
114     }
115     else if (dfmt == nchw && ls == 5 && ak == lrn_across_channels) {
116         parallel_nd(N, (HW + VECTOR_LENGTH - 1) / VECTOR_LENGTH,
117             [&](int n, int hw8) {
118             jit_args_fwd_t args;
119             args.src = &src[n*HW*C + hw8 * VECTOR_LENGTH];
120             args.dst = &dst[n*HW*C + hw8 * VECTOR_LENGTH];
121             args.scratch = &ws[n*HW*C + hw8 * VECTOR_LENGTH];
122             if ((hw8 + 1)*VECTOR_LENGTH > HW)
123                 (*ker_last_)(&args);
124             else
125                 (*ker_)(&args);
126         });
127     }
128     else { // nhwc
129         parallel_nd(N, HW, [&](int n, int hw) {
130             jit_args_fwd_t args;
131             args.src = &src[n*HW*C + hw * C];
132             args.dst = &dst[n*HW*C + hw * C];
133             args.scratch = &ws[n*HW*C + hw * C];
134             (*ker_)(&args);
135         });
136     }
137 }
138
139 template <cpu_isa_t isa>
140 status_t jit_uni_lrn_fwd_t<isa>::pd_t::init() {
141     using namespace prop_kind;
142     using namespace alg_kind;
143
144     assert(engine()->kind() == engine_kind::cpu);
145
146     if (!mayiuse(isa)) return unimplemented;
147
148     const memory_desc_wrapper data_d(data_pd_.desc());
149     bool ok = true
150         && one_of(desc()->prop_kind, forward_training, forward_inference)
151         && everyone_is(data_type::f32, desc()->data_desc.data_type)
152         && !has_zero_dim_memory()
153         && data_d.ndims() == 4
154         && data_d.dims()[1] % VECTOR_LENGTH == 0
155         && data_d.dims()[1] >= 2 * VECTOR_LENGTH
156         && desc()->lrn_beta == 0.75
157         && attr()->has_default_values();
158     if (!ok) return unimplemented;
159
160     if (desc_.prop_kind == forward_training) { ws_pd_ = data_pd_; }
161
162     bool args_ok_across = true
163         && desc()->alg_kind == lrn_across_channels
164         && desc()->local_size == 5
165         && one_of(data_d.format(), nChw8c, nchw, nhwc);
166
167     const int jit_max_local_size = 5; // bigger size triggers too big code size
168     bool args_ok_within = true
169         && desc()->alg_kind == lrn_within_channel
170         && desc()->local_size <= ( jit_max_local_size <= MAX_LOCAL_SIZE
171                                  ? jit_max_local_size : MAX_LOCAL_SIZE)
172         && data_d.dims()[2] >= desc()->local_size
173         && data_d.dims()[3] >= desc()->local_size
174         && one_of(data_d.format(), nChw8c);
175
176     return args_ok_across || args_ok_within ? success : unimplemented;
177 }
178
179 template <cpu_isa_t isa>
180 jit_uni_lrn_bwd_t<isa>::jit_uni_lrn_bwd_t(const pd_t *apd,
181     const input_vector &inputs, const output_vector &outputs)
182     : cpu_primitive_t(apd, inputs, outputs)
183     , ker_(nullptr), ker_first_(nullptr), ker_last_(nullptr)
184 {
185     using namespace alg_kind;
186     const int C = pd()->C();
187     const int H = pd()->H();
188     const int W = pd()->W();
189     const int ls = pd()->desc()->local_size;
190     float A = pd()->desc()->lrn_alpha / ls;
191     float B = pd()->desc()->lrn_beta;
192
193     int use_h_parallelizm = 0;// XXX
194     if (C / VECTOR_LENGTH == 1) {
195         ker_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
196             nchw8c_across(H, W, 3), A, B, use_h_parallelizm);
197     }
198     else {
199         ker_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
200             nchw8c_across(H, W, 0), A, B, use_h_parallelizm);
201         ker_first_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
202             nchw8c_across(H, W, -1), A, B, use_h_parallelizm);
203         ker_last_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
204             nchw8c_across(H, W, +1), A, B, use_h_parallelizm);
205     }
206 }
207
208 template <cpu_isa_t isa>
209 jit_uni_lrn_bwd_t<isa>::~jit_uni_lrn_bwd_t()
210 {
211     delete ker_; delete ker_first_; delete ker_last_;
212 }
213
214 template <cpu_isa_t isa>
215 void jit_uni_lrn_bwd_t<isa>::execute_backward() const {
216     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
217     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
218     auto ws = reinterpret_cast<const data_t*>(this->input_memory(2));
219     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
220
221     const int N = pd()->MB();
222     const int C = pd()->C();
223     const int H = pd()->H();
224     const int W = pd()->W();
225
226     int use_h_parallelizm = 0; // XXX
227     if (use_h_parallelizm) {
228         parallel_nd(N, C / VECTOR_LENGTH, H, [&](int n, int c8, int h) {
229             auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH
230                 + h*W*VECTOR_LENGTH;
231             jit_args_bwd_t args;
232             args.src = &src[offset];
233             args.diff_dst = &diff_dst[offset];
234             args.scratch = &ws[offset];
235             args.diff_src = &diff_src[offset];
236             if (C / VECTOR_LENGTH == 1)
237                 (*ker_)(&args);
238             else if (c8 == 0)
239                 (*ker_first_)(&args);
240             else if (c8 == C / VECTOR_LENGTH - 1)
241                 (*ker_last_)(&args);
242             else
243                 (*ker_)(&args);
244         });
245     }
246     else {
247         parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
248             auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH;
249             jit_args_bwd_t args;
250             args.src = &src[offset];
251             args.diff_dst = &diff_dst[offset];
252             args.scratch = &ws[offset];
253             args.diff_src = &diff_src[offset];
254             if (C / VECTOR_LENGTH == 1)
255                 (*ker_)(&args);
256             else if (c8 == 0)
257                 (*ker_first_)(&args);
258             else if (c8 == C / VECTOR_LENGTH - 1)
259                 (*ker_last_)(&args);
260             else
261                 (*ker_)(&args);
262         });
263     }
264 }
265
266 template <cpu_isa_t isa>
267 status_t jit_uni_lrn_bwd_t<isa>::pd_t::init() {
268     using namespace prop_kind;
269     using namespace alg_kind;
270
271     assert(engine()->kind() == engine_kind::cpu);
272
273     if (!mayiuse(isa)) return unimplemented;
274
275     const memory_desc_wrapper data_d(data_pd_.desc());
276     bool ok = true
277         && utils::one_of(desc()->prop_kind, backward, backward_data)
278         && utils::everyone_is(data_type::f32, desc()->data_desc.data_type)
279         && !has_zero_dim_memory()
280         && data_d.ndims() == 4
281         && data_d.dims()[1] % VECTOR_LENGTH == 0
282         && desc()->lrn_beta == 0.75
283         && attr()->has_default_values();
284     if (!ok) return unimplemented;
285
286     ws_pd_ = data_pd_;
287
288     auto fwd_ws_d_ = hint_fwd_pd_->workspace_pd()->desc();
289     bool ws_ok = true
290         && fwd_ws_d_->ndims == data_d.ndims()
291         && fwd_ws_d_->format == data_d.format()
292         && fwd_ws_d_->data_type == data_d.data_type();
293     if (!ws_ok) return unimplemented;
294
295     bool args_ok_across = true
296         && desc()->alg_kind == lrn_across_channels
297         && desc()->local_size == 5
298         && utils::one_of(data_d.format(), nChw8c);
299
300     return args_ok_across ? success : unimplemented;
301 }
302
303 template struct jit_uni_lrn_fwd_t<sse42>;
304 template struct jit_uni_lrn_fwd_t<avx2>;
305 template struct jit_uni_lrn_bwd_t<avx2>;
306
307 }
308 }
309 }
310
311 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s