1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <vpu/utils/error.hpp>
6 #include "vpu/ngraph/operations/out_shape_of_reshape.hpp"
8 namespace ngraph { namespace vpu { namespace op {
10 constexpr NodeTypeInfo OutShapeOfReshape::type_info;
12 OutShapeOfReshape::OutShapeOfReshape(
13 const Output<Node>& inDataShape,
14 const Output<Node>& outShapeDescriptor,
15 bool specialZero) : Op({inDataShape, outShapeDescriptor}), m_specialZero(specialZero) {
16 constructor_validate_and_infer_types();
19 void OutShapeOfReshape::validate_and_infer_types() {
20 NODE_VALIDATION_CHECK(this, get_input_size() == 2,
21 "OutShapeOfReshape (", get_friendly_name(),
22 ") must have only 2 inputs, provided: ", get_input_size());
24 const auto& inDataShapeTensorShape = get_input_partial_shape(0);
25 NODE_VALIDATION_CHECK(this, inDataShapeTensorShape.is_static(),
26 "OutShapeOfReshape (", get_friendly_name(),
27 ") doesn't support dynamic input data shape");
28 NODE_VALIDATION_CHECK(this, inDataShapeTensorShape.rank().get_length() == 1,
29 "OutShapeOfReshape (", get_friendly_name(),
30 ") must have input data shape tensor with rank 1, provided: ",
31 inDataShapeTensorShape.rank().get_length());
33 const auto& outShapeDescriptorTensorShape = get_input_partial_shape(1);
34 NODE_VALIDATION_CHECK(this, outShapeDescriptorTensorShape.is_static(),
35 "OutShapeOfReshape (", get_friendly_name(),
36 ") doesn't support dynamic output shape descriptor");
37 NODE_VALIDATION_CHECK(this, outShapeDescriptorTensorShape.rank().get_length() == 1,
38 "OutShapeOfReshape (", get_friendly_name(),
39 ") must have output shape descriptor tensor with rank 1, provided: ",
40 outShapeDescriptorTensorShape.rank().get_length());
42 const auto& inDataShapeTensorType = get_input_element_type(0);
43 NODE_VALIDATION_CHECK(this,
44 inDataShapeTensorType.is_static() &&
45 inDataShapeTensorType.is_integral_number(),
46 "OutShapeOfReshape (", get_friendly_name(),
47 ") input data type needs to be an integral type. Got: ",
48 inDataShapeTensorType);
49 const auto& outShapeDescriptorTensorType = get_input_element_type(1);
50 NODE_VALIDATION_CHECK(this,
51 outShapeDescriptorTensorType.is_static() &&
52 outShapeDescriptorTensorType.is_integral_number(),
53 "OutShapeOfReshape (", get_friendly_name(),
54 ") shape descriptor type needs to be an integral type. Got: ",
55 outShapeDescriptorTensorType);
57 set_output_type(0, element::i64, outShapeDescriptorTensorShape);
60 std::shared_ptr<Node> OutShapeOfReshape::clone_with_new_inputs(const OutputVector& new_args) const {
61 check_new_args_count(this, new_args);
62 return std::make_shared<OutShapeOfReshape>(new_args.at(0), new_args.at(1), m_specialZero);
65 bool OutShapeOfReshape::visit_attributes(ngraph::AttributeVisitor& visitor) {
66 visitor.on_attribute("special_zero", m_specialZero);
72 template<element::Type_t ET>
73 bool getShapeFromHostTensorData(const HostTensorPtr& data, Shape& result) {
74 using T = typename element_type_traits<ET>::value_type;
75 T* dataPtr = data->get_data_ptr<ET>();
79 size_t outputRank = data->get_shape()[0];
81 for (int i = 0; i < outputRank; i++) {
82 result.push_back(dataPtr[i]);
88 template<element::Type_t ET>
89 bool setShapeToHostTensorData(const HostTensorPtr& data, const Shape& shape) {
90 using T = typename element_type_traits<ET>::value_type;
91 T* dataPtr = data->get_data_ptr<ET>();
95 size_t outputRank = data->get_shape()[0];
96 if (shape.size() != outputRank) {
100 for (int i = 0; i < outputRank; i++) {
101 dataPtr[i] = shape[i];
106 bool getShapeFromHostTensorData(const HostTensorPtr& data, Shape& shape) {
108 switch (data->get_element_type()) {
109 case element::Type_t::i8:
110 rc = getShapeFromHostTensorData<element::Type_t::i8>(data, shape);
112 case element::Type_t::i16:
113 rc = getShapeFromHostTensorData<element::Type_t::i16>(data, shape);
115 case element::Type_t::i32:
116 rc = getShapeFromHostTensorData<element::Type_t::i32>(data, shape);
118 case element::Type_t::i64:
119 rc = getShapeFromHostTensorData<element::Type_t::i64>(data, shape);
121 case element::Type_t::u8:
122 rc = getShapeFromHostTensorData<element::Type_t::u8>(data, shape);
124 case element::Type_t::u16:
125 rc = getShapeFromHostTensorData<element::Type_t::u16>(data, shape);
127 case element::Type_t::u32:
128 rc = getShapeFromHostTensorData<element::Type_t::u32>(data, shape);
130 case element::Type_t::u64:
131 rc = getShapeFromHostTensorData<element::Type_t::u64>(data, shape);
138 bool setShapeToHostTensorData(const HostTensorPtr& data, const Shape& shape) {
140 switch (data->get_element_type()) {
141 case element::Type_t::i8:
142 rc = setShapeToHostTensorData<element::Type_t::i8>(data, shape);
144 case element::Type_t::i16:
145 rc = setShapeToHostTensorData<element::Type_t::i16>(data, shape);
147 case element::Type_t::i32:
148 rc = setShapeToHostTensorData<element::Type_t::i32>(data, shape);
150 case element::Type_t::i64:
151 rc = setShapeToHostTensorData<element::Type_t::i64>(data, shape);
153 case element::Type_t::u8:
154 rc = setShapeToHostTensorData<element::Type_t::u8>(data, shape);
156 case element::Type_t::u16:
157 rc = setShapeToHostTensorData<element::Type_t::u16>(data, shape);
159 case element::Type_t::u32:
160 rc = setShapeToHostTensorData<element::Type_t::u32>(data, shape);
162 case element::Type_t::u64:
163 rc = setShapeToHostTensorData<element::Type_t::u64>(data, shape);
170 bool evaluateOutShapeOfReshape(
171 const HostTensorPtr& inDataShapeTensor,
172 const HostTensorPtr& outShapeDescriptorTensor,
174 const HostTensorPtr& outShapeTensor) {
175 if (!inDataShapeTensor || !outShapeDescriptorTensor || !outShapeTensor) {
181 if (!getShapeFromHostTensorData(inDataShapeTensor, inputShape)) {
184 if (!getShapeFromHostTensorData(outShapeDescriptorTensor, outputShape)) {
188 if (std::any_of(outputShape.begin(), outputShape.end(), [](int64_t value) { return value < -1; })) {
192 int zeroDimsCount = std::count_if(outputShape.begin(), outputShape.end(),
193 [](int64_t value) { return value == 0; });
194 int negativeDimsCount = std::count_if(outputShape.begin(), outputShape.end(),
195 [](int64_t value) { return value == -1; });
196 if (negativeDimsCount > 1) {
200 size_t outputRank = outputShape.size();
202 if (!(zeroDimsCount && specialZero) && !negativeDimsCount) {
203 if (shape_size(inputShape) != shape_size(outputShape)) {
207 int negativeDimIdx = -1;
209 size_t inputTotalDimCount = shape_size(inputShape);
210 size_t outputTotalDimCount = 1;
213 // compute the output shape
214 for (size_t i = 0; i < outputRank; i++) {
215 if (outputShape[i] == 0 && specialZero) {
216 // Copy input_shape[i] for zero values
217 if (i > inputShape.size() - 1) {
220 outputShape[i] = inputShape[i];
221 outputTotalDimCount *= inputShape[i];
222 } else if (outputShape[i] == -1) {
225 outputTotalDimCount *= outputShape[i];
229 if (negativeDimIdx != -1) {
230 // Infer size such that number of output elements matches
232 if (outputTotalDimCount == 0) {
233 if (inputTotalDimCount != 0) {
236 outputShape[negativeDimIdx] = 0;
238 if (inputTotalDimCount % outputTotalDimCount != 0) {
241 outputShape[negativeDimIdx] = inputTotalDimCount / outputTotalDimCount;
246 if (!setShapeToHostTensorData(outShapeTensor, outputShape)) {
253 } // namespace out_shape
255 bool OutShapeOfReshape::evaluate(const HostTensorVector& outputs,
256 const HostTensorVector& inputs) const {
257 return out_shape::evaluateOutShapeOfReshape(inputs[0], inputs[1], m_specialZero, outputs[0]);
263 } // namespace ngraph