Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / cl_kernels / batch_norm_gpu_ref.cl
index 7fe1a8a..aaf60c3 100644 (file)
 __attribute__((reqd_work_group_size(LOCAL_SIZE, 1, 1)))
 KERNEL(batch_norm_gpu)(
     const __global UNIT_TYPE* input,
-#ifdef FORWARD
-     __global UNIT_TYPE* inv_var,
-#endif
+       #ifdef MEAN_VAR_OUT
+               __global UNIT_TYPE* mean_out,
+               __global UNIT_TYPE* variance_out,
+       #endif
+       #ifdef SCALE_SHIFT
+            __global UNIT_TYPE* scale,
+                __global UNIT_TYPE* shift,
+       #endif
+       #ifdef FORWARD
+               __global UNIT_TYPE* inv_var,
+       #endif
        __global UNIT_TYPE* output)
 {
     __local ACCUMULATOR_TYPE sum[LOCAL_SIZE];
@@ -56,7 +64,9 @@ KERNEL(batch_norm_gpu)(
     }
 
     UNIT_TYPE mean = sum[0] / (OUTPUT_BATCH_NUM * OUTPUT_SIZE_X * OUTPUT_SIZE_Y);
-
+#ifdef MEAN_VAR_OUT
+               mean_out[f] = mean;
+#endif
     sum[local_idx] = 0;
 
     input_idx = GET_DATA_INDEX(INPUT0, local_idx, f, 0, 0);
@@ -83,7 +93,9 @@ KERNEL(batch_norm_gpu)(
     }
 
     float variance = sum[0] / (OUTPUT_BATCH_NUM * OUTPUT_SIZE_X * OUTPUT_SIZE_Y);
-
+#ifdef MEAN_VAR_OUT
+       variance_out[f] = variance;
+#endif
     float inv_variance = (float)(1.0 / sqrt(variance + EPSILON));
 #ifdef FORWARD
     if (local_idx == 0)
@@ -95,9 +107,15 @@ KERNEL(batch_norm_gpu)(
     {
         for (uint x = 0; x < OUTPUT_SIZE_X; x++)
         {
-            output[out_idx] = inv_variance * (input[out_idx] - mean);
+                       #ifdef SCALE_SHIFT
+                               output[out_idx] = (inv_variance * (input[out_idx] - mean)) * scale[f] + shift[f];
+                       #else
+                               output[out_idx] = inv_variance * (input[out_idx] - mean);
+                       #endif
             out_idx += OUTPUT_X_PITCH;
         }
         out_idx += OUTPUT_Y_PITCH - OUTPUT_SIZE_X * OUTPUT_X_PITCH;
     }
-}
\ No newline at end of file
+}
+
+#undef LOCAL_SIZE
\ No newline at end of file