Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / extension / ext_interp.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
7 #include <vector>
8 #if defined(HAVE_SSE) || defined(HAVE_AVX2) || defined(HAVE_AVX512F)
9 #include <immintrin.h>
10 #endif
11 #include "ie_parallel.hpp"
12
13 namespace InferenceEngine {
14 namespace Extensions {
15 namespace Cpu {
16
17 class InterpImpl: public ExtLayerBase {
18 public:
19     explicit InterpImpl(const CNNLayer* layer) {
20         try {
21             if (layer->insData.size() != 1 || layer->outData.empty())
22                 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
23
24             if (layer->insData[0].lock()->dims.size() != 4)
25                 THROW_IE_EXCEPTION << "Interp supports only 4d blobs!";
26
27             // We don't read other parameters since they are needed only for dst reshape in caffe
28             pad_beg = layer->GetParamAsInt("pad_beg");
29             pad_end = layer->GetParamAsInt("pad_end");
30             align_corners = layer->GetParamsAsBool("align_corners", true);
31
32 #if defined(HAVE_AVX512F)
33             auto blk_layout = ConfLayout::BLK16;
34 #else
35             auto blk_layout = ConfLayout::BLK8;
36 #endif
37
38             addConfig(layer,  {DataConfigurator(blk_layout)}, {DataConfigurator(blk_layout)});
39         } catch (InferenceEngine::details::InferenceEngineException &ex) {
40             errorMsg = ex.what();
41         }
42     }
43
44     StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
45                        ResponseDesc *resp) noexcept override {
46         int IN = static_cast<int>(inputs[0]->getTensorDesc().getDims()[0]);
47         int IC = static_cast<int>(
48                 inputs[0]->getTensorDesc().getBlockingDesc().getBlockDims()[1] *
49                 inputs[0]->getTensorDesc().getBlockingDesc().getBlockDims()[4]);
50         int IH = static_cast<int>(inputs[0]->getTensorDesc().getDims()[2]);
51         int IW = static_cast<int>(inputs[0]->getTensorDesc().getDims()[3]);
52
53         int OH = static_cast<int>(outputs[0]->getTensorDesc().getDims()[2]);
54         int OW = static_cast<int>(outputs[0]->getTensorDesc().getDims()[3]);
55
56         int IH_pad = IH + pad_beg + pad_end;
57         int IW_pad = IW + pad_beg + pad_end;
58
59         const auto *src_data = inputs[0]->buffer().as<const float *>();
60         auto *dst_data = outputs[0]->buffer().as<float *>();
61
62         interpolate(IN, IC, src_data, -pad_beg, -pad_beg, IH_pad, IW_pad, IH, IW, dst_data, 0, 0, OH, OW, OH, OW);
63         return OK;
64     }
65
66 private:
67     int pad_beg;
68     int pad_end;
69     bool align_corners;
70
71     void interpolate(const int N, const int C,
72                      const float *src, const int x1, const int y1,
73                      const int IH_pad, const int IW_pad, const int IH, const int IW,
74                      float *dst, const int x2, const int y2,
75                      const int OH_pad, const int OW_pad, const int OH, const int OW) {
76         if (IH_pad == OH_pad && IW_pad == OW_pad) {
77             for (int i = 0; i < N * C * OH * OW; i++) {
78                 dst[i] = src[i];
79             }
80             return;
81         }
82
83         float rh;
84         float rw;
85         if (align_corners) {
86             rh = (OH_pad > 1) ? static_cast<float>(IH_pad - 1) / (OH_pad - 1) : 0.0f;
87             rw = (OW_pad > 1) ? static_cast<float>(IW_pad - 1) / (OW_pad - 1) : 0.0f;
88         } else {
89             rh = static_cast<float>(IH_pad) / (OH_pad);
90             rw = static_cast<float>(IW_pad) / (OW_pad);
91         }
92
93 #if defined(HAVE_AVX512F)
94         const int block_size = 16;
95 #else
96         const int block_size = 8;
97 #endif
98
99         // Align channel number to block size to deal with channels padding in IE with multiple blobs
100         int CB = (C + block_size - 1) & (-block_size);
101
102         int CH = (C + block_size - 1) / block_size;
103
104         parallel_for3d(N, CH, OH_pad, [&](int n, int cb, int h) {
105                     const float *psrc = src + n * CB * IH * IW;
106
107                     float fh = rh * h;
108                     int ih0 = static_cast<int>(fh);
109                     int ih1 = (ih0 < IH_pad - 1) ? ih0 + 1 : ih0;
110
111                     float h_lambda0 = fh - ih0;
112                     float h_lambda1 = 1.0f - h_lambda0;
113
114                     for (int w = 0; w < OW_pad; ++w) {
115                         float fw = rw * w;
116                         int iw0 = static_cast<int>(fw);
117                         int iw1 = (iw0 < IW_pad - 1) ? iw0 + 1 : iw0;
118
119                         float w_lambda0 = fw - iw0;
120                         float w_lambda1 = 1.0f - w_lambda0;
121
122                         const float *psrc00 =
123                                 psrc + cb * block_size * IW * IH + (y1 + ih0) * IW * block_size + (x1 + iw0) * block_size;
124                         const float *psrc01 =
125                                 psrc + cb * block_size * IW * IH + (y1 + ih0) * IW * block_size + (x1 + iw1) * block_size;
126                         const float *psrc10 =
127                                 psrc + cb * block_size * IW * IH + (y1 + ih1) * IW * block_size + (x1 + iw0) * block_size;
128                         const float *psrc11 =
129                                 psrc + cb * block_size * IW * IH + (y1 + ih1) * IW * block_size + (x1 + iw1) * block_size;
130
131                         float *pdst = dst + n * CB * OH * OW + cb * block_size * OW * OH + (y2 + h) * OW * block_size +
132                                       (x2 + w) * block_size;
133
134 #if defined(HAVE_AVX512F)
135                         __m512 vwl0 = _mm512_set1_ps(w_lambda0);
136                         __m512 vwl1 = _mm512_set1_ps(w_lambda1);
137                         __m512 vhl0 = _mm512_set1_ps(h_lambda0);
138                         __m512 vhl1 = _mm512_set1_ps(h_lambda1);
139                         __m512 vsrc00 = _mm512_loadu_ps(psrc00);
140                         __m512 vsrc01 = _mm512_loadu_ps(psrc01);
141                         __m512 vsrc10 = _mm512_loadu_ps(psrc10);
142                         __m512 vsrc11 = _mm512_loadu_ps(psrc11);
143
144                         __m512 vdst0 = _mm512_fmadd_ps(vwl1, vsrc00, _mm512_mul_ps(vwl0, vsrc01));
145                         __m512 vdst1 = _mm512_fmadd_ps(vwl1, vsrc10, _mm512_mul_ps(vwl0, vsrc11));
146                         __m512 vdst  = _mm512_fmadd_ps(vhl1, vdst0, _mm512_mul_ps(vhl0, vdst1));
147
148                         _mm512_storeu_ps(pdst, vdst);
149 #elif defined(HAVE_AVX2)
150                         __m256 vwl0 = _mm256_set1_ps(w_lambda0);
151                         __m256 vwl1 = _mm256_set1_ps(w_lambda1);
152                         __m256 vhl0 = _mm256_set1_ps(h_lambda0);
153                         __m256 vhl1 = _mm256_set1_ps(h_lambda1);
154                         __m256 vsrc00 = _mm256_loadu_ps(psrc00);
155                         __m256 vsrc01 = _mm256_loadu_ps(psrc01);
156                         __m256 vsrc10 = _mm256_loadu_ps(psrc10);
157                         __m256 vsrc11 = _mm256_loadu_ps(psrc11);
158
159                        __m256 vdst0 = _mm256_fmadd_ps(vwl1, vsrc00, _mm256_mul_ps(vwl0, vsrc01));
160                        __m256 vdst1 = _mm256_fmadd_ps(vwl1, vsrc10, _mm256_mul_ps(vwl0, vsrc11));
161                        __m256 vdst  = _mm256_fmadd_ps(vhl1, vdst0, _mm256_mul_ps(vhl0, vdst1));
162
163                        _mm256_storeu_ps(pdst, vdst);
164 #elif defined(HAVE_SSE)
165                         __m128 vwl0 = _mm_set1_ps(w_lambda0);
166                         __m128 vwl1 = _mm_set1_ps(w_lambda1);
167                         __m128 vhl0 = _mm_set1_ps(h_lambda0);
168                         __m128 vhl1 = _mm_set1_ps(h_lambda1);
169                         for (int i = 0; i < block_size/4; i++) {
170                             __m128 vsrc00 = _mm_loadu_ps(psrc00 + i*block_size/2);
171                             __m128 vsrc01 = _mm_loadu_ps(psrc01 + i*block_size/2);
172                             __m128 vsrc10 = _mm_loadu_ps(psrc10 + i*block_size/2);
173                             __m128 vsrc11 = _mm_loadu_ps(psrc11 + i*block_size/2);
174
175                            __m128 vdst00 = _mm_mul_ps(vwl1, vsrc00);
176                            __m128 vdst01 = _mm_mul_ps(vwl0, vsrc01);
177                            __m128 vdst10 = _mm_mul_ps(vwl1, vsrc10);
178                            __m128 vdst11 = _mm_mul_ps(vwl0, vsrc11);
179
180                            __m128 vdst0 = _mm_add_ps(vdst00, vdst01);
181                            __m128 vdst1 = _mm_add_ps(vdst10, vdst11);
182
183                             __m128 vdst = _mm_add_ps(_mm_mul_ps(vhl1, vdst0), _mm_mul_ps(vhl0, vdst1));
184
185                            _mm_storeu_ps(pdst + i*block_size/2, vdst);
186                         }
187 #else
188                         for (int c = 0; c < block_size; ++c) {
189                             pdst[c] = h_lambda1 * (w_lambda1 * psrc00[c] + w_lambda0 * psrc01[c]) +
190                                       h_lambda0 * (w_lambda1 * psrc10[c] + w_lambda0 * psrc11[c]);
191                         }
192 #endif
193             }
194         });
195     }
196 };
197
198 REG_FACTORY_FOR(ImplFactory<InterpImpl>, Interp);
199
200 }  // namespace Cpu
201 }  // namespace Extensions
202 }  // namespace InferenceEngine