2 // Copyright (c) 2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "gather_kernel_ref.h"
18 #include "kernel_selector_utils.h"
20 namespace kernel_selector
22 static int32_t GetGatherChannelIndex(const gather_params& params)
24 Tensor::DataChannelName name = Tensor::DataChannelName::X;
32 case GatherAxis::FEATURE:
34 case GatherAxis::BATCH:
39 return DataTensor::Channelndex(params.output.GetLayout(), name);
42 ParamsKey GatherKernelRef::GetSupportedKey() const
45 k.EnableInputDataType(Datatype::F16);
46 k.EnableInputDataType(Datatype::F32);
47 k.EnableOutputDataType(Datatype::F16);
48 k.EnableOutputDataType(Datatype::F32);
49 k.EnableAllInputLayout();
50 k.EnableAllOutputLayout();
51 k.EnableTensorOffset();
52 k.EnableTensorPitches();
54 k.EnableDifferentTypes();
55 k.EnableLookUpTableIndicesFormat(Datatype::F32);
59 static size_t getPartSize(const gather_params& params, int32_t axis)
62 for (size_t i = params.inputs[0].Dimentions() - axis; i > 0; --i)
63 partSize *= params.inputs[0].GetDims()[i-1].v;
67 static size_t getNumberOfParts(const gather_params& params, size_t partSize)
69 return params.inputs[0].LogicalSize() / partSize;
72 static size_t getSliceSize(const gather_params& params, int32_t axis)
74 size_t numberOfItemsInSlice = 1;
75 for (size_t i = params.inputs[0].Dimentions() - axis - 1; i > 0; --i)
76 numberOfItemsInSlice *= params.inputs[0].GetDims()[i-1].v;
77 return numberOfItemsInSlice;
80 CommonDispatchData GatherKernelRef::SetDefault(const gather_params& params, const optional_params&) const
82 CommonDispatchData runInfo;
84 const int32_t axis = GetGatherChannelIndex(params);
86 const size_t numberOfParts = params.inputs[0].LogicalSize() / getPartSize(params, axis);
88 size_t gws = numberOfParts * params.inputs[1].LogicalSize();
90 const size_t vectorSize = 16;
92 runInfo.gws0 = Align(gws, vectorSize);
96 runInfo.lws0 = vectorSize;
100 runInfo.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
105 JitConstants GatherKernelRef::GetJitConstants(const gather_params& params) const
107 JitConstants jit = MakeBaseParamsJitConstants(params);
109 int32_t axis = GetGatherChannelIndex(params);
110 size_t partSize = getPartSize(params, axis);
111 size_t sliceSize = getSliceSize(params, axis);
112 size_t numberOfParts = getNumberOfParts(params, partSize);
113 size_t numberOfIndexes = params.inputs[1].LogicalSize();
115 jit.AddConstant(MakeJitConstant("AXIS", axis));
116 jit.AddConstant(MakeJitConstant("PART_SIZE", partSize));
117 jit.AddConstant(MakeJitConstant("SLICE_SIZE", sliceSize));
118 jit.AddConstant(MakeJitConstant("PARTS_NUMBER", numberOfParts));
119 jit.AddConstant(MakeJitConstant("COMPUTATIONAL_OPERATIONS_NUMBER", numberOfParts * numberOfIndexes));
124 KernelsData GatherKernelRef::GetKernelsData(const Params& params, const optional_params& options) const
126 KernelData kd = KernelData::Default<gather_params>(params);
127 gather_params& newParams = *static_cast<gather_params*>(kd.params.get());
129 assert(params.GetType() == KernelType::GATHER);
131 auto runInfo = SetDefault(newParams, options);
132 auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
133 auto cldnn_jit = GetJitConstants(newParams);
134 std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
136 auto& kernel = kd.kernels[0];
138 FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point, "", false, false, 2);
140 kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE;