Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / extension / ext_resample.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
7 #include <vector>
8 #include <string>
9 #include <algorithm>
10 #if defined(HAVE_SSE) || defined(HAVE_AVX2) || defined(HAVE_AVX512F)
11 #include <immintrin.h>
12 #endif
13 #include <cmath>
14 #include <cassert>
15 #include "ie_parallel.hpp"
16 #include "simple_copy.h"
17
18 namespace InferenceEngine {
19 namespace Extensions {
20 namespace Cpu {
21
22 inline int div_up(const int a, const int b) {
23     assert(b);
24     return (a + b - 1) / b;
25 }
26
27 class ResampleImpl: public ExtLayerBase {
28 public:
29     explicit ResampleImpl(const CNNLayer* layer) {
30         try {
31             if (layer->insData.size() != 1 || layer->outData.empty())
32                 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
33
34             if (layer->insData[0].lock()->dims.size() != 4)
35                 THROW_IE_EXCEPTION << "Resample supports only 4D blobs!";
36
37             type = layer->GetParamAsString("type");
38             antialias = layer->GetParamAsBool("antialias", false);
39
40 #if defined(HAVE_AVX512F)
41             auto blk_layout = ConfLayout::BLK16;
42 #else
43             auto blk_layout = ConfLayout::BLK8;
44 #endif
45             addConfig(layer, {DataConfigurator(ConfLayout::PLN)}, {DataConfigurator(ConfLayout::PLN)});
46             if (type == "caffe.ResampleParameter.NEAREST")
47                 addConfig(layer, {DataConfigurator(blk_layout)}, {DataConfigurator(blk_layout)});
48         } catch (InferenceEngine::details::InferenceEngineException &ex) {
49             errorMsg = ex.what();
50         }
51     }
52
53     StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
54                        ResponseDesc *resp) noexcept override {
55         const auto *src_data = inputs[0]->cbuffer().as<const float *>();
56         auto *dst_data = outputs[0]->buffer().as<float *>();
57 #ifdef WIN32
58 #undef IN
59 #endif
60         Layout layout = inputs[0]->layout();
61         Precision precision = inputs[0]->precision();
62
63         size_t IN = inputs[0]->getTensorDesc().getDims()[0];
64         size_t IC = inputs[0]->getTensorDesc().getDims()[1];
65         size_t IH = inputs[0]->getTensorDesc().getDims()[2];
66         size_t IW = inputs[0]->getTensorDesc().getDims()[3];
67
68         size_t OH = outputs[0]->getTensorDesc().getDims()[2];
69         size_t OW = outputs[0]->getTensorDesc().getDims()[3];
70
71         if (IW == OW && IH == OH && type == "caffe.ResampleParameter.LINEAR") {
72             size_t size = IN * IC * IH * IW;
73             if (inputs[0]->getTensorDesc().getPrecision() == Precision::FP32) {
74                 size *= sizeof(float);
75             }
76             simple_copy(dst_data, outputs[0]->byteSize(), src_data, size);
77             return OK;
78         }
79
80         float fx = static_cast<float>(IW) / static_cast<float>(OW);
81         float fy = static_cast<float>(IH) / static_cast<float>(OH);
82
83         bool isDownsample = (fx > 1) || (fy > 1);
84
85         if (type == "caffe.ResampleParameter.NEAREST") {
86             if (!isDownsample && fx == 0.25f && fy == 0.25f) {
87                 if (layout == NCHW || layout == NHWC) {
88                     if (precision == Precision::FP32) {
89                         Upsample_Nearest_PLN<float, 4>(src_data, dst_data, IN, IC, IH, IW, layout);
90                     } else {
91                         Upsample_Nearest_PLN<uint8_t, 4>(reinterpret_cast<const uint8_t*>(src_data),
92                                                          reinterpret_cast<uint8_t*>(dst_data), IN, IC, IH, IW, layout);
93                     }
94                 } else {
95                     Upsample_Nearest_BLK<4>(src_data, dst_data, IN, IC, IH, IW);
96                 }
97             } else if (!isDownsample && fx == 0.5f && fy == 0.5f) {
98                 if (layout == NCHW || layout == NHWC) {
99                     if (precision == Precision::FP32) {
100                         Upsample_Nearest_PLN<float, 2>(src_data, dst_data, IN, IC, IH, IW, layout);
101                     } else {
102                         Upsample_Nearest_PLN<uint8_t, 2>(reinterpret_cast<const uint8_t*>(src_data),
103                                                          reinterpret_cast<uint8_t*>(dst_data), IN, IC, IH, IW, layout);
104                     }
105                 } else {
106                     Upsample_Nearest_BLK<2>(src_data, dst_data, IN, IC, IH, IW);
107                 }
108             } else {
109                 if (layout == NCHW) {
110                     NearestNeighborKernel_PLN(src_data, dst_data, IN, IC, IH, IW, fx, fy, OH, OW);
111                 } else {
112                     NearestNeighborKernel_BLK(src_data, dst_data, IN, IC, IH, IW, fx, fy, OH, OW);
113                 }
114             }
115         } else if (type == "caffe.ResampleParameter.LINEAR") {
116             size_t kernel_width = 2;
117
118 #if defined(HAVE_SSE) || defined(HAVE_AVX2)
119             if (!isDownsample && fx == 0.25f && fy == 0.25f)
120                 Upsample4x_TriangleInterpolation(src_data, IW, IH, fx, fy, dst_data, OW, OH, IC, IN);
121             else
122 #endif
123                 InterpolationKernel(src_data, IW, IH, fx, fy, dst_data, OW, OH, IC, IN, kernel_width, isDownsample && antialias);
124         }
125         return OK;
126     }
127
128 private:
129     std::string type;
130     bool antialias;
131
132     static inline float triangleCoeff(float x) {
133         return std::max(0.0f, 1 - std::abs(x));
134     }
135
136     static void InterpolationKernel(const float *in_ptr_,
137                                     const size_t iw, const size_t ih,
138                                     const float fx, const float fy,
139                                     float *out_ptr_,
140                                     const size_t ow, const size_t oh, const size_t channels, const size_t batch,
141                                     size_t kernel_width, bool antialias) {
142         for (size_t b = 0; b < batch; b++) {
143             for (size_t c = 0; c < channels; c++) {
144                 const float *in_ptr = in_ptr_ + iw * ih * channels * b + iw * ih * c;
145                 float *out_ptr = out_ptr_ + ow * oh * channels * b + ow * oh * c;
146
147                 for (size_t oy = 0; oy < oh; oy++) {
148                     for (size_t ox = 0; ox < ow; ox++) {
149                         float ix = ox * fx + fy / 2.0f - 0.5f;
150                         float iy = oy * fy + fx / 2.0f - 0.5f;
151
152                         int ix_r = static_cast<int>(round(ix));
153                         int iy_r = static_cast<int>(round(iy));
154
155                         float sum = 0;
156                         float wsum = 0;
157
158                         float ax = 1.0f / (antialias ? fx : 1.0f);
159                         float ay = 1.0f / (antialias ? fy : 1.0f);
160
161                         int rx = (fx < 1.0f) ? 2 : static_cast<int>(ceil(static_cast<float>(kernel_width) / ax));
162                         int ry = (fy < 1.0f) ? 2 : static_cast<int>(ceil(static_cast<float>(kernel_width) / ay));
163
164                         for (int y = iy_r - ry; y <= iy_r + ry; y++) {
165                             for (int x = ix_r - rx; x <= ix_r + rx; x++) {
166                                 if (y < 0 || x < 0 || y >= static_cast<int>(ih) || x >= static_cast<int>(iw))
167                                     continue;
168
169                                 float dx = ix - x;
170                                 float dy = iy - y;
171
172                                 float w = ax * triangleCoeff(ax * dx) * ay * triangleCoeff(ay * dy);
173
174                                 sum += w * in_ptr[y * iw + x];
175                                 wsum += w;
176                             }
177                         }
178
179                         out_ptr[oy * ow + ox] = (!wsum) ? 0 : (sum / wsum);
180                     }
181                 }
182             }
183         }
184     }
185
186     static void NearestNeighborKernel_PLN(const float *in_ptr_, float *out_ptr_, int B, int C, int IH, int IW, float fx, float fy, int OH, int OW) {
187         for (int b = 0; b < B; b++) {
188             for (int c = 0; c < C; c++) {
189                 const float *in_ptr = in_ptr_ + IW * IH * C * b + IW * IH * c;
190                 float *out_ptr = out_ptr_ + OW * OH * C * b + OW * OH * c;
191
192                 for (int oy = 0; oy < OH; oy++) {
193                     for (int ox = 0; ox < OW; ox++) {
194                         float ix = ox * fx + fy / 2.0f - 0.5f;
195                         float iy = oy * fy + fx / 2.0f - 0.5f;
196
197                         size_t ix_r = static_cast<size_t>(round(ix));
198                         size_t iy_r = static_cast<size_t>(round(iy));
199
200                         out_ptr[oy * OW + ox] = in_ptr[iy_r * IW + ix_r];
201                     }
202                 }
203             }
204         }
205     }
206
207     static void NearestNeighborKernel_BLK(const float *in_ptr_, float *out_ptr_, int B, int C, int IH, int IW, float fx, float fy, int OH, int OW) {
208         int blk_size = 8;
209         int CB = div_up(C, blk_size);
210
211         for (int b = 0; b < B; b++) {
212             for (int cb = 0; cb < CB; cb++) {
213                 const float *in_ptr = in_ptr_ + IW * IH * CB * blk_size * b + IW * IH * cb * blk_size;
214                 float *out_ptr = out_ptr_ + OW * OH * CB * blk_size * b + OW * OH * cb * blk_size;
215
216                 for (int oy = 0; oy < OH; oy++) {
217                     for (int ox = 0; ox < OW; ox++) {
218                         float ix = ox * fx + fy / 2.0f - 0.5f;
219                         float iy = oy * fy + fx / 2.0f - 0.5f;
220
221                         size_t ix_r = static_cast<size_t>(round(ix));
222                         size_t iy_r = static_cast<size_t>(round(iy));
223
224                         for (int c = 0; c < blk_size; c++) {
225                             float value = in_ptr[iy_r * IW * blk_size + ix_r * blk_size + c];
226
227                             out_ptr[oy * OW * blk_size + ox * blk_size + c] = value;
228                         }
229                     }
230                 }
231             }
232         }
233     }
234
235     template <typename T, int factor>
236     static void Upsample_Nearest_PLN(const T *in_ptr_, T *out_ptr_, int B, int C, int IH, int IW, Layout layout) {
237         int OH = factor * IH;
238         int OW = factor * IW;
239
240         if (layout == NCHW) {
241             for (int b = 0; b < B; b++) {
242                 for (int c = 0; c < C; c++) {
243                     const T *in_ptr = in_ptr_ + IW * IH * C * b + IW * IH * c;
244                     T *out_ptr = out_ptr_ + OW * OH * C * b + OW * OH * c;
245
246                     for (int iy = 0; iy < IH; iy++) {
247                         for (int ix = 0; ix < IW; ix++) {
248                             int oy = factor * iy;
249                             int ox = factor * ix;
250                             float value = in_ptr[iy * IW + ix];
251
252                             for (int fh = 0; fh < factor; fh++) {
253                                 for (int fw = 0; fw < factor; fw++) {
254                                     out_ptr[(oy + fh) * OW + ox + fw] = static_cast<T>(value);
255                                 }
256                             }
257                         }
258                     }
259                 }
260             }
261         } else {
262             int block_size = C;
263             int block_size_bytes = block_size * sizeof(T);
264
265             int ICIWIH = C * IW * IH;
266             int OWOH = OW * OH;
267             int OCOWOH = C * OWOH;
268
269             int stepX = factor;
270             int stepY = factor;
271
272 #ifdef _OPENMP
273 #pragma omp parallel for collapse(2)
274 #endif
275             for (int mb = 0; mb < B; mb++) {
276                 for (int oh = 0; oh < OH; oh += stepY) {
277                     size_t dst_off = mb * OCOWOH + (oh * OW) * block_size;
278                     size_t src_off = mb * ICIWIH + (oh / stepY * IW) * block_size;
279
280                     for (int ow = 0; ow < OW; ow += stepX) {
281                         size_t dst_off_curr = dst_off + ow * block_size;
282                         size_t src_off_curr = src_off + ow / stepX * block_size;
283
284                         memcpy(&out_ptr_[dst_off_curr], &in_ptr_[src_off_curr], block_size_bytes);
285
286                         for (int owx = 1; owx < stepX; owx++) {
287                             memcpy(&out_ptr_[dst_off_curr + block_size * owx], &in_ptr_[src_off_curr], block_size_bytes);
288                         }
289                     }
290
291                     for (int ohy = 1; ohy < stepY; ohy++) {
292                         memcpy(&out_ptr_[dst_off + OW * block_size * ohy], &out_ptr_[dst_off], block_size_bytes * OW);
293                     }
294                 }
295             }
296         }
297     }
298
299     template <int factor>
300     static void Upsample_Nearest_BLK(const float *in_ptr_, float *out_ptr_, int B, int C, int IH, int IW) {
301 #if defined(HAVE_AVX512F)
302         int blk_size = 16;
303 #else
304         int blk_size = 8;
305 #endif
306
307 #if defined(HAVE_AVX512F)
308         typedef __m512 vec_type;
309 #elif defined(HAVE_AVX2)
310         typedef __m256 vec_type;
311 #endif
312
313         int CB = div_up(C, blk_size);
314
315         int OH = factor * IH;
316         int OW = factor * IW;
317
318         parallel_for2d(B, CB, [&](int b, int cb) {
319 #if defined(HAVE_AVX2) || defined(HAVE_AVX512F)
320                 const float *in_ptr = in_ptr_ + IW * IH * CB * blk_size * b + IW * IH * cb * blk_size;
321                 float *out_ptr = out_ptr_ + OW * OH * CB * blk_size * b + OW * OH * cb * blk_size;
322
323                 for (int iy = 0; iy < IH; iy++) {
324                     for (int ix = 0; ix < IW; ix++) {
325                         int oy = factor * iy;
326                         int ox = factor * ix;
327
328                         vec_type vsrc = _mm_uni_loadu_ps(in_ptr + iy * IW * blk_size + ix * blk_size);
329
330                         for (int fh = 0; fh < factor; fh++) {
331                             for (int fw = 0; fw < factor; fw++) {
332                                 _mm_uni_storeu_ps(out_ptr + (oy + fh) * OW * blk_size + (ox + fw) * blk_size, vsrc);
333                             }
334                         }
335                     }
336                 }
337 #else
338                 const float *in_ptr = in_ptr_ + IW * IH * CB * blk_size * b + IW * IH * cb * blk_size;
339                 float *out_ptr = out_ptr_ + OW * OH * CB * blk_size * b + OW * OH * cb * blk_size;
340
341                 for (int iy = 0; iy < IH; iy++) {
342                     for (int ix = 0; ix < IW; ix++) {
343                         int oy = factor * iy;
344                         int ox = factor * ix;
345
346                         for (int c = 0; c < blk_size; c++) {
347                             float value = in_ptr[iy * IW * blk_size + ix * blk_size + c];
348
349                             for (int fh = 0; fh < factor; fh++) {
350                                 for (int fw = 0; fw < factor; fw++) {
351                                     out_ptr[(oy + fh) * OW * blk_size + (ox + fw) * blk_size + c] = value;
352                                 }
353                             }
354                         }
355                     }
356                 }
357 #endif
358         });
359     }
360
361
362 #if defined(HAVE_SSE) || defined(HAVE_AVX2)
363     static void Upsample4x_TriangleInterpolation(const float *in_ptr_,
364                                                  const size_t iw, const size_t ih,
365                                                  const float fx, const float fy,
366                                                  float *out_ptr_,
367                                                  const size_t ow, const size_t oh, const size_t channels, const size_t batch) {
368     #if defined(HAVE_AVX2)
369         static float table_avx2[4][8*4] = {
370                 {
371                         0.140625f, 0.046875f, 0.046875f, 0.140625f, 0.140625f, 0.046875f, 0.046875f, 0.140625f,
372                         0.234375f, 0.328125f, 0.328125f, 0.234375f, 0.234375f, 0.328125f, 0.328125f, 0.234375f,
373                         0.234375f, 0.078125f, 0.078125f, 0.234375f, 0.234375f, 0.078125f, 0.078125f, 0.234375f,
374                         0.390625f, 0.546875f, 0.546875f, 0.390625f, 0.390625f, 0.546875f, 0.546875f, 0.390625f
375                 },
376                 {
377                         0.046875f, 0.015625f, 0.015625f, 0.046875f, 0.046875f, 0.015625f, 0.015625f, 0.046875f,
378                         0.078125f, 0.109375f, 0.109375f, 0.078125f, 0.078125f, 0.109375f, 0.109375f, 0.078125f,
379                         0.328125f, 0.109375f, 0.109375f, 0.328125f, 0.328125f, 0.109375f, 0.109375f, 0.328125f,
380                         0.546875f, 0.765625f, 0.765625f, 0.546875f, 0.546875f, 0.765625f, 0.765625f, 0.546875f
381                 },
382                 {
383                         0.328125f, 0.109375f, 0.109375f, 0.328125f, 0.328125f, 0.109375f, 0.109375f, 0.328125f,
384                         0.546875f, 0.765625f, 0.765625f, 0.546875f, 0.546875f, 0.765625f, 0.765625f, 0.546875f,
385                         0.046875f, 0.015625f, 0.015625f, 0.046875f, 0.046875f, 0.015625f, 0.015625f, 0.046875f,
386                         0.078125f, 0.109375f, 0.109375f, 0.078125f, 0.078125f, 0.109375f, 0.109375f, 0.078125f
387                 },
388                 {
389                         0.234375f, 0.078125f, 0.078125f, 0.234375f, 0.234375f, 0.078125f, 0.078125f, 0.234375f,
390                         0.390625f, 0.546875f, 0.546875f, 0.390625f, 0.390625f, 0.546875f, 0.546875f, 0.390625f,
391                         0.140625f, 0.046875f, 0.046875f, 0.140625f, 0.140625f, 0.046875f, 0.046875f, 0.140625f,
392                         0.234375f, 0.328125f, 0.328125f, 0.234375f, 0.234375f, 0.328125f, 0.328125f, 0.234375f
393                 }
394         };
395     #endif
396
397     #if defined(HAVE_SSE) || defined(HAVE_AVX2)
398         static float table_sse[4][4*4] = {
399             {
400                 0.140625f, 0.046875f, 0.046875f, 0.140625f,
401                 0.234375f, 0.328125f, 0.328125f, 0.234375f,
402                 0.234375f, 0.078125f, 0.078125f, 0.234375f,
403                 0.390625f, 0.546875f, 0.546875f, 0.390625f
404             },
405             {
406                 0.046875f, 0.015625f, 0.015625f, 0.046875f,
407                 0.078125f, 0.109375f, 0.109375f, 0.078125f,
408                 0.328125f, 0.109375f, 0.109375f, 0.328125f,
409                 0.546875f, 0.765625f, 0.765625f, 0.546875f
410             },
411             {
412                 0.328125f, 0.109375f, 0.109375f, 0.328125f,
413                 0.546875f, 0.765625f, 0.765625f, 0.546875f,
414                 0.046875f, 0.015625f, 0.015625f, 0.046875f,
415                 0.078125f, 0.109375f, 0.109375f, 0.078125f
416             },
417             {
418                 0.234375f, 0.078125f, 0.078125f, 0.234375f,
419                 0.390625f, 0.546875f, 0.546875f, 0.390625f,
420                 0.140625f, 0.046875f, 0.046875f, 0.140625f,
421                 0.234375f, 0.328125f, 0.328125f, 0.234375f
422             }
423         };
424     #endif
425         for (size_t b = 0; b < batch; b++) {
426             for (size_t c = 0; c < channels; c++) {
427                 const float *in_ptr = in_ptr_ + b * channels * iw * ih + c * iw * ih;
428                 float *out_ptr = out_ptr_ + b * channels * ow * oh + c * ow * oh;
429
430                 size_t oy = 0;
431                 {
432                     float iy = oy * fy + fx / 2.0f - 0.5f;
433                     size_t iy_r = static_cast<size_t>(round(iy));
434
435                     size_t ox = 0;
436         #if defined(HAVE_AVX2)
437                     for (; ox <= ow - 8; ox += 8) {
438                         float ix = (ox + 0) * fx + fy / 2.0f - 0.5f;
439                         size_t ix_r = static_cast<size_t>(round(ix));
440
441                         __m256 vx00 = _mm256_setzero_ps();
442                         __m256 vx01 = _mm256_setzero_ps();
443                         __m256 vx02 = _mm256_setzero_ps();
444
445                         __m128 vx10_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r - 1);
446                         __m128 vx11_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 0);
447                         __m128 vx12_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 1);
448                         __m128 vx13_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 2);
449
450                         __m128 vx20_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r - 1);
451                         __m128 vx21_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r + 0);
452                         __m128 vx22_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r + 1);
453                         __m128 vx23_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r + 2);
454
455                         __m256 vx10 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx10_), vx11_, 1);
456                         __m256 vx11 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx11_), vx12_, 1);
457                         __m256 vx12 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx12_), vx13_, 1);
458                         __m256 vx20 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx20_), vx21_, 1);
459                         __m256 vx21 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx21_), vx22_, 1);
460                         __m256 vx22 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx22_), vx23_, 1);
461
462                         for (size_t i = 0; i < 4; i++) {
463                             __m256 vc0 = i < 2 ? _mm256_setzero_ps() : _mm256_loadu_ps(table_avx2[i] + 0);
464                             __m256 vc1 = i < 2 ? _mm256_setzero_ps() : _mm256_loadu_ps(table_avx2[i] + 8);
465                             __m256 vc2 = _mm256_loadu_ps(table_avx2[i] + 16);
466                             __m256 vc3 = _mm256_loadu_ps(table_avx2[i] + 24);
467
468                             if (ox == 0) {
469                                 if (i > 1)
470                                     vc0 = _mm256_insertf128_ps(vc0, _mm_shuffle_ps(_mm_setzero_ps(), _mm256_extractf128_ps(vc0, 0), 0xD0), 0);
471                                 vc2 = _mm256_insertf128_ps(vc2, _mm_shuffle_ps(_mm_setzero_ps(), _mm256_extractf128_ps(vc2, 0), 0xD0), 0);
472                             } else if (ox == ow - 8) {
473                                 if (i > 1)
474                                     vc0 = _mm256_insertf128_ps(vc0, _mm_shuffle_ps(_mm256_extractf128_ps(vc0, 1), _mm_setzero_ps(), 0x07), 1);
475                                 vc2 = _mm256_insertf128_ps(vc2, _mm_shuffle_ps(_mm256_extractf128_ps(vc2, 1), _mm_setzero_ps(), 0x07), 1);
476                             }
477
478                             __m256 vsrc0 = i < 2 ? _mm256_shuffle_ps(vx00, vx02, 0x0) : _mm256_shuffle_ps(vx10, vx12, 0x0);
479                             __m256 vsrc1 = i < 2 ? _mm256_shuffle_ps(vx01, vx01, 0x0) : _mm256_shuffle_ps(vx11, vx11, 0x0);
480                             __m256 vsrc2 = i < 2 ? _mm256_shuffle_ps(vx10, vx12, 0x0) : _mm256_shuffle_ps(vx20, vx22, 0x0);
481                             __m256 vsrc3 = i < 2 ? _mm256_shuffle_ps(vx11, vx11, 0x0) : _mm256_shuffle_ps(vx21, vx21, 0x0);
482
483                             __m256 res = _mm256_setzero_ps();
484
485                             res = _mm256_fmadd_ps(vsrc0, vc0, res);
486                             res = _mm256_fmadd_ps(vsrc1, vc1, res);
487                             res = _mm256_fmadd_ps(vsrc2, vc2, res);
488                             res = _mm256_fmadd_ps(vsrc3, vc3, res);
489                             __m256 wei = _mm256_add_ps(_mm256_add_ps(vc0, vc1), _mm256_add_ps(vc2, vc3));
490
491                             res = _mm256_div_ps(res, wei);
492
493                             _mm256_storeu_ps(out_ptr + (oy + i) * ow + ox, res);
494                         }
495                     }
496         #endif
497
498         #if defined(HAVE_SSE) || defined(HAVE_AVX2)
499                     for (; ox <= ow - 4; ox += 4) {
500                         float ix = (ox + 0) * fx + fy / 2.0f - 0.5f;
501                         size_t ix_r = static_cast<size_t>(round(ix));
502
503                         __m128 vx00 = _mm_setzero_ps();
504                         __m128 vx01 = _mm_setzero_ps();
505                         __m128 vx02 = _mm_setzero_ps();
506
507                         __m128 vx10 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r-1);
508                         __m128 vx11 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r+0);
509                         __m128 vx12 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r+1);
510
511                         __m128 vx20 = _mm_load_ss(in_ptr+(iy_r+1)*iw+ix_r-1);
512                         __m128 vx21 = _mm_load_ss(in_ptr+(iy_r+1)*iw+ix_r+0);
513                         __m128 vx22 = _mm_load_ss(in_ptr+(iy_r+1)*iw+ix_r+1);
514
515                         for (size_t i = 0; i < 4; i++) {
516                             __m128 vc0 = i < 2 ? _mm_setzero_ps() : _mm_loadu_ps(table_sse[i] + 0);
517                             __m128 vc1 = i < 2 ? _mm_setzero_ps() : _mm_loadu_ps(table_sse[i] + 4);
518                             __m128 vc2 = _mm_loadu_ps(table_sse[i] +  8);
519                             __m128 vc3 = _mm_loadu_ps(table_sse[i] + 12);
520
521                             if (ox == 0) {
522                                 if (i > 1)
523                                     vc0 = _mm_shuffle_ps(_mm_setzero_ps(), vc0, 0xD0);
524                                 vc2 = _mm_shuffle_ps(_mm_setzero_ps(), vc2, 0xD0);
525                             } else if (ox == ow - 4) {
526                                 if (i > 1)
527                                     vc0 = _mm_shuffle_ps(vc0, _mm_setzero_ps() , 0x07);
528                                 vc2 = _mm_shuffle_ps(vc2, _mm_setzero_ps() , 0x07);
529                             }
530
531                             __m128 vsrc0 = i < 2 ? _mm_shuffle_ps(vx00, vx02, 0x0) : _mm_shuffle_ps(vx10, vx12, 0x0);
532                             __m128 vsrc1 = i < 2 ? _mm_shuffle_ps(vx01, vx01, 0x0) : _mm_shuffle_ps(vx11, vx11, 0x0);
533                             __m128 vsrc2 = i < 2 ? _mm_shuffle_ps(vx10, vx12, 0x0) : _mm_shuffle_ps(vx20, vx22, 0x0);
534                             __m128 vsrc3 = i < 2 ? _mm_shuffle_ps(vx11, vx11, 0x0) : _mm_shuffle_ps(vx21, vx21, 0x0);
535
536                             __m128 vres0 = _mm_mul_ps(vsrc0, vc0);
537                             __m128 vres1 = _mm_mul_ps(vsrc1, vc1);
538                             __m128 vres2 = _mm_mul_ps(vsrc2, vc2);
539                             __m128 vres3 = _mm_mul_ps(vsrc3, vc3);
540
541                             __m128 res = _mm_add_ps(_mm_add_ps(vres0, vres1), _mm_add_ps(vres2, vres3));
542                             __m128 wei = _mm_add_ps(_mm_add_ps(vc0, vc1), _mm_add_ps(vc2, vc3));
543
544                             res = _mm_div_ps(res, wei);
545
546                             _mm_storeu_ps(out_ptr + (oy+i)*ow + ox, res);
547                         }
548                     }
549         #endif
550                 }
551
552                 for (oy = 4; oy <= oh - 8; oy += 4) {
553                     float iy = oy * fy + fx / 2.0f - 0.5f;
554                     size_t iy_r = static_cast<size_t>(round(iy));
555
556                     size_t ox = 0;
557         #if defined(HAVE_AVX2)
558                     for (; ox <= ow - 8; ox += 8) {
559                         float ix = (ox + 0) * fx + fy / 2.0f - 0.5f;
560                         size_t ix_r = static_cast<size_t>(round(ix));
561
562                         __m128 vx00_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r - 1);
563                         __m128 vx01_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r + 0);
564                         __m128 vx02_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r + 1);
565                         __m128 vx03_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r + 2);
566
567                         __m128 vx10_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r - 1);
568                         __m128 vx11_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 0);
569                         __m128 vx12_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 1);
570                         __m128 vx13_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 2);
571
572                         __m128 vx20_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r - 1);
573                         __m128 vx21_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r + 0);
574                         __m128 vx22_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r + 1);
575                         __m128 vx23_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r + 2);
576
577                         __m256 vx00 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx00_), vx01_, 1);
578                         __m256 vx01 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx01_), vx02_, 1);
579                         __m256 vx02 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx02_), vx03_, 1);
580
581                         __m256 vx10 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx10_), vx11_, 1);
582                         __m256 vx11 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx11_), vx12_, 1);
583                         __m256 vx12 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx12_), vx13_, 1);
584
585                         __m256 vx20 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx20_), vx21_, 1);
586                         __m256 vx21 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx21_), vx22_, 1);
587                         __m256 vx22 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx22_), vx23_, 1);
588
589                         for (size_t i = 0; i < 4; i++) {
590                             __m256 vc0 = _mm256_loadu_ps(table_avx2[i] + 0);
591                             __m256 vc1 = _mm256_loadu_ps(table_avx2[i] + 8);
592                             __m256 vc2 = _mm256_loadu_ps(table_avx2[i] + 16);
593                             __m256 vc3 = _mm256_loadu_ps(table_avx2[i] + 24);
594
595                             if (ox == 0) {
596                                 vc0 = _mm256_insertf128_ps(vc0, _mm_shuffle_ps(_mm_setzero_ps(), _mm256_extractf128_ps(vc0, 0), 0xD0), 0);
597                                 vc2 = _mm256_insertf128_ps(vc2, _mm_shuffle_ps(_mm_setzero_ps(), _mm256_extractf128_ps(vc2, 0), 0xD0), 0);
598                             } else if (ox == ow - 8) {
599                                 vc0 = _mm256_insertf128_ps(vc0, _mm_shuffle_ps(_mm256_extractf128_ps(vc0, 1), _mm_setzero_ps(), 0x07), 1);
600                                 vc2 = _mm256_insertf128_ps(vc2, _mm_shuffle_ps(_mm256_extractf128_ps(vc2, 1), _mm_setzero_ps(), 0x07), 1);
601                             }
602
603                             __m256 vsrc0 = i < 2 ? _mm256_shuffle_ps(vx00, vx02, 0x0) : _mm256_shuffle_ps(vx10, vx12, 0x0);
604                             __m256 vsrc1 = i < 2 ? _mm256_shuffle_ps(vx01, vx01, 0x0) : _mm256_shuffle_ps(vx11, vx11, 0x0);
605                             __m256 vsrc2 = i < 2 ? _mm256_shuffle_ps(vx10, vx12, 0x0) : _mm256_shuffle_ps(vx20, vx22, 0x0);
606                             __m256 vsrc3 = i < 2 ? _mm256_shuffle_ps(vx11, vx11, 0x0) : _mm256_shuffle_ps(vx21, vx21, 0x0);
607
608                             __m256 res = _mm256_setzero_ps();
609
610                             res = _mm256_fmadd_ps(vsrc0, vc0, res);
611                             res = _mm256_fmadd_ps(vsrc1, vc1, res);
612                             res = _mm256_fmadd_ps(vsrc2, vc2, res);
613                             res = _mm256_fmadd_ps(vsrc3, vc3, res);
614
615                             if (ox == 0 || ox == ow - 8) {
616                                 __m256 wei = _mm256_add_ps(_mm256_add_ps(vc0, vc1), _mm256_add_ps(vc2, vc3));
617
618                                 res = _mm256_div_ps(res, wei);
619                             }
620
621                             _mm256_storeu_ps(out_ptr + (oy + i) * ow + ox, res);
622                         }
623                     }
624         #endif
625
626         #if defined(HAVE_SSE) || defined(HAVE_AVX2)
627                     for (; ox <= ow - 4; ox += 4) {
628                         float ix = (ox + 0) * fx + fy / 2.0f - 0.5f;
629                         size_t ix_r = static_cast<size_t>(round(ix));
630
631                         __m128 vx00 = _mm_load_ss(in_ptr+(iy_r-1)*iw+ix_r-1);
632                         __m128 vx01 = _mm_load_ss(in_ptr+(iy_r-1)*iw+ix_r+0);
633                         __m128 vx02 = _mm_load_ss(in_ptr+(iy_r-1)*iw+ix_r+1);
634
635                         __m128 vx10 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r-1);
636                         __m128 vx11 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r+0);
637                         __m128 vx12 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r+1);
638
639                         __m128 vx20 = _mm_load_ss(in_ptr+(iy_r+1)*iw+ix_r-1);
640                         __m128 vx21 = _mm_load_ss(in_ptr+(iy_r+1)*iw+ix_r+0);
641                         __m128 vx22 = _mm_load_ss(in_ptr+(iy_r+1)*iw+ix_r+1);
642
643                         for (size_t i = 0; i < 4; i++) {
644                             __m128 vc0 = _mm_loadu_ps(table_sse[i] +  0);
645                             __m128 vc1 = _mm_loadu_ps(table_sse[i] +  4);
646                             __m128 vc2 = _mm_loadu_ps(table_sse[i] +  8);
647                             __m128 vc3 = _mm_loadu_ps(table_sse[i] + 12);
648
649                             if (ox == 0) {
650                                 vc0 = _mm_shuffle_ps(_mm_setzero_ps(), vc0, 0xD0);
651                                 vc2 = _mm_shuffle_ps(_mm_setzero_ps(), vc2, 0xD0);
652                             } else if (ox == ow - 4) {
653                                 vc0 = _mm_shuffle_ps(vc0, _mm_setzero_ps() , 0x07);
654                                 vc2 = _mm_shuffle_ps(vc2, _mm_setzero_ps() , 0x07);
655                             }
656
657                             __m128 vsrc0 = i < 2 ? _mm_shuffle_ps(vx00, vx02, 0x0) : _mm_shuffle_ps(vx10, vx12, 0x0);
658                             __m128 vsrc1 = i < 2 ? _mm_shuffle_ps(vx01, vx01, 0x0) : _mm_shuffle_ps(vx11, vx11, 0x0);
659                             __m128 vsrc2 = i < 2 ? _mm_shuffle_ps(vx10, vx12, 0x0) : _mm_shuffle_ps(vx20, vx22, 0x0);
660                             __m128 vsrc3 = i < 2 ? _mm_shuffle_ps(vx11, vx11, 0x0) : _mm_shuffle_ps(vx21, vx21, 0x0);
661
662                             __m128 vres0 = _mm_mul_ps(vsrc0, vc0);
663                             __m128 vres1 = _mm_mul_ps(vsrc1, vc1);
664                             __m128 vres2 = _mm_mul_ps(vsrc2, vc2);
665                             __m128 vres3 = _mm_mul_ps(vsrc3, vc3);
666
667                             __m128 res = _mm_add_ps(_mm_add_ps(vres0, vres1), _mm_add_ps(vres2, vres3));
668                             if (ox == 0 || ox == ow - 4) {
669                                 __m128 wei = _mm_add_ps(_mm_add_ps(vc0, vc1), _mm_add_ps(vc2, vc3));
670
671                                 res = _mm_div_ps(res, wei);
672                             }
673
674                             _mm_storeu_ps(out_ptr + (oy+i)*ow + ox, res);
675                         }
676                     }
677         #endif
678                 }
679
680                 oy = oh - 4;
681                 {
682                     float iy = oy * fy + fx / 2.0f - 0.5f;
683                     size_t iy_r = static_cast<size_t>(round(iy));
684
685                     size_t ox = 0;
686
687         #if defined(HAVE_AVX2)
688                     for (; ox <= ow - 8; ox += 8) {
689                         float ix = (ox + 0) * fx + fy / 2.0f - 0.5f;
690                         size_t ix_r = static_cast<size_t>(round(ix));
691
692                         __m128 vx00_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r - 1);
693                         __m128 vx01_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r + 0);
694                         __m128 vx02_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r + 1);
695                         __m128 vx03_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r + 2);
696
697                         __m128 vx10_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r - 1);
698                         __m128 vx11_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 0);
699                         __m128 vx12_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 1);
700                         __m128 vx13_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 2);
701
702                         __m256 vx00 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx00_), vx01_, 1);
703                         __m256 vx01 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx01_), vx02_, 1);
704                         __m256 vx02 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx02_), vx03_, 1);
705
706                         __m256 vx10 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx10_), vx11_, 1);
707                         __m256 vx11 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx11_), vx12_, 1);
708                         __m256 vx12 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx12_), vx13_, 1);
709
710                         __m256 vx20 = _mm256_setzero_ps();
711                         __m256 vx21 = _mm256_setzero_ps();
712                         __m256 vx22 = _mm256_setzero_ps();
713
714                         for (size_t i = 0; i < 4; i++) {
715                             __m256 vc0 = _mm256_loadu_ps(table_avx2[i] + 0);
716                             __m256 vc1 = _mm256_loadu_ps(table_avx2[i] + 8);
717                             __m256 vc2 = i < 2 ? _mm256_loadu_ps(table_avx2[i] + 16) : _mm256_setzero_ps();
718                             __m256 vc3 = i < 2 ? _mm256_loadu_ps(table_avx2[i] + 24) : _mm256_setzero_ps();
719
720                             if (ox == 0) {
721                                 vc0 = _mm256_insertf128_ps(vc0, _mm_shuffle_ps(_mm_setzero_ps(), _mm256_extractf128_ps(vc0, 0), 0xD0), 0);
722                                 if (i < 2)
723                                     vc2 = _mm256_insertf128_ps(vc2, _mm_shuffle_ps(_mm_setzero_ps(), _mm256_extractf128_ps(vc2, 0), 0xD0), 0);
724                             } else if (ox == ow - 8) {
725                                 vc0 = _mm256_insertf128_ps(vc0, _mm_shuffle_ps(_mm256_extractf128_ps(vc0, 1), _mm_setzero_ps(), 0x07), 1);
726                                 if (i < 2)
727                                     vc2 = _mm256_insertf128_ps(vc2, _mm_shuffle_ps(_mm256_extractf128_ps(vc2, 1), _mm_setzero_ps(), 0x07), 1);
728                             }
729
730                             __m256 vsrc0 = i < 2 ? _mm256_shuffle_ps(vx00, vx02, 0x0) : _mm256_shuffle_ps(vx10, vx12, 0x0);
731                             __m256 vsrc1 = i < 2 ? _mm256_shuffle_ps(vx01, vx01, 0x0) : _mm256_shuffle_ps(vx11, vx11, 0x0);
732                             __m256 vsrc2 = i < 2 ? _mm256_shuffle_ps(vx10, vx12, 0x0) : _mm256_shuffle_ps(vx20, vx22, 0x0);
733                             __m256 vsrc3 = i < 2 ? _mm256_shuffle_ps(vx11, vx11, 0x0) : _mm256_shuffle_ps(vx21, vx21, 0x0);
734
735                             __m256 res = _mm256_setzero_ps();
736
737                             res = _mm256_fmadd_ps(vsrc0, vc0, res);
738                             res = _mm256_fmadd_ps(vsrc1, vc1, res);
739                             res = _mm256_fmadd_ps(vsrc2, vc2, res);
740                             res = _mm256_fmadd_ps(vsrc3, vc3, res);
741
742                             __m256 wei = _mm256_add_ps(_mm256_add_ps(vc0, vc1), _mm256_add_ps(vc2, vc3));
743
744                             res = _mm256_div_ps(res, wei);
745
746                             _mm256_storeu_ps(out_ptr + (oy + i) * ow + ox, res);
747                         }
748                     }
749         #endif
750
751         #if defined(HAVE_SSE) || defined(HAVE_AVX2)
752                     for (; ox <= ow - 4; ox += 4) {
753                         float ix = (ox + 0) * fx + fy / 2.0f - 0.5f;
754                         size_t ix_r = static_cast<size_t>(round(ix));
755
756                         __m128 vx00 = _mm_load_ss(in_ptr+(iy_r-1)*iw+ix_r-1);
757                         __m128 vx01 = _mm_load_ss(in_ptr+(iy_r-1)*iw+ix_r+0);
758                         __m128 vx02 = _mm_load_ss(in_ptr+(iy_r-1)*iw+ix_r+1);
759
760                         __m128 vx10 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r-1);
761                         __m128 vx11 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r+0);
762                         __m128 vx12 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r+1);
763
764                         __m128 vx20 = _mm_setzero_ps();
765                         __m128 vx21 = _mm_setzero_ps();
766                         __m128 vx22 = _mm_setzero_ps();
767
768                         for (size_t i = 0; i < 4; i++) {
769                             __m128 vc0 = _mm_loadu_ps(table_sse[i] +  0);
770                             __m128 vc1 = _mm_loadu_ps(table_sse[i] +  4);
771                             __m128 vc2 = i < 2 ?_mm_loadu_ps(table_sse[i] +  8) : _mm_setzero_ps();
772                             __m128 vc3 = i < 2 ?_mm_loadu_ps(table_sse[i] + 12) : _mm_setzero_ps();
773
774                             if (ox == 0) {
775                                 vc0 = _mm_shuffle_ps(_mm_setzero_ps(), vc0, 0xD0);
776                                 if (i < 2)
777                                     vc2 = _mm_shuffle_ps(_mm_setzero_ps(), vc2, 0xD0);
778                             } else if (ox == ow - 4) {
779                                 vc0 = _mm_shuffle_ps(vc0, _mm_setzero_ps() , 0x07);
780                                 if (i < 2)
781                                     vc2 = _mm_shuffle_ps(vc2, _mm_setzero_ps() , 0x07);
782                             }
783
784                             __m128 vsrc0 = i < 2 ? _mm_shuffle_ps(vx00, vx02, 0x0) : _mm_shuffle_ps(vx10, vx12, 0x0);
785                             __m128 vsrc1 = i < 2 ? _mm_shuffle_ps(vx01, vx01, 0x0) : _mm_shuffle_ps(vx11, vx11, 0x0);
786                             __m128 vsrc2 = i < 2 ? _mm_shuffle_ps(vx10, vx12, 0x0) : _mm_shuffle_ps(vx20, vx22, 0x0);
787                             __m128 vsrc3 = i < 2 ? _mm_shuffle_ps(vx11, vx11, 0x0) : _mm_shuffle_ps(vx21, vx21, 0x0);
788
789                             __m128 vres0 = _mm_mul_ps(vsrc0, vc0);
790                             __m128 vres1 = _mm_mul_ps(vsrc1, vc1);
791                             __m128 vres2 = _mm_mul_ps(vsrc2, vc2);
792                             __m128 vres3 = _mm_mul_ps(vsrc3, vc3);
793
794                             __m128 res = _mm_add_ps(_mm_add_ps(vres0, vres1), _mm_add_ps(vres2, vres3));
795                             __m128 wei = _mm_add_ps(_mm_add_ps(vc0, vc1), _mm_add_ps(vc2, vc3));
796
797                             res = _mm_div_ps(res, wei);
798
799                             _mm_storeu_ps(out_ptr + (oy+i)*ow + ox, res);
800                         }
801                     }
802         #endif
803                 }
804             }
805         }
806     }
807 #endif  // defined(HAVE_SSE) || defined(HAVE_AVX2)
808 };
809
810 REG_FACTORY_FOR(ImplFactory<ResampleImpl>, Resample);
811
812 }  // namespace Cpu
813 }  // namespace Extensions
814 }  // namespace InferenceEngine