d76d5e1b4cf7aa74dfe02263a627875751ff3fe9
[platform/upstream/dldt.git] /
1 // Copyright (c) 2016-2020 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15
16 #include <iostream>
17 #include "tensor_type.h"
18 #include "concatenation_kernel_base.h"
19 #include <algorithm>
20 #include <vector>
21
22 namespace kernel_selector {
23 Tensor::DataChannelName ConcatenationKernelBase::GetConcatChannel(const concatenation_params& params) const {
24     switch (params.axis) {
25         case ConcatAxis::X:
26             return Tensor::DataChannelName::X;
27         case ConcatAxis::Y:
28             return Tensor::DataChannelName::Y;
29         case ConcatAxis::Z:
30             return Tensor::DataChannelName::Z;
31         case ConcatAxis::W:
32             return Tensor::DataChannelName::W;
33         case ConcatAxis::FEATURE:
34             return Tensor::DataChannelName::FEATURE;
35         case ConcatAxis::BATCH:
36             return Tensor::DataChannelName::BATCH;
37         default:
38             return Tensor::DataChannelName::X;
39     }
40 }
41
42 int32_t ConcatenationKernelBase::GetConcatChannelIndex(const concatenation_params& params) const {
43     return DataTensor::Channelndex(params.output.GetLayout(), GetConcatChannel(params));
44 }
45
46 bool ConcatenationKernelBase::Validate(const Params& p, const optional_params&) const {
47     if (p.GetType() != KernelType::CONCATENATION) {
48         return false;
49     }
50
51     const concatenation_params& params = static_cast<const concatenation_params&>(p);
52
53     if (GetConcatChannelIndex(params) == -1) {
54         return false;
55     }
56
57     return true;
58 }
59
60 JitConstants ConcatenationKernelBase::GetJitConstants(const concatenation_params& params) const {
61     JitConstants jit = MakeBaseParamsJitConstants(params);
62
63     jit.AddConstants({
64         MakeJitConstant("CONCAT_" + toString(params.axis), 1),
65     });
66
67     jit.AddConstant(MakeJitConstant("CONCAT_AXIS_INDEX", GetConcatChannelIndex(params)));
68     return jit;
69 }
70
71 ConcatenationKernelBase::DispatchData ConcatenationKernelBase::SetDefault(const concatenation_params& params) const {
72     DispatchData kd;
73
74     const auto& dims = params.inputs[0].GetDims();
75     auto layout = params.inputs[0].GetLayout();
76
77     std::vector<int> idx = { DataTensor::Channelndex(layout, Tensor::DataChannelName::BATCH),
78                              DataTensor::Channelndex(layout, Tensor::DataChannelName::FEATURE),
79                              DataTensor::Channelndex(layout, Tensor::DataChannelName::Y),
80                              DataTensor::Channelndex(layout, Tensor::DataChannelName::X) };
81
82     // Determine global work sizes.
83     kd.gws0 = idx[2] != -1 ? dims[idx[2]].v : 1;  // Y
84     kd.gws1 = idx[1] != -1 ? dims[idx[1]].v : 1;  // F
85     kd.gws2 = idx[0] != -1 ? dims[idx[0]].v : 1;  // B
86
87     kd.lws0 = std::min(std::max(kd.gws0, static_cast<size_t>(1)), static_cast<size_t>(32));
88     while (kd.gws0 % kd.lws0 != 0) {
89         --kd.lws0;
90     }
91
92     kd.lws1 = 1;
93     kd.lws2 = 1;
94     kd.efficiency = DONT_USE_IF_HAVE_SOMETHING_ELSE;
95     return kd;
96 }
97
98 KernelsData ConcatenationKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const {
99     if (!Validate(params, options)) {
100         return {};
101     }
102
103     const concatenation_params& orgParams = static_cast<const concatenation_params&>(params);
104
105     KernelData kd = KernelData::Default<concatenation_params>(params, orgParams.inputs.size());
106
107     uint32_t lastOffset = 0;
108     const auto concatChannelIndex = GetConcatChannelIndex(orgParams);
109     float efficiency = FORCE_PRIORITY_1;
110     size_t ifm_offset = 0;
111     for (size_t i = 0; i < orgParams.inputs.size(); i++) {
112         const auto& input = orgParams.inputs[i];
113
114         auto newParams = orgParams;
115         newParams.inputs.resize(1);
116         newParams.inputs[0] = input;
117         size_t ifm = input.Feature().v;
118         newParams.isAligned = ifm_offset % 16 == 0 && ifm % 16 == 0;
119         ifm_offset += ifm;
120
121         auto& kernel = kd.kernels[i];
122         DispatchData runInfo = SetDefault(newParams);
123         auto cldnnJit = GetJitConstants(newParams);
124         auto entryPoint = GetEntryPoint(kernelName, newParams.layerID, options);
125         auto jit = CreateJit(kernelName, cldnnJit, entryPoint);
126
127         kernel.workGroups.global = {runInfo.gws0, runInfo.gws1, runInfo.gws2};
128         kernel.workGroups.local = {runInfo.lws0, runInfo.lws1, runInfo.lws2};
129         kernel.kernelString = GetKernelString(kernelName, jit, entryPoint, params.engineInfo);
130         kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, (uint32_t)i});
131         kernel.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 0});
132
133         ScalarDescriptor s;
134         s.t = ScalarDescriptor::Types::UINT32;
135         s.v.u32 = lastOffset;
136         kernel.scalars.push_back(s);
137         kernel.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0});
138
139         lastOffset += (uint32_t)input.GetDims()[concatChannelIndex].v;
140         efficiency = std::max(efficiency, runInfo.efficiency);
141     }
142
143     kd.estimatedTime = efficiency;
144
145     return {kd};
146 }
147 }  // namespace kernel_selector