1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
14 #include "ie_parallel.hpp"
15 #include "simple_copy.h"
17 namespace InferenceEngine {
18 namespace Extensions {
21 class GatherImpl: public ExtLayerBase {
23 explicit GatherImpl(const CNNLayer* layer) {
25 if (layer->insData.size() != 2 || layer->outData.empty())
26 THROW_IE_EXCEPTION << layer->name << " Incorrect number of input/output edges!";
28 Precision inIdxPrecision = layer->insData[GATHER_INDEXES].lock()->getTensorDesc().getPrecision();
29 if (inIdxPrecision != Precision::FP32 && inIdxPrecision != Precision::I32)
30 THROW_IE_EXCEPTION << layer->name << " Incorrect input precision. Only FP32 or I32 are supported!";
32 // Remove redundant dimensions
33 const SizeVector& dictionary_dims = layer->insData[GATHER_DICTIONARY].lock()->getTensorDesc().getDims();
34 SizeVector dims_actual;
35 for (size_t i = 0; i < dictionary_dims.size(); i++) {
36 if (dictionary_dims[i] > 1) {
37 for (size_t j = i; j < dictionary_dims.size(); j++)
38 dims_actual.push_back(dictionary_dims[j]);
43 if (dims_actual.size() == 0)
44 THROW_IE_EXCEPTION << layer->name << " Incorrect input parameters dimension!";
46 axis = static_cast<int>(layer->GetParamAsInt("axis"));
47 // Dictionary must be at least rank axis + 1
48 if (axis > 0 && static_cast<int>(dims_actual.size()) < (1 + axis))
49 THROW_IE_EXCEPTION << layer->name << " Incorrect input parameters dimensions and axis number!";
50 else if (axis < 0 && (static_cast<int>(dims_actual.size()) + axis) < 0)
51 THROW_IE_EXCEPTION << layer->name << " Incorrect input parameters dimensions and axis number!";
54 axis += dims_actual.size();
56 // Find number of dictionaries, index range and data length
57 for (int i = 0; i < axis; i++)
58 numDictionaries *= dims_actual[i];
59 indexRange = dims_actual[axis];
60 for (size_t i = axis + 1; i < dims_actual.size(); i++)
61 dataLength *= dims_actual[i];
64 THROW_IE_EXCEPTION << layer->name << " Incorrect input parameters dimension!";
66 addConfig(layer, { DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN) },
67 { DataConfigurator(ConfLayout::PLN) });
68 } catch (InferenceEngine::details::InferenceEngineException &ex) {
73 StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept override {
74 switch (inputs[GATHER_INDEXES]->precision()) {
76 gather(inputs[GATHER_INDEXES]->cbuffer().as<const float *>(), inputs[GATHER_INDEXES], inputs[GATHER_DICTIONARY], outputs[0]);
79 gather(inputs[GATHER_INDEXES]->cbuffer().as<const int32_t *>(), inputs[GATHER_INDEXES], inputs[GATHER_DICTIONARY], outputs[0]);
89 template <typename data_t>
90 void gather(data_t *src_dataIdx, Blob::Ptr indexes, Blob::Ptr dictionary, Blob::Ptr output);
93 size_t numDictionaries = 1;
94 size_t indexRange = 0;
95 size_t dataLength = 1;
96 const size_t GATHER_DICTIONARY = 0;
97 const size_t GATHER_INDEXES = 1;
100 template <typename data_t>
101 void GatherImpl::gather(data_t *src_dataIdx, Blob::Ptr indexes, Blob::Ptr dictionary, Blob::Ptr output) {
102 size_t src_dataIdxSize = indexes->size();
103 const float *src_dataDict = dictionary->cbuffer().as<const float *>() + dictionary->getTensorDesc().getBlockingDesc().getOffsetPadding();
104 float* dst_data = output->cbuffer().as<float *>() + output->getTensorDesc().getBlockingDesc().getOffsetPadding();
105 src_dataIdx += indexes->getTensorDesc().getBlockingDesc().getOffsetPadding();
108 parallel_for(src_dataIdxSize, [&](size_t i) {
109 unsigned int idx = static_cast<unsigned int>(src_dataIdx[i]);
112 if (idx < indexRange) {
113 // Copying data to destination from Dictionary
114 simple_copy(&dst_data[i * dataLength],
115 output->byteSize() - (dataLength * i),
116 &src_dataDict[dataLength * idx],
117 sizeof(float) * dataLength);
119 std::fill_n(&dst_data[i * dataLength], dataLength, 0.f);
123 parallel_for(src_dataIdxSize, [&](size_t i) {
124 unsigned int idx = static_cast<unsigned int>(src_dataIdx[i]);
127 if (idx < indexRange) {
128 // Copying data to destination from Dictionary
129 for (size_t j = 0; j < numDictionaries; j++) {
130 simple_copy(&dst_data[dataLength * (i + j * src_dataIdxSize)],
131 output->byteSize() - (dataLength * (i + j * src_dataIdxSize)),
132 &src_dataDict[dataLength * (idx + j * indexRange)],
133 sizeof(float) * dataLength);
136 for (size_t j = 0; j < numDictionaries; j++) {
137 std::fill_n(&dst_data[dataLength * (i + j * src_dataIdxSize)], dataLength, 0.f);
144 REG_FACTORY_FOR(ImplFactory<GatherImpl>, Gather);
147 } // namespace Extensions
148 } // namespace InferenceEngine