*/
#include "fully_connected_kernel_fb_io_b8_f8.h"
-#include "kernel_selector_utils.h"
namespace kernel_selector
{
return k;
}
- std::unique_ptr<FullyConnected_fb_io_b8_f8::DispatchData> FullyConnected_fb_io_b8_f8::SetDefault(const fully_connected_params& arg) const
+ FullyConnected_fb_io_b8_f8::DispatchData FullyConnected_fb_io_b8_f8::SetDefault(const fully_connected_params& arg, int ) const
{
auto kd = FullyConnectedBlockKernelBase::SetDefault(arg);
const auto& output = arg.output;
size_t groups_per_batches = GetLocalGroupsSize(arg);
- kd->gws0 = output.LogicalSize() / (GetNeuronsPerWorkItem(arg) * GetBatchesPerWorkItem(arg) * groups_per_batches);
- kd->gws1 = groups_per_batches;
- kd->lws0 = 8;
- kd->lws1 = 1;
+ kd.gws0 = Align(output.LogicalSize() / (GetNeuronsPerWorkItem(arg) * GetBatchesPerWorkItem(arg) * groups_per_batches), 8);
+ kd.gws1 = groups_per_batches;
+ kd.lws0 = 8;
+ kd.lws1 = 1;
- return std::move(kd);
+ return kd;
}
bool FullyConnected_fb_io_b8_f8::Validate(const Params& p, const optional_params& o) const
const auto batches = output.Batch().v;
const auto x_size = output.LogicalSize() / batches;
+ const auto& input = params.inputs[0];
+ const auto input_x_size = input.LogicalSize() / input.Batch().v;
+ const bool proper_input_aligment = (input_x_size % 8) == 0;
+ const bool proper_output_aligment = (output.LogicalSize() / (GetNeuronsPerWorkItem(params) * GetBatchesPerWorkItem(params) * GetLocalGroupsSize(params)) % 8) == 0;
const bool bSupportedBatch = (batches % 8) == 0;
const bool bSupportedFeature = (x_size % 8) == 0;
if (!bSupportedBatch ||
- !bSupportedFeature)
+ !bSupportedFeature ||
+ !proper_input_aligment ||
+ !proper_output_aligment)
{
return false;
}
KernelsData FullyConnected_fb_io_b8_f8::GetKernelsData(const Params& params, const optional_params& optParams) const
{
assert(params.GetType() == KernelType::FULLY_CONNECTED);
-
+ KernelsData res = {};
const auto& orgParams = static_cast<const fully_connected_params&>(params);
float estimated_time =
orgParams.inputs[0].GetDType() == Datatype::F16 && orgParams.output.Batch().v >= 16 ?
FORCE_PRIORITY_3 : FORCE_PRIORITY_5;
- return GetCommonKernelsData(params, optParams, DataLayout::fb, { WeightsLayout::io }, estimated_time);
+ for (size_t i = 0; i < autoTuneOptions.size(); i++)
+ {
+ KernelsData kd = GetTunedKernelsDataByIndex(params, optParams, DataLayout::fb, { WeightsLayout::io }, estimated_time, (int)i);
+ if (!kd.empty())
+ {
+ res.emplace_back(kd[0]);
+ }
+ }
+
+ return res;
}
-}
\ No newline at end of file
+}