Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / include / batch_norm_inst.h
index 9569527..175dc8e 100644 (file)
@@ -18,6 +18,7 @@
 #pragma once
 #include "api/CPP/batch_norm.hpp"
 #include "primitive_inst.h"
+#include "mutable_data_inst.h"
 
 namespace cldnn
 {
@@ -33,10 +34,34 @@ public:
     program_node& input() const { return get_dependency(0); }
     program_node& mean() const { return get_dependency(1); }
     program_node& variance() const { return get_dependency(2); }
-    program_node& inv_variance() const { return get_dependency(1); };
+    program_node& scale() const 
+       { 
+               if(get_dependencies().size() >= 5)
+                       return get_dependency(3); 
+               else
+                       return get_dependency(1);
+       }
+    program_node& shift() const 
+       { 
+               if (get_dependencies().size() >= 5)
+                       return get_dependency(4); 
+               else
+                       return get_dependency(2);
+       }
+    program_node& inv_variance() const 
+       { 
+               if (get_dependencies().size() == 2)
+                       return get_dependency(1);
+               else if (get_dependencies().size() == 6)
+                       return get_dependency(5);
+               else
+                       return get_dependency(3);
+       };
     bool variance_term() const { return !get_primitive()->variance.empty(); }
     bool use_global_stats() const { return !get_primitive()->mean.empty() && !get_primitive()->variance.empty(); };
+    bool use_scale_shift() const { return !get_primitive()->scale.empty() && !get_primitive()->shift.empty(); };
     bool forwad_pass() const { return !get_primitive()->inv_variance.empty(); };
+    bool calc_mean_var() const { return (use_global_stats() && mean().is_type<mutable_data>() && variance().is_type<mutable_data>()); };
 
 };
 
@@ -56,9 +81,33 @@ public:
 
     memory_impl& mean_memory() const { return dep_memory(1); }
     memory_impl& variance_memory() const { return dep_memory(2); }
-    memory_impl& inv_variance_memory() const { return dep_memory(1); };
+    memory_impl& scale_memory() const 
+       {
+               if (dependencies().size() >= 5)
+                       return dep_memory(3);
+               else
+                       return dep_memory(1);
+       }
+    memory_impl& shift_memory() const 
+       {
+               if (dependencies().size() >= 5)
+                       return dep_memory(4);
+               else
+                       return dep_memory(2);
+       }
+    memory_impl& inv_variance_memory() const
+       {
+               if (dependencies().size() == 2)
+                       return dep_memory(1);
+               else if (dependencies().size() == 6)
+                       return dep_memory(5);
+               else
+                       return dep_memory(3);
+       };
     bool use_global_stats() const { return !argument.mean.empty() && !argument.variance.empty(); };
+    bool use_scale_shift() const { return !argument.scale.empty() && !argument.scale.empty(); };
     bool forwad_pass() const { return !argument.inv_variance.empty(); };
+    bool calc_mean_var() const { return node.calc_mean_var(); };
 };
 
 using batch_norm_inst = typed_primitive_inst<batch_norm>;