updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / graph_transformer.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "graph_transformer.h"
6
7 #include <vector>
8 #include <string>
9 #include <iterator>
10 #include <map>
11 #include <memory>
12
13 #include <cpp/ie_cnn_network.h>
14 #include <details/ie_cnn_network_tools.h>
15 #include <details/caseless.hpp>
16 #include "network_serializer.h"
17 #include "cnn_network_impl.hpp"
18 #include "blob_factory.hpp"
19 #include "graph_tools.hpp"
20 #include <shape_infer/const_infer/ie_const_infer_holder.hpp>
21
22 using namespace InferenceEngine;
23 using namespace InferenceEngine::details;
24
25 namespace InferenceEngine {
26
27 void checkConstWithBlobs(const CNNLayerPtr& layer) {
28     if (layer == nullptr) {
29         THROW_IE_EXCEPTION << "Invalid argument: layer is nullable";
30     }
31     if (layer->type != "Const") {
32         THROW_IE_EXCEPTION << "Unexpected layer type '" << layer->name << "'";
33     }
34     if (layer->blobs.size() != 1) {
35         THROW_IE_EXCEPTION << "Unexpected blobs count " << layer->blobs.size() << " for layer '" << layer->name << "'";
36     }
37     if (layer->insData.size() != 0) {
38         THROW_IE_EXCEPTION << "Unexpected inputs count " << layer->insData.size() << " for layer '" << layer->name << "'";
39     }
40     if (layer->outData.size() != 1) {
41         THROW_IE_EXCEPTION << "Unexpected outputs count " << layer->outData.size() << " for layer '" << layer->name << "'";
42     }
43 }
44
45 ConstTransformer::ConstTransformer(details::CNNNetworkImpl* _network) {
46     if (!_network)
47         THROW_IE_EXCEPTION << "[ERROR]: Failed to init ConstTransformer with null pointer of network";
48
49     network = _network;
50 }
51
52 std::vector<std::string>
53 ConstTransformer::foldConstSubgraphsInternal(const std::map<std::string, bool>& constLayers, const BlobMap& constData,
54                                                const std::vector<CNNLayerPtr>& sortedLayers) {
55     std::vector<std::string> remainingConstLayers;
56     for (const auto& layer : sortedLayers) {
57         if (constLayers.find(layer->name) != constLayers.end()) {
58             // const layer doesn't need parent connections -> erase them
59             for (const auto& insData : layer->insData) {
60                 auto& inputTo = insData.lock()->getInputTo();
61                 inputTo.erase(layer->name);
62                 // Note: to resolve corner case above layers can be marked as const with const data, just to be removed properly..
63                 // and maybe this logic wouldn't be needed
64                 if (inputTo.empty()) {
65                     auto creator = insData.lock()->getCreatorLayer().lock();
66                     auto it = std::find(creator->outData.begin(), creator->outData.end(), insData.lock());
67                     if (it != creator->outData.end()) {
68                         network->removeData((*it)->getName());
69                         creator->outData.erase(it);
70                     }
71                 }
72             }
73             layer->insData.clear();
74
75             if (constLayers.at(layer->name)) {
76                 for (const auto& outData : layer->outData) {
77                     for (const auto& inputTo : outData->getInputTo()) {
78                         CNNLayerPtr inputToLayer;
79                         std::string inputToName;
80                         std::tie(inputToName, inputToLayer) = inputTo;
81                         auto& insData = inputToLayer->insData;
82                         auto insDataIt = std::find_if(insData.begin(), insData.end(),
83                                                       [&outData](const DataWeakPtr& current) {
84                                                           return current.lock()->getName() == outData->getName();
85                                                       });
86                         // remove connection with const data, because for const child it's not needed, for dynamic - new one will be created
87                         if (insDataIt != insData.end()) {
88                             insDataIt = inputToLayer->insData.erase(insDataIt);
89                         }
90                     }
91                     network->removeData(outData->getName());
92                 }
93                 network->removeLayer(layer->name);
94             } else {
95                 // if only one output data is not const - do nothing, otherwise - run procedure below
96                 // note: multiple const output data requires multiple layers with blob["custom"] to keep const data
97                 bool keepConstData = layer->outData.size() == 1;
98                 if (keepConstData) {
99                     auto outData = layer->outData[0];
100                     for (const auto& inputTo : outData->getInputTo()) {
101                         if (constLayers.find(inputTo.first) != constLayers.end()) {
102                             keepConstData = false;
103                         }
104                     }
105                 }
106                 if (keepConstData) {
107                     if (!constLayers.at(layer->name)) {
108                         auto outData = layer->outData[0];
109                         if (layer->blobs.find("custom") == layer->blobs.end()) {
110                             // if there's no const data - set it
111                             const auto it = constData.find(outData->getName());
112                             if (it != constData.end()) {
113                                 layer->blobs["custom"] = it->second;
114                             }
115                         }
116                         if (layer->type != "Const") {
117                             // layer was calculated during the Const Propagation, need to hide its semantic (type, params)
118                             LayerParams layerParams{layer->name + "__" + outData->getName() + "__Const", "Const",
119                                                     layer->precision};
120                             auto newLayer = std::make_shared<CNNLayer>(layerParams);
121                             for (const auto& data : layer->outData) {
122                                 data->getCreatorLayer() = newLayer;
123                             }
124                             newLayer->outData = layer->outData;
125                             newLayer->blobs["custom"] = layer->blobs["custom"];
126                             network->removeLayer(layer->name);
127                             network->addLayer(newLayer);
128                             remainingConstLayers.push_back(newLayer->name);
129                         } else {
130                             // Layer with `Const` type should be also considered on trimming shape inputs
131                             remainingConstLayers.push_back(layer->name);
132                         }
133                     }
134                 } else {
135                     for (const auto& outData : layer->outData) {
136                         for (const auto& inputTo : outData->getInputTo()) {
137                             CNNLayerPtr inputToLayer;
138                             std::string inputToName;
139                             std::tie(inputToName, inputToLayer) = inputTo;
140                             auto& insData = inputToLayer->insData;
141                             auto insDataIt = std::find_if(insData.begin(), insData.end(),
142                                                           [&outData](const DataWeakPtr& current) {
143                                                               return current.lock()->getName() == outData->getName();
144                                                           });
145                             // remove connection with const data, because for const child it's not needed, for dynamic - new one will be created
146                             if (insDataIt != insData.end()) {
147                                 insDataIt = inputToLayer->insData.erase(insDataIt);
148                             }
149                             if (constLayers.find(inputToName) == constLayers.end()) {
150                                 // next layer is not const, need to attach const data to it via blobs["custom"] of new Const layer
151                                 LayerParams layerParams{layer->name + "__" + outData->getName() + "__Const", "Const",
152                                                         layer->precision};
153                                 auto newLayer = std::make_shared<CNNLayer>(layerParams);
154                                 remainingConstLayers.push_back(newLayer->name);
155                                 const auto it = constData.find(outData->getName());
156                                 if (it != constData.end()) {
157                                     newLayer->blobs["custom"] = it->second;
158                                 }
159                                 auto newData = std::make_shared<Data>(outData->getName() + "__" + inputToName,
160                                                                       outData->getTensorDesc());
161                                 newData->getCreatorLayer() = newLayer;
162                                 newData->getInputTo()[inputToName] = inputToLayer;
163                                 newLayer->outData = {newData};
164                                 network->addLayer(newLayer);
165                                 network->getData(newData->getName()) = newData;
166                                 inputToLayer->insData.insert(insDataIt, newData);
167                             }
168                         }
169                     }
170                     for (const auto& data : layer->outData) {
171                         network->removeData(data->getName());
172                     }
173                     network->removeLayer(layer->name);
174                 }
175             }
176         }
177     }
178     return remainingConstLayers;
179 }
180
181 const std::map<std::string, bool> ConstTransformer::getConstLayers(const std::vector<CNNLayerPtr>& sortedLayers) {
182     std::map<std::string, bool> mapConstLayers;
183     // collect all const layers, which inputs are const layers.
184     for (const auto& layer : sortedLayers) {
185         // Layers with "Shape" and "Const" type are Const by definition
186         if (layer->type == "Shape" || layer->type == "Const") {
187             mapConstLayers[layer->name] = false;
188         } else if (layer->type != "Quantize") {
189             bool isAllInputsConst = true;
190             for (auto const& data : layer->insData) {
191                 auto creatorName = data.lock()->getCreatorLayer().lock()->name;
192                 if (mapConstLayers.find(creatorName) == mapConstLayers.end()) {
193                     isAllInputsConst = false;
194                 }
195             }
196             if (isAllInputsConst && !layer->insData.empty()) mapConstLayers[layer->name] = false;
197         }
198     }
199     // Add mark for const layers, if it's used for shape taking layers as second input
200     // true - is used and can be deleted from graph, as no influence on data, false - opposite
201     std::map<std::string, bool> mapVisitedLayers = mapConstLayers;
202     for (auto rit = sortedLayers.rbegin(); rit != sortedLayers.rend(); rit++) {
203         auto currentLayer = (*rit);
204         std::string currentLayerName = currentLayer->name;
205         bool isCurrentConst = mapConstLayers.find(currentLayerName) != mapConstLayers.end();
206         for (int i = 0; i < currentLayer->insData.size(); i++) {
207             std::string creatorName;
208             if (currentLayer->insData[i].lock()) {
209                 auto creator = currentLayer->insData[i].lock()->getCreatorLayer().lock();
210                 if (creator) {
211                     creatorName = creator->name;
212                 }
213             }
214             bool isCreatorConst = mapConstLayers.find(creatorName) != mapConstLayers.end();
215             if (isCreatorConst) {
216                 // mark second const input of shape taking layers (Reshape, Interp..), if they wasn't visited before
217                 if ((i == 1) && (shapeTaking.find(currentLayer->type)) != shapeTaking.end()) {
218                     if (!mapConstLayers[creatorName]) {
219                         if (!mapVisitedLayers.at(creatorName)) {
220                             mapConstLayers[creatorName] = true;
221                         }
222                     }
223                 } else {
224                     if (isCurrentConst) {
225                         if (mapConstLayers.at(currentLayerName)) {
226                             if (!mapConstLayers[creatorName]) {
227                                 if (!mapVisitedLayers.at(creatorName)) {
228                                     mapConstLayers[creatorName] = true;
229                                 }
230                             }
231                         } else {
232                             mapConstLayers[creatorName] = false;
233                         }
234                     } else {
235                         mapConstLayers[creatorName] = false;
236                     }
237                 }
238             }
239             mapVisitedLayers[creatorName] = true;
240         }
241         mapVisitedLayers[currentLayerName] = true;
242     }
243     return mapConstLayers;
244 }
245
246 const BlobMap ConstTransformer::getConstData(const std::map<std::string, bool>& constLayers, const std::vector<CNNLayerPtr>& sortedLayers) {
247     ShapeInfer::ConstInferHolder holder;
248     BlobMap constData;
249     auto getInputBlobs = [&constData](const std::vector<DataWeakPtr>& insData,
250                                       bool isForShape) -> std::vector<Blob::CPtr> {
251         std::vector<Blob::CPtr> inputBlobs;
252         // special case of Const layers: no inputs, no input blobs
253         if (insData.empty()) {
254             return {};
255         }
256         for (const auto& data : insData) {
257             std::string dataName = data.lock()->getName();
258             if (constData.find(dataName) != constData.end()) {
259                 // get blobs, inferred before
260                 inputBlobs.push_back(constData.at(dataName));
261             } else {
262                 // special case of Shape layer: no input data, but blob contains info about dimensions, layout and etc...
263                 auto blob = make_blob_with_precision(data.lock()->getTensorDesc());
264                 inputBlobs.push_back(blob);
265             }
266         }
267         return inputBlobs;
268     };
269
270     auto getOutputBlobs = [](const std::vector<DataPtr>& outData) -> std::vector<Blob::Ptr> {
271         std::vector<Blob::Ptr> outputBlobs;
272         for (const auto& data : outData) {
273             auto blob = make_blob_with_precision(data->getTensorDesc());
274             blob->allocate();
275             outputBlobs.push_back(blob);
276         }
277         return outputBlobs;
278     };
279
280     for (const auto& layer : sortedLayers) {
281         if (constLayers.find(layer->name) != constLayers.end()) {
282             std::string layerName = layer->name;
283             bool isForShape = constLayers.at(layerName);
284             CNNLayerPtr layer;
285             ResponseDesc resp;
286             IE_ASSERT(StatusCode::OK == network->getLayerByName(layerName.c_str(), layer, &resp));
287
288             auto implPtr = holder.getConstInferImpl(layer->type);
289             if (!implPtr && !isForShape)
290                 THROW_IE_EXCEPTION << "Failed to find reference implementation for `"
291                                       + layer->name + "` Layer with `" + layer->type + "` Type on constant propagation";
292             if (!isForShape) {
293                 auto outputBlobs = getOutputBlobs(layer->outData);
294                 implPtr->infer(getInputBlobs(layer->insData, isForShape), layer->params, layer->blobs, outputBlobs);
295                 for (int i = 0; i < layer->outData.size(); i++) {
296                     std::string dataName = layer->outData[i]->getName();
297                     auto shapes = layer->outData[i]->getTensorDesc().getDims();
298                     outputBlobs[i]->getTensorDesc().reshape(shapes, TensorDesc::getLayoutByDims(shapes));
299                     constData[dataName] = outputBlobs[i];
300                 }
301             }
302         }
303     }
304     return constData;
305 }
306
307 void ConstTransformer::trimShapeInputs(const std::vector<std::string>& constLayers) {
308     for (const auto& layerName : constLayers) {
309         CNNLayerPtr layer;
310         ResponseDesc resp;
311         IE_ASSERT(StatusCode::OK == network->getLayerByName(layerName.c_str(), layer, &resp));
312
313         if (layer->outData.size() == 1 && layer->type == "Const" && layer->insData.empty()) {
314             auto constData = layer->outData[0];
315             std::map<std::string, CNNLayerPtr> inputToMap = constData->getInputTo();
316             for (const auto& inputTo : inputToMap) {
317                 CNNLayerPtr inputToLayer = inputTo.second;
318                 if (shapeTaking.find(inputToLayer->type) != shapeTaking.end()) {
319                     auto& insData = inputToLayer->insData;
320                     auto it = std::find_if(insData.begin(), insData.end(),
321                                            [&constData](const DataWeakPtr& current) {
322                                                return current.lock()->getName() == constData->getName();
323                                            });
324                     if (it != insData.end() && std::distance(insData.begin(), it) == 1) {
325                         inputToLayer->insData.erase(it);
326                         constData->getInputTo().erase(inputTo.first);
327                     }
328                 }
329             }
330             if (constData->getInputTo().empty()) {
331                 network->removeData(constData->getName());
332                 network->removeLayer(layer->name);
333             }
334         }
335     }
336 }
337
338 void ConstTransformer::foldConstSubgraphs() {
339     auto sortedLayers = details::CNNNetSortTopologically(*network);
340     auto constLayers = getConstLayers(sortedLayers);
341     auto constData = getConstData(constLayers, sortedLayers);
342     foldConstSubgraphsInternal(constLayers, constData, sortedLayers);
343 }
344
345 void ConstTransformer::fullTrim() {
346     auto sortedLayers = details::CNNNetSortTopologically(*network);
347     auto constMapLayers = getConstLayers(sortedLayers);
348     auto constData = getConstData(constMapLayers, sortedLayers);
349     auto constLayers = foldConstSubgraphsInternal(constMapLayers, constData, sortedLayers);
350     trimShapeInputs(constLayers);
351 }
352
353 void ConstTransformer::moveWeights() {
354     for (const auto& layerIt : network->allLayers()) {
355         WeightableLayer* weightableLayer = dynamic_cast<WeightableLayer*>(layerIt.second.get());
356         if ((weightableLayer != nullptr) &&
357             (CaselessEq<std::string>()(weightableLayer->type, "Convolution") || CaselessEq<std::string>()(weightableLayer->type, "FullyConnected")) &&
358             (weightableLayer->insData.size() > 1)) {
359             if (weightableLayer->insData.size() > 3) {
360                 THROW_IE_EXCEPTION << "Unexpected inputs count for " << weightableLayer->name;
361             }
362
363             const DataPtr insData = weightableLayer->insData[1].lock();
364             if (!insData) {
365                 THROW_IE_EXCEPTION << "Weights input is absent for layer " << weightableLayer->name;
366             }
367
368             InferenceEngine::Blob::Ptr weightsBlob;
369             const CNNLayerPtr weightsLayer = insData->getCreatorLayer().lock();
370             if (!weightsLayer) {
371                 THROW_IE_EXCEPTION << "Weights layer absent for layer " << weightableLayer->name;
372             }
373
374             bool removePathOnWeights = false;
375             if (CaselessEq<std::string>()(weightsLayer->type, "Const")) {
376                 checkConstWithBlobs(weightsLayer);
377
378                 weightsBlob = weightsLayer->blobs.begin()->second;
379                 network->removeData(weightsLayer->name);
380                 network->removeLayer(weightsLayer->name);
381
382                 weightableLayer->_weights = weightsBlob;
383                 weightableLayer->blobs["weights"] = weightsBlob;
384                 removePathOnWeights = true;
385             }
386
387             bool removePathOnBiases = false;
388             if (weightableLayer->insData.size() > 2) {
389                 const DataPtr insData = weightableLayer->insData[2].lock();
390                 if (!insData) {
391                     THROW_IE_EXCEPTION << "Biases input is absent for layer " << weightableLayer->name;
392                 }
393
394                 const CNNLayerPtr biasesLayer = insData->getCreatorLayer().lock();
395                 if (!biasesLayer) {
396                     THROW_IE_EXCEPTION << "Biases layer absent for layer " << weightableLayer->name;
397                 }
398
399                 checkConstWithBlobs(biasesLayer);
400
401                 weightableLayer->_biases = biasesLayer->blobs.begin()->second;
402                 weightableLayer->blobs["biases"] = weightableLayer->_biases;
403                 network->removeData(biasesLayer->name);
404                 network->removeLayer(biasesLayer->name);
405                 removePathOnBiases = true;
406             }
407
408             if (removePathOnWeights && removePathOnBiases) {
409                 weightableLayer->insData.erase(weightableLayer->insData.begin() + 1, weightableLayer->insData.end());
410             } else if (removePathOnWeights) {
411                 weightableLayer->insData.erase(weightableLayer->insData.begin() + 1, weightableLayer->insData.begin() + 2);
412             } else if (removePathOnBiases) {
413                 weightableLayer->insData.erase(weightableLayer->insData.begin() + 2, weightableLayer->insData.end());
414             }
415         }
416     }
417 }
418 }  // namespace InferenceEngine