1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
10 #include <ie_layers.h>
11 #include <graph_tools.hpp>
14 #include <blob_factory.hpp>
16 #include "shape_infer/built-in/ie_built_in_holder.hpp"
17 #include "shape_infer/ie_reshaper.hpp"
18 #include "details/caseless.hpp"
19 #include "details/ie_cnn_network_tools.h"
20 #include "ie_reshaper.hpp"
21 #include "ie_cnn_layer_builder.h"
23 using namespace InferenceEngine;
24 using namespace InferenceEngine::details;
25 using namespace ShapeInfer;
27 Reshaper::Reshaper(Builder::Network* network): network(network) {}
29 static std::vector<CNNLayerPtr> SortTopologicallyStartsFrom(const std::vector<DataPtr> &inputs) {
30 std::vector<CNNLayerPtr> all_layers;
31 CNNNetForestDFS(inputs, [&](CNNLayerPtr current){
32 all_layers.push_back(current);
34 std::reverse(all_layers.begin(), all_layers.end());
38 Reshaper::Reshaper(std::vector<DataPtr> insDatas, const LauncherCreator::Ptr& launcherCreator): network(nullptr) {
39 auto builtIn = std::make_shared<BuiltInShapeInferHolder>();
40 _allTypes = getTypeNamesFromExtension(builtIn);
41 _extensions.push_back(builtIn);
43 _allSortedLayers = SortTopologicallyStartsFrom(insDatas);
44 for (auto &in_data : insDatas) {
45 for (auto layer : in_data->inputTo) {
46 _inputLayers.insert(layer.second);
50 if (_inputLayers.empty() || _allSortedLayers.empty())
51 THROW_IE_EXCEPTION << "Unsupported model for shape inference: failed to collect inputs and layers";
53 for (auto const& currentLayer : _allSortedLayers) {
54 auto createdLauncher = launcherCreator->createNotInputLauncher(currentLayer.get(), _extensions);
55 _launchers.insert(createdLauncher);
59 Reshaper::Reshaper(ICNNNetwork& network, const LauncherCreator::Ptr& launcherCreator): network(nullptr) {
60 auto builtIn = std::make_shared<BuiltInShapeInferHolder>();
61 _allTypes = getTypeNamesFromExtension(builtIn);
62 _extensions.push_back(builtIn);
64 auto inputLayers = CNNNetGetAllInputLayers(network);
65 for (const auto& layer : inputLayers) {
66 _inputLayers.insert(layer);
69 _allSortedLayers = CNNNetSortTopologically(network);
70 if (_inputLayers.empty() || _allSortedLayers.empty())
71 THROW_IE_EXCEPTION << "Unsupported model for shape inference: failed to collect inputs and layers";
72 for (auto const& currentLayer : _allSortedLayers) {
73 auto foundInput = std::find_if(_inputLayers.begin(), _inputLayers.end(),
74 [¤tLayer](const CNNLayerPtr& inputLayer) {
75 return currentLayer->name == inputLayer->name;
77 ReshapeLauncher::Ptr createdLauncher;
78 if (foundInput == _inputLayers.end()) {
79 createdLauncher = launcherCreator->createNotInputLauncher(currentLayer.get(), _extensions);
81 createdLauncher = launcherCreator->createInputLauncher(currentLayer.get(), _extensions);
83 _launchers.insert(createdLauncher);
87 void Reshaper::AddExtension(const IShapeInferExtensionPtr& extension) {
88 if (!extension) THROW_IE_EXCEPTION << "Failed to add empty shape infer extension";
91 network->getContext().addExtension(extension);
95 auto newLayerTypes = getTypeNamesFromExtension(extension);
96 std::string badLayerTypes;
97 for (const auto& type : newLayerTypes) {
98 auto ret = _allTypes.insert(type);
100 if (!badLayerTypes.empty())
101 badLayerTypes += ", ";
102 badLayerTypes += type;
105 if (!badLayerTypes.empty())
106 THROW_IE_EXCEPTION << "Failed to add extension with already registered types:" << badLayerTypes;
108 for (auto const& layerType : newLayerTypes) {
109 auto foundLauncher = _launchers.begin();
110 // find all layers with given type
111 std::vector<ReshapeLauncher::Ptr> launchersToInsert;
112 while (foundLauncher != _launchers.end()) {
113 foundLauncher = std::find_if(foundLauncher, _launchers.end(),
114 [&layerType](const ReshapeLauncher::Ptr& launcher) {
115 return layerType == launcher->getLayerType();
117 if (foundLauncher != _launchers.end()) {
118 IShapeInferImpl::Ptr impl;
119 StatusCode sts = extension->getShapeInferImpl(impl, layerType.c_str(), nullptr);
120 if (sts == OK && impl != nullptr) {
121 auto newLauncher = std::make_shared<ReshapeLauncher>((*foundLauncher)->getLayer(), impl);
122 newLauncher->setShapeInferImpl(impl);
123 launchersToInsert.push_back(newLauncher);
124 foundLauncher = _launchers.erase(foundLauncher);
126 THROW_IE_EXCEPTION << "Failed to get registered Shape Infer Implementation for type: " << layerType;
130 for (const auto& launcher : launchersToInsert) {
131 _launchers.insert(launcher);
134 _extensions.push_back(extension);
137 ReshapeLauncher::Ptr Reshaper::getLauncherByLayerName(const std::string& layerName) const {
138 auto foundLauncher = std::find_if(_launchers.begin(), _launchers.end(),
139 [&layerName](const ReshapeLauncher::Ptr& launcher) {
140 return launcher->getLayerName() == layerName;
142 if (foundLauncher == _launchers.end())
143 THROW_IE_EXCEPTION << "Failed to reshape layer ('" << layerName << "'): can't find the corresponding launcher";
144 return *foundLauncher;
147 StatusCode Reshaper::run(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp) {
149 return networkShapeInfer(inputShapes, resp);
151 // Reset all shapes from previous run
152 for (const auto& launcher : _launchers) {
156 // Set new input shapes
157 for (auto const& input : _inputLayers) {
158 std::string layerName = input->name;
159 for (auto const& outData : input->outData) {
160 std::string dataName = outData->name;
161 auto foundShapeIt = inputShapes.find(dataName);
162 auto foundLauncher = getLauncherByLayerName(layerName);
163 if (foundShapeIt != inputShapes.end()) {
164 foundLauncher->setShapeByName(foundShapeIt->second, dataName);
166 foundLauncher->setIRShapeByName(dataName);
172 for (auto& layer : _allSortedLayers) {
173 auto foundLauncher = getLauncherByLayerName(layer->name);
174 foundLauncher->reshape(_launchers);
175 foundLauncher->constInfer(_launchers);
179 for (auto& layer : _allSortedLayers) {
180 auto foundLauncher = getLauncherByLayerName(layer->name);
181 foundLauncher->applyChanges(layer.get());
186 StatusCode Reshaper::runNoApply(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp) {
187 // Reset all shapes from previous run
188 for (const auto& launcher : _launchers) {
192 // Set new input shapes
193 for (auto const& input : _inputLayers) {
194 std::string layerName = input->name;
195 for (auto const& inData_w : input->insData) {
196 auto inData = inData_w.lock();
197 auto dataName = inData->name;
198 auto foundShapeIt = inputShapes.find(dataName);
199 auto foundLauncher = getLauncherByLayerName(layerName);
200 if (foundShapeIt != inputShapes.end()) {
201 foundLauncher->setShapeByName(foundShapeIt->second, dataName);
203 foundLauncher->setIRShapeByName(dataName);
209 for (auto& layer : _allSortedLayers) {
210 auto foundLauncher = getLauncherByLayerName(layer->name);
211 foundLauncher->reshape(_launchers);
216 StatusCode Reshaper::apply(ResponseDesc* resp) {
218 for (auto& layer : _allSortedLayers) {
219 auto foundLauncher = getLauncherByLayerName(layer->name);
220 foundLauncher->applyChanges(layer.get());
225 SizeVector Reshaper::getResultShapeFor(DataPtr &data, ResponseDesc* resp) {
226 auto creator_layer = data->creatorLayer.lock();
227 std::string creator_layer_name;
229 creator_layer_name = creator_layer->name;
231 auto foundLauncher = getLauncherByLayerName(creator_layer_name);
232 return foundLauncher->getShapeByName(data->getName());
235 StatusCode Reshaper::networkShapeInfer(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp) {
237 return DescriptionBuffer(GENERAL_ERROR, resp) << "Cannot infer shapes! Network is not loaded.";
238 std::vector<Builder::Layer> propagatedLayers;
239 Builder::Network propagatedNetwork(*network);
241 // Set new input shapes
242 for (auto& layer : propagatedNetwork) {
243 if (inputShapes.find(layer->getName()) == inputShapes.end() ||
244 details::CaselessEq<std::string>()(layer->getType(), "Const"))
247 if (layer->getOutputPorts().size() != 1)
248 return DescriptionBuffer(GENERAL_ERROR, resp) << "Cannot infer shapes! Input layers can have only one output port.";
250 layer->getOutputPorts()[0].setShape(inputShapes.find(layer->getName())->second);
253 std::map<idx_t, std::map<std::string, std::string>> preparedParams;
254 // Prepare params for split layer
255 for (auto& layer : propagatedNetwork) {
256 if ((layer->getType() == "Reshape" || layer->getType() == "Flatten") &&
257 layer->getInputPorts().size() != 2 && !layer->getInputPorts()[0].shape().empty() &&
258 layer->getParameters().find("axis") != layer->getParameters().end() &&
259 (layer->getParameters().find("dim") == layer->getParameters().end() ||
260 layer->getParameters().at("dim").as<std::vector<int>>().empty())) {
261 auto inputShape = layer->getInputPorts()[0].shape();
262 size_t inputShapeTotal = std::accumulate(inputShape.begin(), inputShape.end(), 1lu,
263 std::multiplies<size_t>());
264 std::vector<int> dim;
265 size_t axis = layer->getParameters().at("axis");
266 for (size_t i = 0; i < axis; i++) {
267 dim.emplace_back(inputShape[i]);
268 inputShapeTotal /= inputShape[i];
270 if (dim.size() < inputShape.size())
271 dim.emplace_back(inputShapeTotal);
272 layer->getParameters()["dim"] = dim;
275 std::map<std::string, std::string> params = InferenceEngine::Builder::convertParameters2Strings(layer->getParameters());
276 if (layer->getType() == "Split") {
277 Builder::SplitLayer splitLayer(layer);
278 std::vector<size_t> sizes;
279 size_t axisSize = splitLayer.getInputPort().shape()[splitLayer.getAxis()];
280 size_t uninitOuts(0);
281 for (const auto& port : layer->getOutputPorts()) {
282 if (port.shape().empty()) {
285 } else if (port.shape().size() <= splitLayer.getAxis()) {
286 THROW_IE_EXCEPTION << "Incorrect output shapes in Split layer " << layer->getName();
288 sizes.push_back(port.shape()[splitLayer.getAxis()]);
289 axisSize -= port.shape()[splitLayer.getAxis()];
293 if ((axisSize && !uninitOuts) || (axisSize && uninitOuts && axisSize % uninitOuts))
294 THROW_IE_EXCEPTION << "Incorrect output shapes in Split layer " << layer->getName();
296 size_t commonSize = uninitOuts != 0 ? axisSize / uninitOuts : 0;
297 for (size_t i = 0; i < sizes.size() && commonSize; i++) {
299 sizes[i] = commonSize;
302 std::string out_sizes;
303 for (const auto& size : sizes) {
304 if (!out_sizes.empty())
306 out_sizes += std::to_string(size);
308 if (!out_sizes.empty())
309 params["out_sizes"] = out_sizes;
312 preparedParams[layer->getId()] = params;
315 // Try to propagate shapes
316 for (auto& layer : propagatedNetwork) {
317 // constant layer does not change during the shape inference and also the Const blob always has C layout and
318 // doesn't know its real shape, so don't run shape propagation for it
319 if (details::CaselessEq<std::string>()(layer->getType(), "Const"))
321 const auto impl = network->getContext().getShapeInferImpl(layer->getType());
323 return DescriptionBuffer(NOT_FOUND, resp) <<
324 "Cannot infer shapes! Shape infer implementation was not found for type " << layer->getType() << ".";
325 std::vector<SizeVector> inShapes;
326 std::vector<SizeVector> outShapes;
327 std::map<std::string, std::string> params;
328 std::map<std::string, Blob::Ptr> blobs;
330 std::vector<Blob::CPtr> inBlobs;
331 for (const auto& inPort : layer->getInputPorts().empty() ? layer->getOutputPorts() : layer->getInputPorts()) {
332 if (inPort.getParameters().find("type") == inPort.getParameters().end()) {
333 inBlobs.push_back(inPort.getData()->getData());
336 params = preparedParams[layer->getId()];
338 for (const auto& port : layer->getInputPorts()) {
339 if (port.getParameters().find("type") == port.getParameters().end() ||
340 port.getData()->getData()->cbuffer() == nullptr)
342 blobs[port.getParameters().at("type")] = port.getData()->getData();
344 for (const auto& it : layer->getParameters()) {
345 if (!it.second.is<Blob::CPtr>())
347 blobs[it.first] = std::const_pointer_cast<Blob>(it.second.as<Blob::CPtr>());
350 StatusCode sts = impl->inferShapes(inBlobs, params, blobs, outShapes, resp);
354 if (outShapes.size() != layer->getOutputPorts().size())
355 return DescriptionBuffer(GENERAL_ERROR, resp) << "Cannot infer shapes! The number of output shapes is not "
356 "equal the number of output ports for layer "
359 for (size_t i = 0; i < outShapes.size(); i++) {
360 layer->getOutputPorts()[i].setShape(outShapes[i]);
362 for (const auto& connection : propagatedNetwork.getLayerConnections(layer->getId())) {
363 if (connection.from().layerId() != layer->getId())
365 auto nextLayer = propagatedNetwork.getLayer(connection.to().layerId());
366 nextLayer->getInputPorts()[connection.to().portId()].setShape(outShapes[connection.from().portId()]);
371 for (auto& layer : *network) {
372 const auto& propagatedLayer = propagatedNetwork.getLayer(layer->getId());
373 for (size_t i = 0; i < layer->getInputPorts().size(); i++) {
374 layer->getInputPorts()[i].setShape(propagatedLayer->getInputPorts()[i].shape());
376 for (size_t i = 0; i < layer->getOutputPorts().size(); i++) {
377 layer->getOutputPorts()[i].setShape(propagatedLayer->getOutputPorts()[i].shape());
383 caseless_set<std::string> Reshaper::getTypeNamesFromExtension(const IShapeInferExtensionPtr& extension) {
384 char** types = nullptr;
385 unsigned int size = 0;
387 StatusCode sts = extension->getShapeInferTypes(types, size, &resp);
388 if (sts != OK) THROW_IE_EXCEPTION << "Failed to get types from extension: " << resp.msg;
389 caseless_set<std::string> typesSet;
390 for (int i = 0; i < size; i++) {
391 std::string type(types[i], strlen(types[i]));
393 typesSet.insert(type);
400 LauncherCreator::createNotInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions) {
401 auto layerType = layer->type;
402 if ((::details::equal(layerType, "memory") && layer->GetParamAsInt("index")) ||
403 ::details::equal(layerType, "const") || ::details::equal(layerType, "input")) {
404 THROW_IE_EXCEPTION << "Failed to reshape: Layer with type `" << layerType
405 << "` can't be intermediate layer in network";
408 for (const auto& extension : extensions) {
409 IShapeInferImpl::Ptr impl = nullptr;
410 StatusCode sts = extension->getShapeInferImpl(impl, layerType.c_str(), nullptr);
411 if (sts == OK && impl != nullptr) {
412 if (::details::equal(layerType, "memory") && !layer->GetParamAsInt("index")) {
413 return std::make_shared<OutMemoryReshapeLauncher>(layer, nullptr);
415 return std::make_shared<ReshapeLauncher>(layer, impl);
418 return std::make_shared<FakeReshapeLauncher>(layer, nullptr);
422 LauncherCreator::createInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions) {
423 auto layerType = layer->type;
424 if (::details::equal(layerType, "memory") && layer->GetParamAsInt("index")) {
425 return std::make_shared<InputReshapeLauncher>(layer, nullptr);
426 } else if (::details::equal(layerType, "const")) {
427 return std::make_shared<ConstReshapeLauncher>(layer, nullptr);
428 } else if (::details::equal(layerType, "input")) {
429 return std::make_shared<InputReshapeLauncher>(layer, nullptr);
431 THROW_IE_EXCEPTION << "Failed to reshape: Layer with type `" << layerType
432 << "` can't be input. Supported input types: Input, Const and Memory(with index=1)";