Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / gpu / fully_connected_grad_weights_gpu.cpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "fully_connected_grad_weights_inst.h"
18 #include "primitive_gpu_base.h"
19 #include "implementation_map.h"
20 #include "error_handler.h"
21 #include "network_impl.h"
22 #include "kernel_selector_helper.h"
23 #include "fully_connected_grad_weights/fully_connected_grad_weights_kernel_selector.h"
24 #include "fully_connected_grad_weights/fully_connected_grad_weights_kernel_base.h"
25 #include "api/CPP/fully_connected_grad_weights.hpp"
26
27 namespace cldnn { namespace gpu {
28
29 struct fully_connected_grad_weights_gpu : typed_primitive_gpu_impl<fully_connected_grad_weights>
30 {
31     using parent = typed_primitive_gpu_impl<fully_connected_grad_weights>;
32     using parent::parent;
33
34 protected:
35
36     virtual bool validate_impl(const typed_primitive_inst<fully_connected_grad_weights>& instance) const override
37     {
38         bool res = true;
39
40         if (instance.use_momentum())
41         {
42             CLDNN_ERROR_LAYOUT_MISMATCH(_outer.id(), "Filter memory", instance.weights_memory().get_layout(), "previous weights grad memory", _outer.prev_weights_grad().get_output_layout(), "");
43             CLDNN_ERROR_LAYOUT_MISMATCH(_outer.id(), "Bias memory", instance.bias_memory().get_layout(), "previous bias grad memory", _outer.prev_bias_grad().get_output_layout(), "");
44         }
45
46         return res;
47     }
48
49     virtual kernel::kernel_arguments_data get_arguments(typed_primitive_inst<fully_connected_grad_weights>& instance, int32_t) const override
50     {
51         kernel::kernel_arguments_data args = parent::get_arguments(instance, 1);
52         args.weights = &instance.weights_memory();
53         args.bias = instance.bias_term() ? &instance.bias_memory() : nullptr;
54         args.prev_weights_grad = instance.use_momentum() ? &instance.prev_weights_grad() : nullptr;
55         args.prev_bias_grad = instance.bias_term() ? instance.use_momentum() ? &instance.prev_bias_grad() : nullptr : nullptr;
56
57         args.lr = instance.get_network().get_learning_rate();
58
59         return args;
60     }
61
62 public:
63
64     static primitive_impl* create(const fully_connected_grad_weights_node& arg)
65     {
66         auto fully_connected_grad_weights_params = get_default_learning_params<kernel_selector::fully_connected_grad_weights_params>(arg);
67         auto fully_connected_grad_weights_optional_params = get_default_learning_optional_params<kernel_selector::fully_connected_grad_weights_optional_params>(arg.get_program());
68
69         fully_connected_grad_weights_params.gradient = true;
70         fully_connected_grad_weights_params.inputs.push_back(convert_data_tensor(arg.get_dependency(1).get_output_layout()));
71
72         auto& kernel_selector = kernel_selector::fully_connected_grad_weights_kernel_selector::Instance();
73         auto best_kernels = kernel_selector.GetBestKernels(fully_connected_grad_weights_params, fully_connected_grad_weights_optional_params);
74         CLDNN_ERROR_BOOL(arg.id(), "Best_kernel.empty()", best_kernels.empty(), "Cannot find a proper kernel with this arguments");
75
76         auto fully_connected_grad_weights = new fully_connected_grad_weights_gpu(arg, best_kernels[0]);
77
78         return fully_connected_grad_weights;
79     };
80 };
81
82
83 namespace {
84     struct attach {
85         attach() {
86             auto val_fw = fully_connected_grad_weights_gpu::create;
87
88             implementation_map<fully_connected_grad_weights>::add({
89                 { std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), val_fw },
90                 { std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), val_fw },
91                 { std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw },
92                 { std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw },
93                 { std::make_tuple(engine_types::ocl, data_types::f32, format::byxf), val_fw },
94                 { std::make_tuple(engine_types::ocl, data_types::f16, format::byxf), val_fw },
95             });
96         }
97         ~attach() {}
98     };
99     attach attach_impl;
100 }
101 } }