*/
#include "fully_connected_kernel_fb_io_block.h"
-#include "kernel_selector_utils.h"
namespace kernel_selector
{
return k;
}
- std::unique_ptr<FullyConnected_fb_io_block::FullyConnectedKernelBase::DispatchData> FullyConnected_fb_io_block::SetDefault(const fully_connected_params& arg) const
+
+ FullyConnected_fb_io_block::DispatchData FullyConnected_fb_io_block::SetDefault(const fully_connected_params& arg, int ) const
{
- auto kd = std::unique_ptr<DispatchData>(new DispatchData(*FullyConnectedKernelBase::SetDefault(arg)));
+ auto kd = FullyConnectedKernelBase::SetDefault(arg);
const auto& output = arg.output;
auto batch_size = output.Batch().v;
// for at least one input data set from batch.
auto rg_count = CeilDiv(response_size, units_per_sg_read);
- kd->lws0 = sub_group_size;
+ kd.lws0 = sub_group_size;
// Number of work items needed to process all response groups.
- kd->gws0 = rg_count * sub_group_size;
- kd->lws1 = 1;
- kd->gws1 = batch_size / units_per_sg_read;
-
- kd->unit_byte_size = unit_byte_size;
- kd->chunk_type = chunk_type;
- kd->chunk_byte_size = chunk_byte_size;
- kd->units_per_chunk = units_per_chunk;
- kd->bytes_per_sg_read = sub_group_size * chunk_byte_size;
- kd->units_per_sg_read = units_per_sg_read;
- kd->rg_count = (uint32_t)rg_count;
- kd->last_rg_size = response_size % units_per_sg_read;
- return std::move(kd);
+ kd.gws0 = rg_count * sub_group_size;
+ kd.lws1 = 1;
+ kd.gws1 = batch_size / units_per_sg_read;
+
+ kd.unit_byte_size = unit_byte_size;
+ kd.chunk_type = chunk_type;
+ kd.chunk_byte_size = chunk_byte_size;
+ kd.units_per_chunk = units_per_chunk;
+ kd.bytes_per_sg_read = sub_group_size * chunk_byte_size;
+ kd.units_per_sg_read = units_per_sg_read;
+ kd.rg_count = (uint32_t)rg_count;
+ kd.last_rg_size = response_size % units_per_sg_read;
+ return kd;
}
JitConstants FullyConnected_fb_io_block::GetJitConstants(const fully_connected_params& params, const FullyConnectedKernelBase::DispatchData& run_info) const
{
- auto &d = static_cast<const DispatchData&>(run_info);
auto cldnn_jit = FullyConnectedKernelBase::GetJitConstants(params, run_info);
cldnn_jit.AddConstants({
- MakeJitConstant("SUB_GROUP_SIZE", d.lws0),
- MakeJitConstant("WORK_ITEMS_PER_BATCH", d.gws1),
- MakeJitConstant("UNIT_BYTE_SIZE", d.unit_byte_size),
- MakeJitConstant("CHUNK_TYPE", d.chunk_type),
- MakeJitConstant("CHUNK_BYTE_SIZE", d.chunk_byte_size),
- MakeJitConstant("UNITS_PER_CHUNK", d.units_per_chunk),
- MakeJitConstant("BYTES_PER_SG_READ", d.bytes_per_sg_read),
- MakeJitConstant("UNITS_PER_SG_READ", d.units_per_sg_read),
- MakeJitConstant("RG_COUNT", d.rg_count),
- MakeJitConstant("LAST_RG_SIZE", d.last_rg_size),
+ MakeJitConstant("SUB_GROUP_SIZE", run_info.lws0),
+ MakeJitConstant("WORK_ITEMS_PER_BATCH", run_info.gws1),
+ MakeJitConstant("UNIT_BYTE_SIZE", run_info.unit_byte_size),
+ MakeJitConstant("CHUNK_TYPE", run_info.chunk_type),
+ MakeJitConstant("CHUNK_BYTE_SIZE", run_info.chunk_byte_size),
+ MakeJitConstant("UNITS_PER_CHUNK", run_info.units_per_chunk),
+ MakeJitConstant("BYTES_PER_SG_READ", run_info.bytes_per_sg_read),
+ MakeJitConstant("UNITS_PER_SG_READ", run_info.units_per_sg_read),
+ MakeJitConstant("RG_COUNT", run_info.rg_count),
+ MakeJitConstant("LAST_RG_SIZE", run_info.last_rg_size),
});
return cldnn_jit;
}
// (fb == fyxb flatten fyx, not yxfb flatten yxf).
// the order of the add operation cause some numeric changes. in order to avoid them right now we use yxfb/oiyx instead.
// return GetCommonKernelsData(params, optParams, DataLayout::fb, WeightsLayout::io, estimated_time);
- return GetCommonKernelsData(params, optParams, DataLayout::yxfb, { WeightsLayout::yxio }, estimated_time);
- }
+ //return GetCommonKernelsData(params, optParams, DataLayout::yxfb, { WeightsLayout::yxio }, estimated_time);
+
+ KernelsData res = {};
+ for (size_t i = 0; i < autoTuneOptions.size(); i++)
+ {
+ KernelsData kd = GetTunedKernelsDataByIndex(params, optParams, DataLayout::yxfb, { WeightsLayout::yxio }, estimated_time, (int)i);
+ if (!kd.empty())
+ {
+ res.emplace_back(kd[0]);
+ }
+ }
+
+ return res;
+ }
}