1 // Copyright (c) 2016 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "fully_connected_kernel_fb_io_b8_f8.h"
18 namespace kernel_selector {
19 ParamsKey FullyConnected_fb_io_b8_f8::GetSupportedKey() const {
21 k.EnableInputDataType(Datatype::F32);
22 k.EnableInputDataType(Datatype::F16);
23 k.EnableOutputDataType(Datatype::F32);
24 k.EnableOutputDataType(Datatype::F16);
25 k.EnableInputWeightsType(WeightsType::F32);
26 k.EnableInputWeightsType(WeightsType::F16);
27 k.EnableAllInputLayout();
28 k.EnableOutputLayout(DataLayout::fb);
30 k.EnableBiasPerFeature();
31 k.EnableNonBiasTerm();
36 FullyConnected_fb_io_b8_f8::DispatchData FullyConnected_fb_io_b8_f8::SetDefault(const fully_connected_params& arg,
38 auto kd = FullyConnectedBlockKernelBase::SetDefault(arg);
40 const auto& output = arg.output;
42 size_t groups_per_batches = GetLocalGroupsSize(arg);
44 Align(output.LogicalSize() / (GetNeuronsPerWorkItem(arg) * GetBatchesPerWorkItem(arg) * groups_per_batches), 8);
45 kd.gws1 = groups_per_batches;
52 bool FullyConnected_fb_io_b8_f8::Validate(const Params& p, const optional_params& o) const {
53 if (!FullyConnectedBlockKernelBase::Validate(p, o)) {
57 const auto& params = static_cast<const fully_connected_params&>(p);
59 const auto& output = params.output;
60 const auto batches = output.Batch().v;
61 const auto x_size = output.LogicalSize() / batches;
63 const auto& input = params.inputs[0];
64 const auto input_x_size = input.LogicalSize() / input.Batch().v;
65 const bool proper_input_aligment = (input_x_size % 8) == 0;
66 const bool proper_output_aligment =
67 (output.LogicalSize() /
68 (GetNeuronsPerWorkItem(params) * GetBatchesPerWorkItem(params) * GetLocalGroupsSize(params)) % 8) == 0;
69 const bool bSupportedBatch = (batches % 8) == 0;
70 const bool bSupportedFeature = (x_size % 8) == 0;
72 if (!bSupportedBatch || !bSupportedFeature || !proper_input_aligment || !proper_output_aligment) {
79 KernelsData FullyConnected_fb_io_b8_f8::GetKernelsData(const Params& params, const optional_params& optParams) const {
80 assert(params.GetType() == KernelType::FULLY_CONNECTED);
82 const auto& orgParams = static_cast<const fully_connected_params&>(params);
84 float estimated_time = orgParams.inputs[0].GetDType() == Datatype::F16 && orgParams.output.Batch().v >= 16
88 for (size_t i = 0; i < autoTuneOptions.size(); i++) {
90 GetTunedKernelsDataByIndex(params, optParams, DataLayout::fb, {WeightsLayout::io}, estimated_time, static_cast<int>(i));
92 res.emplace_back(kd[0]);
98 } // namespace kernel_selector