Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / gpu / lookup_table_gpu.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "lookup_table_inst.h"
18 #include "primitive_gpu_base.h"
19 #include "implementation_map.h"
20 #include "error_handler.h"
21 #include "kernel_selector_helper.h"
22 #include "lookup_table/lookup_table_kernel_selector.h"
23 #include "lookup_table/lookup_table_kernel_base.h"
24 #include "kernel_runner.h"
25
26 namespace cldnn {
27     namespace gpu {
28
29         struct lookup_table_gpu : typed_primitive_gpu_impl<lookup_table>
30         {
31             using parent = typed_primitive_gpu_impl<lookup_table>;
32             using parent::parent;
33
34         protected:
35
36             virtual bool validate_impl(const typed_primitive_inst<lookup_table>& instance) const override
37             {
38                 bool res = true;
39
40                 // Check whether all memory elements use the same unit type (FP16 or FP32).
41                 CLDNN_ERROR_DATA_TYPES_MISMATCH(_outer.id(), "Input memory", instance.input_memory(1).get_layout().data_type, "output memory", instance.output_memory().get_layout().data_type, "");
42
43                 return res;
44             }
45
46             virtual kernel::kernel_arguments_data get_arguments(typed_primitive_inst<lookup_table>& instance, int32_t) const override
47             {
48                 kernel::kernel_arguments_data args = parent::get_arguments(instance, 0);
49
50                 return args;
51             }
52
53         public:
54
55             static primitive_impl* create(const lookup_table_node &arg)
56             {
57                 const auto& primitive = arg.get_primitive();
58                 //const auto& input_layout = arg.input().get_output_layout();
59
60                 //const auto& input_size = input_layout.size;
61
62                 const auto& axis = primitive->axis;
63                 const auto& with_axis = primitive->with_axis;
64
65                 auto lookt_params = get_default_params<kernel_selector::lookup_table_params>(arg);
66                 auto lookt_optional_params = get_default_optional_params<kernel_selector::lookup_table_optional_params>(arg.get_program());
67
68                 lookt_params.inputIndices = convert_data_tensor(arg.indices().get_output_layout());
69                 if (with_axis) {
70                     switch (axis)
71                     {
72                     case lookup_table::batch:
73                         lookt_params.lookUpTableAxis = kernel_selector::lookt_axis::BATCH;
74                         lookt_params.numberOfValues = (uint32_t)lookt_params.inputIndices.Batch().v;
75                         break;
76                     case lookup_table::feature:
77                         lookt_params.lookUpTableAxis = kernel_selector::lookt_axis::FEATURE;
78                         lookt_params.numberOfValues = (uint32_t)lookt_params.inputIndices.Feature().v;
79                         break;
80                     case lookup_table::x:
81                         lookt_params.lookUpTableAxis = kernel_selector::lookt_axis::X;
82                         lookt_params.numberOfValues = (uint32_t)lookt_params.inputIndices.X().v;
83                         break;
84                     case lookup_table::y:
85                         lookt_params.lookUpTableAxis = kernel_selector::lookt_axis::Y;
86                         lookt_params.numberOfValues = (uint32_t)lookt_params.inputIndices.Y().v;
87                         break;
88                     default:
89                         break;
90                     }
91                 }
92                 else
93                     lookt_params.numberOfValues = (uint32_t)lookt_params.inputIndices.X().v;
94                 
95                 auto& kernel_selector = kernel_selector::lookup_table_kernel_selector::Instance();
96                 
97                 kernel_selector::KernelsData best_kernels = kernel_selector.GetBestKernels(lookt_params, lookt_optional_params);
98
99                 CLDNN_ERROR_BOOL(arg.id(), "Best_kernel.empty()", best_kernels.empty(), "Cannot find a proper kernel with this arguments");
100
101                 auto conv = new lookup_table_gpu(arg, best_kernels[0]);
102
103                 return conv;
104             }
105         };
106
107         namespace {
108             struct attach {
109                 attach() {
110                     implementation_map<lookup_table>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), lookup_table_gpu::create);
111                     implementation_map<lookup_table>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), lookup_table_gpu::create);
112                     implementation_map<lookup_table>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::yxfb), lookup_table_gpu::create);
113                     implementation_map<lookup_table>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), lookup_table_gpu::create);
114                     implementation_map<lookup_table>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), lookup_table_gpu::create);
115                     implementation_map<lookup_table>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfyx), lookup_table_gpu::create);
116                 }
117                 ~attach() {}
118             };
119             attach attach_impl;
120         }
121     }
122 }