Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / gpu / gather_gpu.cpp
1 /*
2 // Copyright (c) 2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "gather_inst.h"
18 #include "primitive_gpu_base.h"
19 #include "implementation_map.h"
20 #include "kernel_selector_helper.h"
21 #include "gather/gather_kernel_selector.h"
22 #include "gather/gather_kernel_ref.h"
23 #include "error_handler.h"
24
25 using namespace cldnn;
26
27 namespace cldnn
28 {
29 namespace gpu
30 {
31     kernel_selector::gather_axis convert_axis(gather::gather_axis axis)
32     {
33         switch (axis)
34         {
35             case gather::along_x: return kernel_selector::gather_axis::X;
36             case gather::along_y: return kernel_selector::gather_axis::Y;
37             case gather::along_f: return kernel_selector::gather_axis::FEATURE;
38             case gather::along_b: return kernel_selector::gather_axis::BATCH;
39             default:
40                 return kernel_selector::gather_axis::X;
41         }
42     }
43
44     struct gather_gpu : typed_primitive_gpu_impl<gather>
45     {
46         using parent = typed_primitive_gpu_impl<gather>;
47         using parent::parent;
48
49     public:
50
51         static primitive_impl* create(const gather_node& arg)
52         {
53             auto gather_params = get_default_params<kernel_selector::gather_params>(arg);
54             auto gather_optional_params = get_default_optional_params<kernel_selector::gather_optional_params>(arg.get_program());
55
56             gather_params.axis = convert_axis(arg.get_primitive()->axis);
57
58             gather_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
59
60             auto& kernel_selector = kernel_selector::gather_kernel_selector::Instance();
61             auto best_kernels = kernel_selector.GetBestKernels(gather_params, gather_optional_params);
62
63             CLDNN_ERROR_BOOL(arg.id(), "Best_kernel.empty()", best_kernels.empty(), "Cannot find a proper kernel with this arguments");
64
65             auto gather = new gather_gpu(arg, best_kernels[0]);
66
67             return gather;
68         }
69     };
70
71     namespace
72     {
73         struct attach
74         {
75             attach()
76             {
77                 auto val_fw = gather_gpu::create;
78                 implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
79                 implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
80             }
81             ~attach() = default;
82         };
83         attach attach_impl;
84     }
85 } //namespace cldnn
86 } //namespace gpu