{
kernel::kernel_arguments_data args;
-
- if (!instance.use_global_stats())
- {
- args.inputs = { &instance.input_memory() };
- if (instance.forwad_pass())
- args.inputs.push_back(&instance.inv_variance_memory());
- }
- else
- {
- args.inputs = { &instance.input_memory(), &instance.mean_memory(), &instance.variance_memory() };
- }
+ args.inputs = { &instance.input_memory() };
+
+ if (instance.use_global_stats()) {
+ args.inputs.push_back(&instance.mean_memory());
+ args.inputs.push_back(&instance.variance_memory());
+ }
+
+ if (instance.use_scale_shift()) {
+ args.inputs.push_back(&instance.scale_memory());
+ args.inputs.push_back(&instance.shift_memory());
+ }
+
+ if (instance.forwad_pass())
+ args.inputs.push_back(&instance.inv_variance_memory());
args.output = &instance.output_memory();
static primitive_impl* create(const batch_norm_node &arg)
{
- if (!arg.use_global_stats())
+ if (!arg.use_global_stats()
+ || arg.calc_mean_var() )
{
auto norm_params = get_default_params<kernel_selector::batch_norm_params>(arg);
auto norm_optional_params = get_default_optional_params<kernel_selector::batch_norm_optional_params>(arg.get_program());
norm_params.batchNormParams.epsilon = arg.get_primitive()->epsilon;
norm_params.batchNormParams.with_inv_var = arg.forwad_pass();
+ norm_params.batchNormParams.with_scale_shift = arg.use_scale_shift();
+ if (arg.calc_mean_var())
+ norm_params.batchNormParams.with_mean_var_out = arg.calc_mean_var();
auto& kernel_selector = kernel_selector::batch_norm_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(norm_params, norm_optional_params);
ew_params.inputs.push_back(convert_data_tensor(arg.mean().get_output_layout()));
ew_params.inputs.push_back(convert_data_tensor(arg.variance().get_output_layout()));
-
+
ew_params.operations.push_back({
{ kernel_selector::eltwise_params::InputType::Buffer(0), kernel_selector::eltwise_params::InputType::Buffer(1) },
kernel_selector::eltwise_mode::SUB });
{ kernel_selector::eltwise_params::InputType::Intermediate(0), kernel_selector::eltwise_params::InputType::Intermediate(2) },
kernel_selector::eltwise_mode::MUL });
+ if (arg.use_scale_shift()) {
+ ew_params.inputs.push_back(convert_data_tensor(arg.scale().get_output_layout()));
+ ew_params.inputs.push_back(convert_data_tensor(arg.shift().get_output_layout()));
+
+ ew_params.operations.push_back({
+ { kernel_selector::eltwise_params::InputType::Intermediate(3), kernel_selector::eltwise_params::InputType::Buffer(3) },
+ kernel_selector::eltwise_mode::MUL });
+
+ ew_params.operations.push_back({
+ { kernel_selector::eltwise_params::InputType::Intermediate(4), kernel_selector::eltwise_params::InputType::Buffer(4) },
+ kernel_selector::eltwise_mode::ADD });
+ }
+
ew_params.layoutBased = true;
auto& kernel_selector = kernel_selector::eltwise_kernel_selector::Instance();