Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / shape_infer / const_infer / ie_gather_const_infer.hpp
1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <ie_blob.h>
8 #include <map>
9 #include <memory>
10 #include <cmath>
11 #include <string>
12 #include <vector>
13 #include <ie_layers.h>
14 #include <ie_algorithm.hpp>
15 #include "ie_const_infer_impl.hpp"
16 #include "ie_parallel.hpp"
17
18 namespace InferenceEngine {
19 namespace ShapeInfer {
20
21 struct GatherParams {
22     size_t dataLength = 1;
23     int axis = 0;
24     size_t indexRange = 0;
25     size_t numDictionaries = 1;
26 };
27
28 template<typename data_t>
29 void
30 gather(data_t* src_dataIdx, const Blob::CPtr& indexes, const Blob::CPtr& dictionary, const Blob::Ptr& output,
31        const GatherParams& p) {
32     size_t src_dataIdxSize = indexes->size();
33     size_t dataSize = sizeof(float) * p.dataLength;
34
35     const float* src_dataDict =
36             dictionary->cbuffer().as<const float*>() + dictionary->getTensorDesc().getBlockingDesc().getOffsetPadding();
37     float* dst_data = output->cbuffer().as<float*>() + output->getTensorDesc().getBlockingDesc().getOffsetPadding();
38     src_dataIdx += indexes->getTensorDesc().getBlockingDesc().getOffsetPadding();
39
40     if (p.axis == 0) {
41         parallel_for(src_dataIdxSize, [&](size_t i) {
42             int idx = static_cast<int>(src_dataIdx[i]);
43
44             //  Index clipping
45             details::clipping(&idx, 0, p.indexRange);
46
47             //  Copying data to destination from Dictionary
48             ie_memcpy(&dst_data[p.dataLength * i],
49                       output->byteSize() - (p.dataLength * i),
50                       &src_dataDict[p.dataLength * idx],
51                       dataSize);
52         });
53     } else {
54         parallel_for(src_dataIdxSize, [&](size_t i) {
55             int idx = static_cast<int>(src_dataIdx[i]);
56
57             //  Index clipping
58             details::clipping(&idx, 0, p.indexRange);
59
60             //  Copying data to destination from Dictionary
61             for (size_t j = 0; j < p.numDictionaries; j++) {
62                 ie_memcpy(&dst_data[p.dataLength * (i + j * src_dataIdxSize)],
63                           output->byteSize() - (p.dataLength * (i + j * src_dataIdxSize)),
64                           &src_dataDict[p.dataLength * (idx + j * p.indexRange)],
65                           dataSize);
66             }
67         });
68     }
69 }
70
71 /**
72  *@brief Implementation of Const inference for Gather layer
73  */
74 class GatherConstInfer : public ConstInferImpl {
75 public:
76     explicit GatherConstInfer(const std::string& type) : ConstInferImpl(type) {}
77
78     void inferImpl(const std::vector<Blob::CPtr>& inData,
79                    const std::map<std::string, std::string>& params,
80                    const std::map<std::string, Blob::Ptr>& blobs,
81                    std::vector<Blob::Ptr>& outData) override {
82         LayerParams lp{};
83         CNNLayer layer(lp);
84         layer.params = params;
85
86
87         const size_t GATHER_DICTIONARY = 0;
88         const size_t GATHER_INDEXES = 1;
89
90         if (inData.size() != 2 || outData.empty())
91             THROW_IE_EXCEPTION << " Incorrect number of input/output edges!";
92
93         Precision inIdxPrecision = inData[GATHER_INDEXES]->getTensorDesc().getPrecision();
94         if (inIdxPrecision != Precision::FP32 &&
95             inIdxPrecision != Precision::I32 &&
96             inIdxPrecision != Precision::U16 &&
97             inIdxPrecision != Precision::I16 &&
98             inIdxPrecision != Precision::U8 &&
99             inIdxPrecision != Precision::I8)
100             THROW_IE_EXCEPTION << " Incorrect input precision. Only FP32|I32|U16|I16|U8|I8 are supported!";
101
102         //  Remove redundant dimensions
103         const SizeVector& dictionary_dims = inData[GATHER_DICTIONARY]->getTensorDesc().getDims();
104         size_t actualAxis = 0;
105         SizeVector dims_actual;
106         for (size_t i = 0; i < dictionary_dims.size(); i++) {
107             if (dictionary_dims[i] > 1) {
108                 for (size_t j = i; j < dictionary_dims.size(); j++)
109                     dims_actual.push_back(dictionary_dims[j]);
110                 break;
111             }
112         }
113
114         if (dims_actual.size() == 0)
115             THROW_IE_EXCEPTION << " Incorrect input parameters dimension!";
116
117         GatherParams p;
118         p.axis = static_cast<int>(layer.GetParamAsInt("axis"));
119         // Dictionary must be at least rank axis + 1
120         if (p.axis > 0 && dims_actual.size() < (1 + p.axis))
121             THROW_IE_EXCEPTION << " Incorrect input parameters dimensions and axis number!";
122         else if (p.axis < 0 && (static_cast<int>(dims_actual.size()) + p.axis) < 0)
123             THROW_IE_EXCEPTION << " Incorrect input parameters dimensions and axis number!";
124
125         if (p.axis < 0)
126             p.axis += dims_actual.size();
127
128         //  Find number of dictionaries, index range and data length
129         for (size_t i = 0; i < p.axis; i++)
130             p.numDictionaries *= dims_actual[i];
131         p.indexRange = dims_actual[p.axis];
132         for (size_t i = p.axis + 1; i < dims_actual.size(); i++)
133             p.dataLength *= dims_actual[i];
134
135         if (p.dataLength == 0)
136             THROW_IE_EXCEPTION << " Incorrect input parameters dimension!";
137
138
139         switch (inData[GATHER_INDEXES]->precision()) {
140             case Precision::FP32:
141                 gather(inData[GATHER_INDEXES]->cbuffer().as<const float*>(), inData[GATHER_INDEXES],
142                        inData[GATHER_DICTIONARY], outData[0], p);
143                 break;
144             case Precision::I32:
145                 gather(inData[GATHER_INDEXES]->cbuffer().as<const int32_t*>(), inData[GATHER_INDEXES],
146                        inData[GATHER_DICTIONARY], outData[0], p);
147                 break;
148             case Precision::U16:
149                 gather(inData[GATHER_INDEXES]->cbuffer().as<const uint16_t*>(), inData[GATHER_INDEXES],
150                        inData[GATHER_DICTIONARY], outData[0], p);
151                 break;
152             case Precision::I16:
153                 gather(inData[GATHER_INDEXES]->cbuffer().as<const int16_t*>(), inData[GATHER_INDEXES],
154                        inData[GATHER_DICTIONARY], outData[0], p);
155                 break;
156             case Precision::U8:
157                 gather(inData[GATHER_INDEXES]->cbuffer().as<const uint8_t*>(), inData[GATHER_INDEXES],
158                        inData[GATHER_DICTIONARY], outData[0], p);
159                 break;
160             case Precision::I8:
161                 gather(inData[GATHER_INDEXES]->cbuffer().as<const int8_t*>(), inData[GATHER_INDEXES],
162                        inData[GATHER_DICTIONARY], outData[0], p);
163                 break;
164             default:
165                 THROW_IE_EXCEPTION << " Unsupported precision!";
166         }
167     }
168 };
169
170 }  // namespace ShapeInfer
171 }  // namespace InferenceEngine