2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "activation_kernel_base.h"
18 #include "kernel_selector_utils.h"
20 namespace kernel_selector
23 ActivationKernelBase::DispatchData ActivationKernelBase::SetDefault(const activation_params& arg) const
25 const auto& out = arg.output;
28 std::vector<size_t> global = { out.X().v, out.Y().v, out.Feature().v*out.Batch().v };
29 if (out.GetLayout() == DataLayout::yxfb)
31 global[0] = out.Feature().v*out.Batch().v;
32 global[1] = out.X().v;
33 global[2] = out.Y().v;
35 std::vector<size_t> local = GetOptimalLocalWorkGroupSizes(global);
36 runInfo.gws0 = global[0];
37 runInfo.gws1 = global[1];
38 runInfo.gws2 = global[2];
39 runInfo.lws0 = local[0];
40 runInfo.lws1 = local[1];
41 runInfo.lws2 = local[2];
43 runInfo.effiency = DONT_USE_IF_HAVE_SOMETHING_ELSE;
44 runInfo.fp16UnitUsed = out.GetDType() == Datatype::F16;
49 JitConstants ActivationKernelBase::GetJitConstants(const activation_params& params, DispatchData) const
51 JitConstants jit = MakeBaseParamsJitConstants(params);
53 const auto& inputNlParams = params.inputActivationParams;
56 MakeJitConstant("PARAMS_NUM", GetActivationAdditionalParamsNumber(params.activation.function)),
59 if (!inputNlParams.empty())
62 MakeJitConstant("ADDITIONAL_PARAMS", inputNlParams[0]),
63 MakeJitConstant("PARAMETERIZED", ""),
70 bool ActivationKernelBase::Validate(const Params& p, const optional_params& o) const
72 if (p.GetType() != KernelType::ACTIVATION ||
73 o.GetType() != KernelType::ACTIVATION)
81 KernelsData ActivationKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const
83 if (!Validate(params, options))
88 KernelData kd = KernelData::Default<activation_params>(params);
90 activation_params& newParams = *static_cast<activation_params*>(kd.params.get());
91 const std::string kernel_id = GetEntryPoint(kernelName, params.layerID, options);
93 auto runInfo = SetDefault(newParams);
94 auto cldnn_jit = GetJitConstants(newParams, runInfo);
95 auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
96 auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
98 auto& kernel = kd.kernels[0];
99 FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point);
101 if (newParams.gradient)
102 kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 1 });
104 if (!newParams.inputActivationParams.empty())
106 kernel.arguments.push_back({ ArgumentDescriptor::Types::SLOPE, 0 });
109 kd.estimatedTime = runInfo.effiency;