Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / fully_connected_grad_weights / fully_connected_grad_weights_kernel_base.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "fully_connected_grad_weights_kernel_base.h"
18 #include "kernel_selector_utils.h"
19
20 namespace kernel_selector 
21 {
22     JitConstants FullyConnectedGradWeightsKernelBase::GetJitConstants(const fully_connected_grad_weights_params& params) const
23     {
24         JitConstants jit = training_kernel_base::GetJitConstants(params);
25
26         return jit;
27     }
28
29     FullyConnectedGradWeightsKernelBase::DispatchData FullyConnectedGradWeightsKernelBase::SetDefault(const fully_connected_grad_weights_params& params) const
30     {
31         DispatchData kd;
32
33         kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
34         size_t gws0 = params.weights.OFM().v * params.weights.IFM().v;
35         size_t lws0 = std::min(gws0, static_cast<size_t>(32));
36         while (gws0 % lws0)
37         {
38             lws0--;
39         }
40         kd.gws0 = gws0;
41         kd.gws1 = params.weights.X().v;
42         kd.gws2 = params.weights.Y().v;
43         kd.lws0 = lws0;
44         kd.lws1 = 1;
45         kd.lws2 = 1;
46         kd.effiency = DONT_USE_IF_HAVE_SOMETHING_ELSE;
47         return kd;
48     }
49
50     KernelsData FullyConnectedGradWeightsKernelBase::GetKernelsData(const Params& params, const optional_params& options) const
51     {
52         assert(params.GetType() == KernelType::FULLY_CONNECTED_GRAD_WEIGHTS);
53
54         const fully_connected_grad_weights_params& orgParams = static_cast<const fully_connected_grad_weights_params&>(params);
55
56         const std::vector<WeightsLayout> weightsLayouts = {
57             WeightsLayout::oi,
58             WeightsLayout::io,
59             WeightsLayout::oiyx,
60             WeightsLayout::iyxo,
61             WeightsLayout::yxio,
62             WeightsLayout::oyxi
63         };
64
65         DispatchData runInfo = SetDefault(orgParams);
66         KernelData kd = KernelData::Default<fully_connected_grad_weights_params>(params);
67         fully_connected_grad_weights_params& newParams = *static_cast<fully_connected_grad_weights_params*>(kd.params.get());
68
69         bool succeed = UpdateWeightsParams(
70             newParams,
71             options,
72             weightsLayouts,
73             kd.weightsReorderParams);
74
75         if (!succeed)
76         {
77             return{};
78         }
79
80         auto cldnn_jit = GetJitConstants(orgParams);
81         auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, options);
82         auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
83
84         auto& kernel = kd.kernels[0];
85         FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point, DEFAULT, true, !orgParams.bias.empty());
86         if (orgParams.use_momentum)
87         {
88             kernel.arguments.push_back({ ArgumentDescriptor::Types::PREV_WEIGHTS_GRADIENT, 0 });
89             if (!orgParams.bias.empty())
90                 kernel.arguments.push_back({ ArgumentDescriptor::Types::PREV_BIAS_GRADIENT, 0 });
91         }
92         kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 1 });
93         kernel.arguments.push_back({ ArgumentDescriptor::Types::LEARNING_RATE, 0 });
94
95         kd.estimatedTime = runInfo.effiency;
96
97         return{ kd };
98     }
99 }