Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / shape_infer / ie_reshape_io_controllers.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <set>
6 #include <string>
7 #include <vector>
8 #include <ie_layers.h>
9 #include <ie_layer_validators.hpp>
10 #include <blob_factory.hpp>
11 #include "shape_infer/ie_reshape_io_controllers.hpp"
12
13 using namespace InferenceEngine;
14 using namespace ShapeInfer;
15
16 void DefaultChecker::run(const std::vector<DataPtr>& dataVec, const std::string& layerName) {
17     std::string errorBase = "Failed to init controller for reshaping layer `" + layerName + "`";
18     if (dataVec.empty()) THROW_IE_EXCEPTION << errorBase + ": vector of data is empty";
19     for (const auto& data : dataVec) {
20         if (!data) THROW_IE_EXCEPTION << errorBase + ": pointer to the data is null";
21     }
22 }
23
24 InputController::InputController(const std::vector<DataPtr>& dataVec, const std::string& layerName,
25                                  const DefaultChecker::Ptr& checker) : _dataVec(dataVec), _layerName(layerName) {
26     checker->run(_dataVec, layerName);
27     for (const auto& data : _dataVec) {
28         if (data) {
29             _dataNames.push_back(data->name);
30             SizeVector dims = data->getTensorDesc().getDims();
31             _irShapes.push_back(dims);
32             // TODO probably need to create blobs with dimensions, not on getBlobs stage
33             _inferedData.push_back(nullptr);
34         }
35     }
36     _shapes = _irShapes;
37 }
38
39 void InputController::setShapeByName(const SizeVector& shape, const std::string& dataName) {
40     long pos = getPositionByName(dataName);
41     _shapes[pos] = shape;
42 }
43
44 SizeVector InputController::getShapeByName(const std::string& dataName) {
45     long pos = getPositionByName(dataName);
46     return _shapes[pos];
47 }
48
49 std::vector<SizeVector> InputController::getShapes(bool check) {
50     if (check) checkCorrespondence();
51     return _shapes;
52 }
53
54 void InputController::applyChanges() {
55     checkCorrespondence();
56     for (int i = 0; i < _dataVec.size(); i++) {
57         auto data = _dataVec[i];
58         if (data) data->setDims(_shapes[i]);
59     }
60 }
61
62 void InputController::checkCorrespondence() {
63     if (_shapes.size() != _dataVec.size()) {
64         THROW_IE_EXCEPTION << "ReshapeLauncher: Number of data(" << _dataVec.size()
65                            << ") doesn't match with number of shapes(" << _shapes.size() << ") for layer '"
66                            << _layerName << "'!";
67     }
68     // TODO: iterate and check for emptiness and size matching
69 }
70
71 void InputController::reset() {
72     _shapes = _irShapes;
73 }
74
75 std::vector<SizeVector> InputController::getIRShapes() {
76     return _irShapes;
77 }
78
79 SizeVector InputController::getIRShapeByName(const std::string& dataName) {
80     long pos = getPositionByName(dataName);
81     return _irShapes[pos];
82 }
83
84 long InputController::getPositionByName(const std::string& dataName) {
85     auto pos = std::distance(_dataNames.begin(), std::find(_dataNames.begin(), _dataNames.end(), dataName));
86     if (pos < 0 || pos >= _dataNames.size()) {
87         THROW_IE_EXCEPTION << "Failed to find shape that corresponds Data name=" << dataName;
88     }
89     return pos;
90 }
91
92 void InputController::setShapeByIndex(const SizeVector& shape, size_t index) {
93     size_t numShapes = _shapes.size();
94     if (index >= numShapes) {
95         THROW_IE_EXCEPTION << "Failed to set shape for index(" << index << ") that is more than number of shapes: "
96                            << numShapes;
97     }
98     _shapes[index] = shape;
99 }
100
101 bool InputController::isDataAvailable() {
102     if (_inferedData.empty()) return false;
103     for (const auto& data : _inferedData) {
104         if (!data) return false;
105         else if (data->cbuffer() == nullptr) return false;
106     }
107     return true;
108 }
109
110 std::vector<Blob::CPtr> InputController::getBlobs(bool check) {
111     if (check) checkCorrespondence();
112     for (int i = 0; i < _dataVec.size(); i++) {
113         if (_inferedData[i] == nullptr || _inferedData[i]->cbuffer() == nullptr) {
114             TensorDesc desc = _dataVec[i]->getTensorDesc();
115             desc.setDims(_shapes[i]);
116             // special case of Shape layer: no input data, but blob contains info about dimensions, layout and etc...
117             auto blob = make_blob_with_precision(desc);
118             _inferedData[i] = blob;
119         }
120     }
121     return _inferedData;
122 }
123
124 void InputController::setBlobByName(const Blob::CPtr& blob, const std::string& dataName) {
125     long pos = getPositionByName(dataName);
126     _inferedData[pos] = blob;
127 }
128
129 OutputController::OutputController(const std::vector<DataPtr>& data, const std::string& layerName,
130                                    const DefaultChecker::Ptr& checker)
131         : InputController(data, layerName, checker) {}
132
133 void OutputController::propagateShapes(const std::set<ReshapeLauncher::Ptr>& launchers) {
134     checkCorrespondence();
135     unsigned idx = 0;
136     for (auto const& outData : _dataVec) {
137         for (auto const& inputTo : outData->inputTo) {
138             CNNLayerPtr layer = inputTo.second;
139             if (layer == nullptr) {
140                 THROW_IE_EXCEPTION << "Failed to propagate shapes for layer (" << inputTo.first
141                                    << "): connected layer is null";
142             }
143             auto layerName = layer->name;
144             auto foundLauncher = std::find_if(launchers.begin(), launchers.end(),
145                                               [&layerName](const ReshapeLauncher::Ptr& launcher) {
146                                                   return launcher->getLayerName() == layerName;
147                                               });
148             if (foundLauncher == launchers.end())
149                 THROW_IE_EXCEPTION << "Failed to find ReshapeLauncher for layer: '" << layerName << "'";
150             (*foundLauncher)->setShapeByName(_shapes[idx], outData->name);
151         }
152         idx++;
153     }
154 }
155
156 // Combine with propagate shapes
157 void OutputController::propagateBlobs(const std::set<ReshapeLauncher::Ptr>& launchers) {
158     unsigned idx = 0;
159     for (auto const& outData : _dataVec) {
160         for (auto const& inputTo : outData->inputTo) {
161             CNNLayerPtr layer = inputTo.second;
162             if (layer == nullptr) {
163                 THROW_IE_EXCEPTION << "Failed to propagate shapes for layer (" << inputTo.first
164                                    << "): connected layer is null";
165             }
166             auto layerName = layer->name;
167             auto foundLauncher = std::find_if(launchers.begin(), launchers.end(),
168                                               [&layerName](const ReshapeLauncher::Ptr& launcher) {
169                                                   return launcher->getLayerName() == layerName;
170                                               });
171             if (foundLauncher == launchers.end())
172                 THROW_IE_EXCEPTION << "Failed to find ReshapeLauncher for layer: '" << layerName << "'";
173             (*foundLauncher)->setBlobByName(_inferedData[idx], outData->name);
174         }
175         idx++;
176     }
177 }
178
179 void OutputController::setShapes(const std::vector<SizeVector>& shapes) {
180     _shapes = shapes;
181 }
182
183 void OutputController::setBlobs(const std::vector<Blob::Ptr>& blobs) {
184     _inferedData.clear();
185     for (const auto& blob : blobs) {
186         _inferedData.push_back(blob);
187     }
188 }
189
190 std::vector<Blob::Ptr> OutputController::createBlobs() {
191     std::vector<Blob::Ptr> blobs;
192     for (int i = 0; i < _dataVec.size(); i++) {
193         TensorDesc desc = _dataVec[i]->getTensorDesc();
194         desc.setDims(_shapes[i]);
195         auto blob = make_blob_with_precision(desc);
196         blob->allocate();
197         blobs.push_back(blob);
198     }
199     return blobs;
200 }
201