Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / graph_optimizer / trim_to_outputs.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18
19 #include "pass_manager.h"
20
21 //ToDo: remove those include with the appropriate code below once we will have support for multiple outputs of a primitive
22 #include "batch_norm_inst.h"
23 #include "max_unpooling_inst.h"
24 #include "pooling_inst.h"
25
26 using namespace cldnn;
27
28 //This pass optimizes out nodes which have no impact on outputs
29 void trim_to_outputs::run(program_impl& p)
30 {
31     const size_t actual_nodes = p.get_processing_order().size();
32     if (!actual_nodes) //degenerated case but can happen
33         return;
34
35     if (p.get_outputs().size() == actual_nodes)
36         return;
37
38     //do backward bfs starting from all outputs
39     std::list<const std::vector<program_node*>*> stack = { &(p.get_outputs()) };
40
41     std::vector<program_node*> special_nodes;
42     for (auto& node : p.get_processing_order())
43     {
44         if (node->is_type<input_layout>() ||  //input layout may become disconnected during prior boxes calculations so it may have not been marked at this place but we don't want to remove it
45             node->is_type<max_unpooling>() || // ToDo: remove this after support for multi-outputs in primitives will be implemented.
46             node->is_type<batch_norm>() ||
47             (node->is_type<pooling>() && node->as<pooling>().get_primitive()->mode == pooling_mode::max_with_argmax))
48                 special_nodes.push_back(node);
49     }
50     stack.push_back(&special_nodes);
51
52     while (!stack.empty())
53     {
54         auto nodes_list = stack.front();
55         stack.pop_front();
56
57         for (auto& node : *nodes_list)
58         {
59             if (!node->is_marked())
60             {
61                 node->mark();
62                 if (!node->get_dependencies().empty())
63                     stack.push_back(&node->get_dependencies());
64             }
65         }
66     }
67
68     //all not-marked nodes should be removed
69     std::list<program_node*> to_rem;
70     for (auto& node : p.get_processing_order())
71     {
72         if (!node->is_marked())
73             to_rem.push_back(node);
74     }
75     p.remove_nodes(to_rem);
76 }