Publishing R3
[platform/upstream/dldt.git] / inference-engine / src / extension / ext_interp.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include "ext_list.hpp"
7 #include "ext_base.hpp"
8 #include <vector>
9 #include <immintrin.h>
10
11 namespace InferenceEngine {
12 namespace Extensions {
13 namespace Cpu {
14
15 class InterpImpl: public ExtLayerBase {
16 public:
17     explicit InterpImpl(const CNNLayer* layer) {
18         try {
19             if (layer->insData.size() != 1 || layer->outData.empty())
20                 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
21
22             if (layer->insData[0].lock()->dims.size() != 4)
23                 THROW_IE_EXCEPTION << "Interp supports only 4d blobs!";
24
25             // We don't read other parameters since they are needed only for dst reshape in caffe
26             pad_beg = layer->GetParamAsInt("pad_beg");
27             pad_end = layer->GetParamAsInt("pad_end");
28
29 #if defined(HAVE_AVX512F)
30             auto blk_layout = ConfLayout::BLK16;
31 #else
32             auto blk_layout = ConfLayout::BLK8;
33 #endif
34
35             addConfig(layer,  {DataConfigurator(blk_layout)}, {DataConfigurator(blk_layout)});
36         } catch (InferenceEngine::details::InferenceEngineException &ex) {
37             errorMsg = ex.what();
38         }
39     }
40
41     StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
42                        ResponseDesc *resp) noexcept override {
43         int IN = static_cast<int>(inputs[0]->getTensorDesc().getDims()[0]);
44         int IC = static_cast<int>(
45                 inputs[0]->getTensorDesc().getBlockingDesc().getBlockDims()[1] *
46                 inputs[0]->getTensorDesc().getBlockingDesc().getBlockDims()[4]);
47         int IH = static_cast<int>(inputs[0]->getTensorDesc().getDims()[2]);
48         int IW = static_cast<int>(inputs[0]->getTensorDesc().getDims()[3]);
49
50         int OH = static_cast<int>(outputs[0]->getTensorDesc().getDims()[2]);
51         int OW = static_cast<int>(outputs[0]->getTensorDesc().getDims()[3]);
52
53         int IH_pad = IH + pad_beg + pad_end;
54         int IW_pad = IW + pad_beg + pad_end;
55
56         const auto *src_data = inputs[0]->buffer().as<const float *>();
57         auto *dst_data = outputs[0]->buffer().as<float *>();
58
59         interpolate(IN, IC, src_data, -pad_beg, -pad_beg, IH_pad, IW_pad, IH, IW, dst_data, 0, 0, OH, OW, OH, OW);
60         return OK;
61     }
62
63 private:
64     int pad_beg;
65     int pad_end;
66     void interpolate(const int N, const int C,
67                      const float *src, const int x1, const int y1,
68                      const int IH_pad, const int IW_pad, const int IH, const int IW,
69                      float *dst, const int x2, const int y2,
70                      const int OH_pad, const int OW_pad, const int OH, const int OW) {
71         if (IH_pad == OH_pad && IW_pad == OW_pad) {
72             for (int i = 0; i < N * C * OH * OW; i++) {
73                 dst[i] = src[i];
74             }
75             return;
76         }
77
78         const float rh = (OH_pad > 1) ? static_cast<float>(IH_pad - 1) / (OH_pad - 1) : 0.0f;
79         const float rw = (OW_pad > 1) ? static_cast<float>(IW_pad - 1) / (OW_pad - 1) : 0.0f;
80
81 #if defined(HAVE_AVX512F)
82         const int block_size = 16;
83 #else
84         const int block_size = 8;
85 #endif
86
87         // Align channel number to block size to deal with channels padding in IE with multiple blobs
88         int CB = (C + block_size - 1) & (-block_size);
89
90         int CH = (C + block_size - 1) / block_size;
91
92 #if _MSC_VER && !__INTEL_COMPILER
93         #pragma omp parallel for schedule(static)
94 #else
95         #pragma omp parallel for collapse(3) schedule(static)
96 #endif
97         for (int n = 0; n < N; n++) {
98             for (int cb = 0; cb < CH; ++cb) {
99                 for (int h = 0; h < OH_pad; ++h) {
100                     const float *psrc = src + n * CB * IH * IW;
101
102                     float fh = rh * h;
103                     int ih0 = static_cast<int>(fh);
104                     int ih1 = (ih0 < IH_pad - 1) ? ih0 + 1 : ih0;
105
106                     float h_lambda0 = fh - ih0;
107                     float h_lambda1 = 1.0f - h_lambda0;
108
109                     for (int w = 0; w < OW_pad; ++w) {
110                         float fw = rw * w;
111                         int iw0 = static_cast<int>(fw);
112                         int iw1 = (iw0 < IW_pad - 1) ? iw0 + 1 : iw0;
113
114                         float w_lambda0 = fw - iw0;
115                         float w_lambda1 = 1.0f - w_lambda0;
116
117                         const float *psrc00 =
118                                 psrc + cb * block_size * IW * IH + (y1 + ih0) * IW * block_size + (x1 + iw0) * block_size;
119                         const float *psrc01 =
120                                 psrc + cb * block_size * IW * IH + (y1 + ih0) * IW * block_size + (x1 + iw1) * block_size;
121                         const float *psrc10 =
122                                 psrc + cb * block_size * IW * IH + (y1 + ih1) * IW * block_size + (x1 + iw0) * block_size;
123                         const float *psrc11 =
124                                 psrc + cb * block_size * IW * IH + (y1 + ih1) * IW * block_size + (x1 + iw1) * block_size;
125
126                         float *pdst = dst + n * CB * OH * OW + cb * block_size * OW * OH + (y2 + h) * OW * block_size +
127                                       (x2 + w) * block_size;
128
129 #if defined(HAVE_AVX512F)
130                         __m512 vwl0 = _mm512_set1_ps(w_lambda0);
131                         __m512 vwl1 = _mm512_set1_ps(w_lambda1);
132                         __m512 vhl0 = _mm512_set1_ps(h_lambda0);
133                         __m512 vhl1 = _mm512_set1_ps(h_lambda1);
134                         __m512 vsrc00 = _mm512_loadu_ps(psrc00);
135                         __m512 vsrc01 = _mm512_loadu_ps(psrc01);
136                         __m512 vsrc10 = _mm512_loadu_ps(psrc10);
137                         __m512 vsrc11 = _mm512_loadu_ps(psrc11);
138
139                         __m512 vdst0 = _mm512_fmadd_ps(vwl1, vsrc00, _mm512_mul_ps(vwl0, vsrc01));
140                         __m512 vdst1 = _mm512_fmadd_ps(vwl1, vsrc10, _mm512_mul_ps(vwl0, vsrc11));
141                         __m512 vdst  = _mm512_fmadd_ps(vhl1, vdst0, _mm512_mul_ps(vhl0, vdst1));
142
143                         _mm512_storeu_ps(pdst, vdst);
144 #elif defined(HAVE_AVX2)
145                         __m256 vwl0 = _mm256_set1_ps(w_lambda0);
146                         __m256 vwl1 = _mm256_set1_ps(w_lambda1);
147                         __m256 vhl0 = _mm256_set1_ps(h_lambda0);
148                         __m256 vhl1 = _mm256_set1_ps(h_lambda1);
149                         __m256 vsrc00 = _mm256_loadu_ps(psrc00);
150                         __m256 vsrc01 = _mm256_loadu_ps(psrc01);
151                         __m256 vsrc10 = _mm256_loadu_ps(psrc10);
152                         __m256 vsrc11 = _mm256_loadu_ps(psrc11);
153
154                        __m256 vdst0 = _mm256_fmadd_ps(vwl1, vsrc00, _mm256_mul_ps(vwl0, vsrc01));
155                        __m256 vdst1 = _mm256_fmadd_ps(vwl1, vsrc10, _mm256_mul_ps(vwl0, vsrc11));
156                        __m256 vdst  = _mm256_fmadd_ps(vhl1, vdst0, _mm256_mul_ps(vhl0, vdst1));
157
158                        _mm256_storeu_ps(pdst, vdst);
159 #elif defined(HAVE_SSE)
160                         __m128 vwl0 = _mm_set1_ps(w_lambda0);
161                         __m128 vwl1 = _mm_set1_ps(w_lambda1);
162                         __m128 vhl0 = _mm_set1_ps(h_lambda0);
163                         __m128 vhl1 = _mm_set1_ps(h_lambda1);
164                         for (int i = 0; i < block_size/4; i++) {
165                             __m128 vsrc00 = _mm_loadu_ps(psrc00 + i*block_size/2);
166                             __m128 vsrc01 = _mm_loadu_ps(psrc01 + i*block_size/2);
167                             __m128 vsrc10 = _mm_loadu_ps(psrc10 + i*block_size/2);
168                             __m128 vsrc11 = _mm_loadu_ps(psrc11 + i*block_size/2);
169
170                            __m128 vdst00 = _mm_mul_ps(vwl1, vsrc00);
171                            __m128 vdst01 = _mm_mul_ps(vwl0, vsrc01);
172                            __m128 vdst10 = _mm_mul_ps(vwl1, vsrc10);
173                            __m128 vdst11 = _mm_mul_ps(vwl0, vsrc11);
174
175                            __m128 vdst0 = _mm_add_ps(vdst00, vdst01);
176                            __m128 vdst1 = _mm_add_ps(vdst10, vdst11);
177
178                             __m128 vdst = _mm_add_ps(_mm_mul_ps(vhl1, vdst0), _mm_mul_ps(vhl0, vdst1));
179
180                            _mm_storeu_ps(pdst + i*block_size/2, vdst);
181                         }
182 #else
183                         for (int c = 0; c < block_size; ++c) {
184                             pdst[c] = h_lambda1 * (w_lambda1 * psrc00[c] + w_lambda0 * psrc01[c]) +
185                                       h_lambda0 * (w_lambda1 * psrc10[c] + w_lambda0 * psrc11[c]);
186                         }
187 #endif
188                     }
189                 }
190             }
191         }
192     }
193 };
194
195 REG_FACTORY_FOR(ImplFactory<InterpImpl>, Interp);
196
197 }  // namespace Cpu
198 }  // namespace Extensions
199 }  // namespace InferenceEngine