1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "graph_transformer.h"
13 #include <cpp/ie_cnn_network.h>
14 #include <details/ie_cnn_network_tools.h>
15 #include <details/caseless.hpp>
16 #include "network_serializer.h"
17 #include "cnn_network_impl.hpp"
18 #include "blob_factory.hpp"
19 #include "graph_tools.hpp"
20 #include <shape_infer/const_infer/ie_const_infer_holder.hpp>
22 using namespace InferenceEngine;
23 using namespace InferenceEngine::details;
25 namespace InferenceEngine {
27 void checkConstWithBlobs(const CNNLayerPtr& layer) {
28 if (layer == nullptr) {
29 THROW_IE_EXCEPTION << "Invalid argument: layer is nullable";
31 if (layer->type != "Const") {
32 THROW_IE_EXCEPTION << "Unexpected layer type '" << layer->name << "'";
34 if (layer->blobs.size() != 1) {
35 THROW_IE_EXCEPTION << "Unexpected blobs count " << layer->blobs.size() << " for layer '" << layer->name << "'";
37 if (layer->insData.size() != 0) {
38 THROW_IE_EXCEPTION << "Unexpected inputs count " << layer->insData.size() << " for layer '" << layer->name << "'";
40 if (layer->outData.size() != 1) {
41 THROW_IE_EXCEPTION << "Unexpected outputs count " << layer->outData.size() << " for layer '" << layer->name << "'";
45 ConstTransformer::ConstTransformer(details::CNNNetworkImpl* _network) {
47 THROW_IE_EXCEPTION << "[ERROR]: Failed to init ConstTransformer with null pointer of network";
52 std::vector<std::string>
53 ConstTransformer::foldConstSubgraphsInternal(const std::map<std::string, bool>& constLayers, const BlobMap& constData,
54 const std::vector<CNNLayerPtr>& sortedLayers) {
55 std::vector<std::string> remainingConstLayers;
56 for (const auto& layer : sortedLayers) {
57 if (constLayers.find(layer->name) != constLayers.end()) {
58 // const layer doesn't need parent connections -> erase them
59 for (const auto& insData : layer->insData) {
60 auto& inputTo = insData.lock()->getInputTo();
61 inputTo.erase(layer->name);
62 // Note: to resolve corner case above layers can be marked as const with const data, just to be removed properly..
63 // and maybe this logic wouldn't be needed
64 if (inputTo.empty()) {
65 auto creator = insData.lock()->getCreatorLayer().lock();
66 auto it = std::find(creator->outData.begin(), creator->outData.end(), insData.lock());
67 if (it != creator->outData.end()) {
68 network->removeData((*it)->getName());
69 creator->outData.erase(it);
73 layer->insData.clear();
75 if (constLayers.at(layer->name)) {
76 for (const auto& outData : layer->outData) {
77 for (const auto& inputTo : outData->getInputTo()) {
78 CNNLayerPtr inputToLayer;
79 std::string inputToName;
80 std::tie(inputToName, inputToLayer) = inputTo;
81 auto& insData = inputToLayer->insData;
82 auto insDataIt = std::find_if(insData.begin(), insData.end(),
83 [&outData](const DataWeakPtr& current) {
84 return current.lock()->getName() == outData->getName();
86 // remove connection with const data, because for const child it's not needed, for dynamic - new one will be created
87 if (insDataIt != insData.end()) {
88 insDataIt = inputToLayer->insData.erase(insDataIt);
91 network->removeData(outData->getName());
93 network->removeLayer(layer->name);
95 // if only one output data is not const - do nothing, otherwise - run procedure below
96 // note: multiple const output data requires multiple layers with blob["custom"] to keep const data
97 bool keepConstData = layer->outData.size() == 1;
99 auto outData = layer->outData[0];
100 for (const auto& inputTo : outData->getInputTo()) {
101 if (constLayers.find(inputTo.first) != constLayers.end()) {
102 keepConstData = false;
107 if (!constLayers.at(layer->name)) {
108 auto outData = layer->outData[0];
109 if (layer->blobs.find("custom") == layer->blobs.end()) {
110 // if there's no const data - set it
111 const auto it = constData.find(outData->getName());
112 if (it != constData.end()) {
113 layer->blobs["custom"] = it->second;
116 if (layer->type != "Const") {
117 // layer was calculated during the Const Propagation, need to hide its semantic (type, params)
118 LayerParams layerParams{layer->name + "__" + outData->getName() + "__Const", "Const",
120 auto newLayer = std::make_shared<CNNLayer>(layerParams);
121 for (const auto& data : layer->outData) {
122 data->getCreatorLayer() = newLayer;
124 newLayer->outData = layer->outData;
125 newLayer->blobs["custom"] = layer->blobs["custom"];
126 network->removeLayer(layer->name);
127 network->addLayer(newLayer);
128 remainingConstLayers.push_back(newLayer->name);
130 // Layer with `Const` type should be also considered on trimming shape inputs
131 remainingConstLayers.push_back(layer->name);
135 for (const auto& outData : layer->outData) {
136 for (const auto& inputTo : outData->getInputTo()) {
137 CNNLayerPtr inputToLayer;
138 std::string inputToName;
139 std::tie(inputToName, inputToLayer) = inputTo;
140 auto& insData = inputToLayer->insData;
141 auto insDataIt = std::find_if(insData.begin(), insData.end(),
142 [&outData](const DataWeakPtr& current) {
143 return current.lock()->getName() == outData->getName();
145 // remove connection with const data, because for const child it's not needed, for dynamic - new one will be created
146 if (insDataIt != insData.end()) {
147 insDataIt = inputToLayer->insData.erase(insDataIt);
149 if (constLayers.find(inputToName) == constLayers.end()) {
150 // next layer is not const, need to attach const data to it via blobs["custom"] of new Const layer
151 LayerParams layerParams{layer->name + "__" + outData->getName() + "__Const", "Const",
153 auto newLayer = std::make_shared<CNNLayer>(layerParams);
154 remainingConstLayers.push_back(newLayer->name);
155 const auto it = constData.find(outData->getName());
156 if (it != constData.end()) {
157 newLayer->blobs["custom"] = it->second;
159 auto newData = std::make_shared<Data>(outData->getName() + "__" + inputToName,
160 outData->getTensorDesc());
161 newData->getCreatorLayer() = newLayer;
162 newData->getInputTo()[inputToName] = inputToLayer;
163 newLayer->outData = {newData};
164 network->addLayer(newLayer);
165 network->getData(newData->getName()) = newData;
166 inputToLayer->insData.insert(insDataIt, newData);
170 for (const auto& data : layer->outData) {
171 network->removeData(data->getName());
173 network->removeLayer(layer->name);
178 return remainingConstLayers;
181 const std::map<std::string, bool> ConstTransformer::getConstLayers(const std::vector<CNNLayerPtr>& sortedLayers) {
182 std::map<std::string, bool> mapConstLayers;
183 // collect all const layers, which inputs are const layers.
184 for (const auto& layer : sortedLayers) {
185 // Layers with "Shape" and "Const" type are Const by definition
186 if (layer->type == "Shape" || layer->type == "Const") {
187 mapConstLayers[layer->name] = false;
188 } else if (layer->type != "Quantize") {
189 bool isAllInputsConst = true;
190 for (auto const& data : layer->insData) {
191 auto creatorName = data.lock()->getCreatorLayer().lock()->name;
192 if (mapConstLayers.find(creatorName) == mapConstLayers.end()) {
193 isAllInputsConst = false;
196 if (isAllInputsConst && !layer->insData.empty()) mapConstLayers[layer->name] = false;
199 // Add mark for const layers, if it's used for shape taking layers as second input
200 // true - is used and can be deleted from graph, as no influence on data, false - opposite
201 std::map<std::string, bool> mapVisitedLayers = mapConstLayers;
202 for (auto rit = sortedLayers.rbegin(); rit != sortedLayers.rend(); rit++) {
203 auto currentLayer = (*rit);
204 std::string currentLayerName = currentLayer->name;
205 bool isCurrentConst = mapConstLayers.find(currentLayerName) != mapConstLayers.end();
206 for (int i = 0; i < currentLayer->insData.size(); i++) {
207 std::string creatorName;
208 if (currentLayer->insData[i].lock()) {
209 auto creator = currentLayer->insData[i].lock()->getCreatorLayer().lock();
211 creatorName = creator->name;
214 bool isCreatorConst = mapConstLayers.find(creatorName) != mapConstLayers.end();
215 if (isCreatorConst) {
216 // mark second const input of shape taking layers (Reshape, Interp..), if they wasn't visited before
217 if ((i == 1) && (shapeTaking.find(currentLayer->type)) != shapeTaking.end()) {
218 if (!mapConstLayers[creatorName]) {
219 if (!mapVisitedLayers.at(creatorName)) {
220 mapConstLayers[creatorName] = true;
224 if (isCurrentConst) {
225 if (mapConstLayers.at(currentLayerName)) {
226 if (!mapConstLayers[creatorName]) {
227 if (!mapVisitedLayers.at(creatorName)) {
228 mapConstLayers[creatorName] = true;
232 mapConstLayers[creatorName] = false;
235 mapConstLayers[creatorName] = false;
239 mapVisitedLayers[creatorName] = true;
241 mapVisitedLayers[currentLayerName] = true;
243 return mapConstLayers;
246 const BlobMap ConstTransformer::getConstData(const std::map<std::string, bool>& constLayers, const std::vector<CNNLayerPtr>& sortedLayers) {
247 ShapeInfer::ConstInferHolder holder;
249 auto getInputBlobs = [&constData](const std::vector<DataWeakPtr>& insData,
250 bool isForShape) -> std::vector<Blob::CPtr> {
251 std::vector<Blob::CPtr> inputBlobs;
252 // special case of Const layers: no inputs, no input blobs
253 if (insData.empty()) {
256 for (const auto& data : insData) {
257 std::string dataName = data.lock()->getName();
258 if (constData.find(dataName) != constData.end()) {
259 // get blobs, inferred before
260 inputBlobs.push_back(constData.at(dataName));
262 // special case of Shape layer: no input data, but blob contains info about dimensions, layout and etc...
263 auto blob = make_blob_with_precision(data.lock()->getTensorDesc());
264 inputBlobs.push_back(blob);
270 auto getOutputBlobs = [](const std::vector<DataPtr>& outData) -> std::vector<Blob::Ptr> {
271 std::vector<Blob::Ptr> outputBlobs;
272 for (const auto& data : outData) {
273 auto blob = make_blob_with_precision(data->getTensorDesc());
275 outputBlobs.push_back(blob);
280 for (const auto& layer : sortedLayers) {
281 if (constLayers.find(layer->name) != constLayers.end()) {
282 std::string layerName = layer->name;
283 bool isForShape = constLayers.at(layerName);
286 IE_ASSERT(StatusCode::OK == network->getLayerByName(layerName.c_str(), layer, &resp));
288 auto implPtr = holder.getConstInferImpl(layer->type);
289 if (!implPtr && !isForShape)
290 THROW_IE_EXCEPTION << "Failed to find reference implementation for `"
291 + layer->name + "` Layer with `" + layer->type + "` Type on constant propagation";
293 auto outputBlobs = getOutputBlobs(layer->outData);
294 implPtr->infer(getInputBlobs(layer->insData, isForShape), layer->params, layer->blobs, outputBlobs);
295 for (int i = 0; i < layer->outData.size(); i++) {
296 std::string dataName = layer->outData[i]->getName();
297 auto shapes = layer->outData[i]->getTensorDesc().getDims();
298 outputBlobs[i]->getTensorDesc().reshape(shapes, TensorDesc::getLayoutByDims(shapes));
299 constData[dataName] = outputBlobs[i];
307 void ConstTransformer::trimShapeInputs(const std::vector<std::string>& constLayers) {
308 for (const auto& layerName : constLayers) {
311 IE_ASSERT(StatusCode::OK == network->getLayerByName(layerName.c_str(), layer, &resp));
313 if (layer->outData.size() == 1 && layer->type == "Const" && layer->insData.empty()) {
314 auto constData = layer->outData[0];
315 std::map<std::string, CNNLayerPtr> inputToMap = constData->getInputTo();
316 for (const auto& inputTo : inputToMap) {
317 CNNLayerPtr inputToLayer = inputTo.second;
318 if (shapeTaking.find(inputToLayer->type) != shapeTaking.end()) {
319 auto& insData = inputToLayer->insData;
320 auto it = std::find_if(insData.begin(), insData.end(),
321 [&constData](const DataWeakPtr& current) {
322 return current.lock()->getName() == constData->getName();
324 if (it != insData.end() && std::distance(insData.begin(), it) == 1) {
325 inputToLayer->insData.erase(it);
326 constData->getInputTo().erase(inputTo.first);
330 if (constData->getInputTo().empty()) {
331 network->removeData(constData->getName());
332 network->removeLayer(layer->name);
338 void ConstTransformer::foldConstSubgraphs() {
339 auto sortedLayers = details::CNNNetSortTopologically(*network);
340 auto constLayers = getConstLayers(sortedLayers);
341 auto constData = getConstData(constLayers, sortedLayers);
342 foldConstSubgraphsInternal(constLayers, constData, sortedLayers);
345 void ConstTransformer::fullTrim() {
346 auto sortedLayers = details::CNNNetSortTopologically(*network);
347 auto constMapLayers = getConstLayers(sortedLayers);
348 auto constData = getConstData(constMapLayers, sortedLayers);
349 auto constLayers = foldConstSubgraphsInternal(constMapLayers, constData, sortedLayers);
350 trimShapeInputs(constLayers);
353 void ConstTransformer::moveWeights() {
354 for (const auto& layerIt : network->allLayers()) {
355 WeightableLayer* weightableLayer = dynamic_cast<WeightableLayer*>(layerIt.second.get());
356 if ((weightableLayer != nullptr) &&
357 (CaselessEq<std::string>()(weightableLayer->type, "Convolution") || CaselessEq<std::string>()(weightableLayer->type, "FullyConnected")) &&
358 (weightableLayer->insData.size() > 1)) {
359 if (weightableLayer->insData.size() > 3) {
360 THROW_IE_EXCEPTION << "Unexpected inputs count for " << weightableLayer->name;
363 const DataPtr insData = weightableLayer->insData[1].lock();
365 THROW_IE_EXCEPTION << "Weights input is absent for layer " << weightableLayer->name;
368 InferenceEngine::Blob::Ptr weightsBlob;
369 const CNNLayerPtr weightsLayer = insData->getCreatorLayer().lock();
371 THROW_IE_EXCEPTION << "Weights layer absent for layer " << weightableLayer->name;
374 bool removePathOnWeights = false;
375 if (CaselessEq<std::string>()(weightsLayer->type, "Const")) {
376 checkConstWithBlobs(weightsLayer);
378 weightsBlob = weightsLayer->blobs.begin()->second;
379 network->removeData(weightsLayer->name);
380 network->removeLayer(weightsLayer->name);
382 weightableLayer->_weights = weightsBlob;
383 weightableLayer->blobs["weights"] = weightsBlob;
384 removePathOnWeights = true;
387 bool removePathOnBiases = false;
388 if (weightableLayer->insData.size() > 2) {
389 const DataPtr insData = weightableLayer->insData[2].lock();
391 THROW_IE_EXCEPTION << "Biases input is absent for layer " << weightableLayer->name;
394 const CNNLayerPtr biasesLayer = insData->getCreatorLayer().lock();
396 THROW_IE_EXCEPTION << "Biases layer absent for layer " << weightableLayer->name;
399 checkConstWithBlobs(biasesLayer);
401 weightableLayer->_biases = biasesLayer->blobs.begin()->second;
402 weightableLayer->blobs["biases"] = weightableLayer->_biases;
403 network->removeData(biasesLayer->name);
404 network->removeLayer(biasesLayer->name);
405 removePathOnBiases = true;
408 if (removePathOnWeights && removePathOnBiases) {
409 weightableLayer->insData.erase(weightableLayer->insData.begin() + 1, weightableLayer->insData.end());
410 } else if (removePathOnWeights) {
411 weightableLayer->insData.erase(weightableLayer->insData.begin() + 1, weightableLayer->insData.begin() + 2);
412 } else if (removePathOnBiases) {
413 weightableLayer->insData.erase(weightableLayer->insData.begin() + 2, weightableLayer->insData.end());
418 } // namespace InferenceEngine