1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
8 #if defined(HAVE_SSE) || defined(HAVE_AVX2) || defined(HAVE_AVX512F)
11 #include "ie_parallel.hpp"
13 namespace InferenceEngine {
14 namespace Extensions {
17 class InterpImpl: public ExtLayerBase {
19 explicit InterpImpl(const CNNLayer* layer) {
21 if (layer->insData.size() != 1 || layer->outData.empty())
22 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
24 if (layer->insData[0].lock()->dims.size() != 4)
25 THROW_IE_EXCEPTION << "Interp supports only 4d blobs!";
27 // We don't read other parameters since they are needed only for dst reshape in caffe
28 pad_beg = layer->GetParamAsInt("pad_beg");
29 pad_end = layer->GetParamAsInt("pad_end");
30 align_corners = layer->GetParamsAsBool("align_corners", true);
32 #if defined(HAVE_AVX512F)
33 auto blk_layout = ConfLayout::BLK16;
35 auto blk_layout = ConfLayout::BLK8;
38 addConfig(layer, {DataConfigurator(blk_layout)}, {DataConfigurator(blk_layout)});
39 } catch (InferenceEngine::details::InferenceEngineException &ex) {
44 StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
45 ResponseDesc *resp) noexcept override {
46 int IN = static_cast<int>(inputs[0]->getTensorDesc().getDims()[0]);
47 int IC = static_cast<int>(
48 inputs[0]->getTensorDesc().getBlockingDesc().getBlockDims()[1] *
49 inputs[0]->getTensorDesc().getBlockingDesc().getBlockDims()[4]);
50 int IH = static_cast<int>(inputs[0]->getTensorDesc().getDims()[2]);
51 int IW = static_cast<int>(inputs[0]->getTensorDesc().getDims()[3]);
53 int OH = static_cast<int>(outputs[0]->getTensorDesc().getDims()[2]);
54 int OW = static_cast<int>(outputs[0]->getTensorDesc().getDims()[3]);
56 int IH_pad = IH + pad_beg + pad_end;
57 int IW_pad = IW + pad_beg + pad_end;
59 const auto *src_data = inputs[0]->buffer().as<const float *>();
60 auto *dst_data = outputs[0]->buffer().as<float *>();
62 interpolate(IN, IC, src_data, -pad_beg, -pad_beg, IH_pad, IW_pad, IH, IW, dst_data, 0, 0, OH, OW, OH, OW);
71 void interpolate(const int N, const int C,
72 const float *src, const int x1, const int y1,
73 const int IH_pad, const int IW_pad, const int IH, const int IW,
74 float *dst, const int x2, const int y2,
75 const int OH_pad, const int OW_pad, const int OH, const int OW) {
76 if (IH_pad == OH_pad && IW_pad == OW_pad) {
77 for (int i = 0; i < N * C * OH * OW; i++) {
86 rh = (OH_pad > 1) ? static_cast<float>(IH_pad - 1) / (OH_pad - 1) : 0.0f;
87 rw = (OW_pad > 1) ? static_cast<float>(IW_pad - 1) / (OW_pad - 1) : 0.0f;
89 rh = static_cast<float>(IH_pad) / (OH_pad);
90 rw = static_cast<float>(IW_pad) / (OW_pad);
93 #if defined(HAVE_AVX512F)
94 const int block_size = 16;
96 const int block_size = 8;
99 // Align channel number to block size to deal with channels padding in IE with multiple blobs
100 int CB = (C + block_size - 1) & (-block_size);
102 int CH = (C + block_size - 1) / block_size;
104 parallel_for3d(N, CH, OH_pad, [&](int n, int cb, int h) {
105 const float *psrc = src + n * CB * IH * IW;
108 int ih0 = static_cast<int>(fh);
109 int ih1 = (ih0 < IH_pad - 1) ? ih0 + 1 : ih0;
111 float h_lambda0 = fh - ih0;
112 float h_lambda1 = 1.0f - h_lambda0;
114 for (int w = 0; w < OW_pad; ++w) {
116 int iw0 = static_cast<int>(fw);
117 int iw1 = (iw0 < IW_pad - 1) ? iw0 + 1 : iw0;
119 float w_lambda0 = fw - iw0;
120 float w_lambda1 = 1.0f - w_lambda0;
122 const float *psrc00 =
123 psrc + cb * block_size * IW * IH + (y1 + ih0) * IW * block_size + (x1 + iw0) * block_size;
124 const float *psrc01 =
125 psrc + cb * block_size * IW * IH + (y1 + ih0) * IW * block_size + (x1 + iw1) * block_size;
126 const float *psrc10 =
127 psrc + cb * block_size * IW * IH + (y1 + ih1) * IW * block_size + (x1 + iw0) * block_size;
128 const float *psrc11 =
129 psrc + cb * block_size * IW * IH + (y1 + ih1) * IW * block_size + (x1 + iw1) * block_size;
131 float *pdst = dst + n * CB * OH * OW + cb * block_size * OW * OH + (y2 + h) * OW * block_size +
132 (x2 + w) * block_size;
134 #if defined(HAVE_AVX512F)
135 __m512 vwl0 = _mm512_set1_ps(w_lambda0);
136 __m512 vwl1 = _mm512_set1_ps(w_lambda1);
137 __m512 vhl0 = _mm512_set1_ps(h_lambda0);
138 __m512 vhl1 = _mm512_set1_ps(h_lambda1);
139 __m512 vsrc00 = _mm512_loadu_ps(psrc00);
140 __m512 vsrc01 = _mm512_loadu_ps(psrc01);
141 __m512 vsrc10 = _mm512_loadu_ps(psrc10);
142 __m512 vsrc11 = _mm512_loadu_ps(psrc11);
144 __m512 vdst0 = _mm512_fmadd_ps(vwl1, vsrc00, _mm512_mul_ps(vwl0, vsrc01));
145 __m512 vdst1 = _mm512_fmadd_ps(vwl1, vsrc10, _mm512_mul_ps(vwl0, vsrc11));
146 __m512 vdst = _mm512_fmadd_ps(vhl1, vdst0, _mm512_mul_ps(vhl0, vdst1));
148 _mm512_storeu_ps(pdst, vdst);
149 #elif defined(HAVE_AVX2)
150 __m256 vwl0 = _mm256_set1_ps(w_lambda0);
151 __m256 vwl1 = _mm256_set1_ps(w_lambda1);
152 __m256 vhl0 = _mm256_set1_ps(h_lambda0);
153 __m256 vhl1 = _mm256_set1_ps(h_lambda1);
154 __m256 vsrc00 = _mm256_loadu_ps(psrc00);
155 __m256 vsrc01 = _mm256_loadu_ps(psrc01);
156 __m256 vsrc10 = _mm256_loadu_ps(psrc10);
157 __m256 vsrc11 = _mm256_loadu_ps(psrc11);
159 __m256 vdst0 = _mm256_fmadd_ps(vwl1, vsrc00, _mm256_mul_ps(vwl0, vsrc01));
160 __m256 vdst1 = _mm256_fmadd_ps(vwl1, vsrc10, _mm256_mul_ps(vwl0, vsrc11));
161 __m256 vdst = _mm256_fmadd_ps(vhl1, vdst0, _mm256_mul_ps(vhl0, vdst1));
163 _mm256_storeu_ps(pdst, vdst);
164 #elif defined(HAVE_SSE)
165 __m128 vwl0 = _mm_set1_ps(w_lambda0);
166 __m128 vwl1 = _mm_set1_ps(w_lambda1);
167 __m128 vhl0 = _mm_set1_ps(h_lambda0);
168 __m128 vhl1 = _mm_set1_ps(h_lambda1);
169 for (int i = 0; i < block_size/4; i++) {
170 __m128 vsrc00 = _mm_loadu_ps(psrc00 + i*block_size/2);
171 __m128 vsrc01 = _mm_loadu_ps(psrc01 + i*block_size/2);
172 __m128 vsrc10 = _mm_loadu_ps(psrc10 + i*block_size/2);
173 __m128 vsrc11 = _mm_loadu_ps(psrc11 + i*block_size/2);
175 __m128 vdst00 = _mm_mul_ps(vwl1, vsrc00);
176 __m128 vdst01 = _mm_mul_ps(vwl0, vsrc01);
177 __m128 vdst10 = _mm_mul_ps(vwl1, vsrc10);
178 __m128 vdst11 = _mm_mul_ps(vwl0, vsrc11);
180 __m128 vdst0 = _mm_add_ps(vdst00, vdst01);
181 __m128 vdst1 = _mm_add_ps(vdst10, vdst11);
183 __m128 vdst = _mm_add_ps(_mm_mul_ps(vhl1, vdst0), _mm_mul_ps(vhl0, vdst1));
185 _mm_storeu_ps(pdst + i*block_size/2, vdst);
188 for (int c = 0; c < block_size; ++c) {
189 pdst[c] = h_lambda1 * (w_lambda1 * psrc00[c] + w_lambda0 * psrc01[c]) +
190 h_lambda0 * (w_lambda1 * psrc10[c] + w_lambda0 * psrc11[c]);
198 REG_FACTORY_FOR(ImplFactory<InterpImpl>, Interp);
201 } // namespace Extensions
202 } // namespace InferenceEngine