1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
9 #include <ie_layer_validators.hpp>
10 #include <blob_factory.hpp>
11 #include "shape_infer/ie_reshape_io_controllers.hpp"
13 using namespace InferenceEngine;
14 using namespace ShapeInfer;
16 void DefaultChecker::run(const std::vector<DataPtr>& dataVec, const std::string& layerName) {
17 std::string errorBase = "Failed to init controller for reshaping layer `" + layerName + "`";
18 if (dataVec.empty()) THROW_IE_EXCEPTION << errorBase + ": vector of data is empty";
19 for (const auto& data : dataVec) {
20 if (!data) THROW_IE_EXCEPTION << errorBase + ": pointer to the data is null";
24 InputController::InputController(const std::vector<DataPtr>& dataVec, const std::string& layerName,
25 const DefaultChecker::Ptr& checker) : _dataVec(dataVec), _layerName(layerName) {
26 checker->run(_dataVec, layerName);
27 for (const auto& data : _dataVec) {
29 _dataNames.push_back(data->name);
30 SizeVector dims = data->getTensorDesc().getDims();
31 _irShapes.push_back(dims);
32 // TODO probably need to create blobs with dimensions, not on getBlobs stage
33 _inferedData.push_back(nullptr);
39 void InputController::setShapeByName(const SizeVector& shape, const std::string& dataName) {
40 long pos = getPositionByName(dataName);
44 SizeVector InputController::getShapeByName(const std::string& dataName) {
45 long pos = getPositionByName(dataName);
49 std::vector<SizeVector> InputController::getShapes(bool check) {
50 if (check) checkCorrespondence();
54 void InputController::applyChanges() {
55 checkCorrespondence();
56 for (int i = 0; i < _dataVec.size(); i++) {
57 auto data = _dataVec[i];
58 if (data) data->setDims(_shapes[i]);
62 void InputController::checkCorrespondence() {
63 if (_shapes.size() != _dataVec.size()) {
64 THROW_IE_EXCEPTION << "ReshapeLauncher: Number of data(" << _dataVec.size()
65 << ") doesn't match with number of shapes(" << _shapes.size() << ") for layer '"
66 << _layerName << "'!";
68 // TODO: iterate and check for emptiness and size matching
71 void InputController::reset() {
75 std::vector<SizeVector> InputController::getIRShapes() {
79 SizeVector InputController::getIRShapeByName(const std::string& dataName) {
80 long pos = getPositionByName(dataName);
81 return _irShapes[pos];
84 long InputController::getPositionByName(const std::string& dataName) {
85 auto pos = std::distance(_dataNames.begin(), std::find(_dataNames.begin(), _dataNames.end(), dataName));
86 if (pos < 0 || pos >= _dataNames.size()) {
87 THROW_IE_EXCEPTION << "Failed to find shape that corresponds Data name=" << dataName;
92 void InputController::setShapeByIndex(const SizeVector& shape, size_t index) {
93 size_t numShapes = _shapes.size();
94 if (index >= numShapes) {
95 THROW_IE_EXCEPTION << "Failed to set shape for index(" << index << ") that is more than number of shapes: "
98 _shapes[index] = shape;
101 bool InputController::isDataAvailable() {
102 if (_inferedData.empty()) return false;
103 for (const auto& data : _inferedData) {
104 if (!data) return false;
105 else if (data->cbuffer() == nullptr) return false;
110 std::vector<Blob::CPtr> InputController::getBlobs(bool check) {
111 if (check) checkCorrespondence();
112 for (int i = 0; i < _dataVec.size(); i++) {
113 if (_inferedData[i] == nullptr || _inferedData[i]->cbuffer() == nullptr) {
114 TensorDesc desc = _dataVec[i]->getTensorDesc();
115 desc.setDims(_shapes[i]);
116 // special case of Shape layer: no input data, but blob contains info about dimensions, layout and etc...
117 auto blob = make_blob_with_precision(desc);
118 _inferedData[i] = blob;
124 void InputController::setBlobByName(const Blob::CPtr& blob, const std::string& dataName) {
125 long pos = getPositionByName(dataName);
126 _inferedData[pos] = blob;
129 OutputController::OutputController(const std::vector<DataPtr>& data, const std::string& layerName,
130 const DefaultChecker::Ptr& checker)
131 : InputController(data, layerName, checker) {}
133 void OutputController::propagateShapes(const std::set<ReshapeLauncher::Ptr>& launchers) {
134 checkCorrespondence();
136 for (auto const& outData : _dataVec) {
137 for (auto const& inputTo : outData->inputTo) {
138 CNNLayerPtr layer = inputTo.second;
139 if (layer == nullptr) {
140 THROW_IE_EXCEPTION << "Failed to propagate shapes for layer (" << inputTo.first
141 << "): connected layer is null";
143 auto layerName = layer->name;
144 auto foundLauncher = std::find_if(launchers.begin(), launchers.end(),
145 [&layerName](const ReshapeLauncher::Ptr& launcher) {
146 return launcher->getLayerName() == layerName;
148 if (foundLauncher == launchers.end())
149 THROW_IE_EXCEPTION << "Failed to find ReshapeLauncher for layer: '" << layerName << "'";
150 (*foundLauncher)->setShapeByName(_shapes[idx], outData->name);
156 // Combine with propagate shapes
157 void OutputController::propagateBlobs(const std::set<ReshapeLauncher::Ptr>& launchers) {
159 for (auto const& outData : _dataVec) {
160 for (auto const& inputTo : outData->inputTo) {
161 CNNLayerPtr layer = inputTo.second;
162 if (layer == nullptr) {
163 THROW_IE_EXCEPTION << "Failed to propagate shapes for layer (" << inputTo.first
164 << "): connected layer is null";
166 auto layerName = layer->name;
167 auto foundLauncher = std::find_if(launchers.begin(), launchers.end(),
168 [&layerName](const ReshapeLauncher::Ptr& launcher) {
169 return launcher->getLayerName() == layerName;
171 if (foundLauncher == launchers.end())
172 THROW_IE_EXCEPTION << "Failed to find ReshapeLauncher for layer: '" << layerName << "'";
173 (*foundLauncher)->setBlobByName(_inferedData[idx], outData->name);
179 void OutputController::setShapes(const std::vector<SizeVector>& shapes) {
183 void OutputController::setBlobs(const std::vector<Blob::Ptr>& blobs) {
184 _inferedData.clear();
185 for (const auto& blob : blobs) {
186 _inferedData.push_back(blob);
190 std::vector<Blob::Ptr> OutputController::createBlobs() {
191 std::vector<Blob::Ptr> blobs;
192 for (int i = 0; i < _dataVec.size(); i++) {
193 TensorDesc desc = _dataVec[i]->getTensorDesc();
194 desc.setDims(_shapes[i]);
195 auto blob = make_blob_with_precision(desc);
197 blobs.push_back(blob);