2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "fully_connected_kernel_fb_io_block.h"
19 namespace kernel_selector
21 ParamsKey FullyConnected_fb_io_block::GetSupportedKey() const
24 k.EnableInputDataType(Datatype::F16);
25 k.EnableOutputDataType(Datatype::F16);
26 k.EnableInputWeightsType(WeightsType::F16);
27 k.EnableInputWeightsType(WeightsType::F32);
28 k.EnableAllInputLayout();
29 k.EnableOutputLayout(DataLayout::fb);
31 k.EnableBiasPerFeature();
32 k.EnableNonBiasTerm();
38 FullyConnected_fb_io_block::DispatchData FullyConnected_fb_io_block::SetDefault(const fully_connected_params& arg, int ) const
40 auto kd = FullyConnectedKernelBase::SetDefault(arg);
41 const auto& output = arg.output;
43 auto batch_size = output.Batch().v;
44 auto response_size = output.Feature().v;
46 constexpr uint32_t unit_byte_size = sizeof(short);
47 const char* chunk_type = "uint";
48 constexpr uint32_t chunk_byte_size = sizeof(uint32_t);
49 constexpr uint32_t sub_group_size = 16;
50 constexpr uint32_t units_per_chunk = chunk_byte_size / unit_byte_size;
51 constexpr uint32_t units_per_sg_read = sub_group_size * units_per_chunk;
54 // Number of response groups. Each group (except last) writes units_per_sg_read responses
55 // for at least one input data set from batch.
56 auto rg_count = CeilDiv(response_size, units_per_sg_read);
58 kd.lws0 = sub_group_size;
59 // Number of work items needed to process all response groups.
60 kd.gws0 = rg_count * sub_group_size;
62 kd.gws1 = batch_size / units_per_sg_read;
64 kd.unit_byte_size = unit_byte_size;
65 kd.chunk_type = chunk_type;
66 kd.chunk_byte_size = chunk_byte_size;
67 kd.units_per_chunk = units_per_chunk;
68 kd.bytes_per_sg_read = sub_group_size * chunk_byte_size;
69 kd.units_per_sg_read = units_per_sg_read;
70 kd.rg_count = (uint32_t)rg_count;
71 kd.last_rg_size = response_size % units_per_sg_read;
75 JitConstants FullyConnected_fb_io_block::GetJitConstants(const fully_connected_params& params, const FullyConnectedKernelBase::DispatchData& run_info) const
77 auto cldnn_jit = FullyConnectedKernelBase::GetJitConstants(params, run_info);
78 cldnn_jit.AddConstants({
79 MakeJitConstant("SUB_GROUP_SIZE", run_info.lws0),
80 MakeJitConstant("WORK_ITEMS_PER_BATCH", run_info.gws1),
81 MakeJitConstant("UNIT_BYTE_SIZE", run_info.unit_byte_size),
82 MakeJitConstant("CHUNK_TYPE", run_info.chunk_type),
83 MakeJitConstant("CHUNK_BYTE_SIZE", run_info.chunk_byte_size),
84 MakeJitConstant("UNITS_PER_CHUNK", run_info.units_per_chunk),
85 MakeJitConstant("BYTES_PER_SG_READ", run_info.bytes_per_sg_read),
86 MakeJitConstant("UNITS_PER_SG_READ", run_info.units_per_sg_read),
87 MakeJitConstant("RG_COUNT", run_info.rg_count),
88 MakeJitConstant("LAST_RG_SIZE", run_info.last_rg_size),
93 bool FullyConnected_fb_io_block::Validate(const Params& p, const optional_params& o) const
95 if (!FullyConnectedKernelBase::Validate(p, o))
100 const auto& params = static_cast<const fully_connected_params&>(p);
102 const auto& output = params.output;
103 const auto responseSize = output.Feature().v;
104 const auto batches = output.Batch().v;
105 const auto xSize = output.LogicalSize() / batches;
107 constexpr uint32_t subGroupSize = 16;
108 constexpr uint32_t bytesPerElement = sizeof(short);
109 constexpr uint32_t chunkSizeInBytes = sizeof(uint32_t);
110 constexpr uint32_t chunkSizeInElements = chunkSizeInBytes / bytesPerElement;
111 constexpr uint32_t elementsPerBlockRead = subGroupSize * chunkSizeInElements;
113 const bool bSupportedBatch =
115 ((batches % 8) == 0) &&
116 ((batches % elementsPerBlockRead) == 0);
118 const bool bSupportedFeature =
119 (responseSize > 0) &&
120 (((responseSize * bytesPerElement) % 4) == 0) &&
123 if (!bSupportedBatch ||
132 KernelsData FullyConnected_fb_io_block::GetKernelsData(const Params& params, const optional_params& optParams) const
134 assert(params.GetType() == KernelType::FULLY_CONNECTED);
136 const auto& orgParams = static_cast<const fully_connected_params&>(params);
138 float estimated_time =
139 orgParams.inputs[0].GetDType() == Datatype::F16 && orgParams.output.Batch().v >= 16 ?
140 FORCE_PRIORITY_3 : FORCE_PRIORITY_5;
142 // TODO: it should be fb_io. but the original code use this kernel with yxfb and yxio
143 // (fb == fyxb flatten fyx, not yxfb flatten yxf).
144 // the order of the add operation cause some numeric changes. in order to avoid them right now we use yxfb/oiyx instead.
145 // return GetCommonKernelsData(params, optParams, DataLayout::fb, WeightsLayout::io, estimated_time);
146 //return GetCommonKernelsData(params, optParams, DataLayout::yxfb, { WeightsLayout::yxio }, estimated_time);
148 KernelsData res = {};
149 for (size_t i = 0; i < autoTuneOptions.size(); i++)
151 KernelsData kd = GetTunedKernelsDataByIndex(params, optParams, DataLayout::yxfb, { WeightsLayout::yxio }, estimated_time, (int)i);
154 res.emplace_back(kd[0]);