2 // Copyright (c) 2017 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "constants_propagator.h"
18 #include "engine_impl.h"
19 #include "program_impl.h"
20 #include "network_impl.h"
21 #include "memory_impl.h"
23 #include "api/CPP/input_layout.hpp"
25 using namespace cldnn;
27 constants_propagator::constants_propagator(program_impl::ptr program) : prog(program)
31 void constants_propagator::visit_node(program_node& node)
33 if (node.is_constant())
34 handle_constant(node);
37 std::list<std::pair<primitive_id, memory_impl::ptr>> constants_propagator::calculate()
39 if (!has_non_trivial_constants)
43 bo.set_option(build_option::optimize_data(false));
44 bo.set_option(build_option::outputs(const_outputs));
45 network_impl::ptr net = prog->get_engine().build_network(tpl, bo, true);
46 for (auto& cin : const_inputs)
47 net->set_input_data(cin->id(), cin->get_attached_memory());
50 net->reset_execution(true); //wait for computations to complete
51 auto outputs = net->get_outputs();
53 std::list<std::pair<primitive_id, memory_impl::ptr>> ret;
54 for (auto& out : outputs)
55 ret.push_back({ out->id(), &out->output_memory() });
60 void constants_propagator::handle_constant(program_node& node)
62 if (!node.is_type<data>())
65 if (node.has_non_const_user())
66 const_outputs.push_back(node.id());
70 void constants_propagator::add_constant(program_node& node)
72 if (node.is_type<data>())
76 has_non_trivial_constants = true;
78 //if a node is either an endpoint or an output, always add it as an output
79 if (node.is_endpoint() || node.is_output())
80 const_outputs.push_back(node.id());
82 //if a non-tirivial constant has a trivial input, add this input as an input for our network
83 add_deps_to_tpl(node.get_dependencies());
86 void constants_propagator::add_deps_to_tpl(const std::vector<program_node*>& deps)
89 Nodes can share dependencies, if we already have dep in tpl, don't add it again.
96 for (auto& dep : deps)
98 if (dep->is_type<data>())
100 if (is_already_in_tpl(dep->id())) continue;
101 tpl.add(std::make_shared<input_layout>(dep->id(), dep->as<data>().get_primitive()->mem.get_layout()));
102 const_inputs.push_back(&dep->as<data>());
107 bool constants_propagator::is_already_in_tpl(const primitive_id& id)
109 for (auto const& id_in_tpl : tpl.get_primitives_id())
111 if (id == id_in_tpl) return true;