Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / gpu / batch_norm_gpu.cpp
index 8adb888..f5364ad 100644 (file)
@@ -37,17 +37,20 @@ protected:
     {
         kernel::kernel_arguments_data args;
 
-        
-        if (!instance.use_global_stats())
-        {
-            args.inputs = { &instance.input_memory() };
-            if (instance.forwad_pass())
-                args.inputs.push_back(&instance.inv_variance_memory());
-        }
-        else
-        {
-            args.inputs = { &instance.input_memory(), &instance.mean_memory(), &instance.variance_memory() };
-        }
+               args.inputs = { &instance.input_memory() };
+
+               if (instance.use_global_stats()) {
+                       args.inputs.push_back(&instance.mean_memory());
+                       args.inputs.push_back(&instance.variance_memory());
+               }
+
+               if (instance.use_scale_shift()) {
+                       args.inputs.push_back(&instance.scale_memory());
+                       args.inputs.push_back(&instance.shift_memory());
+               }
+
+               if (instance.forwad_pass())
+                       args.inputs.push_back(&instance.inv_variance_memory());
 
         args.output = &instance.output_memory();
 
@@ -58,13 +61,17 @@ public:
 
     static primitive_impl* create(const batch_norm_node &arg) 
     { 
-        if (!arg.use_global_stats())
+        if (!arg.use_global_stats()
+                       || arg.calc_mean_var() )
         {
             auto norm_params = get_default_params<kernel_selector::batch_norm_params>(arg);
             auto norm_optional_params = get_default_optional_params<kernel_selector::batch_norm_optional_params>(arg.get_program());
 
             norm_params.batchNormParams.epsilon = arg.get_primitive()->epsilon;
             norm_params.batchNormParams.with_inv_var = arg.forwad_pass();
+                       norm_params.batchNormParams.with_scale_shift = arg.use_scale_shift();
+            if (arg.calc_mean_var())
+                           norm_params.batchNormParams.with_mean_var_out = arg.calc_mean_var();
 
             auto& kernel_selector = kernel_selector::batch_norm_kernel_selector::Instance();
             auto best_kernels = kernel_selector.GetBestKernels(norm_params, norm_optional_params);
@@ -86,7 +93,7 @@ public:
 
             ew_params.inputs.push_back(convert_data_tensor(arg.mean().get_output_layout()));
             ew_params.inputs.push_back(convert_data_tensor(arg.variance().get_output_layout()));
-
+                       
             ew_params.operations.push_back({
                 { kernel_selector::eltwise_params::InputType::Buffer(0), kernel_selector::eltwise_params::InputType::Buffer(1) },
                 kernel_selector::eltwise_mode::SUB });
@@ -103,6 +110,19 @@ public:
                 { kernel_selector::eltwise_params::InputType::Intermediate(0), kernel_selector::eltwise_params::InputType::Intermediate(2) },
                 kernel_selector::eltwise_mode::MUL });
 
+                       if (arg.use_scale_shift()) {
+                               ew_params.inputs.push_back(convert_data_tensor(arg.scale().get_output_layout()));
+                               ew_params.inputs.push_back(convert_data_tensor(arg.shift().get_output_layout()));
+
+                               ew_params.operations.push_back({
+                                       { kernel_selector::eltwise_params::InputType::Intermediate(3), kernel_selector::eltwise_params::InputType::Buffer(3) },
+                                       kernel_selector::eltwise_mode::MUL });
+
+                               ew_params.operations.push_back({
+                                       { kernel_selector::eltwise_params::InputType::Intermediate(4), kernel_selector::eltwise_params::InputType::Buffer(4) },
+                                       kernel_selector::eltwise_mode::ADD });
+                       }
+
             ew_params.layoutBased = true;
 
             auto& kernel_selector = kernel_selector::eltwise_kernel_selector::Instance();