1 // Copyright (c) 2016-2020 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
17 #include "tensor_type.h"
18 #include "concatenation_kernel_base.h"
22 namespace kernel_selector {
23 Tensor::DataChannelName ConcatenationKernelBase::GetConcatChannel(const concatenation_params& params) const {
24 switch (params.axis) {
26 return Tensor::DataChannelName::X;
28 return Tensor::DataChannelName::Y;
30 return Tensor::DataChannelName::Z;
32 return Tensor::DataChannelName::W;
33 case ConcatAxis::FEATURE:
34 return Tensor::DataChannelName::FEATURE;
35 case ConcatAxis::BATCH:
36 return Tensor::DataChannelName::BATCH;
38 return Tensor::DataChannelName::X;
42 int32_t ConcatenationKernelBase::GetConcatChannelIndex(const concatenation_params& params) const {
43 return DataTensor::Channelndex(params.output.GetLayout(), GetConcatChannel(params));
46 bool ConcatenationKernelBase::Validate(const Params& p, const optional_params&) const {
47 if (p.GetType() != KernelType::CONCATENATION) {
51 const concatenation_params& params = static_cast<const concatenation_params&>(p);
53 if (GetConcatChannelIndex(params) == -1) {
60 JitConstants ConcatenationKernelBase::GetJitConstants(const concatenation_params& params) const {
61 JitConstants jit = MakeBaseParamsJitConstants(params);
64 MakeJitConstant("CONCAT_" + toString(params.axis), 1),
67 jit.AddConstant(MakeJitConstant("CONCAT_AXIS_INDEX", GetConcatChannelIndex(params)));
71 ConcatenationKernelBase::DispatchData ConcatenationKernelBase::SetDefault(const concatenation_params& params) const {
74 const auto& dims = params.inputs[0].GetDims();
75 auto layout = params.inputs[0].GetLayout();
77 std::vector<int> idx = { DataTensor::Channelndex(layout, Tensor::DataChannelName::BATCH),
78 DataTensor::Channelndex(layout, Tensor::DataChannelName::FEATURE),
79 DataTensor::Channelndex(layout, Tensor::DataChannelName::Y),
80 DataTensor::Channelndex(layout, Tensor::DataChannelName::X) };
82 // Determine global work sizes.
83 kd.gws0 = idx[2] != -1 ? dims[idx[2]].v : 1; // Y
84 kd.gws1 = idx[1] != -1 ? dims[idx[1]].v : 1; // F
85 kd.gws2 = idx[0] != -1 ? dims[idx[0]].v : 1; // B
87 kd.lws0 = std::min(std::max(kd.gws0, static_cast<size_t>(1)), static_cast<size_t>(32));
88 while (kd.gws0 % kd.lws0 != 0) {
94 kd.efficiency = DONT_USE_IF_HAVE_SOMETHING_ELSE;
98 KernelsData ConcatenationKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const {
99 if (!Validate(params, options)) {
103 const concatenation_params& orgParams = static_cast<const concatenation_params&>(params);
105 KernelData kd = KernelData::Default<concatenation_params>(params, orgParams.inputs.size());
107 uint32_t lastOffset = 0;
108 const auto concatChannelIndex = GetConcatChannelIndex(orgParams);
109 float efficiency = FORCE_PRIORITY_1;
110 size_t ifm_offset = 0;
111 for (size_t i = 0; i < orgParams.inputs.size(); i++) {
112 const auto& input = orgParams.inputs[i];
114 auto newParams = orgParams;
115 newParams.inputs.resize(1);
116 newParams.inputs[0] = input;
117 size_t ifm = input.Feature().v;
118 newParams.isAligned = ifm_offset % 16 == 0 && ifm % 16 == 0;
121 auto& kernel = kd.kernels[i];
122 DispatchData runInfo = SetDefault(newParams);
123 auto cldnnJit = GetJitConstants(newParams);
124 auto entryPoint = GetEntryPoint(kernelName, newParams.layerID, options);
125 auto jit = CreateJit(kernelName, cldnnJit, entryPoint);
127 kernel.workGroups.global = {runInfo.gws0, runInfo.gws1, runInfo.gws2};
128 kernel.workGroups.local = {runInfo.lws0, runInfo.lws1, runInfo.lws2};
129 kernel.kernelString = GetKernelString(kernelName, jit, entryPoint, params.engineInfo);
130 kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, (uint32_t)i});
131 kernel.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 0});
134 s.t = ScalarDescriptor::Types::UINT32;
135 s.v.u32 = lastOffset;
136 kernel.scalars.push_back(s);
137 kernel.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0});
139 lastOffset += (uint32_t)input.GetDims()[concatChannelIndex].v;
140 efficiency = std::max(efficiency, runInfo.efficiency);
143 kd.estimatedTime = efficiency;
147 } // namespace kernel_selector