Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / shape_infer / built-in / ie_reshape_shape_infer.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <description_buffer.hpp>
8 #include "ie_built_in_impl.hpp"
9 #include "precision_utils.h"
10 #include <ie_layers.h>
11 #include <map>
12 #include <memory>
13 #include <string>
14 #include <vector>
15 #include <debug.h>
16 #include <functional>
17
18 namespace InferenceEngine {
19 namespace ShapeInfer {
20
21 /**
22  *@brief Implementation of Shape inference for Reshape layer
23  */
24 class ReshapeShapeProp : public BuiltInShapeInferImpl {
25 public:
26     explicit ReshapeShapeProp(const std::string& type) : BuiltInShapeInferImpl(type) {}
27
28     void inferShapesImpl(const std::vector<Blob::CPtr>& inBlobs,
29                          const std::map<std::string, std::string>& params,
30                          const std::map<std::string, Blob::Ptr>& blobs,
31                          std::vector<SizeVector>& outShapes) override {
32         LayerParams lp{};
33         ReshapeLayer reshapeLayer(lp);
34         reshapeLayer.params = params;
35         reshapeLayer.type = _type;
36         validate(&reshapeLayer, inBlobs, params, blobs);
37
38         SizeVector outShape;
39         std::vector<int> reshapeMask;
40         if (inBlobs.size() == 2) {
41             if (inBlobs[1]->precision() == Precision::FP32) {
42                 auto* buffer = inBlobs[1]->cbuffer().as<float*>();
43                 if (buffer != nullptr) {
44                     for (int i = 0; i < inBlobs[1]->size(); i++) {
45                         reshapeMask.push_back(static_cast<int>(buffer[i]));
46                     }
47                 } else {
48                     THROW_IE_EXCEPTION << "Second input must have allocated data";
49                 }
50             } else if (inBlobs[1]->precision() == Precision::FP16) {
51                 auto* buffer = inBlobs[1]->cbuffer().as<uint16_t*>();
52                 if (buffer != nullptr) {
53                     for (int i = 0; i < inBlobs[1]->size(); i++) {
54                         reshapeMask.push_back(static_cast<int>(PrecisionUtils::f16tof32(buffer[i])));
55                     }
56                 } else {
57                     THROW_IE_EXCEPTION << "Second input must have allocated data";
58                 }
59             } else {
60                 THROW_IE_EXCEPTION << "Second input has unsupported precision";
61             }
62         } else {
63             reshapeMask = reshapeLayer.shape;
64         }
65         auto inputShape = inShapes[0];
66         size_t inputShapeTotal = std::accumulate(inputShape.begin(), inputShape.end(), 1lu,
67                                                  std::multiplies<size_t>());
68
69         if (reshapeMask.empty()) {
70             outShape = {inputShapeTotal};
71         } else {
72             size_t res = 1;
73             for (int i = 0; i < reshapeMask.size(); i++) {
74                 if (reshapeMask[i] == 0) {
75                     res *= inputShape[i];
76                 } else if (reshapeMask[i] != -1) {
77                     res *= reshapeMask[i];
78                 }
79             }
80             size_t newDim = inputShapeTotal / res;
81             for (int i = 0; i < reshapeMask.size(); i++) {
82                 if (reshapeMask[i] == 0) {
83                     outShape.push_back(inputShape[i]);
84                 } else if (reshapeMask[i] == -1) {
85                     outShape.push_back(newDim);
86                 } else {
87                     outShape.push_back(reshapeMask[i]);
88                 }
89             }
90             size_t outputShapeTotal = std::accumulate(outShape.begin(), outShape.end(), 1lu,
91                                                       std::multiplies<size_t>());
92             if (inputShapeTotal != outputShapeTotal)
93                 THROW_IE_EXCEPTION << "Invalid reshape mask (dim attribute): number of elements in input: "
94                                    << details::dumpVec(inputShape) << " and output: " << details::dumpVec(outShape)
95                                    << " mismatch";
96         }
97         outShapes.emplace_back(outShape);
98     }
99 };
100
101 }  // namespace ShapeInfer
102 }  // namespace InferenceEngine