*/
#include "fully_connected_kernel_bf_io_gemm.h"
-#include "kernel_selector_utils.h"
namespace kernel_selector {
return k;
}
- std::unique_ptr<FullyConnected_bf_io_GEMM::Parent::DispatchData> FullyConnected_bf_io_GEMM::SetDefault(const fully_connected_params& params) const
+ FullyConnected_bf_io_GEMM::DispatchData FullyConnected_bf_io_GEMM::SetDefault(const fully_connected_params& params, int autoTuneIndex) const
{
- auto runInfo = Parent::SetDefault(params);
+ auto runInfo = Parent::SetDefault(params, autoTuneIndex);
const uint32_t localWorkSizeX = 64;
const uint32_t globalWorkSizeX = localWorkSizeX;
std::vector<size_t> global = { globalWorkSizeX, params.output.Feature().v, params.output.Batch().v };
std::vector<size_t> local = { localWorkSizeX, 1, 1 };
- runInfo->gws0 = global[0];
- runInfo->gws1 = global[1];
- runInfo->gws2 = 1;
+ runInfo.gws0 = global[0];
+ runInfo.gws1 = global[1];
+ runInfo.gws2 = 1;
- runInfo->lws0 = local[0];
- runInfo->lws1 = local[1];
- runInfo->lws2 = 1;
+ runInfo.lws0 = local[0];
+ runInfo.lws1 = local[1];
+ runInfo.lws2 = 1;
- runInfo->effiency = FORCE_PRIORITY_6;
+ runInfo.effiency = FORCE_PRIORITY_6;
- return std::move(runInfo);
+ return runInfo;
}
JitConstants FullyConnected_bf_io_GEMM::GetJitConstants(const fully_connected_params& params, const DispatchData& kd) const
KernelsData FullyConnected_bf_io_GEMM::GetKernelsData(const Params& params, const optional_params& options) const
{
- return GetCommonKernelsData(params, options, DataLayout::bf, { WeightsLayout::oiyx }, FORCE_PRIORITY_6);
+ KernelsData res = {};
+ for (size_t i = 0; i < autoTuneOptions.size(); i++)
+ {
+ KernelsData kd = GetTunedKernelsDataByIndex(params, options, DataLayout::bf, { WeightsLayout::oiyx }, FORCE_PRIORITY_6, (int)i);
+ if (!kd.empty())
+ {
+ res.emplace_back(kd[0]);
+ }
+ }
+
+ return res;
}
}