Publishing 2019 R1.1 content and Myriad plugin sources (#162)
[platform/upstream/dldt.git] / inference-engine / src / extension / ext_strided_slice.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
7
8 #include <cmath>
9 #include <string>
10 #include <vector>
11 #include <cassert>
12 #include <algorithm>
13 #include "ie_parallel.hpp"
14
15 namespace InferenceEngine {
16 namespace Extensions {
17 namespace Cpu {
18
19 inline void clipping(int *idx, const int min, const int max) {
20     (*idx) = ((*idx) > min) ? (*idx) : min;
21     (*idx) = ((*idx) < max) ? (*idx) : (max - 1);
22     return;
23 }
24
25 class StridedSliceImpl: public ExtLayerBase {
26 public:
27     explicit StridedSliceImpl(const CNNLayer* layer) {
28         try {
29             if (layer->insData.size() > 4 || layer->outData.size() != 1)
30                 THROW_IE_EXCEPTION << layer->name << " Incorrect number of input/output edges!";
31
32             src_dims = layer->insData[STRIDEDSLICE_DATA].lock()->getTensorDesc().getDims();
33
34             bounds_size = 0;
35             begin_dims = {};
36             if (layer->insData.size() > 1) {
37                 begin_dims = layer->insData[STRIDEDSLICE_BEGIN].lock()->getTensorDesc().getDims();
38                 if (layer->insData[STRIDEDSLICE_BEGIN].lock()->getTensorDesc().getPrecision() != Precision::I32)
39                     THROW_IE_EXCEPTION << layer->name << " Incorrect 'begin' input precision. Only I32 is supported!";
40                 if (begin_dims.size() > 1)
41                     THROW_IE_EXCEPTION << layer->name << " Begin vector should be 1 dimension";
42                 bounds_size = begin_dims[0];
43             }
44
45             if (layer->insData.size() > 2) {
46                 end_dims = layer->insData[STRIDEDSLICE_END].lock()->getTensorDesc().getDims();
47                 if (layer->insData[STRIDEDSLICE_END].lock()->getTensorDesc().getPrecision() != Precision::I32)
48                     THROW_IE_EXCEPTION << layer->name << " Incorrect 'end' input precision. Only I32 is supported!";
49                 if (end_dims.size() > 1)
50                     THROW_IE_EXCEPTION << layer->name << " End vector should be 1 dimension";
51                 if (begin_dims[0] != end_dims[0])
52                     THROW_IE_EXCEPTION << layer->name << " Begin vector size should be equal end vectror size";
53             }
54
55             if (layer->insData.size() > 3) {
56                 stride_dims = layer->insData[STRIDEDSLICE_STRIDE].lock()->getTensorDesc().getDims();
57                 if (layer->insData[STRIDEDSLICE_STRIDE].lock()->getTensorDesc().getPrecision() != Precision::I32)
58                     THROW_IE_EXCEPTION << layer->name << " Incorrect 'strides' input precision. Only I32 is supported!";
59                 if (stride_dims.size() > 1)
60                     THROW_IE_EXCEPTION << layer->name << " End vector should be 1 dimension";
61                 if (begin_dims[0] != stride_dims[0])
62                     THROW_IE_EXCEPTION << layer->name << " Stride vector size should be equal begin vectror size";
63             }
64             dst_dims = layer->outData[0]->getTensorDesc().getDims();
65
66             std::string::size_type i;
67             std::string begin_mask_str = layer->GetParamAsString("begin_mask", "");
68             for (i = 0; i < begin_mask_str.size(); ++i) {
69                 if (begin_mask_str[i] == '1') begin_mask.push_back(1);
70                 else if (begin_mask_str[i] == '0') begin_mask.push_back(0);
71             }
72             for (; i < src_dims.size(); ++i) begin_mask.push_back(1);
73
74             std::string end_mask_str = layer->GetParamAsString("end_mask", "");
75             for (i = 0; i < end_mask_str.size(); ++i) {
76                 if (end_mask_str[i] == '1') end_mask.push_back(1);
77                 else if (end_mask_str[i] == '0') end_mask.push_back(0);
78             }
79             for (; i < src_dims.size(); ++i) end_mask.push_back(1);
80
81             std::string ellipsis_mask_str = layer->GetParamAsString("ellipsis_mask", "");
82             size_t ellipsis_mask_counter = 0;
83             for (i = 0; i < ellipsis_mask_str.size(); ++i) {
84                 if (ellipsis_mask_str[i] == '1') {
85                     ellipsis_mask_counter++;
86                     ellipsis_mask.push_back(1);
87                 } else if (ellipsis_mask_str[i] == '0') {
88                     ellipsis_mask.push_back(0);
89                 }
90             }
91             if (ellipsis_mask_counter > 1)
92                 THROW_IE_EXCEPTION << layer->name << " 'Ellipsis_mask' must be a power of two (only one ellipsis)!";
93             for (; i < src_dims.size(); ++i) ellipsis_mask.push_back(0);
94
95             std::string new_axis_mask_str = layer->GetParamAsString("new_axis_mask", "");
96             for (i = 0; i < new_axis_mask_str.size(); ++i) {
97                 if (new_axis_mask_str[i] == '1') new_axis_mask.push_back(1);
98                 else if (new_axis_mask_str[i] == '0') new_axis_mask.push_back(0);
99             }
100             for (; i < src_dims.size(); ++i) new_axis_mask.push_back(0);
101
102             std::string shrink_axis_mask_str = layer->GetParamAsString("shrink_axis_mask", "");
103             for (i = 0; i < shrink_axis_mask_str.size(); ++i) {
104                 if (shrink_axis_mask_str[i] == '1') shrink_axis_mask.push_back(1);
105                 else if (shrink_axis_mask_str[i] == '0') shrink_axis_mask.push_back(0);
106             }
107             for (; i < src_dims.size(); ++i) shrink_axis_mask.push_back(0);
108
109
110             int new_axis = 0;
111             for (auto& na : new_axis_mask)
112                 new_axis += na;
113
114             shrink_axis = 0;
115             for (auto& sa : shrink_axis_mask)
116                 shrink_axis += sa;
117             max_dims = src_dims.size() + new_axis;
118
119             //  ellipsis_mask must be a power of two (only one ellipsis), so to take a first position
120             ellipsis_pos1 = ellipsis_pos2 = max_dims;
121             for (i = 0; i < ellipsis_mask.size(); i++) {
122                 if (ellipsis_mask[i] > 0) {
123                     ellipsis_pos1 = i;
124                     break;
125                 }
126             }
127             bounds_size -= ellipsis_pos1;
128             if (bounds_size > 0 && (max_dims - bounds_size) > ellipsis_pos1)
129                 ellipsis_pos2 = max_dims - bounds_size;
130
131             begin_dms.assign(max_dims, 0);
132             end_dms.assign(max_dims, -1);
133             stride_dms.assign(max_dims, 1);
134
135             srcStrides = layer->insData[STRIDEDSLICE_DATA].lock()->getTensorDesc().getBlockingDesc().getStrides();
136             dstStrides = layer->outData[0]->getTensorDesc().getBlockingDesc().getStrides();
137             if (layer->insData.size() == 1) {
138                 addConfig(layer, { DataConfigurator(ConfLayout::PLN) }, { DataConfigurator(ConfLayout::PLN) });
139             } else if (layer->insData.size() == 2) {
140                 addConfig(layer, { DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN) }, { DataConfigurator(ConfLayout::PLN) });
141             } else if (layer->insData.size() == 3) {
142                 addConfig(layer, { DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN) },
143                           { DataConfigurator(ConfLayout::PLN) });
144             } else {
145                 addConfig(layer, { DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN),
146                                    DataConfigurator(ConfLayout::PLN) }, { DataConfigurator(ConfLayout::PLN) });
147             }
148         } catch (InferenceEngine::details::InferenceEngineException &ex) {
149             errorMsg = ex.what();
150         }
151     }
152
153     StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept override {
154         const float *src_data = inputs[STRIDEDSLICE_DATA]->cbuffer().as<const float *>() +
155             inputs[STRIDEDSLICE_DATA]->getTensorDesc().getBlockingDesc().getOffsetPadding();
156         int *begin = nullptr, *end = nullptr, *stride = nullptr;
157         if (begin_dims.size())
158             begin = inputs[STRIDEDSLICE_BEGIN]->cbuffer().as<int *>() + inputs[STRIDEDSLICE_BEGIN]->getTensorDesc().getBlockingDesc().getOffsetPadding();
159         if (end_dims.size())
160             end = inputs[STRIDEDSLICE_END]->cbuffer().as<int *>() + inputs[STRIDEDSLICE_END]->getTensorDesc().getBlockingDesc().getOffsetPadding();
161         if (stride_dims.size())
162             stride = inputs[STRIDEDSLICE_STRIDE]->cbuffer().as<int *>() + inputs[STRIDEDSLICE_STRIDE]->getTensorDesc().getBlockingDesc().getOffsetPadding();
163         float* dst_data = outputs[0]->cbuffer().as<float *>() +
164             outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
165
166         InferenceEngine::SizeVector src_dims = inputs[STRIDEDSLICE_DATA]->getTensorDesc().getDims();
167         InferenceEngine::SizeVector srcStrides = inputs[STRIDEDSLICE_DATA]->getTensorDesc().getBlockingDesc().getStrides();
168         InferenceEngine::SizeVector dst_dims = outputs[0]->getTensorDesc().getDims();
169         InferenceEngine::SizeVector dstStrides = outputs[0]->getTensorDesc().getBlockingDesc().getStrides();
170
171         size_t i, j, k, bj, ej, sj;
172         InferenceEngine::SizeVector our_dims;
173         InferenceEngine::SizeVector out_dims;
174         for (i = 0, j = 0, k = 0, bj = 0, ej = 0, sj = 0; static_cast<int>(i) < max_dims; i++) {
175             if (static_cast<int>(i) >= ellipsis_pos1 &&
176                     static_cast<int>(i) < ellipsis_pos2) {
177                 if (new_axis_mask.size() > i && new_axis_mask[i] == 1)
178                     end_dms[i] = 0;
179                 else
180                     end_dms[i] = end_dms[i] >= 0 ? end_dms[i] : src_dims[j++] + end_dms[i];
181
182                 out_dims.push_back(static_cast<int>(ceil(static_cast<float>(abs(end_dms[i] - begin_dms[i]) + 1) / static_cast<float>(abs(stride_dms[i])))));
183                 our_dims.push_back(static_cast<int>(ceil(static_cast<float>(abs(end_dms[i] - begin_dms[i]) + 1) / static_cast<float>(abs(stride_dms[i])))));
184                 k = ellipsis_pos1;
185             } else {
186                 stride_dms[i] = (stride != nullptr && stride_dims[0] > sj && stride[sj] != 0) ? stride[sj++] : 1;
187
188                 if (begin_mask.size() > j && begin_mask[j] == 0)
189                     begin_dms[i] = stride_dms[i] > 0 ? 0 : -1;
190                 else
191                     begin_dms[i] = (begin != nullptr && begin_dims[0] > bj) ? begin[bj] : (stride_dms[i] > 0 ? 0 : -1);
192                 bj++;
193                 begin_dms[i] = begin_dms[i] >= 0 ? begin_dms[i] : src_dims[j] + begin_dms[i];
194                 //  Clipping 'begin'
195                 clipping(&begin_dms[i], 0, src_dims[j]);
196
197                 if (end_mask.size() > j && end_mask[j] == 0) {
198                     end_dms[i] = stride_dms[i] > 0 ? -1 : 0;
199                 } else {
200                     int end_dms_tmp = (end != nullptr && end_dims[0] > ej) ? (stride_dms[i] > 0 ? end[ej] - 1 : end[ej] + 1)
201                                                                      : end_dms[i];
202                     end_dms[i] = (end != nullptr && end_dims[0] > ej) ? end_dms_tmp : (stride_dms[i] > 0 ? -1 : 0);
203                 }
204                 ej++;
205                 end_dms[i] = end_dms[i] >= 0 ? end_dms[i] : src_dims[j] + end_dms[i];
206                 //  Clipping 'end'
207                 clipping(&end_dms[i], 0, src_dims[j]);
208
209                 if (new_axis_mask.size() > i && new_axis_mask[i] == 1)
210                     end_dms[i] = 0;
211                 else
212                     j++;
213
214                 if (shrink_axis_mask.size() > k && shrink_axis_mask[k] == 1)
215                     end_dms[i] = begin_dms[i];
216                 else
217                     out_dims.push_back(static_cast<int>(ceil(static_cast<float>(abs(end_dms[i] - begin_dms[i]) + 1) /
218                                                              static_cast<float>(abs(stride_dms[i])))));
219
220                 our_dims.push_back(static_cast<int>(ceil(static_cast<float>(abs(end_dms[i] - begin_dms[i]) + 1) /
221                                                          static_cast<float>(abs(stride_dms[i])))));
222                 k++;
223             }
224         }
225
226         for (i = 0; i < std::min(out_dims.size(), dst_dims.size()); i++) {
227             if (out_dims[i] != dst_dims[i])
228                 return PARAMETER_MISMATCH;
229         }
230
231         if (static_cast<int>(src_dims.size()) == max_dims && shrink_axis == 0 &&
232                 stride_dms[stride_dms.size()-1] == 1 && stride_dms.size() > 1)
233             strided_slice_vp(src_data, dst_data);
234         else if (static_cast<int>(src_dims.size()) == max_dims && shrink_axis == 0)
235             strided_slice_p(src_data, dst_data);
236         else
237             strided_slice(src_data, dst_data, our_dims);
238
239         return OK;
240     }
241
242 private:
243     const size_t STRIDEDSLICE_DATA = 0;
244     const size_t STRIDEDSLICE_BEGIN = 1;
245     const size_t STRIDEDSLICE_END = 2;
246     const size_t STRIDEDSLICE_STRIDE = 3;
247
248     void strided_slice(const float *src_data, float* dst_data, std::vector<size_t> &dims);
249     void strided_slice_vp(const float *src_data, float* dst_data);
250     void strided_slice_p(const float *src_data, float* dst_data);
251
252     SizeVector begin_dims;
253     SizeVector end_dims;
254     SizeVector stride_dims;
255
256     SizeVector begin_mask;
257     SizeVector end_mask;
258     SizeVector ellipsis_mask;
259     SizeVector new_axis_mask;
260     SizeVector shrink_axis_mask;
261     int shrink_axis;
262
263     SizeVector src_dims;
264     SizeVector dst_dims;
265     std::vector<int> begin_dms;
266     std::vector<int> end_dms;
267     std::vector<int> stride_dms;
268     SizeVector srcStrides;
269     SizeVector dstStrides;
270     int bounds_size;
271     int max_dims;
272     int ellipsis_pos1, ellipsis_pos2;
273 };
274
275 void StridedSliceImpl::strided_slice(const float *src_data, float* dst_data, std::vector<size_t> &dims) {
276     size_t work_amount_dst = dstStrides[0] * dst_dims[0];
277     parallel_nt(0, [&](const int ithr, const int nthr) {
278         int j;
279         size_t i, start = 0, end = 0;
280         SizeVector counters(max_dims, 0);
281         splitter(work_amount_dst, nthr, ithr, start, end);
282         for (j = max_dims - 1, i = start; j >= 0; j--) {
283             counters[j] = i % dims[j];
284             i /= dims[j];
285         }
286         for (size_t iwork = start; iwork < end; ++iwork) {
287             int src_idx = 0;
288             for (i = 0, j = 0; static_cast<int>(i) < max_dims; ++i) {
289                 if (!(new_axis_mask.size() > i && new_axis_mask[i] == 1))
290                     src_idx += (begin_dms[i] + counters[i] * stride_dms[i]) * srcStrides[j++];
291             }
292
293             dst_data[iwork] = src_data[src_idx];
294
295             for (j = max_dims - 1; j >= 0; j--) {
296                 counters[j]++;
297                 if (counters[j] < dims[j])
298                     break;
299                 else
300                     counters[j] = 0;
301             }
302         }
303     });
304 }
305
306 void StridedSliceImpl::strided_slice_vp(const float *src_data, float* dst_data) {
307     //  Vectorized copy
308     size_t dims_size_1 = dst_dims.size() - 1;
309     size_t dataLength = dst_dims[dims_size_1];
310     size_t work_amount_dst = dstStrides[0] * dst_dims[0] / dst_dims[dims_size_1];
311
312     parallel_nt(0, [&](const int ithr, const int nthr) {
313         size_t start = 0, end = 0;
314         SizeVector counters(dims_size_1, 0);
315         splitter(work_amount_dst, nthr, ithr, start, end);
316         size_t src_idx = begin_dms[dims_size_1];
317         for (int j = dims_size_1 - 1, i = start; j >= 0; j--) {
318             counters[j] = i % dst_dims[j];
319             src_idx += (begin_dms[j] + counters[j] * stride_dms[j]) * srcStrides[j];
320             i /= dst_dims[j];
321         }
322
323         for (size_t iwork = start, dst_idx = start * dataLength, i = 1; iwork < end; ++iwork, dst_idx += dataLength) {
324             memcpy(&dst_data[dst_idx], &src_data[src_idx], sizeof(float) * dataLength);
325             for (int j = dims_size_1 - 1; j >= 0; j--) {
326                 counters[j]++;
327                 if (counters[j] < dst_dims[j]) {
328                     src_idx += stride_dms[j] * srcStrides[j];
329                     break;
330                 } else {
331                     counters[j] = i = 0;
332                 }
333             }
334             if (!i) {
335                 for (src_idx = begin_dms[dims_size_1]; i < dims_size_1; ++i)
336                     src_idx += (begin_dms[i] + counters[i] * stride_dms[i]) * srcStrides[i];
337             }
338         }
339     });
340 }
341
342 void StridedSliceImpl::strided_slice_p(const float *src_data, float* dst_data) {
343     size_t dims_size = dst_dims.size();
344     size_t work_amount_dst = dstStrides[0] * dst_dims[0];
345
346     parallel_nt(0, [&](const int ithr, const int nthr) {
347         size_t start = 0, end = 0;
348         SizeVector counters(dims_size, 0);
349         splitter(work_amount_dst, nthr, ithr, start, end);
350         int src_idx = 0;
351         for (int j = dims_size - 1, i = start; j >= 0; j--) {
352             counters[j] = i % dst_dims[j];
353             src_idx += (begin_dms[j] + counters[j] * stride_dms[j]) * srcStrides[j];
354             i /= dst_dims[j];
355         }
356
357         for (size_t iwork = start, dst_idx = start, i = 1; iwork < end; ++iwork, dst_idx++) {
358             dst_data[dst_idx] = src_data[src_idx];
359             for (int j = dims_size - 1; j >= 0; j--) {
360                 counters[j]++;
361                 if (counters[j] < dst_dims[j]) {
362                     src_idx += stride_dms[j] * srcStrides[j];
363                     break;
364                 } else {
365                     counters[j] = i = 0;
366                 }
367             }
368             if (!i) {
369                 for (src_idx = 0; i < dims_size; ++i)
370                     src_idx += (begin_dms[i] + counters[i] * stride_dms[i]) * srcStrides[i];
371             }
372         }
373     });
374 }
375
376 REG_FACTORY_FOR(ImplFactory<StridedSliceImpl>, StridedSlice);
377
378 }  // namespace Cpu
379 }  // namespace Extensions
380 }  // namespace InferenceEngine