2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "fully_connected_grad_weights_inst.h"
18 #include "primitive_gpu_base.h"
19 #include "implementation_map.h"
20 #include "error_handler.h"
21 #include "network_impl.h"
22 #include "kernel_selector_helper.h"
23 #include "fully_connected_grad_weights/fully_connected_grad_weights_kernel_selector.h"
24 #include "fully_connected_grad_weights/fully_connected_grad_weights_kernel_base.h"
25 #include "api/CPP/fully_connected_grad_weights.hpp"
27 namespace cldnn { namespace gpu {
29 struct fully_connected_grad_weights_gpu : typed_primitive_gpu_impl<fully_connected_grad_weights>
31 using parent = typed_primitive_gpu_impl<fully_connected_grad_weights>;
36 virtual bool validate_impl(const typed_primitive_inst<fully_connected_grad_weights>& instance) const override
40 if (instance.use_momentum())
42 CLDNN_ERROR_LAYOUT_MISMATCH(_outer.id(), "Filter memory", instance.weights_memory().get_layout(), "previous weights grad memory", _outer.prev_weights_grad().get_output_layout(), "");
43 CLDNN_ERROR_LAYOUT_MISMATCH(_outer.id(), "Bias memory", instance.bias_memory().get_layout(), "previous bias grad memory", _outer.prev_bias_grad().get_output_layout(), "");
49 virtual kernel::kernel_arguments_data get_arguments(typed_primitive_inst<fully_connected_grad_weights>& instance, int32_t) const override
51 kernel::kernel_arguments_data args = parent::get_arguments(instance, 1);
52 args.weights = &instance.weights_memory();
53 args.bias = instance.bias_term() ? &instance.bias_memory() : nullptr;
54 args.prev_weights_grad = instance.use_momentum() ? &instance.prev_weights_grad() : nullptr;
55 args.prev_bias_grad = instance.bias_term() ? instance.use_momentum() ? &instance.prev_bias_grad() : nullptr : nullptr;
57 args.lr = instance.get_network().get_learning_rate();
64 static primitive_impl* create(const fully_connected_grad_weights_node& arg)
66 auto fully_connected_grad_weights_params = get_default_learning_params<kernel_selector::fully_connected_grad_weights_params>(arg);
67 auto fully_connected_grad_weights_optional_params = get_default_learning_optional_params<kernel_selector::fully_connected_grad_weights_optional_params>(arg.get_program());
69 fully_connected_grad_weights_params.gradient = true;
70 fully_connected_grad_weights_params.inputs.push_back(convert_data_tensor(arg.get_dependency(1).get_output_layout()));
72 auto& kernel_selector = kernel_selector::fully_connected_grad_weights_kernel_selector::Instance();
73 auto best_kernels = kernel_selector.GetBestKernels(fully_connected_grad_weights_params, fully_connected_grad_weights_optional_params);
74 CLDNN_ERROR_BOOL(arg.id(), "Best_kernel.empty()", best_kernels.empty(), "Cannot find a proper kernel with this arguments");
76 auto fully_connected_grad_weights = new fully_connected_grad_weights_gpu(arg, best_kernels[0]);
78 return fully_connected_grad_weights;
86 auto val_fw = fully_connected_grad_weights_gpu::create;
88 implementation_map<fully_connected_grad_weights>::add({
89 { std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), val_fw },
90 { std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), val_fw },
91 { std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw },
92 { std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw },
93 { std::make_tuple(engine_types::ocl, data_types::f32, format::byxf), val_fw },
94 { std::make_tuple(engine_types::ocl, data_types::f16, format::byxf), val_fw },