2 // Copyright (c) 2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "fully_connected_inst.h"
20 #include "primitive_gpu_base.h"
21 #include "implementation_map.h"
22 #include "kernel_selector_helper.h"
23 #include "fully_connected/fully_connected_kernel_selector.h"
24 #include "fully_connected/fully_connected_params.h"
26 #include "network_impl.h"
27 #include "error_handler.h"
28 #include "kernel_runner.h"
30 #include "api/CPP/reorder.hpp"
31 #include "api/CPP/input_layout.hpp"
33 namespace cldnn { namespace gpu {
36 struct fully_connected_gpu : typed_primitive_gpu_impl<fully_connected>
38 using parent = typed_primitive_gpu_impl<fully_connected>;
40 std::vector<network_impl::ptr> _reorders; // TODO: move this reorder to graph compiler
41 memory_impl::cptr new_input_mem; // TODO: remove this hack
43 fully_connected_gpu(const fully_connected_node& arg, const kernel_selector::kernel_data& kd, std::vector<network_impl::ptr> reorders)
50 virtual kernel::kernel_arguments_data get_arguments(typed_primitive_inst<fully_connected>& instance, int32_t) const override
52 kernel::kernel_arguments_data args;
54 args.inputs = { new_input_mem };
55 args.output = &instance.output_memory();
56 args.weights = &instance.weights_memory();
57 args.bias = instance.bias_term() ? &instance.bias_memory() : nullptr;
58 args.weights_quantization_factors = instance.weights_quantization_factors_term() ? &instance.weights_quantization_factors_memory() : nullptr;
59 args.output_calibration_factors = instance.output_calibration_factors_term() ? &instance.output_calibration_factors_memory() : nullptr;
66 event_impl::ptr execute_impl(const std::vector<event_impl::ptr>& events, fully_connected_inst& instance) override
68 std::vector<event_impl::ptr> tmp_events(events);
70 if (_reorders.empty())
72 new_input_mem = &instance.input_memory();
76 auto network = _reorders[0];
77 network->set_input_data("input", instance.input_memory());
78 network->execute(tmp_events);
79 auto output_id = network->get_output_ids()[0];
80 new_input_mem = &network->get_primitive(output_id)->output_memory();
82 tmp_events.push_back(network->get_primitive_event(output_id));
85 return parent::execute_impl(tmp_events, instance);
88 static primitive_impl* create(const fully_connected_node& arg)
90 auto fc_params = get_weights_bias_default_params<kernel_selector::fully_connected_params>(arg);
91 auto fc_optional_params = get_default_weights_bias_optional_params<kernel_selector::fully_connected_optional_params>(arg.get_program());
92 fc_optional_params.allowInputReordering = true;
94 if(arg.get_primitive()->with_activation)
95 convert_activation_func_params(arg.get_primitive(), fc_params.activation);
97 fc_params.output = fc_params.output.FlattenFeatureAndSpatials();
99 const auto primitive = arg.get_primitive();
101 if (primitive->weights_quantization_factors.size() > 0)
103 fc_params.int8_quantization = true;
104 fc_params.weights_quantization_factors.push_back(convert_data_tensor(arg.weights_quantization_factors().get_output_layout()).FlattenFeatureAndSpatials());
105 fc_params.input_quantization_factor = arg.get_input_qf();
107 if (primitive->output_calibration_factors.size() > 0)
109 fc_params.output_calibration = true;
110 fc_params.output_calibration_factors.push_back(convert_data_tensor(arg.output_calibration_factors().get_output_layout()).FlattenFeatureAndSpatials());
113 fc_params.output_quantization_factor = arg.get_output_qf();
116 fc_optional_params.tuningParams.runner = std::make_shared<gpu::kernel_runner>(arg.get_program().get_engine(), true);
118 auto& kernel_selector = kernel_selector::fully_connected_kernel_selector::Instance();
119 auto best_kernels = kernel_selector.GetBestKernels(fc_params, fc_optional_params);
121 CLDNN_ERROR_BOOL(arg.id(), "Best_kernel.empty()", best_kernels.empty(), "Cannot find a proper kernel with this arguments");
123 const auto& new_fc_params = *static_cast<kernel_selector::fully_connected_params*>(best_kernels[0].params.get());
124 std::vector<network_impl::ptr> reorders;
125 if (fc_params.inputs[0].GetLayout() != new_fc_params.inputs[0].GetLayout())
127 const auto& input_layout = arg.input().get_output_layout();
129 tpl.add(std::make_shared<cldnn::input_layout>("input", input_layout));
130 tpl.add(std::make_shared<cldnn::reorder>("reorder", "input", from_data_layout(new_fc_params.inputs[0].GetLayout()), input_layout.data_type));
131 reorders.push_back(arg.get_program().get_engine().build_network(tpl, cldnn::build_options(), true));
134 auto fc = new fully_connected_gpu(arg, best_kernels[0], reorders);
144 auto val_fw = fully_connected_gpu::create;
146 implementation_map<fully_connected>::add({
147 { std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), val_fw },
148 { std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), val_fw },
149 { std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw },
150 { std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw },
151 { std::make_tuple(engine_types::ocl, data_types::f32, format::byxf), val_fw },
152 { std::make_tuple(engine_types::ocl, data_types::f16, format::byxf), val_fw },
153 { std::make_tuple(engine_types::ocl, data_types::i8, format::bfyx), val_fw },
155 { std::make_tuple(engine_types::ocl, data_types::i8, format::byxf_af32), val_fw },
156 { std::make_tuple(engine_types::ocl, data_types::i8, format::fs_bs_yx_bsv4_fsv32), val_fw },
158 { std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv4), val_fw },
159 { std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv4), val_fw },