*/
#include "fully_connected_kernel_mmad_batched.h"
-#include "kernel_selector_utils.h"
namespace kernel_selector
{
return jit;
}
- std::unique_ptr<FullyConnected_mmad_batched::Parent::DispatchData> FullyConnected_mmad_batched::SetDefault(const fully_connected_params& params) const
+ FullyConnected_mmad_batched::DispatchData FullyConnected_mmad_batched::SetDefault(const fully_connected_params& params, int) const
{
auto runInfo = Parent::SetDefault(params);
const auto of_maps = params.output.Feature().v;
const size_t of_threads_per_batch = RoundUp(of_maps, sub_group_size);
- runInfo->gws0 = params.output.Batch().v / 8; // we process 8 batches in a single WG
- runInfo->gws1 = of_threads_per_batch;
- runInfo->gws2 = 1;
+ runInfo.gws0 = params.output.Batch().v / 8; // we process 8 batches in a single WG
+ runInfo.gws1 = of_threads_per_batch;
+ runInfo.gws2 = 1;
- runInfo->lws0 = 1;
- runInfo->lws1 = sub_group_size;
- runInfo->lws2 = 1;
+ runInfo.lws0 = 1;
+ runInfo.lws1 = sub_group_size;
+ runInfo.lws2 = 1;
- runInfo->effiency = FORCE_PRIORITY_1;
- return std::move(runInfo);
+ runInfo.effiency = FORCE_PRIORITY_1;
+ return runInfo;
}
KernelsData FullyConnected_mmad_batched::GetKernelsData(const Params& params, const optional_params& options) const
{
- return GetCommonKernelsData(params, options, DataLayout::fs_bs_yx_bsv4_fsv32,
- { WeightsLayout::os_is_yx_isa8_osv8_isv4 }, FORCE_PRIORITY_1);
+ KernelsData res = {};
+ for (size_t i = 0; i < autoTuneOptions.size(); i++)
+ {
+ KernelsData kd = GetTunedKernelsDataByIndex(params, options, DataLayout::fs_bs_yx_bsv4_fsv32,
+ { WeightsLayout::os_is_yx_isa8_osv8_isv4 }, FORCE_PRIORITY_1, (int)i);
+ if (!kd.empty())
+ {
+ res.emplace_back(kd[0]);
+ }
+ }
+ return res;
}
-}
\ No newline at end of file
+}