1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include "ext_list.hpp"
7 #include "ext_base.hpp"
11 namespace InferenceEngine {
12 namespace Extensions {
15 class InterpImpl: public ExtLayerBase {
17 explicit InterpImpl(const CNNLayer* layer) {
19 if (layer->insData.size() != 1 || layer->outData.empty())
20 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
22 if (layer->insData[0].lock()->dims.size() != 4)
23 THROW_IE_EXCEPTION << "Interp supports only 4d blobs!";
25 // We don't read other parameters since they are needed only for dst reshape in caffe
26 pad_beg = layer->GetParamAsInt("pad_beg");
27 pad_end = layer->GetParamAsInt("pad_end");
29 #if defined(HAVE_AVX512F)
30 auto blk_layout = ConfLayout::BLK16;
32 auto blk_layout = ConfLayout::BLK8;
35 addConfig(layer, {DataConfigurator(blk_layout)}, {DataConfigurator(blk_layout)});
36 } catch (InferenceEngine::details::InferenceEngineException &ex) {
41 StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
42 ResponseDesc *resp) noexcept override {
43 int IN = static_cast<int>(inputs[0]->getTensorDesc().getDims()[0]);
44 int IC = static_cast<int>(
45 inputs[0]->getTensorDesc().getBlockingDesc().getBlockDims()[1] *
46 inputs[0]->getTensorDesc().getBlockingDesc().getBlockDims()[4]);
47 int IH = static_cast<int>(inputs[0]->getTensorDesc().getDims()[2]);
48 int IW = static_cast<int>(inputs[0]->getTensorDesc().getDims()[3]);
50 int OH = static_cast<int>(outputs[0]->getTensorDesc().getDims()[2]);
51 int OW = static_cast<int>(outputs[0]->getTensorDesc().getDims()[3]);
53 int IH_pad = IH + pad_beg + pad_end;
54 int IW_pad = IW + pad_beg + pad_end;
56 const auto *src_data = inputs[0]->buffer().as<const float *>();
57 auto *dst_data = outputs[0]->buffer().as<float *>();
59 interpolate(IN, IC, src_data, -pad_beg, -pad_beg, IH_pad, IW_pad, IH, IW, dst_data, 0, 0, OH, OW, OH, OW);
66 void interpolate(const int N, const int C,
67 const float *src, const int x1, const int y1,
68 const int IH_pad, const int IW_pad, const int IH, const int IW,
69 float *dst, const int x2, const int y2,
70 const int OH_pad, const int OW_pad, const int OH, const int OW) {
71 if (IH_pad == OH_pad && IW_pad == OW_pad) {
72 for (int i = 0; i < N * C * OH * OW; i++) {
78 const float rh = (OH_pad > 1) ? static_cast<float>(IH_pad - 1) / (OH_pad - 1) : 0.0f;
79 const float rw = (OW_pad > 1) ? static_cast<float>(IW_pad - 1) / (OW_pad - 1) : 0.0f;
81 #if defined(HAVE_AVX512F)
82 const int block_size = 16;
84 const int block_size = 8;
87 // Align channel number to block size to deal with channels padding in IE with multiple blobs
88 int CB = (C + block_size - 1) & (-block_size);
90 int CH = (C + block_size - 1) / block_size;
92 #if _MSC_VER && !__INTEL_COMPILER
93 #pragma omp parallel for schedule(static)
95 #pragma omp parallel for collapse(3) schedule(static)
97 for (int n = 0; n < N; n++) {
98 for (int cb = 0; cb < CH; ++cb) {
99 for (int h = 0; h < OH_pad; ++h) {
100 const float *psrc = src + n * CB * IH * IW;
103 int ih0 = static_cast<int>(fh);
104 int ih1 = (ih0 < IH_pad - 1) ? ih0 + 1 : ih0;
106 float h_lambda0 = fh - ih0;
107 float h_lambda1 = 1.0f - h_lambda0;
109 for (int w = 0; w < OW_pad; ++w) {
111 int iw0 = static_cast<int>(fw);
112 int iw1 = (iw0 < IW_pad - 1) ? iw0 + 1 : iw0;
114 float w_lambda0 = fw - iw0;
115 float w_lambda1 = 1.0f - w_lambda0;
117 const float *psrc00 =
118 psrc + cb * block_size * IW * IH + (y1 + ih0) * IW * block_size + (x1 + iw0) * block_size;
119 const float *psrc01 =
120 psrc + cb * block_size * IW * IH + (y1 + ih0) * IW * block_size + (x1 + iw1) * block_size;
121 const float *psrc10 =
122 psrc + cb * block_size * IW * IH + (y1 + ih1) * IW * block_size + (x1 + iw0) * block_size;
123 const float *psrc11 =
124 psrc + cb * block_size * IW * IH + (y1 + ih1) * IW * block_size + (x1 + iw1) * block_size;
126 float *pdst = dst + n * CB * OH * OW + cb * block_size * OW * OH + (y2 + h) * OW * block_size +
127 (x2 + w) * block_size;
129 #if defined(HAVE_AVX512F)
130 __m512 vwl0 = _mm512_set1_ps(w_lambda0);
131 __m512 vwl1 = _mm512_set1_ps(w_lambda1);
132 __m512 vhl0 = _mm512_set1_ps(h_lambda0);
133 __m512 vhl1 = _mm512_set1_ps(h_lambda1);
134 __m512 vsrc00 = _mm512_loadu_ps(psrc00);
135 __m512 vsrc01 = _mm512_loadu_ps(psrc01);
136 __m512 vsrc10 = _mm512_loadu_ps(psrc10);
137 __m512 vsrc11 = _mm512_loadu_ps(psrc11);
139 __m512 vdst0 = _mm512_fmadd_ps(vwl1, vsrc00, _mm512_mul_ps(vwl0, vsrc01));
140 __m512 vdst1 = _mm512_fmadd_ps(vwl1, vsrc10, _mm512_mul_ps(vwl0, vsrc11));
141 __m512 vdst = _mm512_fmadd_ps(vhl1, vdst0, _mm512_mul_ps(vhl0, vdst1));
143 _mm512_storeu_ps(pdst, vdst);
144 #elif defined(HAVE_AVX2)
145 __m256 vwl0 = _mm256_set1_ps(w_lambda0);
146 __m256 vwl1 = _mm256_set1_ps(w_lambda1);
147 __m256 vhl0 = _mm256_set1_ps(h_lambda0);
148 __m256 vhl1 = _mm256_set1_ps(h_lambda1);
149 __m256 vsrc00 = _mm256_loadu_ps(psrc00);
150 __m256 vsrc01 = _mm256_loadu_ps(psrc01);
151 __m256 vsrc10 = _mm256_loadu_ps(psrc10);
152 __m256 vsrc11 = _mm256_loadu_ps(psrc11);
154 __m256 vdst0 = _mm256_fmadd_ps(vwl1, vsrc00, _mm256_mul_ps(vwl0, vsrc01));
155 __m256 vdst1 = _mm256_fmadd_ps(vwl1, vsrc10, _mm256_mul_ps(vwl0, vsrc11));
156 __m256 vdst = _mm256_fmadd_ps(vhl1, vdst0, _mm256_mul_ps(vhl0, vdst1));
158 _mm256_storeu_ps(pdst, vdst);
159 #elif defined(HAVE_SSE)
160 __m128 vwl0 = _mm_set1_ps(w_lambda0);
161 __m128 vwl1 = _mm_set1_ps(w_lambda1);
162 __m128 vhl0 = _mm_set1_ps(h_lambda0);
163 __m128 vhl1 = _mm_set1_ps(h_lambda1);
164 for (int i = 0; i < block_size/4; i++) {
165 __m128 vsrc00 = _mm_loadu_ps(psrc00 + i*block_size/2);
166 __m128 vsrc01 = _mm_loadu_ps(psrc01 + i*block_size/2);
167 __m128 vsrc10 = _mm_loadu_ps(psrc10 + i*block_size/2);
168 __m128 vsrc11 = _mm_loadu_ps(psrc11 + i*block_size/2);
170 __m128 vdst00 = _mm_mul_ps(vwl1, vsrc00);
171 __m128 vdst01 = _mm_mul_ps(vwl0, vsrc01);
172 __m128 vdst10 = _mm_mul_ps(vwl1, vsrc10);
173 __m128 vdst11 = _mm_mul_ps(vwl0, vsrc11);
175 __m128 vdst0 = _mm_add_ps(vdst00, vdst01);
176 __m128 vdst1 = _mm_add_ps(vdst10, vdst11);
178 __m128 vdst = _mm_add_ps(_mm_mul_ps(vhl1, vdst0), _mm_mul_ps(vhl0, vdst1));
180 _mm_storeu_ps(pdst + i*block_size/2, vdst);
183 for (int c = 0; c < block_size; ++c) {
184 pdst[c] = h_lambda1 * (w_lambda1 * psrc00[c] + w_lambda0 * psrc01[c]) +
185 h_lambda0 * (w_lambda1 * psrc10[c] + w_lambda0 * psrc11[c]);
195 REG_FACTORY_FOR(ImplFactory<InterpImpl>, Interp);
198 } // namespace Extensions
199 } // namespace InferenceEngine