2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "scale_grad_weights_kernel_base.h"
18 #include "kernel_selector_utils.h"
20 namespace kernel_selector
22 JitConstants ScaleGradWeightsKernelBase::GetJitConstants(const scale_grad_weights_params& params) const
24 JitConstants jit = training_kernel_base::GetJitConstants(params);
29 ScaleGradWeightsKernelBase::DispatchData ScaleGradWeightsKernelBase::SetDefault(const scale_grad_weights_params& params) const
33 kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
35 kd.gws0 = params.inputs[0].Batch().v;
36 kd.gws1 = params.inputs[0].Feature().v;
39 kd.lws0 = params.inputs[0].Batch().v;
42 kd.effiency = DONT_USE_IF_HAVE_SOMETHING_ELSE;
46 KernelsData ScaleGradWeightsKernelBase::GetKernelsData(const Params& params, const optional_params& options) const
48 assert(params.GetType() == KernelType::SCALE_GRAD_WEIGHTS);
50 const scale_grad_weights_params& orgParams = static_cast<const scale_grad_weights_params&>(params);
52 DispatchData runInfo = SetDefault(orgParams);
53 KernelData kd = KernelData::Default<scale_grad_weights_params>(params);
55 auto cldnn_jit = GetJitConstants(orgParams);
56 auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, options);
57 auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
59 auto& kernel = kd.kernels[0];
60 FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point, DEFAULT, true, !orgParams.bias.empty(), 2);
62 if (orgParams.use_momentum)
64 kernel.arguments.push_back({ ArgumentDescriptor::Types::PREV_WEIGHTS_GRADIENT, 0 });
65 if (!orgParams.bias.empty())
66 kernel.arguments.push_back({ ArgumentDescriptor::Types::PREV_BIAS_GRADIENT, 0 });
68 kernel.arguments.push_back({ ArgumentDescriptor::Types::LEARNING_RATE, 0 });
70 kd.estimatedTime = runInfo.effiency;