Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / index_select / index_select_kernel_base.cpp
1 // Copyright (c) 2018 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15
16 #include "index_select_kernel_base.h"
17
18 #include "kernel_selector_utils.h"
19
20
21 namespace kernel_selector 
22 {
23     JitConstants IndexSelectKernelBase::GetJitConstants(const index_select_params& params)
24     {
25         JitConstants jit = MakeBaseParamsJitConstants(params);
26
27         jit.AddConstant(MakeJitConstant("AXES_NUMBER", params.axes.size()));
28
29         if (params.reverse) {
30             jit.AddConstant(MakeJitConstant("REVERSE", 1));
31         }
32
33         for (size_t i = 0; i < params.axes.size(); i++)
34         {
35             std::string size_name = "REVERSE_AXIS_SIZE";
36             size_t size_value = 0;
37             if (params.axes.size() > 1) {
38                 std::stringstream ss;
39                 ss << "REVERSE_" << toString(params.axes[i]) << "_SIZE";
40                 size_name = ss.str();
41             }
42             jit.AddConstant(MakeJitConstant(toString(params.axes[i]), ""));
43             if (params.reverse) {
44                 if (params.axes[i] == IndexSelectAxis::BATCH)
45                 {
46                     size_value = params.inputs.at(0).Batch().v;
47                 }
48                 else if (params.axes[i] == IndexSelectAxis::X)
49                 {
50                     size_value = params.inputs.at(0).X().v;
51                 }
52                 else if (params.axes[i] == IndexSelectAxis::Y)
53                 {
54                     size_value = params.inputs.at(0).Y().v;
55                 }
56                 else if (params.axes[i] == IndexSelectAxis::FEATURE)
57                 {
58                     size_value = params.inputs.at(0).Feature().v;
59                 }
60             }
61             jit.AddConstant(MakeJitConstant(size_name, size_value));
62         }
63
64         return jit;
65     }
66
67     IndexSelectKernelBase::DispatchData IndexSelectKernelBase::SetDefault(const index_select_params& params)
68     {
69         const auto& output = params.output;
70         DispatchData kd;
71
72         kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
73
74         std::vector<size_t> global;
75         
76         if(params.axes.size() == 1) {
77             if (params.reverse)
78             {
79                 if (params.axes[0] == IndexSelectAxis::BATCH)
80                 {
81                     global = { 1, params.inputs.at(0).Batch().v, output.Feature().v };
82                 }
83                 else if (params.axes[0] == IndexSelectAxis::X)
84                 {
85                     global = { output.Batch().v, params.inputs.at(0).X().v, output.Feature().v };
86                 }
87                 else if (params.axes[0] == IndexSelectAxis::Y)
88                 {
89                     global = { output.Batch().v, params.inputs.at(0).Y().v, output.Feature().v };
90                 }
91                 else if (params.axes[0] == IndexSelectAxis::FEATURE)
92                 {
93                     global = { output.Batch().v, params.inputs.at(0).Feature().v, output.Y().v };
94                 }
95             }
96             else
97             {
98                 const auto indices = params.inputs.at(1).X().v;
99
100                 if (params.axes[0] == IndexSelectAxis::BATCH)
101                 {
102                     global = { 1, indices, output.Feature().v };
103                 }
104                 else if (params.axes[0] == IndexSelectAxis::X || params.axes[0] == IndexSelectAxis::Y)
105                 {
106                     global = { output.Batch().v, indices, output.Feature().v };
107                 }
108                 else if (params.axes[0] == IndexSelectAxis::FEATURE)
109                 {
110                     global = { output.Batch().v, indices, output.Y().v };
111                 }
112             }
113         }
114         else
115         {
116             if (params.reverse)
117             {
118                 global = { output.Batch().v, output.Y().v, output.Feature().v };
119             }
120         }
121
122         const auto& local = GetOptimalLocalWorkGroupSizes(global);
123
124         kd.gws0 = global[0];
125         kd.gws1 = global[1];
126         kd.gws2 = global[2];
127
128         kd.lws0 = local[0];
129         kd.lws1 = local[1];
130         kd.lws2 = local[2];
131
132         return kd;
133     }
134
135     KernelsData IndexSelectKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options, float estimated_time) const
136     {
137         assert(params.GetType() == KernelType::INDEX_SELECT);
138
139         const auto& prim_params = static_cast<const index_select_params&>(params); // NOLINT(cppcoreguidelines-pro-type-static-cast-downcast)
140         
141         auto run_info     = SetDefault(prim_params);
142         KernelData k_data = KernelData::Default<index_select_params>(params);
143
144         auto cldnn_jit   = GetJitConstants(prim_params);
145         auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, options);
146         auto jit         = CreateJit(kernelName, cldnn_jit, entry_point);
147
148         auto& kernel = k_data.kernels[0];
149         FillCLKernelData(kernel, run_info, params.engineInfo, kernelName, jit, entry_point, DEFAULT, false, false, (uint32_t)prim_params.inputs.size());
150
151         k_data.estimatedTime = estimated_time;
152
153         return {k_data};
154     }
155 }