Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / fully_connected_grad_input / fully_connected_grad_input_kernel_base.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "fully_connected_grad_input_kernel_base.h"
18 #include "kernel_selector_utils.h"
19
20 namespace kernel_selector 
21 {
22     JitConstants FullyConnectedGradInputKernelBase::GetJitConstants(const fully_connected_grad_input_params& params) const
23     {
24         return WeightBiasKernelBase::GetJitConstants(params);
25     }
26
27     FullyConnectedGradInputKernelBase::DispatchData FullyConnectedGradInputKernelBase::SetDefault(const fully_connected_grad_input_params& params) const
28     {
29         DispatchData kd;
30
31         kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
32         size_t gws0 = params.output.Batch().v * params.weights.IFM().v;
33         size_t lws0 = std::min(gws0, static_cast<size_t>(32));
34         while (gws0 % lws0)
35         {
36             lws0--;
37         }
38         kd.gws0 = gws0;
39         kd.gws1 = params.weights.X().v;
40         kd.gws2 = params.weights.Y().v;
41         kd.lws0 = lws0;
42         kd.lws1 = 1;
43         kd.lws2 = 1;
44         kd.effiency = DONT_USE_IF_HAVE_SOMETHING_ELSE;
45         return kd;
46     }
47
48     KernelsData FullyConnectedGradInputKernelBase::GetKernelsData(const Params& params, const optional_params& options) const
49     {
50         assert(params.GetType() == KernelType::FULLY_CONNECTED_GRAD_INPUT);
51
52         const fully_connected_grad_input_params& orgParams = static_cast<const fully_connected_grad_input_params&>(params);
53
54         const std::vector<WeightsLayout> weightsLayouts = {
55             WeightsLayout::oi,
56             WeightsLayout::io,
57             WeightsLayout::oiyx,
58             WeightsLayout::iyxo,
59             WeightsLayout::yxio,
60             WeightsLayout::oyxi
61         };
62
63         DispatchData runInfo = SetDefault(orgParams);
64         KernelData kd = KernelData::Default<fully_connected_grad_input_params>(params);
65         fully_connected_grad_input_params& newParams = *static_cast<fully_connected_grad_input_params*>(kd.params.get());
66
67         bool succeed = UpdateWeightsParams(
68             newParams,
69             options,
70             weightsLayouts,
71             kd.weightsReorderParams);
72
73         if (!succeed)
74         {
75             return{};
76         }
77
78         auto cldnn_jit = GetJitConstants(orgParams);
79         auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, options);
80         auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
81
82         auto& kernel = kd.kernels[0];
83         FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point, DEFAULT, true, !orgParams.bias.empty());
84         kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 1 });
85
86         kd.estimatedTime = runInfo.effiency;
87
88         return{ kd };
89     }
90 }