1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
10 #if defined(HAVE_SSE) || defined(HAVE_AVX2) || defined(HAVE_AVX512F)
11 #include <immintrin.h>
15 #include "ie_parallel.hpp"
16 #include "simple_copy.h"
18 namespace InferenceEngine {
19 namespace Extensions {
22 inline int div_up(const int a, const int b) {
24 return (a + b - 1) / b;
27 class ResampleImpl: public ExtLayerBase {
29 explicit ResampleImpl(const CNNLayer* layer) {
31 if (layer->insData.size() != 1 || layer->outData.empty())
32 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
34 if (layer->insData[0].lock()->dims.size() != 4)
35 THROW_IE_EXCEPTION << "Resample supports only 4D blobs!";
37 type = layer->GetParamAsString("type");
38 antialias = layer->GetParamAsBool("antialias", false);
40 #if defined(HAVE_AVX512F)
41 auto blk_layout = ConfLayout::BLK16;
43 auto blk_layout = ConfLayout::BLK8;
45 addConfig(layer, {DataConfigurator(ConfLayout::PLN)}, {DataConfigurator(ConfLayout::PLN)});
46 if (type == "caffe.ResampleParameter.NEAREST")
47 addConfig(layer, {DataConfigurator(blk_layout)}, {DataConfigurator(blk_layout)});
48 } catch (InferenceEngine::details::InferenceEngineException &ex) {
53 StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
54 ResponseDesc *resp) noexcept override {
55 const auto *src_data = inputs[0]->cbuffer().as<const float *>();
56 auto *dst_data = outputs[0]->buffer().as<float *>();
60 Layout layout = inputs[0]->layout();
61 Precision precision = inputs[0]->precision();
63 size_t IN = inputs[0]->getTensorDesc().getDims()[0];
64 size_t IC = inputs[0]->getTensorDesc().getDims()[1];
65 size_t IH = inputs[0]->getTensorDesc().getDims()[2];
66 size_t IW = inputs[0]->getTensorDesc().getDims()[3];
68 size_t OH = outputs[0]->getTensorDesc().getDims()[2];
69 size_t OW = outputs[0]->getTensorDesc().getDims()[3];
71 if (IW == OW && IH == OH && type == "caffe.ResampleParameter.LINEAR") {
72 size_t size = IN * IC * IH * IW;
73 if (inputs[0]->getTensorDesc().getPrecision() == Precision::FP32) {
74 size *= sizeof(float);
76 simple_copy(dst_data, outputs[0]->byteSize(), src_data, size);
80 float fx = static_cast<float>(IW) / static_cast<float>(OW);
81 float fy = static_cast<float>(IH) / static_cast<float>(OH);
83 bool isDownsample = (fx > 1) || (fy > 1);
85 if (type == "caffe.ResampleParameter.NEAREST") {
86 if (!isDownsample && fx == 0.25f && fy == 0.25f) {
87 if (layout == NCHW || layout == NHWC) {
88 if (precision == Precision::FP32) {
89 Upsample_Nearest_PLN<float, 4>(src_data, dst_data, IN, IC, IH, IW, layout);
91 Upsample_Nearest_PLN<uint8_t, 4>(reinterpret_cast<const uint8_t*>(src_data),
92 reinterpret_cast<uint8_t*>(dst_data), IN, IC, IH, IW, layout);
95 Upsample_Nearest_BLK<4>(src_data, dst_data, IN, IC, IH, IW);
97 } else if (!isDownsample && fx == 0.5f && fy == 0.5f) {
98 if (layout == NCHW || layout == NHWC) {
99 if (precision == Precision::FP32) {
100 Upsample_Nearest_PLN<float, 2>(src_data, dst_data, IN, IC, IH, IW, layout);
102 Upsample_Nearest_PLN<uint8_t, 2>(reinterpret_cast<const uint8_t*>(src_data),
103 reinterpret_cast<uint8_t*>(dst_data), IN, IC, IH, IW, layout);
106 Upsample_Nearest_BLK<2>(src_data, dst_data, IN, IC, IH, IW);
109 if (layout == NCHW) {
110 NearestNeighborKernel_PLN(src_data, dst_data, IN, IC, IH, IW, fx, fy, OH, OW);
112 NearestNeighborKernel_BLK(src_data, dst_data, IN, IC, IH, IW, fx, fy, OH, OW);
115 } else if (type == "caffe.ResampleParameter.LINEAR") {
116 size_t kernel_width = 2;
118 #if defined(HAVE_SSE) || defined(HAVE_AVX2)
119 if (!isDownsample && fx == 0.25f && fy == 0.25f)
120 Upsample4x_TriangleInterpolation(src_data, IW, IH, fx, fy, dst_data, OW, OH, IC, IN);
123 InterpolationKernel(src_data, IW, IH, fx, fy, dst_data, OW, OH, IC, IN, kernel_width, isDownsample && antialias);
132 static inline float triangleCoeff(float x) {
133 return std::max(0.0f, 1 - std::abs(x));
136 static void InterpolationKernel(const float *in_ptr_,
137 const size_t iw, const size_t ih,
138 const float fx, const float fy,
140 const size_t ow, const size_t oh, const size_t channels, const size_t batch,
141 size_t kernel_width, bool antialias) {
142 for (size_t b = 0; b < batch; b++) {
143 for (size_t c = 0; c < channels; c++) {
144 const float *in_ptr = in_ptr_ + iw * ih * channels * b + iw * ih * c;
145 float *out_ptr = out_ptr_ + ow * oh * channels * b + ow * oh * c;
147 for (size_t oy = 0; oy < oh; oy++) {
148 for (size_t ox = 0; ox < ow; ox++) {
149 float ix = ox * fx + fy / 2.0f - 0.5f;
150 float iy = oy * fy + fx / 2.0f - 0.5f;
152 int ix_r = static_cast<int>(round(ix));
153 int iy_r = static_cast<int>(round(iy));
158 float ax = 1.0f / (antialias ? fx : 1.0f);
159 float ay = 1.0f / (antialias ? fy : 1.0f);
161 int rx = (fx < 1.0f) ? 2 : static_cast<int>(ceil(static_cast<float>(kernel_width) / ax));
162 int ry = (fy < 1.0f) ? 2 : static_cast<int>(ceil(static_cast<float>(kernel_width) / ay));
164 for (int y = iy_r - ry; y <= iy_r + ry; y++) {
165 for (int x = ix_r - rx; x <= ix_r + rx; x++) {
166 if (y < 0 || x < 0 || y >= static_cast<int>(ih) || x >= static_cast<int>(iw))
172 float w = ax * triangleCoeff(ax * dx) * ay * triangleCoeff(ay * dy);
174 sum += w * in_ptr[y * iw + x];
179 out_ptr[oy * ow + ox] = (!wsum) ? 0 : (sum / wsum);
186 static void NearestNeighborKernel_PLN(const float *in_ptr_, float *out_ptr_, int B, int C, int IH, int IW, float fx, float fy, int OH, int OW) {
187 for (int b = 0; b < B; b++) {
188 for (int c = 0; c < C; c++) {
189 const float *in_ptr = in_ptr_ + IW * IH * C * b + IW * IH * c;
190 float *out_ptr = out_ptr_ + OW * OH * C * b + OW * OH * c;
192 for (int oy = 0; oy < OH; oy++) {
193 for (int ox = 0; ox < OW; ox++) {
194 float ix = ox * fx + fy / 2.0f - 0.5f;
195 float iy = oy * fy + fx / 2.0f - 0.5f;
197 size_t ix_r = static_cast<size_t>(round(ix));
198 size_t iy_r = static_cast<size_t>(round(iy));
200 out_ptr[oy * OW + ox] = in_ptr[iy_r * IW + ix_r];
207 static void NearestNeighborKernel_BLK(const float *in_ptr_, float *out_ptr_, int B, int C, int IH, int IW, float fx, float fy, int OH, int OW) {
209 int CB = div_up(C, blk_size);
211 for (int b = 0; b < B; b++) {
212 for (int cb = 0; cb < CB; cb++) {
213 const float *in_ptr = in_ptr_ + IW * IH * CB * blk_size * b + IW * IH * cb * blk_size;
214 float *out_ptr = out_ptr_ + OW * OH * CB * blk_size * b + OW * OH * cb * blk_size;
216 for (int oy = 0; oy < OH; oy++) {
217 for (int ox = 0; ox < OW; ox++) {
218 float ix = ox * fx + fy / 2.0f - 0.5f;
219 float iy = oy * fy + fx / 2.0f - 0.5f;
221 size_t ix_r = static_cast<size_t>(round(ix));
222 size_t iy_r = static_cast<size_t>(round(iy));
224 for (int c = 0; c < blk_size; c++) {
225 float value = in_ptr[iy_r * IW * blk_size + ix_r * blk_size + c];
227 out_ptr[oy * OW * blk_size + ox * blk_size + c] = value;
235 template <typename T, int factor>
236 static void Upsample_Nearest_PLN(const T *in_ptr_, T *out_ptr_, int B, int C, int IH, int IW, Layout layout) {
237 int OH = factor * IH;
238 int OW = factor * IW;
240 if (layout == NCHW) {
241 for (int b = 0; b < B; b++) {
242 for (int c = 0; c < C; c++) {
243 const T *in_ptr = in_ptr_ + IW * IH * C * b + IW * IH * c;
244 T *out_ptr = out_ptr_ + OW * OH * C * b + OW * OH * c;
246 for (int iy = 0; iy < IH; iy++) {
247 for (int ix = 0; ix < IW; ix++) {
248 int oy = factor * iy;
249 int ox = factor * ix;
250 float value = in_ptr[iy * IW + ix];
252 for (int fh = 0; fh < factor; fh++) {
253 for (int fw = 0; fw < factor; fw++) {
254 out_ptr[(oy + fh) * OW + ox + fw] = static_cast<T>(value);
263 int block_size_bytes = block_size * sizeof(T);
265 int ICIWIH = C * IW * IH;
267 int OCOWOH = C * OWOH;
273 #pragma omp parallel for collapse(2)
275 for (int mb = 0; mb < B; mb++) {
276 for (int oh = 0; oh < OH; oh += stepY) {
277 size_t dst_off = mb * OCOWOH + (oh * OW) * block_size;
278 size_t src_off = mb * ICIWIH + (oh / stepY * IW) * block_size;
280 for (int ow = 0; ow < OW; ow += stepX) {
281 size_t dst_off_curr = dst_off + ow * block_size;
282 size_t src_off_curr = src_off + ow / stepX * block_size;
284 memcpy(&out_ptr_[dst_off_curr], &in_ptr_[src_off_curr], block_size_bytes);
286 for (int owx = 1; owx < stepX; owx++) {
287 memcpy(&out_ptr_[dst_off_curr + block_size * owx], &in_ptr_[src_off_curr], block_size_bytes);
291 for (int ohy = 1; ohy < stepY; ohy++) {
292 memcpy(&out_ptr_[dst_off + OW * block_size * ohy], &out_ptr_[dst_off], block_size_bytes * OW);
299 template <int factor>
300 static void Upsample_Nearest_BLK(const float *in_ptr_, float *out_ptr_, int B, int C, int IH, int IW) {
301 #if defined(HAVE_AVX512F)
307 #if defined(HAVE_AVX512F)
308 typedef __m512 vec_type;
309 #elif defined(HAVE_AVX2)
310 typedef __m256 vec_type;
313 int CB = div_up(C, blk_size);
315 int OH = factor * IH;
316 int OW = factor * IW;
318 parallel_for2d(B, CB, [&](int b, int cb) {
319 #if defined(HAVE_AVX2) || defined(HAVE_AVX512F)
320 const float *in_ptr = in_ptr_ + IW * IH * CB * blk_size * b + IW * IH * cb * blk_size;
321 float *out_ptr = out_ptr_ + OW * OH * CB * blk_size * b + OW * OH * cb * blk_size;
323 for (int iy = 0; iy < IH; iy++) {
324 for (int ix = 0; ix < IW; ix++) {
325 int oy = factor * iy;
326 int ox = factor * ix;
328 vec_type vsrc = _mm_uni_loadu_ps(in_ptr + iy * IW * blk_size + ix * blk_size);
330 for (int fh = 0; fh < factor; fh++) {
331 for (int fw = 0; fw < factor; fw++) {
332 _mm_uni_storeu_ps(out_ptr + (oy + fh) * OW * blk_size + (ox + fw) * blk_size, vsrc);
338 const float *in_ptr = in_ptr_ + IW * IH * CB * blk_size * b + IW * IH * cb * blk_size;
339 float *out_ptr = out_ptr_ + OW * OH * CB * blk_size * b + OW * OH * cb * blk_size;
341 for (int iy = 0; iy < IH; iy++) {
342 for (int ix = 0; ix < IW; ix++) {
343 int oy = factor * iy;
344 int ox = factor * ix;
346 for (int c = 0; c < blk_size; c++) {
347 float value = in_ptr[iy * IW * blk_size + ix * blk_size + c];
349 for (int fh = 0; fh < factor; fh++) {
350 for (int fw = 0; fw < factor; fw++) {
351 out_ptr[(oy + fh) * OW * blk_size + (ox + fw) * blk_size + c] = value;
362 #if defined(HAVE_SSE) || defined(HAVE_AVX2)
363 static void Upsample4x_TriangleInterpolation(const float *in_ptr_,
364 const size_t iw, const size_t ih,
365 const float fx, const float fy,
367 const size_t ow, const size_t oh, const size_t channels, const size_t batch) {
368 #if defined(HAVE_AVX2)
369 static float table_avx2[4][8*4] = {
371 0.140625f, 0.046875f, 0.046875f, 0.140625f, 0.140625f, 0.046875f, 0.046875f, 0.140625f,
372 0.234375f, 0.328125f, 0.328125f, 0.234375f, 0.234375f, 0.328125f, 0.328125f, 0.234375f,
373 0.234375f, 0.078125f, 0.078125f, 0.234375f, 0.234375f, 0.078125f, 0.078125f, 0.234375f,
374 0.390625f, 0.546875f, 0.546875f, 0.390625f, 0.390625f, 0.546875f, 0.546875f, 0.390625f
377 0.046875f, 0.015625f, 0.015625f, 0.046875f, 0.046875f, 0.015625f, 0.015625f, 0.046875f,
378 0.078125f, 0.109375f, 0.109375f, 0.078125f, 0.078125f, 0.109375f, 0.109375f, 0.078125f,
379 0.328125f, 0.109375f, 0.109375f, 0.328125f, 0.328125f, 0.109375f, 0.109375f, 0.328125f,
380 0.546875f, 0.765625f, 0.765625f, 0.546875f, 0.546875f, 0.765625f, 0.765625f, 0.546875f
383 0.328125f, 0.109375f, 0.109375f, 0.328125f, 0.328125f, 0.109375f, 0.109375f, 0.328125f,
384 0.546875f, 0.765625f, 0.765625f, 0.546875f, 0.546875f, 0.765625f, 0.765625f, 0.546875f,
385 0.046875f, 0.015625f, 0.015625f, 0.046875f, 0.046875f, 0.015625f, 0.015625f, 0.046875f,
386 0.078125f, 0.109375f, 0.109375f, 0.078125f, 0.078125f, 0.109375f, 0.109375f, 0.078125f
389 0.234375f, 0.078125f, 0.078125f, 0.234375f, 0.234375f, 0.078125f, 0.078125f, 0.234375f,
390 0.390625f, 0.546875f, 0.546875f, 0.390625f, 0.390625f, 0.546875f, 0.546875f, 0.390625f,
391 0.140625f, 0.046875f, 0.046875f, 0.140625f, 0.140625f, 0.046875f, 0.046875f, 0.140625f,
392 0.234375f, 0.328125f, 0.328125f, 0.234375f, 0.234375f, 0.328125f, 0.328125f, 0.234375f
397 #if defined(HAVE_SSE) || defined(HAVE_AVX2)
398 static float table_sse[4][4*4] = {
400 0.140625f, 0.046875f, 0.046875f, 0.140625f,
401 0.234375f, 0.328125f, 0.328125f, 0.234375f,
402 0.234375f, 0.078125f, 0.078125f, 0.234375f,
403 0.390625f, 0.546875f, 0.546875f, 0.390625f
406 0.046875f, 0.015625f, 0.015625f, 0.046875f,
407 0.078125f, 0.109375f, 0.109375f, 0.078125f,
408 0.328125f, 0.109375f, 0.109375f, 0.328125f,
409 0.546875f, 0.765625f, 0.765625f, 0.546875f
412 0.328125f, 0.109375f, 0.109375f, 0.328125f,
413 0.546875f, 0.765625f, 0.765625f, 0.546875f,
414 0.046875f, 0.015625f, 0.015625f, 0.046875f,
415 0.078125f, 0.109375f, 0.109375f, 0.078125f
418 0.234375f, 0.078125f, 0.078125f, 0.234375f,
419 0.390625f, 0.546875f, 0.546875f, 0.390625f,
420 0.140625f, 0.046875f, 0.046875f, 0.140625f,
421 0.234375f, 0.328125f, 0.328125f, 0.234375f
425 for (size_t b = 0; b < batch; b++) {
426 for (size_t c = 0; c < channels; c++) {
427 const float *in_ptr = in_ptr_ + b * channels * iw * ih + c * iw * ih;
428 float *out_ptr = out_ptr_ + b * channels * ow * oh + c * ow * oh;
432 float iy = oy * fy + fx / 2.0f - 0.5f;
433 size_t iy_r = static_cast<size_t>(round(iy));
436 #if defined(HAVE_AVX2)
437 for (; ox <= ow - 8; ox += 8) {
438 float ix = (ox + 0) * fx + fy / 2.0f - 0.5f;
439 size_t ix_r = static_cast<size_t>(round(ix));
441 __m256 vx00 = _mm256_setzero_ps();
442 __m256 vx01 = _mm256_setzero_ps();
443 __m256 vx02 = _mm256_setzero_ps();
445 __m128 vx10_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r - 1);
446 __m128 vx11_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 0);
447 __m128 vx12_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 1);
448 __m128 vx13_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 2);
450 __m128 vx20_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r - 1);
451 __m128 vx21_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r + 0);
452 __m128 vx22_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r + 1);
453 __m128 vx23_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r + 2);
455 __m256 vx10 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx10_), vx11_, 1);
456 __m256 vx11 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx11_), vx12_, 1);
457 __m256 vx12 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx12_), vx13_, 1);
458 __m256 vx20 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx20_), vx21_, 1);
459 __m256 vx21 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx21_), vx22_, 1);
460 __m256 vx22 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx22_), vx23_, 1);
462 for (size_t i = 0; i < 4; i++) {
463 __m256 vc0 = i < 2 ? _mm256_setzero_ps() : _mm256_loadu_ps(table_avx2[i] + 0);
464 __m256 vc1 = i < 2 ? _mm256_setzero_ps() : _mm256_loadu_ps(table_avx2[i] + 8);
465 __m256 vc2 = _mm256_loadu_ps(table_avx2[i] + 16);
466 __m256 vc3 = _mm256_loadu_ps(table_avx2[i] + 24);
470 vc0 = _mm256_insertf128_ps(vc0, _mm_shuffle_ps(_mm_setzero_ps(), _mm256_extractf128_ps(vc0, 0), 0xD0), 0);
471 vc2 = _mm256_insertf128_ps(vc2, _mm_shuffle_ps(_mm_setzero_ps(), _mm256_extractf128_ps(vc2, 0), 0xD0), 0);
472 } else if (ox == ow - 8) {
474 vc0 = _mm256_insertf128_ps(vc0, _mm_shuffle_ps(_mm256_extractf128_ps(vc0, 1), _mm_setzero_ps(), 0x07), 1);
475 vc2 = _mm256_insertf128_ps(vc2, _mm_shuffle_ps(_mm256_extractf128_ps(vc2, 1), _mm_setzero_ps(), 0x07), 1);
478 __m256 vsrc0 = i < 2 ? _mm256_shuffle_ps(vx00, vx02, 0x0) : _mm256_shuffle_ps(vx10, vx12, 0x0);
479 __m256 vsrc1 = i < 2 ? _mm256_shuffle_ps(vx01, vx01, 0x0) : _mm256_shuffle_ps(vx11, vx11, 0x0);
480 __m256 vsrc2 = i < 2 ? _mm256_shuffle_ps(vx10, vx12, 0x0) : _mm256_shuffle_ps(vx20, vx22, 0x0);
481 __m256 vsrc3 = i < 2 ? _mm256_shuffle_ps(vx11, vx11, 0x0) : _mm256_shuffle_ps(vx21, vx21, 0x0);
483 __m256 res = _mm256_setzero_ps();
485 res = _mm256_fmadd_ps(vsrc0, vc0, res);
486 res = _mm256_fmadd_ps(vsrc1, vc1, res);
487 res = _mm256_fmadd_ps(vsrc2, vc2, res);
488 res = _mm256_fmadd_ps(vsrc3, vc3, res);
489 __m256 wei = _mm256_add_ps(_mm256_add_ps(vc0, vc1), _mm256_add_ps(vc2, vc3));
491 res = _mm256_div_ps(res, wei);
493 _mm256_storeu_ps(out_ptr + (oy + i) * ow + ox, res);
498 #if defined(HAVE_SSE) || defined(HAVE_AVX2)
499 for (; ox <= ow - 4; ox += 4) {
500 float ix = (ox + 0) * fx + fy / 2.0f - 0.5f;
501 size_t ix_r = static_cast<size_t>(round(ix));
503 __m128 vx00 = _mm_setzero_ps();
504 __m128 vx01 = _mm_setzero_ps();
505 __m128 vx02 = _mm_setzero_ps();
507 __m128 vx10 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r-1);
508 __m128 vx11 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r+0);
509 __m128 vx12 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r+1);
511 __m128 vx20 = _mm_load_ss(in_ptr+(iy_r+1)*iw+ix_r-1);
512 __m128 vx21 = _mm_load_ss(in_ptr+(iy_r+1)*iw+ix_r+0);
513 __m128 vx22 = _mm_load_ss(in_ptr+(iy_r+1)*iw+ix_r+1);
515 for (size_t i = 0; i < 4; i++) {
516 __m128 vc0 = i < 2 ? _mm_setzero_ps() : _mm_loadu_ps(table_sse[i] + 0);
517 __m128 vc1 = i < 2 ? _mm_setzero_ps() : _mm_loadu_ps(table_sse[i] + 4);
518 __m128 vc2 = _mm_loadu_ps(table_sse[i] + 8);
519 __m128 vc3 = _mm_loadu_ps(table_sse[i] + 12);
523 vc0 = _mm_shuffle_ps(_mm_setzero_ps(), vc0, 0xD0);
524 vc2 = _mm_shuffle_ps(_mm_setzero_ps(), vc2, 0xD0);
525 } else if (ox == ow - 4) {
527 vc0 = _mm_shuffle_ps(vc0, _mm_setzero_ps() , 0x07);
528 vc2 = _mm_shuffle_ps(vc2, _mm_setzero_ps() , 0x07);
531 __m128 vsrc0 = i < 2 ? _mm_shuffle_ps(vx00, vx02, 0x0) : _mm_shuffle_ps(vx10, vx12, 0x0);
532 __m128 vsrc1 = i < 2 ? _mm_shuffle_ps(vx01, vx01, 0x0) : _mm_shuffle_ps(vx11, vx11, 0x0);
533 __m128 vsrc2 = i < 2 ? _mm_shuffle_ps(vx10, vx12, 0x0) : _mm_shuffle_ps(vx20, vx22, 0x0);
534 __m128 vsrc3 = i < 2 ? _mm_shuffle_ps(vx11, vx11, 0x0) : _mm_shuffle_ps(vx21, vx21, 0x0);
536 __m128 vres0 = _mm_mul_ps(vsrc0, vc0);
537 __m128 vres1 = _mm_mul_ps(vsrc1, vc1);
538 __m128 vres2 = _mm_mul_ps(vsrc2, vc2);
539 __m128 vres3 = _mm_mul_ps(vsrc3, vc3);
541 __m128 res = _mm_add_ps(_mm_add_ps(vres0, vres1), _mm_add_ps(vres2, vres3));
542 __m128 wei = _mm_add_ps(_mm_add_ps(vc0, vc1), _mm_add_ps(vc2, vc3));
544 res = _mm_div_ps(res, wei);
546 _mm_storeu_ps(out_ptr + (oy+i)*ow + ox, res);
552 for (oy = 4; oy <= oh - 8; oy += 4) {
553 float iy = oy * fy + fx / 2.0f - 0.5f;
554 size_t iy_r = static_cast<size_t>(round(iy));
557 #if defined(HAVE_AVX2)
558 for (; ox <= ow - 8; ox += 8) {
559 float ix = (ox + 0) * fx + fy / 2.0f - 0.5f;
560 size_t ix_r = static_cast<size_t>(round(ix));
562 __m128 vx00_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r - 1);
563 __m128 vx01_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r + 0);
564 __m128 vx02_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r + 1);
565 __m128 vx03_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r + 2);
567 __m128 vx10_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r - 1);
568 __m128 vx11_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 0);
569 __m128 vx12_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 1);
570 __m128 vx13_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 2);
572 __m128 vx20_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r - 1);
573 __m128 vx21_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r + 0);
574 __m128 vx22_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r + 1);
575 __m128 vx23_ = _mm_load_ss(in_ptr + (iy_r + 1) * iw + ix_r + 2);
577 __m256 vx00 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx00_), vx01_, 1);
578 __m256 vx01 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx01_), vx02_, 1);
579 __m256 vx02 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx02_), vx03_, 1);
581 __m256 vx10 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx10_), vx11_, 1);
582 __m256 vx11 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx11_), vx12_, 1);
583 __m256 vx12 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx12_), vx13_, 1);
585 __m256 vx20 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx20_), vx21_, 1);
586 __m256 vx21 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx21_), vx22_, 1);
587 __m256 vx22 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx22_), vx23_, 1);
589 for (size_t i = 0; i < 4; i++) {
590 __m256 vc0 = _mm256_loadu_ps(table_avx2[i] + 0);
591 __m256 vc1 = _mm256_loadu_ps(table_avx2[i] + 8);
592 __m256 vc2 = _mm256_loadu_ps(table_avx2[i] + 16);
593 __m256 vc3 = _mm256_loadu_ps(table_avx2[i] + 24);
596 vc0 = _mm256_insertf128_ps(vc0, _mm_shuffle_ps(_mm_setzero_ps(), _mm256_extractf128_ps(vc0, 0), 0xD0), 0);
597 vc2 = _mm256_insertf128_ps(vc2, _mm_shuffle_ps(_mm_setzero_ps(), _mm256_extractf128_ps(vc2, 0), 0xD0), 0);
598 } else if (ox == ow - 8) {
599 vc0 = _mm256_insertf128_ps(vc0, _mm_shuffle_ps(_mm256_extractf128_ps(vc0, 1), _mm_setzero_ps(), 0x07), 1);
600 vc2 = _mm256_insertf128_ps(vc2, _mm_shuffle_ps(_mm256_extractf128_ps(vc2, 1), _mm_setzero_ps(), 0x07), 1);
603 __m256 vsrc0 = i < 2 ? _mm256_shuffle_ps(vx00, vx02, 0x0) : _mm256_shuffle_ps(vx10, vx12, 0x0);
604 __m256 vsrc1 = i < 2 ? _mm256_shuffle_ps(vx01, vx01, 0x0) : _mm256_shuffle_ps(vx11, vx11, 0x0);
605 __m256 vsrc2 = i < 2 ? _mm256_shuffle_ps(vx10, vx12, 0x0) : _mm256_shuffle_ps(vx20, vx22, 0x0);
606 __m256 vsrc3 = i < 2 ? _mm256_shuffle_ps(vx11, vx11, 0x0) : _mm256_shuffle_ps(vx21, vx21, 0x0);
608 __m256 res = _mm256_setzero_ps();
610 res = _mm256_fmadd_ps(vsrc0, vc0, res);
611 res = _mm256_fmadd_ps(vsrc1, vc1, res);
612 res = _mm256_fmadd_ps(vsrc2, vc2, res);
613 res = _mm256_fmadd_ps(vsrc3, vc3, res);
615 if (ox == 0 || ox == ow - 8) {
616 __m256 wei = _mm256_add_ps(_mm256_add_ps(vc0, vc1), _mm256_add_ps(vc2, vc3));
618 res = _mm256_div_ps(res, wei);
621 _mm256_storeu_ps(out_ptr + (oy + i) * ow + ox, res);
626 #if defined(HAVE_SSE) || defined(HAVE_AVX2)
627 for (; ox <= ow - 4; ox += 4) {
628 float ix = (ox + 0) * fx + fy / 2.0f - 0.5f;
629 size_t ix_r = static_cast<size_t>(round(ix));
631 __m128 vx00 = _mm_load_ss(in_ptr+(iy_r-1)*iw+ix_r-1);
632 __m128 vx01 = _mm_load_ss(in_ptr+(iy_r-1)*iw+ix_r+0);
633 __m128 vx02 = _mm_load_ss(in_ptr+(iy_r-1)*iw+ix_r+1);
635 __m128 vx10 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r-1);
636 __m128 vx11 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r+0);
637 __m128 vx12 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r+1);
639 __m128 vx20 = _mm_load_ss(in_ptr+(iy_r+1)*iw+ix_r-1);
640 __m128 vx21 = _mm_load_ss(in_ptr+(iy_r+1)*iw+ix_r+0);
641 __m128 vx22 = _mm_load_ss(in_ptr+(iy_r+1)*iw+ix_r+1);
643 for (size_t i = 0; i < 4; i++) {
644 __m128 vc0 = _mm_loadu_ps(table_sse[i] + 0);
645 __m128 vc1 = _mm_loadu_ps(table_sse[i] + 4);
646 __m128 vc2 = _mm_loadu_ps(table_sse[i] + 8);
647 __m128 vc3 = _mm_loadu_ps(table_sse[i] + 12);
650 vc0 = _mm_shuffle_ps(_mm_setzero_ps(), vc0, 0xD0);
651 vc2 = _mm_shuffle_ps(_mm_setzero_ps(), vc2, 0xD0);
652 } else if (ox == ow - 4) {
653 vc0 = _mm_shuffle_ps(vc0, _mm_setzero_ps() , 0x07);
654 vc2 = _mm_shuffle_ps(vc2, _mm_setzero_ps() , 0x07);
657 __m128 vsrc0 = i < 2 ? _mm_shuffle_ps(vx00, vx02, 0x0) : _mm_shuffle_ps(vx10, vx12, 0x0);
658 __m128 vsrc1 = i < 2 ? _mm_shuffle_ps(vx01, vx01, 0x0) : _mm_shuffle_ps(vx11, vx11, 0x0);
659 __m128 vsrc2 = i < 2 ? _mm_shuffle_ps(vx10, vx12, 0x0) : _mm_shuffle_ps(vx20, vx22, 0x0);
660 __m128 vsrc3 = i < 2 ? _mm_shuffle_ps(vx11, vx11, 0x0) : _mm_shuffle_ps(vx21, vx21, 0x0);
662 __m128 vres0 = _mm_mul_ps(vsrc0, vc0);
663 __m128 vres1 = _mm_mul_ps(vsrc1, vc1);
664 __m128 vres2 = _mm_mul_ps(vsrc2, vc2);
665 __m128 vres3 = _mm_mul_ps(vsrc3, vc3);
667 __m128 res = _mm_add_ps(_mm_add_ps(vres0, vres1), _mm_add_ps(vres2, vres3));
668 if (ox == 0 || ox == ow - 4) {
669 __m128 wei = _mm_add_ps(_mm_add_ps(vc0, vc1), _mm_add_ps(vc2, vc3));
671 res = _mm_div_ps(res, wei);
674 _mm_storeu_ps(out_ptr + (oy+i)*ow + ox, res);
682 float iy = oy * fy + fx / 2.0f - 0.5f;
683 size_t iy_r = static_cast<size_t>(round(iy));
687 #if defined(HAVE_AVX2)
688 for (; ox <= ow - 8; ox += 8) {
689 float ix = (ox + 0) * fx + fy / 2.0f - 0.5f;
690 size_t ix_r = static_cast<size_t>(round(ix));
692 __m128 vx00_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r - 1);
693 __m128 vx01_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r + 0);
694 __m128 vx02_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r + 1);
695 __m128 vx03_ = _mm_load_ss(in_ptr + (iy_r - 1) * iw + ix_r + 2);
697 __m128 vx10_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r - 1);
698 __m128 vx11_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 0);
699 __m128 vx12_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 1);
700 __m128 vx13_ = _mm_load_ss(in_ptr + (iy_r + 0) * iw + ix_r + 2);
702 __m256 vx00 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx00_), vx01_, 1);
703 __m256 vx01 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx01_), vx02_, 1);
704 __m256 vx02 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx02_), vx03_, 1);
706 __m256 vx10 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx10_), vx11_, 1);
707 __m256 vx11 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx11_), vx12_, 1);
708 __m256 vx12 = _mm256_insertf128_ps(_mm256_castps128_ps256(vx12_), vx13_, 1);
710 __m256 vx20 = _mm256_setzero_ps();
711 __m256 vx21 = _mm256_setzero_ps();
712 __m256 vx22 = _mm256_setzero_ps();
714 for (size_t i = 0; i < 4; i++) {
715 __m256 vc0 = _mm256_loadu_ps(table_avx2[i] + 0);
716 __m256 vc1 = _mm256_loadu_ps(table_avx2[i] + 8);
717 __m256 vc2 = i < 2 ? _mm256_loadu_ps(table_avx2[i] + 16) : _mm256_setzero_ps();
718 __m256 vc3 = i < 2 ? _mm256_loadu_ps(table_avx2[i] + 24) : _mm256_setzero_ps();
721 vc0 = _mm256_insertf128_ps(vc0, _mm_shuffle_ps(_mm_setzero_ps(), _mm256_extractf128_ps(vc0, 0), 0xD0), 0);
723 vc2 = _mm256_insertf128_ps(vc2, _mm_shuffle_ps(_mm_setzero_ps(), _mm256_extractf128_ps(vc2, 0), 0xD0), 0);
724 } else if (ox == ow - 8) {
725 vc0 = _mm256_insertf128_ps(vc0, _mm_shuffle_ps(_mm256_extractf128_ps(vc0, 1), _mm_setzero_ps(), 0x07), 1);
727 vc2 = _mm256_insertf128_ps(vc2, _mm_shuffle_ps(_mm256_extractf128_ps(vc2, 1), _mm_setzero_ps(), 0x07), 1);
730 __m256 vsrc0 = i < 2 ? _mm256_shuffle_ps(vx00, vx02, 0x0) : _mm256_shuffle_ps(vx10, vx12, 0x0);
731 __m256 vsrc1 = i < 2 ? _mm256_shuffle_ps(vx01, vx01, 0x0) : _mm256_shuffle_ps(vx11, vx11, 0x0);
732 __m256 vsrc2 = i < 2 ? _mm256_shuffle_ps(vx10, vx12, 0x0) : _mm256_shuffle_ps(vx20, vx22, 0x0);
733 __m256 vsrc3 = i < 2 ? _mm256_shuffle_ps(vx11, vx11, 0x0) : _mm256_shuffle_ps(vx21, vx21, 0x0);
735 __m256 res = _mm256_setzero_ps();
737 res = _mm256_fmadd_ps(vsrc0, vc0, res);
738 res = _mm256_fmadd_ps(vsrc1, vc1, res);
739 res = _mm256_fmadd_ps(vsrc2, vc2, res);
740 res = _mm256_fmadd_ps(vsrc3, vc3, res);
742 __m256 wei = _mm256_add_ps(_mm256_add_ps(vc0, vc1), _mm256_add_ps(vc2, vc3));
744 res = _mm256_div_ps(res, wei);
746 _mm256_storeu_ps(out_ptr + (oy + i) * ow + ox, res);
751 #if defined(HAVE_SSE) || defined(HAVE_AVX2)
752 for (; ox <= ow - 4; ox += 4) {
753 float ix = (ox + 0) * fx + fy / 2.0f - 0.5f;
754 size_t ix_r = static_cast<size_t>(round(ix));
756 __m128 vx00 = _mm_load_ss(in_ptr+(iy_r-1)*iw+ix_r-1);
757 __m128 vx01 = _mm_load_ss(in_ptr+(iy_r-1)*iw+ix_r+0);
758 __m128 vx02 = _mm_load_ss(in_ptr+(iy_r-1)*iw+ix_r+1);
760 __m128 vx10 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r-1);
761 __m128 vx11 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r+0);
762 __m128 vx12 = _mm_load_ss(in_ptr+(iy_r+0)*iw+ix_r+1);
764 __m128 vx20 = _mm_setzero_ps();
765 __m128 vx21 = _mm_setzero_ps();
766 __m128 vx22 = _mm_setzero_ps();
768 for (size_t i = 0; i < 4; i++) {
769 __m128 vc0 = _mm_loadu_ps(table_sse[i] + 0);
770 __m128 vc1 = _mm_loadu_ps(table_sse[i] + 4);
771 __m128 vc2 = i < 2 ?_mm_loadu_ps(table_sse[i] + 8) : _mm_setzero_ps();
772 __m128 vc3 = i < 2 ?_mm_loadu_ps(table_sse[i] + 12) : _mm_setzero_ps();
775 vc0 = _mm_shuffle_ps(_mm_setzero_ps(), vc0, 0xD0);
777 vc2 = _mm_shuffle_ps(_mm_setzero_ps(), vc2, 0xD0);
778 } else if (ox == ow - 4) {
779 vc0 = _mm_shuffle_ps(vc0, _mm_setzero_ps() , 0x07);
781 vc2 = _mm_shuffle_ps(vc2, _mm_setzero_ps() , 0x07);
784 __m128 vsrc0 = i < 2 ? _mm_shuffle_ps(vx00, vx02, 0x0) : _mm_shuffle_ps(vx10, vx12, 0x0);
785 __m128 vsrc1 = i < 2 ? _mm_shuffle_ps(vx01, vx01, 0x0) : _mm_shuffle_ps(vx11, vx11, 0x0);
786 __m128 vsrc2 = i < 2 ? _mm_shuffle_ps(vx10, vx12, 0x0) : _mm_shuffle_ps(vx20, vx22, 0x0);
787 __m128 vsrc3 = i < 2 ? _mm_shuffle_ps(vx11, vx11, 0x0) : _mm_shuffle_ps(vx21, vx21, 0x0);
789 __m128 vres0 = _mm_mul_ps(vsrc0, vc0);
790 __m128 vres1 = _mm_mul_ps(vsrc1, vc1);
791 __m128 vres2 = _mm_mul_ps(vsrc2, vc2);
792 __m128 vres3 = _mm_mul_ps(vsrc3, vc3);
794 __m128 res = _mm_add_ps(_mm_add_ps(vres0, vres1), _mm_add_ps(vres2, vres3));
795 __m128 wei = _mm_add_ps(_mm_add_ps(vc0, vc1), _mm_add_ps(vc2, vc3));
797 res = _mm_div_ps(res, wei);
799 _mm_storeu_ps(out_ptr + (oy+i)*ow + ox, res);
807 #endif // defined(HAVE_SSE) || defined(HAVE_AVX2)
810 REG_FACTORY_FOR(ImplFactory<ResampleImpl>, Resample);
813 } // namespace Extensions
814 } // namespace InferenceEngine