1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "jit_generator.hpp"
19 #include "jit_uni_lrn.hpp"
20 #include "type_helpers.hpp"
27 template <cpu_isa_t isa>
28 jit_uni_lrn_fwd_t<isa>::jit_uni_lrn_fwd_t(
30 const input_vector &inputs, const output_vector &outputs)
31 : cpu_primitive_t(apd, inputs, outputs), ker_(nullptr)
32 , ker_first_(nullptr), ker_last_(nullptr)
34 using namespace alg_kind;
36 const int C = pd()->C();
37 const int H = pd()->H();
38 const int W = pd()->W();
39 const int ls = pd()->desc()->local_size;
40 float A = pd()->desc()->lrn_alpha / ls;
41 float K = pd()->desc()->lrn_k;
43 auto pk = pd()->desc()->prop_kind;
44 auto ak = pd()->desc()->alg_kind;
45 auto dfmt = pd()->src_pd()->desc()->format;
47 if (dfmt == nChw8c && ls == 5 && ak == lrn_across_channels) {
48 ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
49 nchw8c_across(H, W, 0), A, K, pk);
50 ker_first_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
51 nchw8c_across(H, W, -1), A, K, pk);
52 ker_last_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
53 nchw8c_across(H, W, +1), A, K, pk);
54 } else if (dfmt == nChw8c && ak == lrn_within_channel) {
55 /* within channel, local_size (x) local_size */
56 A /= ls; /* XXX: why? */
57 ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
58 nchw8c_within(H, W, ls), A, K, pk);
59 } else if (dfmt == nchw && ls == 5 && ak == lrn_across_channels) {
60 ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
61 nchw_across(C, H*W, 0), A, K, pk);
62 int remind = (H*W) % VECTOR_LENGTH;
64 ker_last_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
65 nchw_across(C, H*W, remind), A, K, pk);
67 } else if (true /* XXX: why */) {
68 ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(nhwc_across(C), A, K, pk);
72 template <cpu_isa_t isa>
73 jit_uni_lrn_fwd_t<isa>::~jit_uni_lrn_fwd_t()
74 { delete ker_; delete ker_first_; delete ker_last_; }
76 template <cpu_isa_t isa>
77 void jit_uni_lrn_fwd_t<isa>::execute_forward() const {
78 using namespace alg_kind;
80 auto src = reinterpret_cast<const data_t*>(this->input_memory(0));
81 auto dst = reinterpret_cast<data_t*>(this->memory(0));
82 auto ws = reinterpret_cast<data_t*>(this->memory(1));
84 const int N = pd()->MB();
85 const int C = pd()->C();
86 const int HW = pd()->H() * pd()->W();
87 const int ls = pd()->desc()->local_size;
89 auto ak = pd()->desc()->alg_kind;
90 auto dfmt = pd()->src_pd()->desc()->format;
92 if (dfmt == nChw8c && ls == 5 && ak == lrn_across_channels) {
93 parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
95 args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH];
96 args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH];
97 args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH];
100 else if (c8 == C / VECTOR_LENGTH - 1)
106 else if (dfmt == nChw8c && ak == lrn_within_channel) {
107 parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
109 args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH];
110 args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH];
111 args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH];
115 else if (dfmt == nchw && ls == 5 && ak == lrn_across_channels) {
116 parallel_nd(N, (HW + VECTOR_LENGTH - 1) / VECTOR_LENGTH,
117 [&](int n, int hw8) {
119 args.src = &src[n*HW*C + hw8 * VECTOR_LENGTH];
120 args.dst = &dst[n*HW*C + hw8 * VECTOR_LENGTH];
121 args.scratch = &ws[n*HW*C + hw8 * VECTOR_LENGTH];
122 if ((hw8 + 1)*VECTOR_LENGTH > HW)
129 parallel_nd(N, HW, [&](int n, int hw) {
131 args.src = &src[n*HW*C + hw * C];
132 args.dst = &dst[n*HW*C + hw * C];
133 args.scratch = &ws[n*HW*C + hw * C];
139 template <cpu_isa_t isa>
140 status_t jit_uni_lrn_fwd_t<isa>::pd_t::init() {
141 using namespace prop_kind;
142 using namespace alg_kind;
144 assert(engine()->kind() == engine_kind::cpu);
146 if (!mayiuse(isa)) return unimplemented;
148 const memory_desc_wrapper data_d(data_pd_.desc());
150 && one_of(desc()->prop_kind, forward_training, forward_inference)
151 && everyone_is(data_type::f32, desc()->data_desc.data_type)
152 && !has_zero_dim_memory()
153 && data_d.ndims() == 4
154 && data_d.dims()[1] % VECTOR_LENGTH == 0
155 && data_d.dims()[1] >= 2 * VECTOR_LENGTH
156 && desc()->lrn_beta == 0.75
157 && attr()->has_default_values();
158 if (!ok) return unimplemented;
160 if (desc_.prop_kind == forward_training) { ws_pd_ = data_pd_; }
162 bool args_ok_across = true
163 && desc()->alg_kind == lrn_across_channels
164 && desc()->local_size == 5
165 && one_of(data_d.format(), nChw8c, nchw, nhwc);
167 const int jit_max_local_size = 5; // bigger size triggers too big code size
168 bool args_ok_within = true
169 && desc()->alg_kind == lrn_within_channel
170 && desc()->local_size <= ( jit_max_local_size <= MAX_LOCAL_SIZE
171 ? jit_max_local_size : MAX_LOCAL_SIZE)
172 && data_d.dims()[2] >= desc()->local_size
173 && data_d.dims()[3] >= desc()->local_size
174 && one_of(data_d.format(), nChw8c);
176 return args_ok_across || args_ok_within ? success : unimplemented;
179 template <cpu_isa_t isa>
180 jit_uni_lrn_bwd_t<isa>::jit_uni_lrn_bwd_t(const pd_t *apd,
181 const input_vector &inputs, const output_vector &outputs)
182 : cpu_primitive_t(apd, inputs, outputs)
183 , ker_(nullptr), ker_first_(nullptr), ker_last_(nullptr)
185 using namespace alg_kind;
186 const int C = pd()->C();
187 const int H = pd()->H();
188 const int W = pd()->W();
189 const int ls = pd()->desc()->local_size;
190 float A = pd()->desc()->lrn_alpha / ls;
191 float B = pd()->desc()->lrn_beta;
193 int use_h_parallelizm = 0;// XXX
194 if (C / VECTOR_LENGTH == 1) {
195 ker_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
196 nchw8c_across(H, W, 3), A, B, use_h_parallelizm);
199 ker_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
200 nchw8c_across(H, W, 0), A, B, use_h_parallelizm);
201 ker_first_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
202 nchw8c_across(H, W, -1), A, B, use_h_parallelizm);
203 ker_last_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
204 nchw8c_across(H, W, +1), A, B, use_h_parallelizm);
208 template <cpu_isa_t isa>
209 jit_uni_lrn_bwd_t<isa>::~jit_uni_lrn_bwd_t()
211 delete ker_; delete ker_first_; delete ker_last_;
214 template <cpu_isa_t isa>
215 void jit_uni_lrn_bwd_t<isa>::execute_backward() const {
216 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
217 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
218 auto ws = reinterpret_cast<const data_t*>(this->input_memory(2));
219 auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
221 const int N = pd()->MB();
222 const int C = pd()->C();
223 const int H = pd()->H();
224 const int W = pd()->W();
226 int use_h_parallelizm = 0; // XXX
227 if (use_h_parallelizm) {
228 parallel_nd(N, C / VECTOR_LENGTH, H, [&](int n, int c8, int h) {
229 auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH
232 args.src = &src[offset];
233 args.diff_dst = &diff_dst[offset];
234 args.scratch = &ws[offset];
235 args.diff_src = &diff_src[offset];
236 if (C / VECTOR_LENGTH == 1)
239 (*ker_first_)(&args);
240 else if (c8 == C / VECTOR_LENGTH - 1)
247 parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
248 auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH;
250 args.src = &src[offset];
251 args.diff_dst = &diff_dst[offset];
252 args.scratch = &ws[offset];
253 args.diff_src = &diff_src[offset];
254 if (C / VECTOR_LENGTH == 1)
257 (*ker_first_)(&args);
258 else if (c8 == C / VECTOR_LENGTH - 1)
266 template <cpu_isa_t isa>
267 status_t jit_uni_lrn_bwd_t<isa>::pd_t::init() {
268 using namespace prop_kind;
269 using namespace alg_kind;
271 assert(engine()->kind() == engine_kind::cpu);
273 if (!mayiuse(isa)) return unimplemented;
275 const memory_desc_wrapper data_d(data_pd_.desc());
277 && utils::one_of(desc()->prop_kind, backward, backward_data)
278 && utils::everyone_is(data_type::f32, desc()->data_desc.data_type)
279 && !has_zero_dim_memory()
280 && data_d.ndims() == 4
281 && data_d.dims()[1] % VECTOR_LENGTH == 0
282 && desc()->lrn_beta == 0.75
283 && attr()->has_default_values();
284 if (!ok) return unimplemented;
288 auto fwd_ws_d_ = hint_fwd_pd_->workspace_pd()->desc();
290 && fwd_ws_d_->ndims == data_d.ndims()
291 && fwd_ws_d_->format == data_d.format()
292 && fwd_ws_d_->data_type == data_d.data_type();
293 if (!ws_ok) return unimplemented;
295 bool args_ok_across = true
296 && desc()->alg_kind == lrn_across_channels
297 && desc()->local_size == 5
298 && utils::one_of(data_d.format(), nChw8c);
300 return args_ok_across ? success : unimplemented;
303 template struct jit_uni_lrn_fwd_t<sse42>;
304 template struct jit_uni_lrn_fwd_t<avx2>;
305 template struct jit_uni_lrn_bwd_t<avx2>;
311 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s