Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / extension / ext_gather.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
7
8 #include <cmath>
9 #include <string>
10 #include <vector>
11 #include <cassert>
12 #include <algorithm>
13 #include <limits>
14 #include "ie_parallel.hpp"
15 #include "simple_copy.h"
16
17 namespace InferenceEngine {
18 namespace Extensions {
19 namespace Cpu {
20
21 class GatherImpl: public ExtLayerBase {
22 public:
23     explicit GatherImpl(const CNNLayer* layer) {
24         try {
25             if (layer->insData.size() != 2 || layer->outData.empty())
26                 THROW_IE_EXCEPTION << layer->name << " Incorrect number of input/output edges!";
27
28             Precision inIdxPrecision = layer->insData[GATHER_INDEXES].lock()->getTensorDesc().getPrecision();
29             if (inIdxPrecision != Precision::FP32 && inIdxPrecision != Precision::I32)
30                 THROW_IE_EXCEPTION << layer->name << " Incorrect input precision. Only FP32 or I32 are supported!";
31
32             //  Remove redundant dimensions
33             const SizeVector& dictionary_dims = layer->insData[GATHER_DICTIONARY].lock()->getTensorDesc().getDims();
34             SizeVector dims_actual;
35             for (size_t i = 0; i < dictionary_dims.size(); i++) {
36                 if (dictionary_dims[i] > 1) {
37                     for (size_t j = i; j < dictionary_dims.size(); j++)
38                         dims_actual.push_back(dictionary_dims[j]);
39                     break;
40                 }
41             }
42
43             if (dims_actual.size() == 0)
44                 THROW_IE_EXCEPTION << layer->name << " Incorrect input parameters dimension!";
45
46             axis = static_cast<int>(layer->GetParamAsInt("axis"));
47             // Dictionary must be at least rank axis + 1
48             if (axis > 0 && static_cast<int>(dims_actual.size()) < (1 + axis))
49                 THROW_IE_EXCEPTION << layer->name << " Incorrect input parameters dimensions and axis number!";
50             else if (axis < 0 && (static_cast<int>(dims_actual.size()) + axis) < 0)
51                 THROW_IE_EXCEPTION << layer->name << " Incorrect input parameters dimensions and axis number!";
52
53             if (axis < 0)
54                 axis += dims_actual.size();
55
56             //  Find number of dictionaries, index range and data length
57             for (int i = 0; i < axis; i++)
58                 numDictionaries *= dims_actual[i];
59             indexRange = dims_actual[axis];
60             for (size_t i = axis + 1; i < dims_actual.size(); i++)
61                 dataLength *= dims_actual[i];
62
63             if (dataLength == 0)
64                 THROW_IE_EXCEPTION << layer->name << " Incorrect input parameters dimension!";
65
66             addConfig(layer, { DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN) },
67                       { DataConfigurator(ConfLayout::PLN) });
68         } catch (InferenceEngine::details::InferenceEngineException &ex) {
69             errorMsg = ex.what();
70         }
71     }
72
73     StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept override {
74         switch (inputs[GATHER_INDEXES]->precision()) {
75             case Precision::FP32:
76                 gather(inputs[GATHER_INDEXES]->cbuffer().as<const float *>(), inputs[GATHER_INDEXES], inputs[GATHER_DICTIONARY], outputs[0]);
77                 break;
78             case Precision::I32:
79                 gather(inputs[GATHER_INDEXES]->cbuffer().as<const int32_t *>(), inputs[GATHER_INDEXES], inputs[GATHER_DICTIONARY], outputs[0]);
80                 break;
81             default:
82                 return GENERAL_ERROR;
83         }
84
85         return OK;
86     }
87
88 private:
89     template <typename data_t>
90     void gather(data_t *src_dataIdx, Blob::Ptr indexes, Blob::Ptr dictionary, Blob::Ptr output);
91
92     int axis = 0;
93     size_t numDictionaries = 1;
94     size_t indexRange = 0;
95     size_t dataLength = 1;
96     const size_t GATHER_DICTIONARY = 0;
97     const size_t GATHER_INDEXES = 1;
98 };
99
100 template <typename data_t>
101 void GatherImpl::gather(data_t *src_dataIdx, Blob::Ptr indexes, Blob::Ptr dictionary, Blob::Ptr output) {
102     size_t src_dataIdxSize = indexes->size();
103     const float *src_dataDict = dictionary->cbuffer().as<const float *>() + dictionary->getTensorDesc().getBlockingDesc().getOffsetPadding();
104     float* dst_data = output->cbuffer().as<float *>() + output->getTensorDesc().getBlockingDesc().getOffsetPadding();
105     src_dataIdx += indexes->getTensorDesc().getBlockingDesc().getOffsetPadding();
106
107     if (axis == 0) {
108         parallel_for(src_dataIdxSize, [&](size_t i) {
109             unsigned int idx = static_cast<unsigned int>(src_dataIdx[i]);
110
111             //  Index clipping
112             if (idx < indexRange) {
113                 //  Copying data to destination from Dictionary
114                 simple_copy(&dst_data[i * dataLength],
115                             output->byteSize() - (dataLength * i),
116                             &src_dataDict[dataLength * idx],
117                             sizeof(float) * dataLength);
118             } else {
119                 std::fill_n(&dst_data[i * dataLength], dataLength, 0.f);
120             }
121         });
122     } else {
123         parallel_for(src_dataIdxSize, [&](size_t i) {
124             unsigned int idx = static_cast<unsigned int>(src_dataIdx[i]);
125
126             //  Index clipping
127             if (idx < indexRange) {
128                 //  Copying data to destination from Dictionary
129                 for (size_t j = 0; j < numDictionaries; j++) {
130                     simple_copy(&dst_data[dataLength * (i + j * src_dataIdxSize)],
131                                 output->byteSize() - (dataLength * (i + j * src_dataIdxSize)),
132                                 &src_dataDict[dataLength * (idx + j * indexRange)],
133                                 sizeof(float) * dataLength);
134                 }
135             } else {
136                 for (size_t j = 0; j < numDictionaries; j++) {
137                     std::fill_n(&dst_data[dataLength * (i + j * src_dataIdxSize)], dataLength, 0.f);
138                 }
139             }
140         });
141     }
142 }
143
144 REG_FACTORY_FOR(ImplFactory<GatherImpl>, Gather);
145
146 }  // namespace Cpu
147 }  // namespace Extensions
148 }  // namespace InferenceEngine