Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / batch_norm / batch_norm_kernel_base.cpp
index ebf881f..064d8a5 100644 (file)
@@ -36,6 +36,10 @@ namespace kernel_selector
         jit.AddConstant(MakeJitConstant("EPSILON", params.batchNormParams.epsilon));
         if (params.batchNormParams.with_inv_var)
             jit.AddConstant(MakeJitConstant("FORWARD", 1));
+               if (params.batchNormParams.with_scale_shift)
+                       jit.AddConstant(MakeJitConstant("SCALE_SHIFT", 1));
+               if (params.batchNormParams.with_mean_var_out)
+                       jit.AddConstant(MakeJitConstant("MEAN_VAR_OUT", 1));
 
         return jit;
     }
@@ -79,7 +83,7 @@ namespace kernel_selector
         auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
 
         auto& kernel = kd.kernels[0];
-        int inputs_num = 1 + orgParams.batchNormParams.with_inv_var;
+        int inputs_num = 1 + orgParams.batchNormParams.with_inv_var + 2*orgParams.batchNormParams.with_scale_shift + 2 * orgParams.batchNormParams.with_mean_var_out;
         FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point, "", false, false, inputs_num);
 
         kd.estimatedTime = estimatedTime;