1 // Copyright (c) 2019 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "contract_inst.h"
18 #include "primitive_gpu_base.h"
19 #include "implementation_map.h"
20 #include "kernel_selector_helper.h"
21 #include "error_handler.h"
22 #include "contract/contract_kernel_selector.h"
23 #include "contract/contract_kernel_base.h"
30 inline kernel_selector::ContractMode convert_to_contract_mode(contract_mode mode)
34 case contract_mode::sum: return kernel_selector::ContractMode::SUM;
35 case contract_mode::prod: return kernel_selector::ContractMode::PRODUCT;
36 case contract_mode::all: return kernel_selector::ContractMode::ALL;
37 case contract_mode::any: return kernel_selector::ContractMode::ANY;
38 case contract_mode::max: return kernel_selector::ContractMode::MAX;
41 return kernel_selector::ContractMode::SUM;
46 struct contract_gpu : typed_primitive_gpu_impl<contract>
48 using parent = typed_primitive_gpu_impl<contract>;
52 static primitive_impl* create(const contract_node& arg)
54 auto c_params = get_default_params<kernel_selector::contract_params>(arg, 1);
55 auto c_optional_params = get_default_optional_params<kernel_selector::contract_optional_params>(arg.get_program());
57 c_params.reduction_axes = arg.get_primitive()->reduction_axes;
58 c_params.mode = convert_to_contract_mode(arg.get_primitive()->mode);
60 auto& kernel_selector = kernel_selector::contract_kernel_selector::Instance();
61 auto best_kernels = kernel_selector.GetBestKernels(c_params, c_optional_params);
63 CLDNN_ERROR_BOOL(arg.id(), "Best_kernel.empty()", best_kernels.empty(), "Cannot find a proper kernel with this arguments");
65 return new contract_gpu(arg, best_kernels[0]);
72 auto val_fw = contract_gpu::create;
74 implementation_map<contract>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
75 implementation_map<contract>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
76 implementation_map<contract>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfyx), val_fw);
77 implementation_map<contract>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfyx), val_fw);
78 implementation_map<contract>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw);
79 implementation_map<contract>::add(std::make_tuple(engine_types::ocl, data_types::i64, format::bfyx), val_fw);