1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
10 #include <ie_layer_validators.hpp>
11 #include "shape_infer/ie_reshape_io_controllers.hpp"
13 using namespace InferenceEngine;
14 using namespace ShapeInfer;
16 void DefaultChecker::run(const std::vector<DataPtr>& dataVec, const std::string& layerName) {
17 std::string errorBase = "Failed to init controller for reshaping layer `" + layerName + "`";
18 if (dataVec.empty()) THROW_IE_EXCEPTION << errorBase + ": vector of data is empty";
19 for (const auto& data : dataVec) {
20 if (!data) THROW_IE_EXCEPTION << errorBase + ": pointer to the data is null";
24 InputController::InputController(const std::vector<DataPtr>& dataVec, const std::string& layerName,
25 bool irShapesOnInit, const DefaultChecker::Ptr& checker)
26 : _dataVec(dataVec), _layerName(layerName), _irShapesOnInit(irShapesOnInit) {
27 checker->run(_dataVec, layerName);
28 for (const auto& data : _dataVec) {
30 _dataNames.push_back(data->name);
31 _shapes.emplace_back();
32 _irShapes.emplace_back();
35 if (_irShapesOnInit) {
36 _irShapes = getIRShapesInternal();
40 void InputController::setShapeByName(const SizeVector& shape, const std::string& dataName) {
41 long pos = getPositionByName(dataName);
45 std::vector<SizeVector> InputController::getShapes(bool check) {
46 if (check) checkCorrespondence();
50 void InputController::applyChanges() {
51 checkCorrespondence();
52 for (int i = 0; i < _dataVec.size(); i++) {
53 auto data = _dataVec[i];
54 if (data) data->setDims(_shapes[i]);
58 void InputController::checkCorrespondence() {
59 if (_shapes.size() != _dataVec.size()) {
60 THROW_IE_EXCEPTION << "ReshapeLauncher: Number of data(" << _dataVec.size()
61 << ") doesn't match with number of shapes(" << _shapes.size() << ") for layer '"
62 << _layerName << "'!";
64 for (const auto& shape : _shapes) {
65 if (shape.empty()) THROW_IE_EXCEPTION << "ReshapeLauncher error: shape is not set";
67 // TODO: iterate and check for emptiness and size matching
70 void InputController::reset() {
71 for (auto& shape : _shapes) {
76 std::vector<SizeVector> InputController::getIRShapes() {
77 return _irShapesOnInit ? _irShapes : getIRShapesInternal();
80 std::vector<SizeVector> InputController::getIRShapesInternal() {
81 std::vector<SizeVector> shapes;
82 for (const auto& data : _dataVec) {
84 shapes.push_back(data->getTensorDesc().getDims());
90 SizeVector InputController::getIRShapeByName(const std::string& dataName) {
91 long pos = getPositionByName(dataName);
92 return _irShapes[pos];
95 long InputController::getPositionByName(const std::string& dataName) {
96 auto pos = std::distance(_dataNames.begin(), std::find(_dataNames.begin(), _dataNames.end(), dataName));
97 if (pos < 0 || pos >= _dataNames.size()) {
98 THROW_IE_EXCEPTION << "Failed to find shape that corresponds Data name=" << dataName;
103 void InputController::setShapeByIndex(const SizeVector& shape, size_t index) {
104 size_t numShapes = _shapes.size();
105 if (index >= numShapes) {
106 THROW_IE_EXCEPTION << "Failed to set shape for index(" << index << ") that is more than number of shapes: "
109 _shapes[index] = shape;
112 OutputController::OutputController(const std::vector<DataPtr>& data, const std::string& layerName, bool irShapesOnInit,
113 const DefaultChecker::Ptr& checker)
114 : InputController(data, layerName, irShapesOnInit, checker) {}
116 void OutputController::propagateShapes(const std::set<ReshapeLauncher::Ptr>& launchers) {
117 checkCorrespondence();
119 for (auto const& outData : _dataVec) {
120 for (auto const& inputTo : outData->inputTo) {
121 CNNLayerPtr layer = inputTo.second;
122 if (layer == nullptr) {
123 THROW_IE_EXCEPTION << "Failed to propagate shapes for layer ("
125 << "): connected layer is null";
127 auto layerName = layer->name;
128 auto foundLauncher = std::find_if(launchers.begin(), launchers.end(),
129 [&layerName](const ReshapeLauncher::Ptr& launcher) {
130 return launcher->getLayerName() == layerName;
132 if (foundLauncher == launchers.end())
133 THROW_IE_EXCEPTION << "Failed to find ReshapeLauncher for layer: '" << layerName << "'";
134 (*foundLauncher)->setShapeByName(_shapes[idx], outData->name);
140 void OutputController::setShapes(const std::vector<SizeVector>& shapes) {