Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / shape_infer / built-in / ie_squeeze_shape_infer.hpp
1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include "ie_built_in_impl.hpp"
8 #include <map>
9 #include <memory>
10 #include <string>
11 #include <vector>
12
13 namespace InferenceEngine {
14 namespace ShapeInfer {
15
16 /**
17  *@brief Implementation of Shape inference for Squeeze layer
18  */
19 class SqueezeShapeProp : public BuiltInShapeInferImpl {
20 public:
21     explicit SqueezeShapeProp(const std::string& type) : BuiltInShapeInferImpl(type) {}
22
23     void inferShapesImpl(const std::vector<Blob::CPtr>& inBlobs,
24                          const std::map<std::string, std::string>& params,
25                          const std::map<std::string, Blob::Ptr>& blobs,
26                          std::vector<SizeVector>& outShapes) override {
27         LayerParams lp{};
28         SqueezeLayer layer(lp);
29         layer.params = params;
30         layer.type = _type;
31         validate(&layer, inBlobs, params, blobs);
32
33         const size_t SQUEEZE_DATA = 0;
34         const size_t SQUEEZE_INDEXES = 1;
35
36         SizeVector data_dims;
37         SizeVector idx_dims;
38
39         idx_dims = inBlobs[SQUEEZE_INDEXES]->getTensorDesc().getDims();
40         if (idx_dims.size() > 1)
41             THROW_IE_EXCEPTION << " Index vector should be 1 dimension";
42
43         if (inBlobs[SQUEEZE_INDEXES]->getTensorDesc().getPrecision() != Precision::I32 &&
44             inBlobs[SQUEEZE_INDEXES]->getTensorDesc().getPrecision() != Precision::FP32)
45             THROW_IE_EXCEPTION << " Incorrect 'indices_to_squeeze' input precision. Only FP32 and I32 are supported!";
46
47         data_dims = inBlobs[SQUEEZE_DATA]->getTensorDesc().getDims();
48
49         if (data_dims.size() <= idx_dims[0] && !(data_dims.size() == 1 && idx_dims[0] == 1))
50             THROW_IE_EXCEPTION << " Incompatible number of data dimensions and indexes vector length!";
51         SizeVector outShape;
52         switch (inBlobs[SQUEEZE_INDEXES]->precision()) {
53             case Precision::FP32: {
54                 float* idx_data = inBlobs[SQUEEZE_INDEXES]->cbuffer().as<float*>() +
55                                   inBlobs[SQUEEZE_INDEXES]->getTensorDesc().getBlockingDesc().getOffsetPadding();
56                 for (size_t i = 0; i < idx_dims[0]; i++) {
57                     float axis = idx_data[i];
58                     if (axis < 0)
59                         axis += data_dims.size();
60
61                     if (axis > data_dims.size()) {
62                         THROW_IE_EXCEPTION << "Index to squeeze exceeds data tensor dimension";
63                     } else if (data_dims[axis] != 1) {
64                         THROW_IE_EXCEPTION << "Index to squeeze of data tensor dimension is not 1";
65                     }
66                 }
67                 for (size_t j = 0; j < data_dims.size(); j++) {
68                     bool found = false;
69                     for (size_t i = 0; i < inBlobs[SQUEEZE_INDEXES]->size(); i++) {
70                         int32_t axis = idx_data[i];
71                         if (axis < 0)
72                             axis += data_dims.size();
73                         if (j == static_cast<size_t>(axis)) found = true;
74                     }
75                     if (!found) outShape.push_back(data_dims[j]);
76                 }
77             }
78                 break;
79             case Precision::I32: {
80                 int32_t* idx_data = inBlobs[SQUEEZE_INDEXES]->cbuffer().as<int32_t*>() +
81                                     inBlobs[SQUEEZE_INDEXES]->getTensorDesc().getBlockingDesc().getOffsetPadding();
82                 for (size_t i = 0; i < idx_dims[0]; i++) {
83                     int32_t axis = idx_data[i];
84                     if (axis < 0)
85                         axis += data_dims.size();
86
87                     if (axis > data_dims.size()) {
88                         THROW_IE_EXCEPTION << "Index to squeeze exceeds data tensor dimension";
89                     } else if (data_dims[axis] != 1) {
90                         THROW_IE_EXCEPTION << "Index to squeeze of data tensor dimension is not 1";
91                     }
92                 }
93                 for (size_t j = 0; j < data_dims.size(); j++) {
94                     bool found = false;
95                     for (size_t i = 0; i < inBlobs[SQUEEZE_INDEXES]->size(); i++) {
96                         int32_t axis = idx_data[i];
97                         if (axis < 0)
98                             axis += data_dims.size();
99                         if (j == static_cast<size_t>(axis)) found = true;
100                     }
101                     if (!found) outShape.push_back(data_dims[j]);
102                 }
103             }
104                 break;
105             default:
106                 THROW_IE_EXCEPTION
107                         << "Incorrect 'indices_to_squeeze' input precision. Only FP32 and I32 are supported!";
108         }
109         outShapes.push_back(outShape);
110     }
111 };
112
113 }  // namespace ShapeInfer
114 }  // namespace InferenceEngine
115