jit.AddConstant(MakeJitConstant("EPSILON", params.batchNormParams.epsilon));
if (params.batchNormParams.with_inv_var)
jit.AddConstant(MakeJitConstant("FORWARD", 1));
+ if (params.batchNormParams.with_scale_shift)
+ jit.AddConstant(MakeJitConstant("SCALE_SHIFT", 1));
+ if (params.batchNormParams.with_mean_var_out)
+ jit.AddConstant(MakeJitConstant("MEAN_VAR_OUT", 1));
return jit;
}
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
auto& kernel = kd.kernels[0];
- int inputs_num = 1 + orgParams.batchNormParams.with_inv_var;
+ int inputs_num = 1 + orgParams.batchNormParams.with_inv_var + 2*orgParams.batchNormParams.with_scale_shift + 2 * orgParams.batchNormParams.with_mean_var_out;
FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point, "", false, false, inputs_num);
kd.estimatedTime = estimatedTime;