Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / include / program_impl.h
index c3cb673..c518d9c 100644 (file)
@@ -1,5 +1,5 @@
 /*
-// Copyright (c) 2016 Intel Corporation
+// Copyright (c) 2019 Intel Corporation
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
 */
 
 ///////////////////////////////////////////////////////////////////////////////////////////////////
+
 #pragma once
+
 #include "api/CPP/program.hpp"
 
 #include "refcounted_obj.h"
-#include "topology_impl.h"
 #include "engine_impl.h"
-#include "program_node.h"
-#include "memory_impl.h"
 
 #include <list>
-#include <algorithm>
 
 namespace cldnn
 {
 
+struct topology_impl;
 struct primitive_impl;
+struct program_node;
 class layout_optimizer;
-class constants_propagator;
-
+class pass_manager;
+class program_impl_wrapper;
 /*
     cldnn_program implementation
 */
 struct program_impl : public refcounted_obj<program_impl>
 {
-    friend struct program_node;
-
+    friend class calculate_prior_boxes;             // to be removed when possible
+    friend class graph_initializations;             // to be removed when possible
+    friend class prepare_padding;                   // to be removed when possible
+    friend class propagate_constants;               // to be removed when possible
+    friend class prepare_primitive_fusing;          // to be removed when possible
+    friend class prepare_conv_eltw_fusing;          // to be removed when possible
+    friend class prepare_conv_eltw_read_write_opt;  // to be removed when possible
+    friend class reorder_inputs;                    // to be removed when possible
+    friend class program_impl_wrapper;              // this class is intended to extend the interface of program_impl for 
+                                                    // the usage within tests_core_internal project only
 public:
-    program_impl(engine_impl& engine_ref, topology_impl const& topology, build_options const& options, bool is_internal);
-
-    void dump_memory_pool() const;
-
-    engine_impl& get_engine() const { return *engine; }
-    build_options get_options() const { return options; }
-    bool is_debug_build() const { return options.get<build_option_type::debug>()->enabled(); }
-
-    std::list<std::shared_ptr<program_node>> get_nodes() const;
-    std::list<program_node*> get_processing_order() const { return processing_order; }
-    std::list<primitive_id> get_optimized_out() const { return optimized_out; }
-    program_node& get_node(primitive_id const& id)
+    struct nodes_ordering
     {
-        try 
+    public:
+        typedef std::list<program_node*> list_of_nodes;
+        typedef list_of_nodes::const_iterator const_iterator;
+        typedef list_of_nodes::iterator node_iterator;
+        const_iterator begin() const { return _processing_order.begin(); }
+        const_iterator end() const { return _processing_order.end(); }
+
+        void calc_processing_order_visit(program_node* node);
+        void calc_processing_order(program_impl& p);
+        int32_t get_processing_number(program_node* node) const { return get_processing_number(get_processing_iterator(*node)); }
+        // int32_t get_processing_number(const_iterator iter) const { return 1+(int32_t)std::distance(begin(), iter); }
+        int32_t get_processing_number(node_iterator iter) const { return 1 + (int32_t)std::distance(_processing_order.begin(), const_iterator(iter)); }
+        void calculate_BFS_processing_order();
+        size_t size() { return _processing_order.size(); }
+        bool is_correct(program_node* node);
+        
+        node_iterator get_processing_iterator(program_node& node) const
         {
-            return *nodes_map.at(id);
+            return processing_order_iterators.at(&node);
         }
-        catch (...)
+        void clear()
         {
-            throw std::runtime_error("Program doesn't contain primtive node: " + id);
+            processing_order_iterators.clear();
+            _processing_order.clear();
         }
-    }
 
-    bool has_node(const primitive_id& prim) const
-    {
-        return nodes_map.count(prim) > 0;
-    }
+        void insert(program_node* key_node, program_node* node)
+        {
+            node_iterator _where = processing_order_iterators.at(key_node);
+            processing_order_iterators[node] = _processing_order.insert(_where, node);
+        }
 
-    program_node const& get_node(primitive_id const& id) const
-    {
-        try
+        void insert_next(program_node* key_node, program_node* node)
         {
-            return *nodes_map.at(id);
+            node_iterator _where = std::next(processing_order_iterators.at(key_node));
+            processing_order_iterators[node] = _processing_order.insert(_where, node);
         }
-        catch (...)
+
+        void erase(program_node* key_node)
         {
-            throw std::runtime_error("Program doesn't contain primtive node: " + id);
+            node_iterator i = processing_order_iterators.at(key_node);
+            processing_order_iterators.erase(key_node);
+            _processing_order.erase(i);
         }
-    }
+
+    private:
+        list_of_nodes _processing_order;
+        std::map<program_node*, node_iterator> processing_order_iterators;
+    };
+
+    template <class T>
+    struct single_element_container
+    {
+        single_element_container(T& t) : elem(&t)
+        {}
+        constexpr size_t size() const { return 1; }
+        single_element_container begin() const { return single_element_container(elem); }
+        single_element_container end() const { return single_element_container(nullptr); }
+        single_element_container& operator ++() { elem = nullptr; return *this; }
+        bool operator !=(single_element_container const& sec) { return elem != sec.elem; }
+
+       T operator *() { return *elem; }
+
+    private:
+        single_element_container(T* t) : elem(t)
+        {}
+
+        T* elem;
+    };
+    program_impl(engine_impl& engine_ref, topology_impl const& topology, build_options const& options, bool is_internal, bool no_optimizations=false);
+    /* constructor used to build a program from subset of nodes of other program (used in propagate_constants) */
+    program_impl(engine_impl& engine_ref, std::set<std::shared_ptr<program_node>> const &nodes, build_options const& options, bool is_internal);
+    ~program_impl();
+    engine_impl& get_engine() const { return *engine; }
+    const build_options& get_options() const { return options; }
+    std::list<program_node*>& get_inputs() { return inputs; }     // ToDo: redesign trim to ouptut pass to make it const as_well as get_engine and get options 
+    std::vector<program_node*>& get_outputs() { return outputs; }  // ToDo: redesign reorder-inputs pass to make it const as_well as get_engine and get options 
+    bool is_debug_build() const { return options.get<build_option_type::debug>()->enabled(); }
+    const nodes_ordering& get_processing_order() const;
+    nodes_ordering& get_processing_order();
+    const std::list<primitive_id>& get_optimized_out() const { return optimized_out; }
+    bool has_node(const primitive_id& prim) const { return nodes_map.count(prim) > 0; }
+    program_node& get_node(primitive_id const& id);
+    program_node const& get_node(primitive_id const& id) const;
+    std::shared_ptr<program_node> get_node_ptr(const primitive_id& prim) { return nodes_map.at(prim);  }
+    std::shared_ptr<program_node> get_node_ptr(const primitive_id& prim) const { return nodes_map.at(prim); }
+    void dump_memory_pool() const;
+
+    //returns already existing program_node for given primitive 'prim' (lookup in 'nodes_map')
+    //if it was previously created, otherwise creates and then returns program_node
+    program_node& get_or_create(std::shared_ptr<primitive> prim);
+
+    // Inserts given program_node 'node' as an intermediate node between 'next' and it's
+    //  dependency at 'prev_idx' index.
+    void add_intermediate(program_node& node, program_node& next, size_t prev_idx,
+                          bool connect_int_node_with_old_dep = true,
+                          bool move_usrs_of_prev_to_node = false);
+
+    // Gets or creates program_node for given primitive 'prim' and inserts it as an intermediate
+    // node between 'next' and it's dependency at 'prev_idx' index.
+    void add_intermediate(std::shared_ptr<primitive> prim, program_node& next, size_t prev_idx, 
+        bool connect_int_node_with_old_dep = true,
+        bool move_usrs_of_prev_to_node = false);
+
+    //removes a node from the graph and deletes it afterwards,
+    //prereq: node cannot be marked as output and has to have exactly one dependency
+    //returns if 'node' has been extracted and removed successfully
+    bool extract_and_remove(program_node& node);
+
+    //returns if 'node' has been removed
+    bool remove_if_dangling(program_node& node);
+
+    void mark_if_constant(program_node& node);
+    // mark if the node is in data flow assuming that all dependencies are marked properly
+    void mark_if_data_flow(program_node& node);
+    //Reverses connection - user becomes dependency.
+
+    void remove_nodes(std::list<program_node*>& to_remove);
+    void dump_program(const char* stage, bool with_full_info, std::function<bool(program_node const&)> const& filter = nullptr) const;
 
 private:
     uint32_t prog_id = 0;
-
     engine_impl::ptr engine;
     build_options options;
-
     std::list<program_node*> inputs;
     std::vector<program_node*> outputs;
-    std::list<program_node*> processing_order;
+    nodes_ordering processing_order;
+    std::unique_ptr<pass_manager> pm;
 
     std::map<primitive_id, std::shared_ptr<program_node>> nodes_map;
-
     std::list<primitive_id> optimized_out;
 
-    // TODO: Remove once we will get full support for input/output padding in all primitive implementations.
-    bool output_size_handling_enabled;
-
     /*
     ** High-level functions, in order of usage
     */
-    void init_graph(topology_impl const& topology);
-    void pre_optimize_graph();
-    void post_optimize_graph();
-    void compile_graph();
+    /* build nodes internal structure based on topology */
+    void prepare_nodes(topology_impl const& topology);
+    /* build nodes internal structure based on the subset of nodes of other program  (used in propagate_constants) */
+    void prepare_nodes(std::set<std::shared_ptr<program_node>> const& nodes);
+    void add_node_dependencies(program_node* node_ptr);
+    void copy_node_dependencies(program_node* dest, program_node* src);
+    void build_program(bool is_internal);
+    void init_graph();
+    void set_options();
+
+    void run_graph_compilation();
+    void pre_optimize_graph(bool is_internal);
+    void post_optimize_graph(bool is_internal);
     void cleanup();
 
     /*
-    ** Initialization functions
-    */
-    void set_outputs();
-    void calc_processing_order();
-    void calc_prior_boxes();
-
-    /*
     ** Analysis functions
     */
-    void mark_constants();
-    void mark_data_flow();
     // TODO: Remove once we will get full support for input/output padding in all primitive implementations.
-    void analyze_output_size_handling_need();
-    void replace_nodes_pre();
-    void replace_nodes_post();
-       void handle_lstm();
+    bool analyze_output_size_handling_need();
+
+    // handle split, deconvolution and upsampling
     void handle_reshape();
 
     /*
     ** Optimization functions
     */
-    void trim_to_outputs();
-    void remove_redundant_reorders();
-    void calculate_BFS_processing_order();
-    void reorder_inputs(layout_optimizer& lo);
-    void pre_optimize_bias(layout_optimizer& lo);
-    void post_optimize_weights(layout_optimizer& lo);
     void apply_needed_padding(program_node& node, program_node& prev_node, const padding& needed_padding);
-    void prepare_padding();
-    void propagate_constants();
-    void prepare_buffer_fusing();
-    void fuse_skip_layers(program_node* node);
-    void prepare_primitive_fusing();
-    void prepare_depthwise_sep_opt();
-    void prep_opt_depthwise_sep_post();
-    void update_processing_numbers();
 
     /*
     ** Memory pool functions
@@ -158,57 +228,15 @@ private:
     /*
     ** Utilities
     */
+    void add_split_outputs();
+    // mark if the node is constant assuming that all dependencies are marked properly
+    void reverse_connection(program_node& dep_node, program_node& user_node);
 
-    //returns already existing program_node for given primitive 'prim' (lookup in 'nodes_map')
-    //if it was previously created, otherwise creates and then returns program_node
-    program_node& get_or_create(std::shared_ptr<primitive> prim);
-
-    // Inserts given program_node 'node' as an intermediate node between 'next' and it's
-    //  dependency at 'prev_idx' index.
-    void add_intermediate(program_node& node, program_node& next, size_t prev_idx, bool connect_int_node_with_old_dep = true);
-
-    // Gets or creates program_node for given primitive 'prim' and inserts it as an intermediate
-    // node between 'next' and it's dependency at 'prev_idx' index.
-    void add_intermediate(std::shared_ptr<primitive> prim, program_node& next, size_t prev_idx, bool connect_int_node_with_old_dep = true)
-    {
-        add_intermediate(get_or_create(prim), next, prev_idx, connect_int_node_with_old_dep);
-    }
-
-    void add_connection(program_node& prev, program_node& next)
-    {
-        prev.users.push_back(&next);
-        next.dependencies.push_back(&prev);
-    }
+    void add_connection(program_node& prev, program_node& next);
 
-    void remove_connection(program_node& prev, program_node& next)
-    {
-        prev.users.remove(&next);
-        next.dependencies.erase(std::remove(next.dependencies.begin(), next.dependencies.end(), &prev), next.dependencies.end());
-    }
-
-    void remove_all_connections(program_node& node) {
-        // since the graph is not topological sorted, we need to remove the node from both dependencies and users
-        for (auto &e : node.users) {
-            e->dependencies.erase(std::remove(e->dependencies.begin(), e->dependencies.end(), &node), e->dependencies.end());
-        }
-        for(auto &e : node.dependencies) {
-            e->users.remove(&node);
-        }
-        node.dependencies.clear();
-               node.users.clear();
-    }
+    void remove_connection(program_node& prev, program_node& next);
 
-    bool processing_order_is_correct(program_node* node)
-    {
-        for (auto& dep : node->get_dependencies())
-        {
-            if (node->processing_num < dep->processing_num)
-            {
-                return false;
-            }
-        }
-        return true;
-    }
+    void remove_all_connections(program_node& node);
 
     void rename(program_node & node, primitive_id const & new_id);
     void swap_names(program_node& node1, program_node& node2);
@@ -216,37 +244,9 @@ private:
 
     //old_node - node which will be replaced
     //new_node - node which will replace the old one
-    //replace_whole_branch - if set to true, 'old_node' will be replaced with all its dependencies and new_node will retain its dependencies
-    //  old's dependencies which are post-dominates by 'old_node' will also be removed
-    void replace(program_node& old_node, program_node& new_node, bool replace_whole_branch, bool check_output_layouts_integrity = true);
-
-    //returns if 'node' has been removed
-    bool remove_if_dangling(program_node& node, bool detach_whole_branch = false);
-
-    //removes a node from the graph and deletes it afterwards,
-    //prereq: node cannot be marked as output and has to have exactly one dependency
-    //returns if 'node' has been extracted and removed successfully
-    bool extract_and_remove(program_node& node);
-    void replace_data_with_optimized(std::map<primitive_id, memory_impl::ptr> const& replace_map);
-    void dump_program(const char* stage, bool with_full_info, std::function<bool(program_node const&)> const& filter = nullptr) const;
-    //Dumps weights and biasses in serialization process, not working yet, in progress.
-    void dump_weights_and_biasses(std::vector<unsigned long long>& offsets, std::vector<std::string>& data_names, std::ofstream& file_stream) const;
-    //Makes serialization with given name.
-    //Placeholder, not working yet, in progress.
-    void serialize(std::string network_name, std::function<bool(program_node const&)> const& filter = nullptr) const;
-
-    template <typename T>
-    void optimize_bias(T& node, layout_optimizer& lo);
-
-    template <typename T>
-    void optimize_weights(T& node, layout_optimizer& lo);
-
-    template <typename T>
-    void optimize_depthwise_sep_pre(T& node);
-
-    template <typename T>
-    void optimize_depthwise_sep_post(T& node);
+    void replace(program_node& old_node, program_node& new_node);
 };
+
 }
 
 API_CAST(::cldnn_program, cldnn::program_impl)