577f0a7db696a87087c0518dd813dcd28bbcfdaa
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / gpu / convolution_grad_weights_gpu.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "convolution_grad_weights_inst.h"
18 #include "primitive_gpu_base.h"
19 #include "implementation_map.h"
20 #include "error_handler.h"
21 #include "network_impl.h"
22 #include "kernel_selector_helper.h"
23 #include "convolution_grad_weights/convolution_grad_weights_kernel_selector.h"
24 #include "convolution_grad_weights/convolution_grad_weights_kernel_base.h"
25 #include <algorithm>
26
27 namespace cldnn {
28 namespace gpu {
29
30 struct convolution_grad_weights_gpu : typed_primitive_gpu_impl<convolution_grad_weights> {
31     using parent = typed_primitive_gpu_impl<convolution_grad_weights>;
32     using parent::parent;
33
34 protected:
35     bool validate_impl(const typed_primitive_inst<convolution_grad_weights>& instance) const override {
36         bool res = true;
37
38         CLDNN_ERROR_NOT_EQUAL(_outer.id(),
39                               "convolution_grad_weights filling value",
40                               _outer.get_output_layout().data_padding.filling_value(),
41                               "padding mode",
42                               0.0f,
43                               "Unknown padding mode in convolution_grad_weights.");
44         // Check whether all memory elements use the same unit type (FP16 or FP32).
45         CLDNN_ERROR_DATA_TYPES_MISMATCH(_outer.id(),
46                                         "Input grad memory",
47                                         instance.input_memory().get_layout().data_type,
48                                         "output memory",
49                                         instance.output_memory().get_layout().data_type,
50                                         "");
51         CLDNN_ERROR_DATA_TYPES_MISMATCH(_outer.id(),
52                                         "Input memory",
53                                         instance.input_memory(1).get_layout().data_type,
54                                         "output memory",
55                                         instance.output_memory().get_layout().data_type,
56                                         "");
57         CLDNN_ERROR_DATA_TYPES_MISMATCH(_outer.id(),
58                                         "Fp32",
59                                         data_types::f32,
60                                         "filter memory",
61                                         instance.weights_memory(0).get_layout().data_type,
62                                         "");
63
64         if (instance.use_momentum()) {
65             CLDNN_ERROR_LAYOUT_MISMATCH(_outer.id(),
66                                         "Filter memory",
67                                         instance.weights_memory(0).get_layout(),
68                                         "previous weights grad memory",
69                                         _outer.prev_weights_grad(0).get_output_layout(),
70                                         "");
71             CLDNN_ERROR_LAYOUT_MISMATCH(_outer.id(),
72                                         "Bias memory",
73                                         instance.bias_memory(0).get_layout(),
74                                         "previous bias grad memory",
75                                         _outer.prev_bias_grad(0).get_output_layout(),
76                                         "");
77         }
78
79         return res;
80     }
81
82     kernel::kernel_arguments_data get_arguments(typed_primitive_inst<convolution_grad_weights>& instance,
83                                                         int32_t split) const override {
84         kernel::kernel_arguments_data args = parent::get_arguments(instance, split);
85
86         args.weights = (memory_impl::cptr) &instance.weights_memory(split);
87         args.bias = (memory_impl::cptr) (instance.bias_term() ? &instance.bias_memory(split) : nullptr);
88         args.prev_weights_grad = (memory_impl::cptr) (instance.use_momentum() ? &instance.prev_weights_grad(split) : nullptr);
89         args.prev_bias_grad =
90             (memory_impl::cptr) (instance.bias_term() ? instance.use_momentum() ? &instance.prev_bias_grad(split) : nullptr : nullptr);
91         args.lr = instance.get_network().get_learning_rate();
92
93         return args;
94     }
95
96     int32_t get_split() const override { return _outer.get_split(); }
97
98 public:
99     static primitive_impl* create(const convolution_grad_weights_node& arg) {
100         const auto& primitive = arg.get_primitive();
101         const auto& weights_layout = arg.weights(0).get_output_layout();
102
103         switch (weights_layout.fused_format()) {
104             case fuse(data_types::f32, format::bfyx):
105             case fuse(data_types::f32, format::yxfb):
106             case fuse(data_types::f16, format::bfyx):
107             case fuse(data_types::f16, format::yxfb):
108                 break;
109             default:
110                 throw std::runtime_error("convolution_grad_weights weights format unsupported");
111         }
112
113         const auto& weights_size = weights_layout.size;
114
115         const auto& split = primitive->split();
116         const auto& stride = primitive->stride;
117 #if 0  // TODO: support dilation
118         const auto& dilation = primitive->dilation;
119 #else
120         const tensor dilation = {0, 0, 1, 1};
121 #endif
122         const auto depthwise_separable_opt = arg.get_depthwise_sep_opt();
123         const auto output_grad_w = arg.output_grad_w();
124
125         const auto& input_offset = primitive->input_offset;
126
127         auto conv_grad_weights_params = get_default_learning_params<kernel_selector::convolution_grad_weights_params>(
128             arg,
129             depthwise_separable_opt ? 1 : split);
130         auto conv_grad_weights_optional_params =
131             get_default_learning_optional_params<kernel_selector::convolution_grad_weights_optional_params>(
132                 arg.get_program());
133
134         conv_grad_weights_params.depthwise_separable_opt = depthwise_separable_opt;
135         conv_grad_weights_params.output_grad_w = output_grad_w;
136
137         conv_grad_weights_params.gradient = true;
138         conv_grad_weights_params.inputs.push_back(convert_data_tensor(arg.get_dependency(1).get_output_layout()));
139
140         conv_grad_weights_params.split = split;
141         conv_grad_weights_params.filterSize = {
142             (uint32_t)weights_size.spatial[0],
143             (uint32_t)weights_size.spatial[1],
144         };
145
146         conv_grad_weights_params.padding = {(uint32_t)std::max(-input_offset.spatial[0], 0),
147                                             (uint32_t)std::max(-input_offset.spatial[1], 0)};
148
149         conv_grad_weights_params.stride = {(uint32_t)stride.spatial[0], (uint32_t)stride.spatial[1]};
150
151         conv_grad_weights_params.dilation = {(uint32_t)dilation.spatial[0], (uint32_t)dilation.spatial[1]};
152
153         auto& kernel_selector = kernel_selector::convolution_grad_weights_kernel_selector::Instance();
154         auto best_kernels = kernel_selector.GetBestKernels(conv_grad_weights_params, conv_grad_weights_optional_params);
155
156         CLDNN_ERROR_BOOL(arg.id(),
157                          "Best_kernel.empty()",
158                          best_kernels.empty(),
159                          "Cannot find a proper kernel with this arguments");
160
161         auto deconv = new convolution_grad_weights_gpu(arg, best_kernels[0]);
162
163         return deconv;
164     }
165 };
166
167 namespace {
168 struct attach {
169     attach() {
170         implementation_map<convolution_grad_weights>::add(
171             std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb),
172             convolution_grad_weights_gpu::create);
173         implementation_map<convolution_grad_weights>::add(
174             std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx),
175             convolution_grad_weights_gpu::create);
176         implementation_map<convolution_grad_weights>::add(
177             std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb),
178             convolution_grad_weights_gpu::create);
179         implementation_map<convolution_grad_weights>::add(
180             std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx),
181             convolution_grad_weights_gpu::create);
182         implementation_map<convolution_grad_weights>::add(
183             std::make_tuple(engine_types::ocl, data_types::f32, format::byxf),
184             convolution_grad_weights_gpu::create);
185         implementation_map<convolution_grad_weights>::add(
186             std::make_tuple(engine_types::ocl, data_types::f16, format::byxf),
187             convolution_grad_weights_gpu::create);
188     }
189     ~attach() {}
190 };
191 attach attach_impl;
192 }  // namespace
193 }  // namespace gpu
194 }  // namespace cldnn