[IE][NGRAPH][BUILD] Enable UNITY build for more targets (#2592)
[platform/upstream/dldt.git] / inference-engine / src / vpu / common / src / ngraph / operations / out_shape_of_reshape.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vpu/utils/error.hpp>
6 #include "vpu/ngraph/operations/out_shape_of_reshape.hpp"
7
8 namespace ngraph { namespace vpu { namespace op {
9
10 constexpr NodeTypeInfo OutShapeOfReshape::type_info;
11
12 OutShapeOfReshape::OutShapeOfReshape(
13         const Output<Node>& inDataShape,
14         const Output<Node>& outShapeDescriptor,
15         bool specialZero) : Op({inDataShape, outShapeDescriptor}), m_specialZero(specialZero) {
16     constructor_validate_and_infer_types();
17 }
18
19 void OutShapeOfReshape::validate_and_infer_types() {
20     NODE_VALIDATION_CHECK(this, get_input_size() == 2,
21                           "OutShapeOfReshape (", get_friendly_name(),
22                           ") must have only 2 inputs, provided: ", get_input_size());
23
24     const auto& inDataShapeTensorShape = get_input_partial_shape(0);
25     NODE_VALIDATION_CHECK(this, inDataShapeTensorShape.is_static(),
26                           "OutShapeOfReshape (", get_friendly_name(),
27                           ") doesn't support dynamic input data shape");
28     NODE_VALIDATION_CHECK(this, inDataShapeTensorShape.rank().get_length() == 1,
29                           "OutShapeOfReshape (", get_friendly_name(),
30                           ") must have input data shape tensor with rank 1, provided: ",
31                           inDataShapeTensorShape.rank().get_length());
32
33     const auto& outShapeDescriptorTensorShape = get_input_partial_shape(1);
34     NODE_VALIDATION_CHECK(this, outShapeDescriptorTensorShape.is_static(),
35                           "OutShapeOfReshape (", get_friendly_name(),
36                           ") doesn't support dynamic output shape descriptor");
37     NODE_VALIDATION_CHECK(this, outShapeDescriptorTensorShape.rank().get_length() == 1,
38                           "OutShapeOfReshape (", get_friendly_name(),
39                           ") must have output shape descriptor tensor with rank 1, provided: ",
40                           outShapeDescriptorTensorShape.rank().get_length());
41
42     const auto& inDataShapeTensorType = get_input_element_type(0);
43     NODE_VALIDATION_CHECK(this,
44                           inDataShapeTensorType.is_static() &&
45                           inDataShapeTensorType.is_integral_number(),
46                           "OutShapeOfReshape (", get_friendly_name(),
47                           ") input data type needs to be an integral type. Got: ",
48                           inDataShapeTensorType);
49     const auto& outShapeDescriptorTensorType = get_input_element_type(1);
50     NODE_VALIDATION_CHECK(this,
51                           outShapeDescriptorTensorType.is_static() &&
52                           outShapeDescriptorTensorType.is_integral_number(),
53                           "OutShapeOfReshape (", get_friendly_name(),
54                           ") shape descriptor type needs to be an integral type. Got: ",
55                           outShapeDescriptorTensorType);
56
57     set_output_type(0, element::i64, outShapeDescriptorTensorShape);
58 }
59
60 std::shared_ptr<Node> OutShapeOfReshape::clone_with_new_inputs(const OutputVector& new_args) const {
61     check_new_args_count(this, new_args);
62     return std::make_shared<OutShapeOfReshape>(new_args.at(0), new_args.at(1), m_specialZero);
63 }
64
65 bool OutShapeOfReshape::visit_attributes(ngraph::AttributeVisitor& visitor) {
66     visitor.on_attribute("special_zero", m_specialZero);
67     return true;
68 }
69
70 namespace out_shape {
71
72 template<element::Type_t ET>
73 bool getShapeFromHostTensorData(const HostTensorPtr& data, Shape& result) {
74     using T = typename element_type_traits<ET>::value_type;
75     T* dataPtr = data->get_data_ptr<ET>();
76     if (!dataPtr) {
77         return false;
78     }
79     size_t outputRank = data->get_shape()[0];
80
81     for (int i = 0; i < outputRank; i++) {
82         result.push_back(dataPtr[i]);
83     }
84
85     return true;
86 }
87
88 template<element::Type_t ET>
89 bool setShapeToHostTensorData(const HostTensorPtr& data, const Shape& shape) {
90     using T = typename element_type_traits<ET>::value_type;
91     T* dataPtr = data->get_data_ptr<ET>();
92     if (!dataPtr) {
93         return false;
94     }
95     size_t outputRank = data->get_shape()[0];
96     if (shape.size() != outputRank) {
97         return false;
98     }
99
100     for (int i = 0; i < outputRank; i++) {
101         dataPtr[i] = shape[i];
102     }
103     return true;
104 }
105
106 bool getShapeFromHostTensorData(const HostTensorPtr& data, Shape& shape) {
107     bool rc = false;
108     switch (data->get_element_type()) {
109         case element::Type_t::i8:
110             rc = getShapeFromHostTensorData<element::Type_t::i8>(data, shape);
111             break;
112         case element::Type_t::i16:
113             rc = getShapeFromHostTensorData<element::Type_t::i16>(data, shape);
114             break;
115         case element::Type_t::i32:
116             rc = getShapeFromHostTensorData<element::Type_t::i32>(data, shape);
117             break;
118         case element::Type_t::i64:
119             rc = getShapeFromHostTensorData<element::Type_t::i64>(data, shape);
120             break;
121         case element::Type_t::u8:
122             rc = getShapeFromHostTensorData<element::Type_t::u8>(data, shape);
123             break;
124         case element::Type_t::u16:
125             rc = getShapeFromHostTensorData<element::Type_t::u16>(data, shape);
126             break;
127         case element::Type_t::u32:
128             rc = getShapeFromHostTensorData<element::Type_t::u32>(data, shape);
129             break;
130         case element::Type_t::u64:
131             rc = getShapeFromHostTensorData<element::Type_t::u64>(data, shape);
132             break;
133         default: rc = false;
134     }
135     return rc;
136 }
137
138 bool setShapeToHostTensorData(const HostTensorPtr& data, const Shape& shape) {
139     bool rc = false;
140     switch (data->get_element_type()) {
141         case element::Type_t::i8:
142             rc = setShapeToHostTensorData<element::Type_t::i8>(data, shape);
143             break;
144         case element::Type_t::i16:
145             rc = setShapeToHostTensorData<element::Type_t::i16>(data, shape);
146             break;
147         case element::Type_t::i32:
148             rc = setShapeToHostTensorData<element::Type_t::i32>(data, shape);
149             break;
150         case element::Type_t::i64:
151             rc = setShapeToHostTensorData<element::Type_t::i64>(data, shape);
152             break;
153         case element::Type_t::u8:
154             rc = setShapeToHostTensorData<element::Type_t::u8>(data, shape);
155             break;
156         case element::Type_t::u16:
157             rc = setShapeToHostTensorData<element::Type_t::u16>(data, shape);
158             break;
159         case element::Type_t::u32:
160             rc = setShapeToHostTensorData<element::Type_t::u32>(data, shape);
161             break;
162         case element::Type_t::u64:
163             rc = setShapeToHostTensorData<element::Type_t::u64>(data, shape);
164             break;
165         default: rc = false;
166     }
167     return rc;
168 }
169
170 bool evaluateOutShapeOfReshape(
171         const HostTensorPtr& inDataShapeTensor,
172         const HostTensorPtr& outShapeDescriptorTensor,
173         bool specialZero,
174         const HostTensorPtr& outShapeTensor) {
175     if (!inDataShapeTensor || !outShapeDescriptorTensor || !outShapeTensor) {
176         return false;
177     }
178     Shape inputShape;
179     Shape outputShape;
180
181     if (!getShapeFromHostTensorData(inDataShapeTensor, inputShape)) {
182         return false;
183     }
184     if (!getShapeFromHostTensorData(outShapeDescriptorTensor, outputShape)) {
185         return false;
186     }
187
188     if (std::any_of(outputShape.begin(), outputShape.end(), [](int64_t value) { return value < -1; })) {
189         return false;
190     }
191
192     int zeroDimsCount = std::count_if(outputShape.begin(), outputShape.end(),
193                                       [](int64_t value) { return value == 0; });
194     int negativeDimsCount = std::count_if(outputShape.begin(), outputShape.end(),
195                                           [](int64_t value) { return value == -1; });
196     if (negativeDimsCount > 1) {
197         return false;
198     }
199
200     size_t outputRank = outputShape.size();
201
202     if (!(zeroDimsCount && specialZero) && !negativeDimsCount) {
203         if (shape_size(inputShape) != shape_size(outputShape)) {
204             return false;
205         }
206     } else {
207         int negativeDimIdx = -1;
208
209         size_t inputTotalDimCount = shape_size(inputShape);
210         size_t outputTotalDimCount = 1;
211
212
213         // compute the output shape
214         for (size_t i = 0; i < outputRank; i++) {
215             if (outputShape[i] == 0 && specialZero) {
216                 // Copy input_shape[i] for zero values
217                 if (i > inputShape.size() - 1) {
218                     return false;
219                 }
220                 outputShape[i] = inputShape[i];
221                 outputTotalDimCount *= inputShape[i];
222             } else if (outputShape[i] == -1) {
223                 negativeDimIdx = i;
224             } else {
225                 outputTotalDimCount *= outputShape[i];
226             }
227         }
228
229         if (negativeDimIdx != -1) {
230             // Infer size such that number of output elements matches
231             // input elements
232             if (outputTotalDimCount == 0) {
233                 if (inputTotalDimCount != 0) {
234                     return false;
235                 }
236                 outputShape[negativeDimIdx] = 0;
237             } else {
238                 if (inputTotalDimCount % outputTotalDimCount != 0) {
239                     return false;
240                 }
241                 outputShape[negativeDimIdx] = inputTotalDimCount / outputTotalDimCount;
242             }
243         }
244     }
245
246     if (!setShapeToHostTensorData(outShapeTensor, outputShape)) {
247         return false;
248     }
249
250     return true;
251 }
252
253 }  // namespace out_shape
254
255 bool OutShapeOfReshape::evaluate(const HostTensorVector& outputs,
256                                  const HostTensorVector& inputs) const {
257     return out_shape::evaluateOutShapeOfReshape(inputs[0], inputs[1], m_specialZero, outputs[0]);
258 }
259
260
261 }  // namespace op
262 }  // namespace vpu
263 }  // namespace ngraph