/*
-// Copyright (c) 2016 Intel Corporation
+// Copyright (c) 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
*/
///////////////////////////////////////////////////////////////////////////////////////////////////
+
#pragma once
+
#include "api/CPP/program.hpp"
#include "refcounted_obj.h"
-#include "topology_impl.h"
#include "engine_impl.h"
-#include "program_node.h"
-#include "memory_impl.h"
#include <list>
-#include <algorithm>
namespace cldnn
{
+struct topology_impl;
struct primitive_impl;
+struct program_node;
class layout_optimizer;
-class constants_propagator;
-
+class pass_manager;
+class program_impl_wrapper;
/*
cldnn_program implementation
*/
struct program_impl : public refcounted_obj<program_impl>
{
- friend struct program_node;
-
+ friend class calculate_prior_boxes; // to be removed when possible
+ friend class graph_initializations; // to be removed when possible
+ friend class prepare_padding; // to be removed when possible
+ friend class propagate_constants; // to be removed when possible
+ friend class prepare_primitive_fusing; // to be removed when possible
+ friend class prepare_conv_eltw_fusing; // to be removed when possible
+ friend class prepare_conv_eltw_read_write_opt; // to be removed when possible
+ friend class reorder_inputs; // to be removed when possible
+ friend class program_impl_wrapper; // this class is intended to extend the interface of program_impl for
+ // the usage within tests_core_internal project only
public:
- program_impl(engine_impl& engine_ref, topology_impl const& topology, build_options const& options, bool is_internal);
-
- void dump_memory_pool() const;
-
- engine_impl& get_engine() const { return *engine; }
- build_options get_options() const { return options; }
- bool is_debug_build() const { return options.get<build_option_type::debug>()->enabled(); }
-
- std::list<std::shared_ptr<program_node>> get_nodes() const;
- std::list<program_node*> get_processing_order() const { return processing_order; }
- std::list<primitive_id> get_optimized_out() const { return optimized_out; }
- program_node& get_node(primitive_id const& id)
+ struct nodes_ordering
{
- try
+ public:
+ typedef std::list<program_node*> list_of_nodes;
+ typedef list_of_nodes::const_iterator const_iterator;
+ typedef list_of_nodes::iterator node_iterator;
+ const_iterator begin() const { return _processing_order.begin(); }
+ const_iterator end() const { return _processing_order.end(); }
+
+ void calc_processing_order_visit(program_node* node);
+ void calc_processing_order(program_impl& p);
+ int32_t get_processing_number(program_node* node) const { return get_processing_number(get_processing_iterator(*node)); }
+ // int32_t get_processing_number(const_iterator iter) const { return 1+(int32_t)std::distance(begin(), iter); }
+ int32_t get_processing_number(node_iterator iter) const { return 1 + (int32_t)std::distance(_processing_order.begin(), const_iterator(iter)); }
+ void calculate_BFS_processing_order();
+ size_t size() { return _processing_order.size(); }
+ bool is_correct(program_node* node);
+
+ node_iterator get_processing_iterator(program_node& node) const
{
- return *nodes_map.at(id);
+ return processing_order_iterators.at(&node);
}
- catch (...)
+ void clear()
{
- throw std::runtime_error("Program doesn't contain primtive node: " + id);
+ processing_order_iterators.clear();
+ _processing_order.clear();
}
- }
- bool has_node(const primitive_id& prim) const
- {
- return nodes_map.count(prim) > 0;
- }
+ void insert(program_node* key_node, program_node* node)
+ {
+ node_iterator _where = processing_order_iterators.at(key_node);
+ processing_order_iterators[node] = _processing_order.insert(_where, node);
+ }
- program_node const& get_node(primitive_id const& id) const
- {
- try
+ void insert_next(program_node* key_node, program_node* node)
{
- return *nodes_map.at(id);
+ node_iterator _where = std::next(processing_order_iterators.at(key_node));
+ processing_order_iterators[node] = _processing_order.insert(_where, node);
}
- catch (...)
+
+ void erase(program_node* key_node)
{
- throw std::runtime_error("Program doesn't contain primtive node: " + id);
+ node_iterator i = processing_order_iterators.at(key_node);
+ processing_order_iterators.erase(key_node);
+ _processing_order.erase(i);
}
- }
+
+ private:
+ list_of_nodes _processing_order;
+ std::map<program_node*, node_iterator> processing_order_iterators;
+ };
+
+ template <class T>
+ struct single_element_container
+ {
+ single_element_container(T& t) : elem(&t)
+ {}
+ constexpr size_t size() const { return 1; }
+ single_element_container begin() const { return single_element_container(elem); }
+ single_element_container end() const { return single_element_container(nullptr); }
+ single_element_container& operator ++() { elem = nullptr; return *this; }
+ bool operator !=(single_element_container const& sec) { return elem != sec.elem; }
+
+ T operator *() { return *elem; }
+
+ private:
+ single_element_container(T* t) : elem(t)
+ {}
+
+ T* elem;
+ };
+ program_impl(engine_impl& engine_ref, topology_impl const& topology, build_options const& options, bool is_internal, bool no_optimizations=false);
+ /* constructor used to build a program from subset of nodes of other program (used in propagate_constants) */
+ program_impl(engine_impl& engine_ref, std::set<std::shared_ptr<program_node>> const &nodes, build_options const& options, bool is_internal);
+ ~program_impl();
+ engine_impl& get_engine() const { return *engine; }
+ const build_options& get_options() const { return options; }
+ std::list<program_node*>& get_inputs() { return inputs; } // ToDo: redesign trim to ouptut pass to make it const as_well as get_engine and get options
+ std::vector<program_node*>& get_outputs() { return outputs; } // ToDo: redesign reorder-inputs pass to make it const as_well as get_engine and get options
+ bool is_debug_build() const { return options.get<build_option_type::debug>()->enabled(); }
+ const nodes_ordering& get_processing_order() const;
+ nodes_ordering& get_processing_order();
+ const std::list<primitive_id>& get_optimized_out() const { return optimized_out; }
+ bool has_node(const primitive_id& prim) const { return nodes_map.count(prim) > 0; }
+ program_node& get_node(primitive_id const& id);
+ program_node const& get_node(primitive_id const& id) const;
+ std::shared_ptr<program_node> get_node_ptr(const primitive_id& prim) { return nodes_map.at(prim); }
+ std::shared_ptr<program_node> get_node_ptr(const primitive_id& prim) const { return nodes_map.at(prim); }
+ void dump_memory_pool() const;
+
+ //returns already existing program_node for given primitive 'prim' (lookup in 'nodes_map')
+ //if it was previously created, otherwise creates and then returns program_node
+ program_node& get_or_create(std::shared_ptr<primitive> prim);
+
+ // Inserts given program_node 'node' as an intermediate node between 'next' and it's
+ // dependency at 'prev_idx' index.
+ void add_intermediate(program_node& node, program_node& next, size_t prev_idx,
+ bool connect_int_node_with_old_dep = true,
+ bool move_usrs_of_prev_to_node = false);
+
+ // Gets or creates program_node for given primitive 'prim' and inserts it as an intermediate
+ // node between 'next' and it's dependency at 'prev_idx' index.
+ void add_intermediate(std::shared_ptr<primitive> prim, program_node& next, size_t prev_idx,
+ bool connect_int_node_with_old_dep = true,
+ bool move_usrs_of_prev_to_node = false);
+
+ //removes a node from the graph and deletes it afterwards,
+ //prereq: node cannot be marked as output and has to have exactly one dependency
+ //returns if 'node' has been extracted and removed successfully
+ bool extract_and_remove(program_node& node);
+
+ //returns if 'node' has been removed
+ bool remove_if_dangling(program_node& node);
+
+ void mark_if_constant(program_node& node);
+ // mark if the node is in data flow assuming that all dependencies are marked properly
+ void mark_if_data_flow(program_node& node);
+ //Reverses connection - user becomes dependency.
+
+ void remove_nodes(std::list<program_node*>& to_remove);
+ void dump_program(const char* stage, bool with_full_info, std::function<bool(program_node const&)> const& filter = nullptr) const;
private:
uint32_t prog_id = 0;
-
engine_impl::ptr engine;
build_options options;
-
std::list<program_node*> inputs;
std::vector<program_node*> outputs;
- std::list<program_node*> processing_order;
+ nodes_ordering processing_order;
+ std::unique_ptr<pass_manager> pm;
std::map<primitive_id, std::shared_ptr<program_node>> nodes_map;
-
std::list<primitive_id> optimized_out;
- // TODO: Remove once we will get full support for input/output padding in all primitive implementations.
- bool output_size_handling_enabled;
-
/*
** High-level functions, in order of usage
*/
- void init_graph(topology_impl const& topology);
- void pre_optimize_graph();
- void post_optimize_graph();
- void compile_graph();
+ /* build nodes internal structure based on topology */
+ void prepare_nodes(topology_impl const& topology);
+ /* build nodes internal structure based on the subset of nodes of other program (used in propagate_constants) */
+ void prepare_nodes(std::set<std::shared_ptr<program_node>> const& nodes);
+ void add_node_dependencies(program_node* node_ptr);
+ void copy_node_dependencies(program_node* dest, program_node* src);
+ void build_program(bool is_internal);
+ void init_graph();
+ void set_options();
+
+ void run_graph_compilation();
+ void pre_optimize_graph(bool is_internal);
+ void post_optimize_graph(bool is_internal);
void cleanup();
/*
- ** Initialization functions
- */
- void set_outputs();
- void calc_processing_order();
- void calc_prior_boxes();
-
- /*
** Analysis functions
*/
- void mark_constants();
- void mark_data_flow();
// TODO: Remove once we will get full support for input/output padding in all primitive implementations.
- void analyze_output_size_handling_need();
- void replace_nodes_pre();
- void replace_nodes_post();
- void handle_lstm();
+ bool analyze_output_size_handling_need();
+
+ // handle split, deconvolution and upsampling
void handle_reshape();
/*
** Optimization functions
*/
- void trim_to_outputs();
- void remove_redundant_reorders();
- void calculate_BFS_processing_order();
- void reorder_inputs(layout_optimizer& lo);
- void pre_optimize_bias(layout_optimizer& lo);
- void post_optimize_weights(layout_optimizer& lo);
void apply_needed_padding(program_node& node, program_node& prev_node, const padding& needed_padding);
- void prepare_padding();
- void propagate_constants();
- void prepare_buffer_fusing();
- void fuse_skip_layers(program_node* node);
- void prepare_primitive_fusing();
- void prepare_depthwise_sep_opt();
- void prep_opt_depthwise_sep_post();
- void update_processing_numbers();
/*
** Memory pool functions
/*
** Utilities
*/
+ void add_split_outputs();
+ // mark if the node is constant assuming that all dependencies are marked properly
+ void reverse_connection(program_node& dep_node, program_node& user_node);
- //returns already existing program_node for given primitive 'prim' (lookup in 'nodes_map')
- //if it was previously created, otherwise creates and then returns program_node
- program_node& get_or_create(std::shared_ptr<primitive> prim);
-
- // Inserts given program_node 'node' as an intermediate node between 'next' and it's
- // dependency at 'prev_idx' index.
- void add_intermediate(program_node& node, program_node& next, size_t prev_idx, bool connect_int_node_with_old_dep = true);
-
- // Gets or creates program_node for given primitive 'prim' and inserts it as an intermediate
- // node between 'next' and it's dependency at 'prev_idx' index.
- void add_intermediate(std::shared_ptr<primitive> prim, program_node& next, size_t prev_idx, bool connect_int_node_with_old_dep = true)
- {
- add_intermediate(get_or_create(prim), next, prev_idx, connect_int_node_with_old_dep);
- }
-
- void add_connection(program_node& prev, program_node& next)
- {
- prev.users.push_back(&next);
- next.dependencies.push_back(&prev);
- }
+ void add_connection(program_node& prev, program_node& next);
- void remove_connection(program_node& prev, program_node& next)
- {
- prev.users.remove(&next);
- next.dependencies.erase(std::remove(next.dependencies.begin(), next.dependencies.end(), &prev), next.dependencies.end());
- }
-
- void remove_all_connections(program_node& node) {
- // since the graph is not topological sorted, we need to remove the node from both dependencies and users
- for (auto &e : node.users) {
- e->dependencies.erase(std::remove(e->dependencies.begin(), e->dependencies.end(), &node), e->dependencies.end());
- }
- for(auto &e : node.dependencies) {
- e->users.remove(&node);
- }
- node.dependencies.clear();
- node.users.clear();
- }
+ void remove_connection(program_node& prev, program_node& next);
- bool processing_order_is_correct(program_node* node)
- {
- for (auto& dep : node->get_dependencies())
- {
- if (node->processing_num < dep->processing_num)
- {
- return false;
- }
- }
- return true;
- }
+ void remove_all_connections(program_node& node);
void rename(program_node & node, primitive_id const & new_id);
void swap_names(program_node& node1, program_node& node2);
//old_node - node which will be replaced
//new_node - node which will replace the old one
- //replace_whole_branch - if set to true, 'old_node' will be replaced with all its dependencies and new_node will retain its dependencies
- // old's dependencies which are post-dominates by 'old_node' will also be removed
- void replace(program_node& old_node, program_node& new_node, bool replace_whole_branch, bool check_output_layouts_integrity = true);
-
- //returns if 'node' has been removed
- bool remove_if_dangling(program_node& node, bool detach_whole_branch = false);
-
- //removes a node from the graph and deletes it afterwards,
- //prereq: node cannot be marked as output and has to have exactly one dependency
- //returns if 'node' has been extracted and removed successfully
- bool extract_and_remove(program_node& node);
- void replace_data_with_optimized(std::map<primitive_id, memory_impl::ptr> const& replace_map);
- void dump_program(const char* stage, bool with_full_info, std::function<bool(program_node const&)> const& filter = nullptr) const;
- //Dumps weights and biasses in serialization process, not working yet, in progress.
- void dump_weights_and_biasses(std::vector<unsigned long long>& offsets, std::vector<std::string>& data_names, std::ofstream& file_stream) const;
- //Makes serialization with given name.
- //Placeholder, not working yet, in progress.
- void serialize(std::string network_name, std::function<bool(program_node const&)> const& filter = nullptr) const;
-
- template <typename T>
- void optimize_bias(T& node, layout_optimizer& lo);
-
- template <typename T>
- void optimize_weights(T& node, layout_optimizer& lo);
-
- template <typename T>
- void optimize_depthwise_sep_pre(T& node);
-
- template <typename T>
- void optimize_depthwise_sep_post(T& node);
+ void replace(program_node& old_node, program_node& new_node);
};
+
}
API_CAST(::cldnn_program, cldnn::program_impl)