2 // Copyright (c) 2017 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "program_node.h"
18 #include "program_impl.h"
19 #include "primitive_inst.h"
20 #include "to_string_utils.h"
21 #include "json_object.h"
24 using namespace cldnn;
26 program_node::program_node(std::shared_ptr<primitive> prim, program_impl & prog) : desc(prim), myprog(prog), org_id(prim->id)
29 output_layout.data_padding = prim->output_padding;
32 void program_node::replace_dependency(size_t idx, program_node& new_dep)
34 if (idx >= dependencies.size())
36 if (dependencies[idx] == &new_dep)
39 dependencies[idx]->users.remove(this);
40 myprog.remove_if_dangling(*dependencies[idx]);
42 dependencies[idx] = &new_dep;
43 new_dep.users.push_back(this);
46 void program_node::replace_dependency(program_node const& old_dep, program_node& new_dep)
48 for (size_t i = 0; i < dependencies.size(); ++i)
49 if (dependencies[i] == &old_dep)
50 return replace_dependency(i, new_dep);
53 std::vector<primitive_id> program_node::get_dependencies_ids() const
55 std::vector<primitive_id> dep_ids;
56 for (auto& dependency : dependencies)
57 dep_ids.push_back(dependency->get_primitive()->id);
61 void program_node::remove_dependency(size_t idx)
63 if (idx >= dependencies.size())
66 dependencies[idx]->users.remove(this);
67 myprog.remove_if_dangling(*dependencies[idx]);
68 dependencies.erase(dependencies.begin() + idx);
71 std::set<primitive_id> program_node::get_memory_dependencies() const
73 return memory_dependencies;
76 void program_node::add_memory_dependency(primitive_id prim)
78 memory_dependencies.insert(prim);
81 void program_node::add_memory_dependency(std::vector<primitive_id> prim_list)
83 memory_dependencies.insert(prim_list.begin(),prim_list.end());
86 std::unique_ptr<json_composite> program_node::desc_to_json() const
88 std::unique_ptr<json_composite> node_info = std::unique_ptr<json_composite>(new json_composite());
89 node_info->add("ptr", "node_" + std::to_string(reinterpret_cast<uintptr_t>(this)));
90 node_info->add("id", id());
91 node_info->add("type", get_extr_type(typeid(*this).name()));
92 node_info->add("internal", bool_to_str(this->is_type<internal_primitive>()));
93 node_info->add("valid output layout", bool_to_str(valid_output_layout));
95 json_composite output_layout_info;
96 output_layout_info.add("data type", dt_to_str(output_layout.data_type));
97 output_layout_info.add("format", fmt_to_str(output_layout.format));
98 output_layout_info.add("size", output_layout.size.to_string());
100 json_composite padding_info;
101 padding_info.add("lower size", output_layout.data_padding.lower_size().to_string());
102 padding_info.add("upper size", output_layout.data_padding.upper_size().to_string());
103 output_layout_info.add("padding info", padding_info);
105 node_info->add("output layout", output_layout_info);
107 node_info->add("in data flow", bool_to_str(data_flow));
108 node_info->add("constant", bool_to_str(constant));
109 node_info->add("in data flow", bool_to_str(data_flow));
110 node_info->add("output", bool_to_str(output));
113 std::vector<std::string> deps_ptrs;
116 auto itr = dependencies.begin();
117 while (itr != dependencies.end())
123 deps_ptrs.push_back(std::to_string(reinterpret_cast<uintptr_t>(*itr++)));
125 if (deps_ptrs.empty())
127 deps_ptrs.push_back("null");
130 node_info->add("dependencies", deps_ptrs);
132 std::vector<std::string> users_ptrs;
135 auto itr = users.begin();
136 while (itr != users.end())
142 users_ptrs.push_back(std::to_string(reinterpret_cast<uintptr_t>(*itr++)));
144 if (users_ptrs.empty())
146 users_ptrs.push_back("null");
149 node_info->add("users", users_ptrs);
150 std::vector<std::string> impls;
153 impls.push_back("null");
158 #pragma clang diagnostic push
159 #pragma clang diagnostic ignored "-Wpotentially-evaluated-expression"
161 impls.push_back(selected_impl->get_kernel_name());
163 #pragma clang diagnostic pop
166 node_info->add("implementation", impls);
170 void program_node::remove_dependency(program_node & node)
172 for (size_t i = 0; i < dependencies.size(); ++i)
173 if (dependencies[i] == &node)
174 remove_dependency(i);
177 bool program_node::is_detached(bool whole_branch)
181 if (!whole_branch && !dependencies.empty())
186 layout program_node::calc_output_layout() const
188 return type()->calc_output_layout(*this);
191 layout program_node::get_output_layout(bool invalidate_users_if_changed)
193 if (valid_output_layout)
194 return output_layout;
196 auto new_layout = calc_output_layout();
197 set_output_layout(new_layout, invalidate_users_if_changed);
201 layout program_node::get_output_layout() const
203 if (!valid_output_layout)
204 throw std::runtime_error("Output layout not calculated");
206 return output_layout;
209 layout program_node::get_non_padded_output_layout(bool invalidate_users_if_changed)
211 auto out_layout = get_output_layout(invalidate_users_if_changed);
212 auto result = layout({ out_layout.data_type, out_layout.format, out_layout.size });
216 bool program_node::set_output_layout(layout new_layout, bool invalidate_users_if_changed)
218 merge_output_padding(new_layout.data_padding);
219 new_layout.data_padding = output_layout.data_padding;
220 bool changed = (new_layout != output_layout);
221 if (changed && invalidate_users_if_changed) //output_layout has changed! invalidate users
224 output_layout = new_layout;
225 valid_output_layout = true;
229 bool program_node::recalc_output_layout(bool invalidate_users_if_changed)
231 return set_output_layout(calc_output_layout(), invalidate_users_if_changed);
234 bool program_node::has_padded_dependency()
236 return std::any_of(get_dependencies().begin(), get_dependencies().end(), [](program_node* node) { return node->is_padded(); });
239 bool program_node::has_padded_dependency() const
241 return std::any_of(get_dependencies().begin(), get_dependencies().end(), [](const program_node* node) { return node->is_padded(); });
244 void program_node::invalidate_users() const
246 for (auto& user : users)
248 if (user->valid_output_layout)
250 user->valid_output_layout = false;
251 user->invalidate_users();
256 primitive_id details::internal_program_node_base::get_next_internal_id()
258 static std::atomic<uint64_t> counter{ 0 };
259 auto idx = counter++;
260 return primitive_id("_cldnn_internal_") + std::to_string(idx);
263 details::internal_program_node_base::internal_program_node_base(program_impl & prog) : program_node(nullptr, prog), internal_id(get_next_internal_id())
267 void details::internal_program_node_base::set_implementation(std::unique_ptr<primitive_impl>&& impl)
269 selected_impl = std::move(impl);