Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / extension / ext_mvn.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
7
8 #include <cmath>
9 #include <string>
10 #include <vector>
11 #include <cassert>
12 #include <algorithm>
13 #if defined(HAVE_AVX2) || defined(HAVE_AVX512F)
14 #include <immintrin.h>
15 #endif
16 #include "ie_parallel.hpp"
17
18 namespace InferenceEngine {
19 namespace Extensions {
20 namespace Cpu {
21
22 inline int div_up(const int a, const int b) {
23     assert(b);
24     return (a + b - 1) / b;
25 }
26
27 class MVNImpl: public ExtLayerBase {
28 public:
29     explicit MVNImpl(const CNNLayer* layer) {
30         try {
31             if (layer->insData.size() != 1 || layer->outData.empty())
32                 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
33
34             across_channels = layer->GetParamAsBool("across_channels", false);
35             normalize_variance = layer->GetParamAsBool("normalize_variance", false);
36             eps = layer->GetParamAsFloat("eps");
37
38 #if defined(HAVE_AVX512F)
39             auto blk_layout = ConfLayout::BLK16;
40 #else
41             auto blk_layout = ConfLayout::BLK8;
42 #endif
43             addConfig(layer, {{blk_layout, false, -1}}, {{blk_layout, false, 0}});
44             addConfig(layer, {{ConfLayout::PLN, false, 0}}, {{ConfLayout::PLN, false, 0}});
45         } catch (InferenceEngine::details::InferenceEngineException &ex) {
46             errorMsg = ex.what();
47         }
48     }
49
50     StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
51                        ResponseDesc *resp) noexcept override {
52         float* src_data = inputs[0]->buffer();
53         float* dst_data = outputs[0]->buffer();
54
55         if (inputs[0]->layout() == NCHW || inputs[0]->layout() == NCDHW) {
56             mvn_pln(src_data, dst_data, inputs[0]->getTensorDesc().getDims());
57         } else {
58             mvn_blk(src_data, dst_data, inputs[0]->getTensorDesc().getDims());
59         }
60
61         return OK;
62     }
63
64 private:
65     void mvn_pln(const float* src_data, float* dst_data, const SizeVector& dims);
66     void mvn_blk(const float* src_data, float* dst_data, const SizeVector& dims);
67
68     bool across_channels = false;
69     bool normalize_variance = true;
70     float eps = 1e-9f;
71 };
72
73 void MVNImpl::mvn_pln(const float* src_data, float* dst_data, const SizeVector& dims) {
74     size_t dims_size = dims.size();
75     size_t N = (dims_size > 0) ? dims[0] : 1lu;
76     size_t C = (dims_size > 1) ? dims[1] : 1lu;
77     size_t D = (dims_size > 4) ? dims[dims_size - 3] : 1lu;
78     size_t H = (dims_size > 3) ? dims[dims_size - 2] : 1lu;
79     size_t W = (dims_size > 2) ? dims[dims_size - 1] : 1lu;
80
81     size_t C1 = H * W;
82     size_t C2 = C1 * D;
83     size_t C3 = C2 * C;
84
85     for (size_t b = 0lu; b < N; b++) {
86         // Calculate mean value
87         size_t cb = b * C3;
88         if (across_channels) {
89             double mean = 0.0;
90             mean = parallel_sum(C, mean, [&](size_t c)->double {
91                 double mean_internal = 0.0;
92                 size_t cc = cb + c * C2;
93                 for (size_t d = 0lu; d < D; d++) {
94                     size_t cd = cc + d * C1;
95                     for (size_t h = 0lu; h < H; h++) {
96                         size_t ch = cd + h * W;
97                         for (size_t w = 0lu; w < W; w++) {
98                             mean_internal += src_data[ch + w];
99                         }
100                     }
101                 }
102                 return mean_internal;
103             });
104
105             mean /= C3;
106             parallel_for(C, [&](int c) {
107                 size_t cc = cb + c * C2;
108                 for (size_t d = 0lu; d < D; d++) {
109                     size_t cd = cc + d * C1;
110                     for (size_t h = 0lu; h < H; h++) {
111                         size_t ch = cd + h * W;
112                         for (size_t w = 0lu; w < W; w++) {
113                             size_t cw = ch + w;
114                             dst_data[cw] = src_data[cw] - static_cast<float>(mean);
115                         }
116                     }
117                 }
118             });
119         } else {
120             parallel_for(C, [&](size_t c) {
121                 double mean = 0.f;
122                 size_t cc = cb + c * C2;
123                 for (size_t d = 0lu; d < D; d++) {
124                     size_t cd = cc + d * C1;
125                     for (size_t h = 0lu; h < H; h++) {
126                         size_t ch = cd + h * W;
127                         for (size_t w = 0lu; w < W; w++) {
128                             mean += src_data[ch + w];
129                         }
130                     }
131                 }
132
133                 mean /= static_cast<double>(C2);
134
135                 for (size_t d = 0lu; d < D; d++) {
136                     size_t cd = cc + d * C1;
137                     for (size_t h = 0lu; h < H; h++) {
138                         size_t ch = cd + h * W;
139                         for (size_t w = 0lu; w < W; w++) {
140                             size_t cw = ch + w;
141                             dst_data[cw] = src_data[cw] - static_cast<float>(mean);
142                         }
143                     }
144                 }
145             });
146         }
147     }
148
149     if (normalize_variance) {
150         for (size_t b = 0lu; b < N; b++) {
151             // Calculate variances value
152             size_t cb = b * C3;
153             if (across_channels) {
154                 double variance = 0.0;
155                 variance = parallel_sum(C, variance, [&](size_t c)->double {
156                     double variance_internal = 0.0;
157                     size_t cc = cb + c * C2;
158                     for (size_t d = 0lu; d < D; d++) {
159                         size_t cd = cc + d * C1;
160                         for (size_t h = 0lu; h < H; h++) {
161                             size_t ch = cd + h * W;
162                             for (size_t w = 0lu; w < W; w++) {
163                                 variance_internal += std::pow(dst_data[ch + w], 2);
164                             }
165                         }
166                     }
167                     return variance_internal;
168                 });
169
170                 variance /= C3;
171                 variance += eps;
172                 variance = std::pow(variance, 0.5f);
173                 parallel_for(C, [&](int c) {
174                     size_t cc = cb + c * C2;
175                     for (size_t d = 0lu; d < D; d++) {
176                         size_t cd = cc + d * C1;
177                         for (size_t h = 0lu; h < H; h++) {
178                             size_t ch = cd + h * W;
179                             for (size_t w = 0lu; w < W; w++) {
180                                 dst_data[ch + w] /= static_cast<float>(variance);
181                             }
182                         }
183                     }
184                 });
185             } else {
186                 parallel_for(C, [&](size_t c) {
187                     double variance = 0.0;
188                     size_t cc = cb + c * C2;
189                     for (size_t d = 0lu; d < D; d++) {
190                         size_t cd = cc + d * C1;
191                         for (size_t h = 0lu; h < H; h++) {
192                             size_t ch = cd + h * W;
193                             for (size_t w = 0lu; w < W; w++) {
194                                 variance += std::pow(dst_data[ch + w], 2);
195                             }
196                         }
197                     }
198
199                     variance /= static_cast<double>(C2);
200                     variance += eps;
201                     variance = std::pow(variance, 0.5f);
202                     for (size_t d = 0lu; d < D; d++) {
203                         size_t cd = cc + d * C1;
204                         for (size_t h = 0lu; h < H; h++) {
205                             size_t ch = cd + h * W;
206                             for (size_t w = 0lu; w < W; w++) {
207                                 dst_data[ch + w] /= static_cast<float>(variance);
208                             }
209                         }
210                     }
211                 });
212             }
213         }
214     }
215 }
216
217 void MVNImpl::mvn_blk(const float* src_data, float* dst_data, const SizeVector& dims) {
218 #if defined(HAVE_AVX512F)
219     size_t blk_size = 16;
220 #else
221     size_t blk_size = 8lu;
222 #endif
223
224 #if defined(HAVE_AVX512F)
225     typedef __m512 vec_type;
226 #elif defined(HAVE_AVX2)
227     typedef __m256 vec_type;
228 #endif
229     size_t dims_size = dims.size();
230     size_t N = (dims_size > 0) ? dims[0] : 1lu;
231     size_t C = (dims_size > 1) ? dims[1] : 1lu;
232     size_t D = (dims_size > 4) ? dims[dims_size - 3] : 1lu;
233     size_t H = (dims_size > 3) ? dims[dims_size - 2] : 1lu;
234     size_t W = (dims_size > 2) ? dims[dims_size - 1] : 1lu;
235
236     int CB = div_up(static_cast<int>(C), static_cast<int>(blk_size));
237
238     size_t C0 = W * blk_size;
239     size_t C1 = C0 * H;
240     size_t C2 = C1 * D;
241     size_t C3 = C2 * CB;
242     size_t C5 = C * D * H * W;
243
244     if (normalize_variance) {
245         for (size_t b = 0lu; b < N; b++) {
246             size_t ccb = b * C3;
247             if (across_channels) {
248                 double mean = 0.0;
249                 mean = parallel_sum3d(CB, D, H, mean, [&](size_t cb, size_t d, size_t h)->double {
250                     size_t ccbd = ccb + cb * C2 + d * C1 + h * C0;
251                     size_t min_cb = std::min(blk_size, C - cb * blk_size);
252                     double mean_internal = 0.0;
253                     for (size_t w = 0lu; w < W; w++) {
254                         size_t cw = ccbd + w * blk_size;
255                         for (size_t c = 0lu; c < min_cb; c++) {
256                             mean_internal += src_data[cw + c];
257                         }
258                     }
259                     return mean_internal;
260                 });
261
262                 mean /= static_cast<double>(C5);
263
264                 double variance = 0.0;
265                 variance = parallel_sum3d(CB, D, H, variance, [&](size_t cb, size_t d, size_t h)->double {
266                     size_t ccbd = ccb + cb * C2 + d * C1 + h * C0;
267                     double variance_internal = 0.0;
268                     for (size_t w = 0lu, min_cb = std::min(blk_size, C - cb * blk_size); w < W; w++) {
269                         size_t cw = ccbd + w * blk_size;
270                         for (size_t c = 0lu; c < min_cb; c++) {
271                             variance_internal += std::pow(static_cast<double>(src_data[cw + c]) - mean, 2);
272                         }
273                     }
274                     return variance_internal;
275                 });
276
277                 variance /= static_cast<double>(C5);
278                 variance += eps;
279                 variance = std::pow(variance, 0.5f);
280
281                 parallel_for3d(CB, D, H, [&](size_t cb, size_t d, size_t h) {
282                     size_t ccbd = ccb + cb * C2 + d * C1 + h * C0;
283                     for (size_t w = 0lu, min_cb = std::min(blk_size, C - cb * blk_size); w < W; w++) {
284                         size_t cw = ccbd + w * blk_size;
285                         for (size_t c = 0lu; c < min_cb; c++) {
286                             size_t src_offset = cw + c;
287
288                             dst_data[src_offset] = static_cast<float>((static_cast<double>(src_data[src_offset]) - mean) / variance);
289                         }
290                     }
291                 });
292             } else {
293                 parallel_for(CB, [&](size_t cb) {
294                     size_t src_off = ccb + cb * C2;
295 #if defined(HAVE_AVX2) || defined(HAVE_AVX512F)
296                     vec_type vmean = _mm_uni_setzero_ps();
297                     for (size_t d = 0lu; d < D; d++) {
298                         size_t cd = src_off + d * C1;
299                         for (size_t h = 0lu; h < H; h++) {
300                             size_t ch = cd + h * C0;
301                             for (size_t w = 0lu; w < W; w++) {
302                                 vec_type vsrc = _mm_uni_loadu_ps(src_data + ch + w * blk_size);
303                                 vmean = _mm_uni_add_ps(vmean, vsrc);
304                             }
305                         }
306                     }
307
308                     vec_type vsize = _mm_uni_set1_ps(static_cast<float>(D * H * W));
309                     vmean = _mm_uni_div_ps(vmean, vsize);
310
311                     vec_type vvariance = _mm_uni_setzero_ps();
312                     for (size_t d = 0lu; d < D; d++) {
313                         size_t cd = src_off + d * C1;
314                         for (size_t h = 0lu; h < H; h++) {
315                             size_t ch = cd + h * C0;
316                             for (size_t w = 0lu; w < W; w++) {
317                                 vec_type vsrc = _mm_uni_loadu_ps(src_data + ch + w * blk_size);
318                                 vsrc = _mm_uni_sub_ps(vsrc, vmean);
319                                 vvariance = _mm_uni_add_ps(vvariance, _mm_uni_mul_ps(vsrc, vsrc));
320                             }
321                         }
322                     }
323                     vvariance = _mm_uni_div_ps(vvariance, vsize);
324
325                     vec_type veps = _mm_uni_set1_ps(eps);
326                     vvariance = _mm_uni_add_ps(vvariance, veps);
327
328                     vvariance = _mm_uni_sqrt_ps(vvariance);
329
330                     for (size_t d = 0lu; d < D; d++) {
331                         size_t cd = src_off + d * C1;
332                         for (size_t h = 0lu; h < H; h++) {
333                             size_t ch = cd + h * C0;
334                             for (size_t w = 0lu; w < W; w++) {
335                                 size_t offset = ch + w * blk_size;
336                                 vec_type vsrc = _mm_uni_loadu_ps(src_data + offset);
337                                 vsrc = _mm_uni_sub_ps(vsrc, vmean);
338                                 _mm_uni_storeu_ps(dst_data + offset, _mm_uni_div_ps(vsrc, vvariance));
339                             }
340                         }
341                     }
342 #else
343                     size_t min_cb = std::min(blk_size, C - cb * blk_size);
344                     for (size_t c = 0; c < min_cb; c++) {
345                         size_t cc = src_off + c;
346
347                         double mean = 0.0;
348                         for (size_t d = 0; d < D; d++) {
349                             size_t cd = cc + d * C1;
350                             for (size_t h = 0; h < H; h++) {
351                                 size_t ch = cd + h * C0;
352                                 for (size_t w = 0; w < W; w++) {
353                                     mean += src_data[ch + w * blk_size];
354                                 }
355                             }
356                         }
357
358                         size_t C4 = D * H * W;
359                         mean /= static_cast<double>(C4);
360
361                         double variance = 0.0;
362                         for (size_t d = 0lu; d < D; d++) {
363                             size_t cd = cc + d * C1;
364                             for (size_t h = 0lu; h < H; h++) {
365                                 size_t ch = cd + h * C0;
366                                 for (size_t w = 0lu; w < W; w++) {
367                                     double value = static_cast<double>(src_data[ch + w * blk_size]) - mean;
368                                     variance += std::pow(value, 2);
369                                 }
370                             }
371                         }
372
373                         variance /= static_cast<double>(C4);
374                         variance += eps;
375                         variance = std::pow(variance, 0.5f);
376
377                         for (size_t d = 0lu; d < D; d++) {
378                             size_t cd = cc + d * C1;
379                             for (size_t h = 0lu; h < H; h++) {
380                                 size_t ch = cd + h * C0;
381                                 for (size_t w = 0lu; w < W; w++) {
382                                     size_t index = ch + w * blk_size;
383                                     dst_data[index] = (src_data[index] - static_cast<float>(mean)) / static_cast<float>(variance);
384                                 }
385                             }
386                         }
387                     }
388 #endif
389                 });
390             }
391         }
392     } else {
393         for (size_t b = 0; b < N; b++) {
394             size_t ccb = b * C3;
395             if (across_channels) {
396                 double mean = 0.0;
397                 mean = parallel_sum3d(CB, D, H, mean, [&](size_t cb, size_t d, size_t h)->double {
398                     size_t ccbd = ccb + cb * C2 + d * C1 + h * C0;
399                     double mean_internal = 0.f;
400                     for (size_t w = 0lu, min_cb = std::min(blk_size, C - cb * blk_size); w < W; w++) {
401                         size_t cw = ccbd + w * blk_size;
402                         for (size_t c = 0lu; c < min_cb; c++) {
403                             mean_internal += src_data[cw + c];
404                         }
405                     }
406                     return mean_internal;
407                 });
408
409                 mean /= static_cast<double>(C5);
410
411                 parallel_for3d(CB, D, H, [&](size_t cb, size_t d, size_t h) {
412                     size_t ccbd = ccb + cb * C2 + d * C1 + h * C0;
413                     for (size_t w = 0lu, min_cb = std::min(blk_size, C - cb * blk_size); w < W; w++) {
414                         size_t cw = ccbd + w * blk_size;
415                         for (size_t c = 0lu; c < min_cb; c++) {
416                             size_t src_offset = cw + c;
417
418                             dst_data[src_offset] = src_data[src_offset] - static_cast<float>(mean);
419                         }
420                     }
421                 });
422             } else {
423                 parallel_for(CB, [&](size_t cb) {
424                     size_t src_off = ccb + cb * C2;
425 #if defined(HAVE_AVX2) || defined(HAVE_AVX512F)
426                     vec_type vmean = _mm_uni_setzero_ps();
427                     for (size_t d = 0lu; d < D; d++) {
428                         size_t cd = src_off + d * C1;
429                         for (size_t h = 0lu; h < H; h++) {
430                             size_t ch = cd + h * C0;
431                             for (size_t w = 0lu; w < W; w++) {
432                                 vec_type vsrc = _mm_uni_loadu_ps(src_data + ch + w * blk_size);
433                                 vmean = _mm_uni_add_ps(vmean, vsrc);
434                             }
435                         }
436                     }
437
438                     vec_type vsize = _mm_uni_set1_ps(static_cast<float>(D * H * W));
439                     vmean = _mm_uni_div_ps(vmean, vsize);
440
441                     for (size_t d = 0lu; d < D; d++) {
442                         size_t cd = src_off + d * C1;
443                         for (size_t h = 0lu; h < H; h++) {
444                             size_t ch = cd + h * C0;
445                             for (size_t w = 0lu; w < W; w++) {
446                                 size_t offset = ch + w * blk_size;
447                                 vec_type vsrc = _mm_uni_loadu_ps(src_data + offset);
448                                 _mm_uni_storeu_ps(dst_data + offset, _mm_uni_sub_ps(vsrc, vmean));
449                             }
450                         }
451                     }
452 #else
453                     size_t min_cb = std::min(blk_size, C - cb * blk_size);
454                     for (size_t c = 0lu; c < min_cb; c++) {
455                         size_t cc = src_off + c;
456                         double mean = 0.0;
457                         for (size_t d = 0lu; d < D; d++) {
458                             size_t cd = cc + d * C1;
459                             for (size_t h = 0lu; h < H; h++) {
460                                 size_t ch = cd + h * C0;
461                                 for (size_t w = 0lu; w < W; w++) {
462                                     mean += src_data[ch + w * blk_size];
463                                 }
464                             }
465                         }
466
467                         size_t C4 = D * H * W;
468                         mean /= static_cast<double>(C4);
469
470                         for (size_t d = 0lu; d < D; d++) {
471                             size_t cd = cc + d * C1;
472                             for (size_t h = 0lu; h < H; h++) {
473                                 size_t ch = cd + h * C0;
474                                 for (size_t w = 0lu; w < W; w++) {
475                                     size_t index = ch + w * blk_size;
476                                     dst_data[index] = src_data[index] - static_cast<float>(mean);
477                                 }
478                             }
479                         }
480                     }
481 #endif
482                 });
483             }
484         }
485     }
486 }
487
488 REG_FACTORY_FOR(ImplFactory<MVNImpl>, MVN);
489
490 }  // namespace Cpu
491 }  // namespace Extensions
492 }  // namespace InferenceEngine