Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / gpu / apply_adam_gpu.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "apply_adam_inst.h"
18 #include "primitive_gpu_base.h"
19 #include "implementation_map.h"
20 #include "error_handler.h"
21 #include "kernel_selector_helper.h"
22 #include "eltwise/eltwise_kernel_selector.h"
23 #include "eltwise/eltwise_kernel_base.h"
24
25 namespace cldnn { namespace gpu {
26
27 struct apply_adam_gpu : typed_primitive_gpu_impl<apply_adam>
28 {
29     using parent = typed_primitive_gpu_impl<apply_adam>;
30     using parent::parent;
31
32 protected:
33
34     virtual kernel::kernel_arguments_data get_arguments(typed_primitive_inst<apply_adam>& instance, int32_t) const override
35     {
36         kernel::kernel_arguments_data args;
37
38         args.inputs = { &instance.input_memory(), &instance.m_memory(), &instance.v_memory(), &instance.beta1_power_memory(), &instance.beta2_power_memory() };
39         args.output = &instance.output_memory();
40
41         return args;
42     }
43
44 public:
45
46     static primitive_impl* create(const apply_adam_node &arg) 
47     { 
48         auto ew_params = get_default_params<kernel_selector::eltwise_params>(arg);
49         auto ew_optional_params = get_default_optional_params<kernel_selector::eltwise_optional_params>(arg.get_program());
50         const float lr = arg.get_primitive()->lr;
51         const float beta1 = arg.get_primitive()->beta1;
52         const float beta2 = arg.get_primitive()->beta2;
53         const float epsilon =
54             (arg.input().get_output_layout().data_type == data_types::f16) ?
55             std::max(0.00007f, arg.get_primitive()->epsilon) : // prevent underflow if the epsilon is too small for fp16
56             arg.get_primitive()->epsilon;
57
58         ew_params.inputs.push_back(convert_data_tensor(arg.m().get_output_layout()));
59         ew_params.inputs.push_back(convert_data_tensor(arg.v().get_output_layout()));
60         ew_params.inputs.push_back(convert_data_tensor(arg.beta1_power().get_output_layout()));
61         ew_params.inputs.push_back(convert_data_tensor(arg.beta2_power().get_output_layout()));
62
63         //lr_t = lr * sqrt(1 - pow(beta2, t_f)) / (1 - pow(beta1, t_f))
64         ew_params.eltwiseParams.operations.push_back({
65             { kernel_selector::eltwise_params::InputType::Scalar(1), kernel_selector::eltwise_params::InputType::Buffer(3) },
66             kernel_selector::eltwise_mode::SUB });
67
68         ew_params.eltwiseParams.operations.push_back({
69             { kernel_selector::eltwise_params::InputType::Scalar(1), kernel_selector::eltwise_params::InputType::Buffer(4) },
70             kernel_selector::eltwise_mode::SUB });
71
72         ew_params.eltwiseParams.operations.push_back({
73             { kernel_selector::eltwise_params::InputType::Intermediate(1) },
74             kernel_selector::eltwise_mode::SQRT });
75
76         ew_params.eltwiseParams.operations.push_back({
77             { kernel_selector::eltwise_params::InputType::Intermediate(2), kernel_selector::eltwise_params::InputType::Scalar(lr) },
78             kernel_selector::eltwise_mode::MUL });
79
80         ew_params.eltwiseParams.operations.push_back({
81             { kernel_selector::eltwise_params::InputType::Intermediate(3), kernel_selector::eltwise_params::InputType::Intermediate(0) },
82             kernel_selector::eltwise_mode::DIV });
83
84         //m_t = beta1 * m_f + (1 - beta1) * input_grad
85         ew_params.eltwiseParams.operations.push_back({
86             { kernel_selector::eltwise_params::InputType::Scalar(beta1), kernel_selector::eltwise_params::InputType::Buffer(1) },
87             kernel_selector::eltwise_mode::MUL });
88         
89         ew_params.eltwiseParams.operations.push_back({
90             { kernel_selector::eltwise_params::InputType::Scalar(1), kernel_selector::eltwise_params::InputType::Scalar(beta1) },
91             kernel_selector::eltwise_mode::SUB });
92         
93         ew_params.eltwiseParams.operations.push_back({
94             { kernel_selector::eltwise_params::InputType::Intermediate(6), kernel_selector::eltwise_params::InputType::Buffer(0) },
95             kernel_selector::eltwise_mode::MUL });
96         
97         ew_params.eltwiseParams.operations.push_back({
98             { kernel_selector::eltwise_params::InputType::Intermediate(5), kernel_selector::eltwise_params::InputType::Intermediate(7) },
99             kernel_selector::eltwise_mode::ADD });
100
101         //save the result in m mutable_data primitive
102         ew_params.eltwiseParams.updateInputIds.push_back({ 1, 8 });
103         
104         ////v_t = beta2 * v_f + (1 - beta2) * input_grad * input_grad
105         ew_params.eltwiseParams.operations.push_back({
106             { kernel_selector::eltwise_params::InputType::Scalar(beta2), kernel_selector::eltwise_params::InputType::Buffer(2) },
107             kernel_selector::eltwise_mode::MUL });
108         
109         ew_params.eltwiseParams.operations.push_back({
110             { kernel_selector::eltwise_params::InputType::Scalar(1), kernel_selector::eltwise_params::InputType::Scalar(beta2) },
111             kernel_selector::eltwise_mode::SUB });
112         
113         ew_params.eltwiseParams.operations.push_back({
114             { kernel_selector::eltwise_params::InputType::Intermediate(10), kernel_selector::eltwise_params::InputType::Buffer(0) },
115             kernel_selector::eltwise_mode::MUL });
116         
117         ew_params.eltwiseParams.operations.push_back({
118             { kernel_selector::eltwise_params::InputType::Intermediate(11), kernel_selector::eltwise_params::InputType::Buffer(0) },
119             kernel_selector::eltwise_mode::MUL });
120         
121         ew_params.eltwiseParams.operations.push_back({
122             { kernel_selector::eltwise_params::InputType::Intermediate(9), kernel_selector::eltwise_params::InputType::Intermediate(12) },
123             kernel_selector::eltwise_mode::ADD });
124         
125         //save the result in v mutable_data primitive
126         ew_params.eltwiseParams.updateInputIds.push_back({ 2, 13 });
127
128         ////result = var - lr_t * m_t / (sqrt(v_t) + epsilon)
129         ew_params.eltwiseParams.operations.push_back({
130             { kernel_selector::eltwise_params::InputType::Intermediate(13) },
131             kernel_selector::eltwise_mode::SQRT });
132         
133         ew_params.eltwiseParams.operations.push_back({
134             { kernel_selector::eltwise_params::InputType::Intermediate(14), kernel_selector::eltwise_params::InputType::Scalar(epsilon) },
135             kernel_selector::eltwise_mode::ADD });
136         
137         ew_params.eltwiseParams.operations.push_back({
138             { kernel_selector::eltwise_params::InputType::Intermediate(4), kernel_selector::eltwise_params::InputType::Intermediate(8) },
139             kernel_selector::eltwise_mode::MUL });
140         
141         ew_params.eltwiseParams.operations.push_back({
142             { kernel_selector::eltwise_params::InputType::Intermediate(16), kernel_selector::eltwise_params::InputType::Intermediate(15) },
143             kernel_selector::eltwise_mode::DIV });
144
145         ew_params.eltwiseParams.operations.push_back({
146             { kernel_selector::eltwise_params::InputType::OutBuffer(), kernel_selector::eltwise_params::InputType::Intermediate(17) },
147             kernel_selector::eltwise_mode::SUB });
148
149         ew_params.eltwiseParams.layoutBased = true;
150
151         auto& kernel_selector = kernel_selector::eltwise_kernel_selector::Instance();
152         auto best_kernels = kernel_selector.GetBestKernels(ew_params, ew_optional_params);
153
154         CLDNN_ERROR_BOOL(arg.id(), "Best_kernel.empty()", best_kernels.empty(), "Cannot find a proper kernel with this arguments");
155
156         auto norm = new apply_adam_gpu(arg, best_kernels[0]);
157
158         return norm;
159     };
160 };
161
162 namespace {
163     struct attach {
164         attach() {
165             auto val_fw = apply_adam_gpu::create;
166
167             implementation_map<apply_adam>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), val_fw);
168             implementation_map<apply_adam>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), val_fw);
169             implementation_map<apply_adam>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
170             implementation_map<apply_adam>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
171             implementation_map<apply_adam>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::byxf), val_fw);
172             implementation_map<apply_adam>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::byxf), val_fw);
173         }
174         ~attach() {}
175     };
176     attach attach_impl;
177 }
178 } }