2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "pass_manager.h"
21 //ToDo: remove those include with the appropriate code below once we will have support for multiple outputs of a primitive
22 #include "batch_norm_inst.h"
23 #include "max_unpooling_inst.h"
24 #include "pooling_inst.h"
26 using namespace cldnn;
28 //This pass optimizes out nodes which have no impact on outputs
29 void trim_to_outputs::run(program_impl& p)
31 const size_t actual_nodes = p.get_processing_order().size();
32 if (!actual_nodes) //degenerated case but can happen
35 if (p.get_outputs().size() == actual_nodes)
38 //do backward bfs starting from all outputs
39 std::list<const std::vector<program_node*>*> stack = { &(p.get_outputs()) };
41 std::vector<program_node*> special_nodes;
42 for (auto& node : p.get_processing_order())
44 if (node->is_type<input_layout>() || //input layout may become disconnected during prior boxes calculations so it may have not been marked at this place but we don't want to remove it
45 node->is_type<max_unpooling>() || // ToDo: remove this after support for multi-outputs in primitives will be implemented.
46 node->is_type<batch_norm>() ||
47 (node->is_type<pooling>() && node->as<pooling>().get_primitive()->mode == pooling_mode::max_with_argmax))
48 special_nodes.push_back(node);
50 stack.push_back(&special_nodes);
52 while (!stack.empty())
54 auto nodes_list = stack.front();
57 for (auto& node : *nodes_list)
59 if (!node->is_marked())
62 if (!node->get_dependencies().empty())
63 stack.push_back(&node->get_dependencies());
68 //all not-marked nodes should be removed
69 std::list<program_node*> to_rem;
70 for (auto& node : p.get_processing_order())
72 if (!node->is_marked())
73 to_rem.push_back(node);
75 p.remove_nodes(to_rem);