1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
12 #include "ie_parallel.hpp"
14 namespace InferenceEngine {
15 namespace Extensions {
18 class ExpandImpl: public ExtLayerBase {
20 explicit ExpandImpl(const CNNLayer* layer) {
22 if (layer->insData.empty() || layer->outData.empty())
23 THROW_IE_EXCEPTION << layer->name << " Incorrect number of input/output edges!";
25 if (layer->insData.size() != 2)
26 THROW_IE_EXCEPTION << layer->name << " Incorrect number of input edges!";
28 SizeVector shape_dims = layer->insData[EXPAND_SHAPE].lock()->getTensorDesc().getDims();
29 if (shape_dims.size() > 1)
30 THROW_IE_EXCEPTION << layer->name << " Shape vector should be 1 dimension";
32 if (layer->insData[EXPAND_SHAPE].lock()->getTensorDesc().getPrecision() != Precision::I32)
33 THROW_IE_EXCEPTION << layer->name << " Shape vector should be I32!";
35 if (!(layer->insData[EXPAND_INPUT].lock()->getTensorDesc().getPrecision() == Precision::I32 &&
36 layer->outData[0]->getTensorDesc().getPrecision() == Precision::I32) &&
37 !(layer->insData[EXPAND_INPUT].lock()->getTensorDesc().getPrecision() == Precision::FP32 &&
38 layer->outData[0]->getTensorDesc().getPrecision() == Precision::FP32)) {
39 THROW_IE_EXCEPTION << layer->name <<
40 " Input and output tensors should have same precision and only FP32 and I32 are supported!";
43 src_dims = layer->insData[EXPAND_INPUT].lock()->getTensorDesc().getDims();
44 srcStrides = layer->insData[EXPAND_INPUT].lock()->getTensorDesc().getBlockingDesc().getStrides();
45 addConfig(layer, { DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN) },
46 { DataConfigurator(ConfLayout::PLN) });
47 } catch (InferenceEngine::details::InferenceEngineException &ex) {
52 StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept override {
53 int32_t* shape_dims = inputs[EXPAND_SHAPE]->cbuffer().as<int32_t *>() +
54 inputs[EXPAND_SHAPE]->getTensorDesc().getBlockingDesc().getOffsetPadding();
55 size_t shape_size = (inputs[EXPAND_SHAPE]->getTensorDesc().getDims())[0];
56 SizeVector dst_dims = outputs[0]->getTensorDesc().getDims();
58 if (dst_dims.size() != shape_size) {
60 std::string errorMsg = "Output tensor dimension mismatch";
61 errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
63 return PARAMETER_MISMATCH;
66 if (src_dims.size() > dst_dims.size()) {
68 std::string errorMsg = "Output tensor dimension is smaller then input tensor dimension";
69 errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
71 return PARAMETER_MISMATCH;
75 for (i = 0; i < dst_dims.size(); i++) {
76 if (static_cast<int>(dst_dims[i]) != shape_dims[i]) {
78 std::string errorMsg = "Output tensor dimension size mismatch";
79 errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
81 return PARAMETER_MISMATCH;
85 size_t prefix_size = dst_dims.size() - src_dims.size();
86 for (i = 0; i < src_dims.size(); i++) {
87 if (src_dims[i] != 1 &&
88 static_cast<int>(src_dims[i]) != shape_dims[i + prefix_size]) {
90 std::string errorMsg = "In/Output corresponding dimension must have the same value, or Input dimension is equal to 1";
91 errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
93 return PARAMETER_MISMATCH;
97 InferenceEngine::SizeVector dstStrides = outputs[0]->getTensorDesc().getBlockingDesc().getStrides();
98 InferenceEngine::SizeVector src_aligned(dst_dims.size());
99 InferenceEngine::SizeVector srcStrides_aligned(dst_dims.size());
100 for (i = 0; i < dst_dims.size(); i++) {
101 if (i < prefix_size) {
103 srcStrides_aligned[i] = srcStrides[0];
105 src_aligned[i] = src_dims[i - prefix_size];
106 srcStrides_aligned[i] = srcStrides[i - prefix_size];
110 size_t work_amount_dst = dstStrides[0] * dst_dims[0];
112 switch (outputs[0]->precision()) {
113 case Precision::FP32: {
114 const float *src_data = inputs[EXPAND_INPUT]->cbuffer().as<const float *>() +
115 inputs[EXPAND_INPUT]->getTensorDesc().getBlockingDesc().getOffsetPadding();
116 float* dst_data = outputs[0]->cbuffer().as<float *>() +
117 outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
119 parallel_nt(0, [&](const int ithr, const int nthr) {
120 size_t i, src_idx, start = 0, end = 0;
121 SizeVector counters(dst_dims.size(), 0);
122 splitter(work_amount_dst, nthr, ithr, start, end);
123 for (int j = dst_dims.size() - 1, i = start; j >= 0; j--) {
124 counters[j] = i % dst_dims[j];
127 for (size_t iwork = start; iwork < end; ++iwork) {
128 for (i = 0, src_idx = 0; i < dst_dims.size(); ++i)
129 src_idx += counters[i] ? ((counters[i] % src_aligned[i]) * srcStrides_aligned[i]) : 0;
131 dst_data[iwork] = src_data[src_idx];
133 for (int j = dst_dims.size() - 1; j >= 0; j--) {
134 counters[j] = (counters[j] + 1) % dst_dims[j];
135 if (counters[j] != 0) break;
141 case Precision::I32: {
142 const int32_t *src_data = inputs[EXPAND_INPUT]->cbuffer().as<const int32_t *>() +
143 inputs[EXPAND_INPUT]->getTensorDesc().getBlockingDesc().getOffsetPadding();
144 int32_t* dst_data = outputs[0]->cbuffer().as<int32_t *>() +
145 outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
147 parallel_nt(0, [&](const int ithr, const int nthr) {
148 size_t i, src_idx, start = 0, end = 0;
149 SizeVector counters(dst_dims.size(), 0);
150 splitter(work_amount_dst, nthr, ithr, start, end);
151 for (int j = dst_dims.size() - 1, i = start; j >= 0; j--) {
152 counters[j] = i % dst_dims[j];
155 for (size_t iwork = start; iwork < end; ++iwork) {
156 for (i = 0, src_idx = 0; i < dst_dims.size(); ++i)
157 src_idx += counters[i] ? ((counters[i] % src_aligned[i]) * srcStrides_aligned[i]) : 0;
159 dst_data[iwork] = src_data[src_idx];
161 for (int j = dst_dims.size() - 1; j >= 0; j--) {
162 counters[j] = (counters[j] + 1) % dst_dims[j];
163 if (counters[j] != 0) break;
171 std::string errorMsg = "Incorrect output precision. Only FP32 and I32 are supported!";
172 errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
174 return GENERAL_ERROR;
181 const size_t EXPAND_INPUT = 0;
182 const size_t EXPAND_SHAPE = 1;
185 SizeVector srcStrides;
188 REG_FACTORY_FOR(ImplFactory<ExpandImpl>, Expand);
191 } // namespace Extensions
192 } // namespace InferenceEngine