Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / primitive_inst.cpp
index 32c7861..30ff836 100644 (file)
@@ -22,6 +22,7 @@
 #include "input_layout_inst.h"
 #include "max_unpooling_inst.h"
 #include "apply_adam_inst.h"
+#include "fused_conv_eltwise_inst.h"
 
 #include "network_impl.h"
 #include "engine_impl.h"
@@ -40,11 +41,12 @@ uint32_t primitive_inst::get_network_id() const
 
 event_impl::ptr primitive_inst::execute(const std::vector<event_impl::ptr>& events)
 {
-    CLDNN_ERROR_BOOL(id(), "Invalid/unset input", !_has_valid_input, "Cannot execute primitive " + id() + " with invalid/unset input");
+    const auto primitive_id = id();
+    CLDNN_ERROR_BOOL(primitive_id, "Invalid/unset input", !_has_valid_input, "Cannot execute primitive " + primitive_id + " with invalid/unset input");
     on_execute();
 
     if (_exec_deps.size() == 0)
-       return _impl->execute(events, *this);      
+        return _impl->execute(events, *this);
 
     std::vector<event_impl::ptr> dependencies;
     dependencies.reserve(_exec_deps.size());
@@ -53,15 +55,15 @@ event_impl::ptr primitive_inst::execute(const std::vector<event_impl::ptr>& even
         auto id = input->id();
         try {
             // if the requested event deos not exits it means that it has not been executed, so the processing_order is wrong or synchronization failed.
-            auto ev = get_network().get_primitive_event(id); 
+            auto ev = get_network().get_primitive_event(id);
             dependencies.emplace_back(ev);
-        }
+            }
         catch (const std::out_of_range& oor) {
-            std::string temp = std::string("internal CLDNN error: execution order corrupted.") + std::string("\n") +  std::string(oor.what() + std::string("\n"));
+            std::string temp = std::string("internal CLDNN error: execution order corrupted.") + std::string("\n") + std::string(oor.what() + std::string("\n"));
             CLDNN_ERROR_MESSAGE(id, temp);
         }
     }
-    return _impl->execute(dependencies, *this);  
+    return _impl->execute(dependencies, *this);
 }
 
 void primitive_inst::build_deps()
@@ -95,6 +97,16 @@ primitive_inst::primitive_inst(network_impl& network, program_node const& node,
             //For certain primitives, it is known which dependency is used for synchronization only
             else if (user->is_type<apply_adam>() && (user->as<apply_adam>().has_additional_dep()) && (user->as<apply_adam>().additional_dep().id() == node.id()))
                 user_count--;
+            else if (user->is_type<fused_conv_eltwise>())
+            {
+                if ((*user->as<fused_conv_eltwise>().get_users().begin())->is_type<mutable_data>())
+                {
+                    if (user->as<fused_conv_eltwise>().get_dependency(1).id() == node.id())
+                    {
+                        user_count--;
+                    }
+                }
+            } 
         }
 
         if (user_count == 1 && mutable_data_count == 1)
@@ -119,15 +131,9 @@ memory_impl::ptr primitive_inst::allocate_output()
         return get_network().get_engine().allocate_memory(layout, _node.id(), get_network_id(), _node.get_memory_dependencies(), false);
     }
     else if (_network.is_internal() ||
-        _node.is_type<data>() ||
-        _node.is_type<mutable_data>() ||
-        _node.is_type<input_layout>() ||
-        //for max_unpooling initial zero values are significant
-        _node.is_type<max_unpooling>() ||
-        //apply adam's output initial val should be either 0 or use same buffer as mutable_data after it (no allocation needed)
-        _node.is_type<apply_adam>() ||
-        _node.can_be_optimized() ||
-        _node.is_output())
+             (!_node.can_share_buffer()) ||
+             _node.can_be_optimized() ||
+            _node.is_output())
     {
         return get_network().get_engine().allocate_memory(layout);
     }