Publishing R3
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / shape_infer / ie_reshape_io_controllers.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <set>
7 #include <string>
8 #include <vector>
9 #include <ie_layers.h>
10 #include <ie_layer_validators.hpp>
11 #include "shape_infer/ie_reshape_io_controllers.hpp"
12
13 using namespace InferenceEngine;
14 using namespace ShapeInfer;
15
16 void DefaultChecker::run(const std::vector<DataPtr>& dataVec, const std::string& layerName) {
17     std::string errorBase = "Failed to init controller for reshaping layer `" + layerName + "`";
18     if (dataVec.empty()) THROW_IE_EXCEPTION << errorBase + ": vector of data is empty";
19     for (const auto& data : dataVec) {
20         if (!data) THROW_IE_EXCEPTION << errorBase + ": pointer to the data is null";
21     }
22 }
23
24 InputController::InputController(const std::vector<DataPtr>& dataVec, const std::string& layerName,
25                                  bool irShapesOnInit, const DefaultChecker::Ptr& checker)
26         : _dataVec(dataVec), _layerName(layerName), _irShapesOnInit(irShapesOnInit) {
27     checker->run(_dataVec, layerName);
28     for (const auto& data : _dataVec) {
29         if (data) {
30             _dataNames.push_back(data->name);
31             _shapes.emplace_back();
32             _irShapes.emplace_back();
33         }
34     }
35     if (_irShapesOnInit) {
36         _irShapes = getIRShapesInternal();
37     }
38 }
39
40 void InputController::setShapeByName(const SizeVector& shape, const std::string& dataName) {
41     long pos = getPositionByName(dataName);
42     _shapes[pos] = shape;
43 }
44
45 std::vector<SizeVector> InputController::getShapes(bool check) {
46     if (check) checkCorrespondence();
47     return _shapes;
48 }
49
50 void InputController::applyChanges() {
51     checkCorrespondence();
52     for (int i = 0; i < _dataVec.size(); i++) {
53         auto data = _dataVec[i];
54         if (data) data->setDims(_shapes[i]);
55     }
56 }
57
58 void InputController::checkCorrespondence() {
59     if (_shapes.size() != _dataVec.size()) {
60         THROW_IE_EXCEPTION << "ReshapeLauncher: Number of data(" << _dataVec.size()
61                            << ") doesn't match with number of shapes(" << _shapes.size() << ") for layer '"
62                            << _layerName << "'!";
63     }
64     for (const auto& shape : _shapes) {
65         if (shape.empty()) THROW_IE_EXCEPTION << "ReshapeLauncher error: shape is not set";
66     }
67     // TODO: iterate and check for emptiness and size matching
68 }
69
70 void InputController::reset() {
71     for (auto& shape : _shapes) {
72         shape.clear();
73     }
74 }
75
76 std::vector<SizeVector> InputController::getIRShapes() {
77     return _irShapesOnInit ? _irShapes : getIRShapesInternal();
78 }
79
80 std::vector<SizeVector> InputController::getIRShapesInternal() {
81     std::vector<SizeVector> shapes;
82     for (const auto& data : _dataVec) {
83         if (data) {
84             shapes.push_back(data->getTensorDesc().getDims());
85         }
86     }
87     return shapes;
88 }
89
90 SizeVector InputController::getIRShapeByName(const std::string& dataName) {
91     long pos = getPositionByName(dataName);
92     return _irShapes[pos];
93 }
94
95 long InputController::getPositionByName(const std::string& dataName) {
96     auto pos = std::distance(_dataNames.begin(), std::find(_dataNames.begin(), _dataNames.end(), dataName));
97     if (pos < 0 || pos >= _dataNames.size()) {
98         THROW_IE_EXCEPTION << "Failed to find shape that corresponds Data name=" << dataName;
99     }
100     return pos;
101 }
102
103 void InputController::setShapeByIndex(const SizeVector& shape, size_t index) {
104     size_t numShapes = _shapes.size();
105     if (index >= numShapes) {
106         THROW_IE_EXCEPTION << "Failed to set shape for index(" << index << ") that is more than number of shapes: "
107                            << numShapes;
108     }
109     _shapes[index] = shape;
110 }
111
112 OutputController::OutputController(const std::vector<DataPtr>& data, const std::string& layerName, bool irShapesOnInit,
113                                    const DefaultChecker::Ptr& checker)
114         : InputController(data, layerName, irShapesOnInit, checker) {}
115
116 void OutputController::propagateShapes(const std::set<ReshapeLauncher::Ptr>& launchers) {
117     checkCorrespondence();
118     unsigned idx = 0;
119     for (auto const& outData : _dataVec) {
120         for (auto const& inputTo : outData->inputTo) {
121             CNNLayerPtr layer = inputTo.second;
122             if (layer == nullptr) {
123                 THROW_IE_EXCEPTION << "Failed to propagate shapes for layer ("
124                             << inputTo.first
125                             << "): connected layer is null";
126             }
127             auto layerName = layer->name;
128             auto foundLauncher = std::find_if(launchers.begin(), launchers.end(),
129                                               [&layerName](const ReshapeLauncher::Ptr& launcher) {
130                                                   return launcher->getLayerName() == layerName;
131                                               });
132             if (foundLauncher == launchers.end())
133                 THROW_IE_EXCEPTION << "Failed to find ReshapeLauncher for layer: '" << layerName << "'";
134             (*foundLauncher)->setShapeByName(_shapes[idx], outData->name);
135         }
136         idx++;
137     }
138 }
139
140 void OutputController::setShapes(const std::vector<SizeVector>& shapes) {
141     _shapes = shapes;
142 }