1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
7 #include <description_buffer.hpp>
8 #include "ie_built_in_impl.hpp"
9 #include "precision_utils.h"
10 #include <ie_layers.h>
18 namespace InferenceEngine {
19 namespace ShapeInfer {
22 *@brief Implementation of Shape inference for Reshape layer
24 class ReshapeShapeProp : public BuiltInShapeInferImpl {
26 explicit ReshapeShapeProp(const std::string& type) : BuiltInShapeInferImpl(type) {}
28 void inferShapesImpl(const std::vector<Blob::CPtr>& inBlobs,
29 const std::map<std::string, std::string>& params,
30 const std::map<std::string, Blob::Ptr>& blobs,
31 std::vector<SizeVector>& outShapes) override {
33 ReshapeLayer reshapeLayer(lp);
34 reshapeLayer.params = params;
35 reshapeLayer.type = _type;
36 validate(&reshapeLayer, inBlobs, params, blobs);
39 std::vector<int> reshapeMask;
40 if (inBlobs.size() == 2) {
41 if (inBlobs[1]->precision() == Precision::FP32) {
42 auto* buffer = inBlobs[1]->cbuffer().as<float*>();
43 if (buffer != nullptr) {
44 for (int i = 0; i < inBlobs[1]->size(); i++) {
45 reshapeMask.push_back(static_cast<int>(buffer[i]));
48 THROW_IE_EXCEPTION << "Second input must have allocated data";
50 } else if (inBlobs[1]->precision() == Precision::FP16) {
51 auto* buffer = inBlobs[1]->cbuffer().as<uint16_t*>();
52 if (buffer != nullptr) {
53 for (int i = 0; i < inBlobs[1]->size(); i++) {
54 reshapeMask.push_back(static_cast<int>(PrecisionUtils::f16tof32(buffer[i])));
57 THROW_IE_EXCEPTION << "Second input must have allocated data";
60 THROW_IE_EXCEPTION << "Second input has unsupported precision";
63 reshapeMask = reshapeLayer.shape;
65 auto inputShape = inShapes[0];
66 size_t inputShapeTotal = std::accumulate(inputShape.begin(), inputShape.end(), 1lu,
67 std::multiplies<size_t>());
69 if (reshapeMask.empty()) {
70 outShape = {inputShapeTotal};
73 for (int i = 0; i < reshapeMask.size(); i++) {
74 if (reshapeMask[i] == 0) {
76 } else if (reshapeMask[i] != -1) {
77 res *= reshapeMask[i];
80 size_t newDim = inputShapeTotal / res;
81 for (int i = 0; i < reshapeMask.size(); i++) {
82 if (reshapeMask[i] == 0) {
83 outShape.push_back(inputShape[i]);
84 } else if (reshapeMask[i] == -1) {
85 outShape.push_back(newDim);
87 outShape.push_back(reshapeMask[i]);
90 size_t outputShapeTotal = std::accumulate(outShape.begin(), outShape.end(), 1lu,
91 std::multiplies<size_t>());
92 if (inputShapeTotal != outputShapeTotal)
93 THROW_IE_EXCEPTION << "Invalid reshape mask (dim attribute): number of elements in input: "
94 << details::dumpVec(inputShape) << " and output: " << details::dumpVec(outShape)
97 outShapes.emplace_back(outShape);
101 } // namespace ShapeInfer
102 } // namespace InferenceEngine