Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / batch_norm.cpp
index 0aea3e6..2b972b1 100644 (file)
@@ -18,6 +18,7 @@
 #include "primitive_type_base.h"
 #include "error_handler.h"
 #include "json_object.h"
+#include "mutable_data_inst.h"
 
 namespace cldnn
 {
@@ -29,40 +30,43 @@ primitive_type_id batch_norm_type_id()
 
 layout batch_norm_inst::calc_output_layout(batch_norm_node const& node)
 {
+    assert((bool)node.get_primitive()->output_data_type == false
+           && "Output data type forcing is not supported for batch_norm_node!");
     return node.input().get_non_padded_output_layout();
 }
 
 std::string batch_norm_inst::to_string(batch_norm_node const& node)
 {
-    auto desc      = node.get_primitive();
-    auto node_info = node.desc_to_json();
-    auto& mean     = node.mean();
     bool variance_term = node.variance_term();
-    auto& inv_var  = node.inv_variance();
 
     std::stringstream primitive_description;
-
     json_composite batch_norm_info;
     if (node.use_global_stats())
     {
-        batch_norm_info.add("mean_id", mean.id());
+        batch_norm_info.add("mean_id", node.mean().id());
         if (variance_term)
         {
             batch_norm_info.add("variance_id", node.variance().id());
         }
     }
+    if (node.use_scale_shift())
+    {
+        batch_norm_info.add("scale_id", node.scale().id());
+        batch_norm_info.add("shift_id", node.shift().id());
+    }
     if (node.forwad_pass())
     {
-        batch_norm_info.add("inv_var", inv_var.id());
+        batch_norm_info.add("inv_var", node.inv_variance().id());
     }
-    batch_norm_info.add("epsilon", desc->epsilon);
+    batch_norm_info.add("epsilon", node.get_primitive()->epsilon);
 
-    node_info->add("batch norm info", batch_norm_info);
-    node_info->dump(primitive_description);
+    node.desc_to_json()->add("batch norm info", batch_norm_info);
+    node.desc_to_json()->dump(primitive_description);
 
     return primitive_description.str();
 }
 
+
 batch_norm_inst::typed_primitive_inst(network_impl& network, batch_norm_node const& node)
     :parent(network, node) 
 {
@@ -71,8 +75,27 @@ batch_norm_inst::typed_primitive_inst(network_impl& network, batch_norm_node con
         auto mean_format = node.mean().get_output_layout().format;
         auto variance_format = node.variance().get_output_layout().format;
 
-        CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Mean format", mean_format.value, "supported mean formats", format::yxfb, format::bfyx);
-        CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Variance format", variance_format.value, "supported variance formats", format::yxfb, format::bfyx);
+        CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Mean format", mean_format.value, "supported mean formats", format::yxfb, format::bfyx, format::byxf);
+        CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Variance format", variance_format.value, "supported variance formats", format::yxfb, format::bfyx, format::byxf);
+
+               auto is_mean_mutable_data = node.mean().is_type<mutable_data>();
+               auto is_var_mutable_data = node.variance().is_type<mutable_data>();
+
+               CLDNN_ERROR_BOOL(node.id(), "mean and variance are not the same type", (is_mean_mutable_data != is_var_mutable_data), "");
     }
+
+       if (use_scale_shift()) {
+               auto scale_format = node.scale().get_output_layout().format;
+               auto shift_format = node.shift().get_output_layout().format;
+
+               CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Scale format", scale_format.value, "supported scale formats", format::yxfb, format::bfyx, format::byxf);
+               CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Shift format", shift_format.value, "supported shift formats", format::yxfb, format::bfyx, format::byxf);
+       }
+
+       if (forwad_pass())
+       {
+               auto is_inv_var_mutable_data = node.inv_variance().is_type<mutable_data>();
+               CLDNN_ERROR_BOOL(node.id(), "inv_variance is not mutable_data type", !is_inv_var_mutable_data, "");
+       }
+}
 }
-}
\ No newline at end of file