2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "fully_connected_grad_weights_kernel_base.h"
18 #include "kernel_selector_utils.h"
20 namespace kernel_selector
22 JitConstants FullyConnectedGradWeightsKernelBase::GetJitConstants(const fully_connected_grad_weights_params& params) const
24 JitConstants jit = training_kernel_base::GetJitConstants(params);
29 FullyConnectedGradWeightsKernelBase::DispatchData FullyConnectedGradWeightsKernelBase::SetDefault(const fully_connected_grad_weights_params& params) const
33 kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
34 size_t gws0 = params.weights.OFM().v * params.weights.IFM().v;
35 size_t lws0 = std::min(gws0, static_cast<size_t>(32));
41 kd.gws1 = params.weights.X().v;
42 kd.gws2 = params.weights.Y().v;
46 kd.effiency = DONT_USE_IF_HAVE_SOMETHING_ELSE;
50 KernelsData FullyConnectedGradWeightsKernelBase::GetKernelsData(const Params& params, const optional_params& options) const
52 assert(params.GetType() == KernelType::FULLY_CONNECTED_GRAD_WEIGHTS);
54 const fully_connected_grad_weights_params& orgParams = static_cast<const fully_connected_grad_weights_params&>(params);
56 const std::vector<WeightsLayout> weightsLayouts = {
65 DispatchData runInfo = SetDefault(orgParams);
66 KernelData kd = KernelData::Default<fully_connected_grad_weights_params>(params);
67 fully_connected_grad_weights_params& newParams = *static_cast<fully_connected_grad_weights_params*>(kd.params.get());
69 bool succeed = UpdateWeightsParams(
73 kd.weightsReorderParams);
80 auto cldnn_jit = GetJitConstants(orgParams);
81 auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, options);
82 auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
84 auto& kernel = kd.kernels[0];
85 FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point, DEFAULT, true, !orgParams.bias.empty());
86 if (orgParams.use_momentum)
88 kernel.arguments.push_back({ ArgumentDescriptor::Types::PREV_WEIGHTS_GRADIENT, 0 });
89 if (!orgParams.bias.empty())
90 kernel.arguments.push_back({ ArgumentDescriptor::Types::PREV_BIAS_GRADIENT, 0 });
92 kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 1 });
93 kernel.arguments.push_back({ ArgumentDescriptor::Types::LEARNING_RATE, 0 });
95 kd.estimatedTime = runInfo.effiency;