2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "apply_adam_inst.h"
18 #include "primitive_gpu_base.h"
19 #include "implementation_map.h"
20 #include "error_handler.h"
21 #include "kernel_selector_helper.h"
22 #include "eltwise/eltwise_kernel_selector.h"
23 #include "eltwise/eltwise_kernel_base.h"
25 namespace cldnn { namespace gpu {
27 struct apply_adam_gpu : typed_primitive_gpu_impl<apply_adam>
29 using parent = typed_primitive_gpu_impl<apply_adam>;
34 virtual kernel::kernel_arguments_data get_arguments(typed_primitive_inst<apply_adam>& instance, int32_t) const override
36 kernel::kernel_arguments_data args;
38 args.inputs = { &instance.input_memory(), &instance.m_memory(), &instance.v_memory(), &instance.beta1_power_memory(), &instance.beta2_power_memory() };
39 args.output = &instance.output_memory();
46 static primitive_impl* create(const apply_adam_node &arg)
48 auto ew_params = get_default_params<kernel_selector::eltwise_params>(arg);
49 auto ew_optional_params = get_default_optional_params<kernel_selector::eltwise_optional_params>(arg.get_program());
50 const float lr = arg.get_primitive()->lr;
51 const float beta1 = arg.get_primitive()->beta1;
52 const float beta2 = arg.get_primitive()->beta2;
54 (arg.input().get_output_layout().data_type == data_types::f16) ?
55 std::max(0.00007f, arg.get_primitive()->epsilon) : // prevent underflow if the epsilon is too small for fp16
56 arg.get_primitive()->epsilon;
58 ew_params.inputs.push_back(convert_data_tensor(arg.m().get_output_layout()));
59 ew_params.inputs.push_back(convert_data_tensor(arg.v().get_output_layout()));
60 ew_params.inputs.push_back(convert_data_tensor(arg.beta1_power().get_output_layout()));
61 ew_params.inputs.push_back(convert_data_tensor(arg.beta2_power().get_output_layout()));
63 //lr_t = lr * sqrt(1 - pow(beta2, t_f)) / (1 - pow(beta1, t_f))
64 ew_params.eltwiseParams.operations.push_back({
65 { kernel_selector::eltwise_params::InputType::Scalar(1), kernel_selector::eltwise_params::InputType::Buffer(3) },
66 kernel_selector::eltwise_mode::SUB });
68 ew_params.eltwiseParams.operations.push_back({
69 { kernel_selector::eltwise_params::InputType::Scalar(1), kernel_selector::eltwise_params::InputType::Buffer(4) },
70 kernel_selector::eltwise_mode::SUB });
72 ew_params.eltwiseParams.operations.push_back({
73 { kernel_selector::eltwise_params::InputType::Intermediate(1) },
74 kernel_selector::eltwise_mode::SQRT });
76 ew_params.eltwiseParams.operations.push_back({
77 { kernel_selector::eltwise_params::InputType::Intermediate(2), kernel_selector::eltwise_params::InputType::Scalar(lr) },
78 kernel_selector::eltwise_mode::MUL });
80 ew_params.eltwiseParams.operations.push_back({
81 { kernel_selector::eltwise_params::InputType::Intermediate(3), kernel_selector::eltwise_params::InputType::Intermediate(0) },
82 kernel_selector::eltwise_mode::DIV });
84 //m_t = beta1 * m_f + (1 - beta1) * input_grad
85 ew_params.eltwiseParams.operations.push_back({
86 { kernel_selector::eltwise_params::InputType::Scalar(beta1), kernel_selector::eltwise_params::InputType::Buffer(1) },
87 kernel_selector::eltwise_mode::MUL });
89 ew_params.eltwiseParams.operations.push_back({
90 { kernel_selector::eltwise_params::InputType::Scalar(1), kernel_selector::eltwise_params::InputType::Scalar(beta1) },
91 kernel_selector::eltwise_mode::SUB });
93 ew_params.eltwiseParams.operations.push_back({
94 { kernel_selector::eltwise_params::InputType::Intermediate(6), kernel_selector::eltwise_params::InputType::Buffer(0) },
95 kernel_selector::eltwise_mode::MUL });
97 ew_params.eltwiseParams.operations.push_back({
98 { kernel_selector::eltwise_params::InputType::Intermediate(5), kernel_selector::eltwise_params::InputType::Intermediate(7) },
99 kernel_selector::eltwise_mode::ADD });
101 //save the result in m mutable_data primitive
102 ew_params.eltwiseParams.updateInputIds.push_back({ 1, 8 });
104 ////v_t = beta2 * v_f + (1 - beta2) * input_grad * input_grad
105 ew_params.eltwiseParams.operations.push_back({
106 { kernel_selector::eltwise_params::InputType::Scalar(beta2), kernel_selector::eltwise_params::InputType::Buffer(2) },
107 kernel_selector::eltwise_mode::MUL });
109 ew_params.eltwiseParams.operations.push_back({
110 { kernel_selector::eltwise_params::InputType::Scalar(1), kernel_selector::eltwise_params::InputType::Scalar(beta2) },
111 kernel_selector::eltwise_mode::SUB });
113 ew_params.eltwiseParams.operations.push_back({
114 { kernel_selector::eltwise_params::InputType::Intermediate(10), kernel_selector::eltwise_params::InputType::Buffer(0) },
115 kernel_selector::eltwise_mode::MUL });
117 ew_params.eltwiseParams.operations.push_back({
118 { kernel_selector::eltwise_params::InputType::Intermediate(11), kernel_selector::eltwise_params::InputType::Buffer(0) },
119 kernel_selector::eltwise_mode::MUL });
121 ew_params.eltwiseParams.operations.push_back({
122 { kernel_selector::eltwise_params::InputType::Intermediate(9), kernel_selector::eltwise_params::InputType::Intermediate(12) },
123 kernel_selector::eltwise_mode::ADD });
125 //save the result in v mutable_data primitive
126 ew_params.eltwiseParams.updateInputIds.push_back({ 2, 13 });
128 ////result = var - lr_t * m_t / (sqrt(v_t) + epsilon)
129 ew_params.eltwiseParams.operations.push_back({
130 { kernel_selector::eltwise_params::InputType::Intermediate(13) },
131 kernel_selector::eltwise_mode::SQRT });
133 ew_params.eltwiseParams.operations.push_back({
134 { kernel_selector::eltwise_params::InputType::Intermediate(14), kernel_selector::eltwise_params::InputType::Scalar(epsilon) },
135 kernel_selector::eltwise_mode::ADD });
137 ew_params.eltwiseParams.operations.push_back({
138 { kernel_selector::eltwise_params::InputType::Intermediate(4), kernel_selector::eltwise_params::InputType::Intermediate(8) },
139 kernel_selector::eltwise_mode::MUL });
141 ew_params.eltwiseParams.operations.push_back({
142 { kernel_selector::eltwise_params::InputType::Intermediate(16), kernel_selector::eltwise_params::InputType::Intermediate(15) },
143 kernel_selector::eltwise_mode::DIV });
145 ew_params.eltwiseParams.operations.push_back({
146 { kernel_selector::eltwise_params::InputType::OutBuffer(), kernel_selector::eltwise_params::InputType::Intermediate(17) },
147 kernel_selector::eltwise_mode::SUB });
149 ew_params.eltwiseParams.layoutBased = true;
151 auto& kernel_selector = kernel_selector::eltwise_kernel_selector::Instance();
152 auto best_kernels = kernel_selector.GetBestKernels(ew_params, ew_optional_params);
154 CLDNN_ERROR_BOOL(arg.id(), "Best_kernel.empty()", best_kernels.empty(), "Cannot find a proper kernel with this arguments");
156 auto norm = new apply_adam_gpu(arg, best_kernels[0]);
165 auto val_fw = apply_adam_gpu::create;
167 implementation_map<apply_adam>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), val_fw);
168 implementation_map<apply_adam>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), val_fw);
169 implementation_map<apply_adam>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
170 implementation_map<apply_adam>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
171 implementation_map<apply_adam>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::byxf), val_fw);
172 implementation_map<apply_adam>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::byxf), val_fw);