2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "lookup_table_inst.h"
18 #include "primitive_gpu_base.h"
19 #include "implementation_map.h"
20 #include "error_handler.h"
21 #include "kernel_selector_helper.h"
22 #include "lookup_table/lookup_table_kernel_selector.h"
23 #include "lookup_table/lookup_table_kernel_base.h"
24 #include "kernel_runner.h"
29 struct lookup_table_gpu : typed_primitive_gpu_impl<lookup_table>
31 using parent = typed_primitive_gpu_impl<lookup_table>;
36 virtual bool validate_impl(const typed_primitive_inst<lookup_table>& instance) const override
40 // Check whether all memory elements use the same unit type (FP16 or FP32).
41 CLDNN_ERROR_DATA_TYPES_MISMATCH(_outer.id(), "Input memory", instance.input_memory(1).get_layout().data_type, "output memory", instance.output_memory().get_layout().data_type, "");
46 virtual kernel::kernel_arguments_data get_arguments(typed_primitive_inst<lookup_table>& instance, int32_t) const override
48 kernel::kernel_arguments_data args = parent::get_arguments(instance, 0);
55 static primitive_impl* create(const lookup_table_node &arg)
57 const auto& primitive = arg.get_primitive();
58 //const auto& input_layout = arg.input().get_output_layout();
60 //const auto& input_size = input_layout.size;
62 const auto& axis = primitive->axis;
63 const auto& with_axis = primitive->with_axis;
65 auto lookt_params = get_default_params<kernel_selector::lookup_table_params>(arg);
66 auto lookt_optional_params = get_default_optional_params<kernel_selector::lookup_table_optional_params>(arg.get_program());
68 lookt_params.inputIndices = convert_data_tensor(arg.indices().get_output_layout());
72 case lookup_table::batch:
73 lookt_params.lookUpTableAxis = kernel_selector::lookt_axis::BATCH;
74 lookt_params.numberOfValues = (uint32_t)lookt_params.inputIndices.Batch().v;
76 case lookup_table::feature:
77 lookt_params.lookUpTableAxis = kernel_selector::lookt_axis::FEATURE;
78 lookt_params.numberOfValues = (uint32_t)lookt_params.inputIndices.Feature().v;
81 lookt_params.lookUpTableAxis = kernel_selector::lookt_axis::X;
82 lookt_params.numberOfValues = (uint32_t)lookt_params.inputIndices.X().v;
85 lookt_params.lookUpTableAxis = kernel_selector::lookt_axis::Y;
86 lookt_params.numberOfValues = (uint32_t)lookt_params.inputIndices.Y().v;
93 lookt_params.numberOfValues = (uint32_t)lookt_params.inputIndices.X().v;
95 auto& kernel_selector = kernel_selector::lookup_table_kernel_selector::Instance();
97 kernel_selector::KernelsData best_kernels = kernel_selector.GetBestKernels(lookt_params, lookt_optional_params);
99 CLDNN_ERROR_BOOL(arg.id(), "Best_kernel.empty()", best_kernels.empty(), "Cannot find a proper kernel with this arguments");
101 auto conv = new lookup_table_gpu(arg, best_kernels[0]);
110 implementation_map<lookup_table>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), lookup_table_gpu::create);
111 implementation_map<lookup_table>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), lookup_table_gpu::create);
112 implementation_map<lookup_table>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::yxfb), lookup_table_gpu::create);
113 implementation_map<lookup_table>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), lookup_table_gpu::create);
114 implementation_map<lookup_table>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), lookup_table_gpu::create);
115 implementation_map<lookup_table>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfyx), lookup_table_gpu::create);