2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "batch_norm_inst.h"
18 #include "primitive_gpu_base.h"
19 #include "implementation_map.h"
20 #include "error_handler.h"
21 #include "kernel_selector_helper.h"
22 #include "batch_norm/batch_norm_kernel_base.h"
23 #include "batch_norm/batch_norm_kernel_selector.h"
24 #include "eltwise/eltwise_kernel_selector.h"
25 #include "eltwise/eltwise_kernel_base.h"
27 namespace cldnn { namespace gpu {
29 struct batch_norm_gpu : typed_primitive_gpu_impl<batch_norm>
31 using parent = typed_primitive_gpu_impl<batch_norm>;
36 virtual kernel::kernel_arguments_data get_arguments(typed_primitive_inst<batch_norm>& instance, int32_t) const override
38 kernel::kernel_arguments_data args;
40 args.inputs = { &instance.input_memory() };
42 if (instance.use_global_stats()) {
43 args.inputs.push_back(&instance.mean_memory());
44 args.inputs.push_back(&instance.variance_memory());
47 if (instance.use_scale_shift()) {
48 args.inputs.push_back(&instance.scale_memory());
49 args.inputs.push_back(&instance.shift_memory());
52 if (instance.forwad_pass())
53 args.inputs.push_back(&instance.inv_variance_memory());
55 args.output = &instance.output_memory();
62 static primitive_impl* create(const batch_norm_node &arg)
64 if (!arg.use_global_stats()
65 || arg.calc_mean_var() )
67 auto norm_params = get_default_params<kernel_selector::batch_norm_params>(arg);
68 auto norm_optional_params = get_default_optional_params<kernel_selector::batch_norm_optional_params>(arg.get_program());
70 norm_params.batchNormParams.epsilon = arg.get_primitive()->epsilon;
71 norm_params.batchNormParams.with_inv_var = arg.forwad_pass();
72 norm_params.batchNormParams.with_scale_shift = arg.use_scale_shift();
73 if (arg.calc_mean_var())
74 norm_params.batchNormParams.with_mean_var_out = arg.calc_mean_var();
76 auto& kernel_selector = kernel_selector::batch_norm_kernel_selector::Instance();
77 auto best_kernels = kernel_selector.GetBestKernels(norm_params, norm_optional_params);
79 CLDNN_ERROR_BOOL(arg.id(), "Best_kernel.empty()", best_kernels.empty(), "Cannot find a proper kernel with this arguments");
81 auto norm = new batch_norm_gpu(arg, best_kernels[0]);
87 auto ew_params = get_default_params<kernel_selector::eltwise_params>(arg);
88 auto ew_optional_params = get_default_optional_params<kernel_selector::eltwise_optional_params>(arg.get_program());
90 (arg.input().get_output_layout().data_type == data_types::f16) ?
91 std::max(0.00007f, arg.get_primitive()->epsilon) : // prevent underflow if the epsilon is too small for fp16
92 arg.get_primitive()->epsilon;
94 ew_params.inputs.push_back(convert_data_tensor(arg.mean().get_output_layout()));
95 ew_params.inputs.push_back(convert_data_tensor(arg.variance().get_output_layout()));
97 ew_params.operations.push_back({
98 { kernel_selector::eltwise_params::InputType::Buffer(0), kernel_selector::eltwise_params::InputType::Buffer(1) },
99 kernel_selector::eltwise_mode::SUB });
101 ew_params.operations.push_back({
102 { kernel_selector::eltwise_params::InputType::Buffer(2), kernel_selector::eltwise_params::InputType::Scalar(epsilon) },
103 kernel_selector::eltwise_mode::ADD });
105 ew_params.operations.push_back({
106 { kernel_selector::eltwise_params::InputType::Intermediate(1) },
107 kernel_selector::eltwise_mode::RSQRT });
109 ew_params.operations.push_back({
110 { kernel_selector::eltwise_params::InputType::Intermediate(0), kernel_selector::eltwise_params::InputType::Intermediate(2) },
111 kernel_selector::eltwise_mode::MUL });
113 if (arg.use_scale_shift()) {
114 ew_params.inputs.push_back(convert_data_tensor(arg.scale().get_output_layout()));
115 ew_params.inputs.push_back(convert_data_tensor(arg.shift().get_output_layout()));
117 ew_params.operations.push_back({
118 { kernel_selector::eltwise_params::InputType::Intermediate(3), kernel_selector::eltwise_params::InputType::Buffer(3) },
119 kernel_selector::eltwise_mode::MUL });
121 ew_params.operations.push_back({
122 { kernel_selector::eltwise_params::InputType::Intermediate(4), kernel_selector::eltwise_params::InputType::Buffer(4) },
123 kernel_selector::eltwise_mode::ADD });
126 ew_params.layoutBased = true;
128 auto& kernel_selector = kernel_selector::eltwise_kernel_selector::Instance();
129 auto best_kernels = kernel_selector.GetBestKernels(ew_params, ew_optional_params);
131 CLDNN_ERROR_BOOL(arg.id(), "Best_kernel.empty()", best_kernels.empty(), "Cannot find a proper kernel with this arguments");
133 auto norm = new batch_norm_gpu(arg, best_kernels[0]);
143 auto val_fw = batch_norm_gpu::create;
145 implementation_map<batch_norm>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), val_fw);
146 implementation_map<batch_norm>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), val_fw);
147 implementation_map<batch_norm>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
148 implementation_map<batch_norm>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
149 implementation_map<batch_norm>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::byxf), val_fw);
150 implementation_map<batch_norm>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::byxf), val_fw);