2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "fully_connected_grad_input_kernel_base.h"
18 #include "kernel_selector_utils.h"
20 namespace kernel_selector
22 JitConstants FullyConnectedGradInputKernelBase::GetJitConstants(const fully_connected_grad_input_params& params) const
24 return WeightBiasKernelBase::GetJitConstants(params);
27 FullyConnectedGradInputKernelBase::DispatchData FullyConnectedGradInputKernelBase::SetDefault(const fully_connected_grad_input_params& params) const
31 kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
32 size_t gws0 = params.output.Batch().v * params.weights.IFM().v;
33 size_t lws0 = std::min(gws0, static_cast<size_t>(32));
39 kd.gws1 = params.weights.X().v;
40 kd.gws2 = params.weights.Y().v;
44 kd.effiency = DONT_USE_IF_HAVE_SOMETHING_ELSE;
48 KernelsData FullyConnectedGradInputKernelBase::GetKernelsData(const Params& params, const optional_params& options) const
50 assert(params.GetType() == KernelType::FULLY_CONNECTED_GRAD_INPUT);
52 const fully_connected_grad_input_params& orgParams = static_cast<const fully_connected_grad_input_params&>(params);
54 const std::vector<WeightsLayout> weightsLayouts = {
63 DispatchData runInfo = SetDefault(orgParams);
64 KernelData kd = KernelData::Default<fully_connected_grad_input_params>(params);
65 fully_connected_grad_input_params& newParams = *static_cast<fully_connected_grad_input_params*>(kd.params.get());
67 bool succeed = UpdateWeightsParams(
71 kd.weightsReorderParams);
78 auto cldnn_jit = GetJitConstants(orgParams);
79 auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, options);
80 auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
82 auto& kernel = kd.kernels[0];
83 FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point, DEFAULT, true, !orgParams.bias.empty());
84 kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 1 });
86 kd.estimatedTime = runInfo.effiency;