2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "softmax_kernel_base.h"
19 namespace kernel_selector
21 JitConstants SoftmaxKernelBase::GetJitConstants(const softmax_params& params, SoftmaxKernelBase::DispatchData kd) const
23 JitConstants mem_consts = MakeBaseParamsJitConstants(params);
25 mem_consts.AddConstants({
26 MakeJitConstant("ALONG_" + toString(params.dim), "")
29 mem_consts.AddConstants({
30 MakeJitConstant("ITEMS_NUM", kd.itemsNum),
31 MakeJitConstant("LWS", kd.lws0),
32 MakeJitConstant("GWS", kd.gws0),
33 MakeJitConstant("DATA_SETS_COUNT",kd.dataSetsCount),
34 MakeJitConstant("DATA_SET_SIZE", kd.dataSetSize),
35 MakeJitConstant("LEFTOVERS", kd.leftovers),
41 SoftmaxKernelBase::DispatchData SoftmaxKernelBase::SetDefault(const softmax_params& params, const optional_params&) const
54 runInfo.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
55 runInfo.leftovers = 0;
57 runInfo.normIndex = 0;
58 runInfo.dataSetsCount = 0;
59 runInfo.dataSetSize = 0;
64 bool SoftmaxKernelBase::Validate(const Params& p, const optional_params& o) const
66 if (p.GetType() != KernelType::SOFT_MAX ||
67 o.GetType() != KernelType::SOFT_MAX)
75 KernelsData SoftmaxKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const
77 if (!Validate(params, options))
82 const softmax_params& orgParams = static_cast<const softmax_params&>(params);
83 KernelData kd = KernelData::Default<softmax_params>(params);
85 auto runInfo = SetDefault(orgParams, options);
86 auto cldnn_jit = GetJitConstants(orgParams, runInfo);
87 auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, options);
88 auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
90 auto& kernel = kd.kernels[0];
91 FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point);
93 kd.estimatedTime = runInfo.effiency;
98 bool SoftmaxKernelBaseBF::Validate(const Params& p, const optional_params& o) const
100 if (!Parent::Validate(p, o))
105 const softmax_params& params = static_cast<const softmax_params&>(p);
106 const auto& input = params.inputs[0];
108 if (params.activation.function != ActivationFunction::NONE)
113 if (input.GetLayout() == DataLayout::bf ||
114 input.GetLayout() == DataLayout::fb)
121 case SoftmaxDim::X: return input.Y().v == 1 && input.Feature().v == 1;
122 case SoftmaxDim::Y: return input.X().v == 1 && input.Feature().v == 1;
123 case SoftmaxDim::FEATURE: return input.X().v == 1 && input.Y().v == 1;
124 default: return false;
128 SoftmaxKernelBase::DispatchData SoftmaxKernelBaseBF::SetDefault(const softmax_params& params, const optional_params& options) const
130 const auto& input = params.inputs[0];
132 DispatchData kd = Parent::SetDefault(params, options);
134 auto flatten_input = input.FlattenFeatureAndSpatials();
135 kd.dataSetSize = flatten_input.Feature().v;
136 kd.dataSetsCount = input.Batch().v;