1 // Copyright (c) 2016 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "fully_connected_kernel_base.h"
17 #include "kernel_selector_utils.h"
18 #include "common_tools.h"
23 namespace kernel_selector {
24 JitConstants FullyConnectedKernelBase::GetJitConstants(const fully_connected_params& params,
25 const FullyConnectedKernelBase::DispatchData&) const {
26 JitConstants jit = WeightBiasKernelBase::GetJitConstants(params);
27 const auto& input = params.inputs[0];
28 const auto x_size = input.LogicalSize() / input.Batch().v;
30 jit.AddConstant(MakeJitConstant("INPUT0_ELEMENTS_COUNT", x_size));
31 jit.AddConstant(MakeJitConstant("QUANTIZATION_TERM", params.int8_quantization));
33 if (params.int8_quantization) {
34 jit.AddConstants({MakeJitConstant("W_QF", params.weights_quantization_factors[0])});
35 jit.AddConstants({MakeJitConstant("I_QF", params.input_quantization_factor)});
37 if (params.output_calibration) {
38 jit.AddConstant(MakeJitConstant("CALIBRATION_TERM", params.output_calibration));
39 jit.AddConstant(MakeJitConstant("O_QF", params.output_calibration_factors[0]));
42 jit.AddConstants({MakeJitConstant("O_QF", params.output_quantization_factor)});
49 FullyConnectedKernelBase::DispatchData FullyConnectedKernelBase::SetDefault(const fully_connected_params& params,
51 DispatchData dispatchData;
52 dispatchData.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
54 // Determine global work sizes.
55 dispatchData.gws0 = params.output.LogicalSize();
56 dispatchData.gws1 = dispatchData.gws2 = 1;
58 // Find largest positive local work size that is divider for global work size.
59 dispatchData.lws0 = std::min(std::max(dispatchData.gws0, static_cast<size_t>(1)), static_cast<size_t>(32));
60 while (dispatchData.gws0 % dispatchData.lws0 != 0) {
63 dispatchData.lws1 = dispatchData.lws2 = 1;
68 KernelsData FullyConnectedKernelBase::GetCommonKernelsData(const Params& params,
69 const optional_params& options,
71 std::vector<WeightsLayout> wl,
73 const std::string exeMode,
74 int autoTuneIndex) const {
75 if (!Validate(params, options) || wl.empty()) {
79 const auto& orgParams = static_cast<const fully_connected_params&>(params);
80 const auto& orgOptParams = static_cast<const fully_connected_optional_params&>(options);
82 bool bProperInput = orgParams.inputs[0].GetLayout() == dl;
83 if (!bProperInput && !orgParams.inputs[0].PitchesDifferFromLogicalDims()) {
84 bProperInput = (dl == DataLayout::fb && orgParams.inputs[0].GetLayout() == DataLayout::fyxb) ||
85 (dl == DataLayout::bf && orgParams.inputs[0].GetLayout() == DataLayout::bfyx);
88 const bool bSupportedInput = orgOptParams.allowInputReordering || bProperInput;
90 if (!bSupportedInput) {
94 KernelData kd = KernelData::Default<fully_connected_params>(params);
95 fully_connected_params& newParams = *static_cast<fully_connected_params*>(kd.params.get());
98 newParams.inputs[0] = newParams.inputs[0].TransformIgnorePadding(dl);
99 kd.reorderInput = true;
102 bool succeed = UpdateWeightsParams(newParams,
105 kd.weightsReorderParams,
112 kd.kernels.resize(1);
114 auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, options);
116 const DispatchData runInfo = SetDefault(newParams, autoTuneIndex);
117 auto cldnn_jit = GetJitConstants(newParams, runInfo);
118 std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
120 auto& kernel = kd.kernels[0];
121 FillCLKernelData(kernel,
129 !orgParams.bias.empty(),
131 newParams.int8_quantization,
132 newParams.output_calibration);
134 // TODO Pass estimated time only through DispatchData
135 kd.estimatedTime = estimated_time;
136 kd.autoTuneIndex = autoTuneIndex;
140 std::string FullyConnectedKernelBase::GetAutoTuneOptions(int autoTuneIndex) const {
141 if ((autoTuneIndex >= 0) && (autoTuneIndex < static_cast<int>(autoTuneOptions.size()))) {
142 return autoTuneOptions[autoTuneIndex];
148 KernelsData FullyConnectedKernelBase::GetTunedKernelsDataByIndex(const Params& params,
149 const optional_params& options,
151 std::vector<WeightsLayout> wl,
152 float estimated_time,
153 const int autoTuneIndex) const {
154 return GetCommonKernelsData(params,
159 GetAutoTuneOptions(autoTuneIndex),
163 } // namespace kernel_selector