Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / fully_connected / fully_connected_kernel_fb_io_b8_f8.cpp
index 1a3e98d..839f940 100644 (file)
@@ -15,7 +15,6 @@
 */
 
 #include "fully_connected_kernel_fb_io_b8_f8.h"
-#include "kernel_selector_utils.h"
 
 namespace kernel_selector 
 {
@@ -37,19 +36,19 @@ namespace kernel_selector
         return k;
     }
 
-    std::unique_ptr<FullyConnected_fb_io_b8_f8::DispatchData> FullyConnected_fb_io_b8_f8::SetDefault(const fully_connected_params& arg) const
+    FullyConnected_fb_io_b8_f8::DispatchData FullyConnected_fb_io_b8_f8::SetDefault(const fully_connected_params& arg, int ) const
     {
         auto kd = FullyConnectedBlockKernelBase::SetDefault(arg);
 
         const auto& output = arg.output;
         
         size_t groups_per_batches = GetLocalGroupsSize(arg);
-        kd->gws0 = output.LogicalSize() / (GetNeuronsPerWorkItem(arg) * GetBatchesPerWorkItem(arg) * groups_per_batches);
-        kd->gws1 = groups_per_batches;
-        kd->lws0 = 8;
-        kd->lws1 = 1;
+        kd.gws0 = Align(output.LogicalSize() / (GetNeuronsPerWorkItem(arg) * GetBatchesPerWorkItem(arg) * groups_per_batches), 8);
+        kd.gws1 = groups_per_batches;
+        kd.lws0 = 8;
+        kd.lws1 = 1;
 
-        return std::move(kd);
+        return kd;
     }
 
     bool FullyConnected_fb_io_b8_f8::Validate(const Params& p, const optional_params& o) const
@@ -65,11 +64,17 @@ namespace kernel_selector
         const auto batches = output.Batch().v;
         const auto x_size = output.LogicalSize() / batches;
 
+        const auto& input = params.inputs[0];
+        const auto input_x_size = input.LogicalSize() / input.Batch().v;
+        const bool proper_input_aligment = (input_x_size % 8) == 0;
+        const bool proper_output_aligment = (output.LogicalSize() / (GetNeuronsPerWorkItem(params) * GetBatchesPerWorkItem(params) * GetLocalGroupsSize(params)) % 8) == 0;
         const bool bSupportedBatch = (batches % 8) == 0;
         const bool bSupportedFeature = (x_size % 8) == 0;
 
         if (!bSupportedBatch ||
-            !bSupportedFeature)
+            !bSupportedFeature ||
+            !proper_input_aligment ||
+            !proper_output_aligment)
         {
             return false;
         }
@@ -80,13 +85,22 @@ namespace kernel_selector
     KernelsData FullyConnected_fb_io_b8_f8::GetKernelsData(const Params& params, const optional_params& optParams) const
     {
         assert(params.GetType() == KernelType::FULLY_CONNECTED);
-
+        KernelsData res = {};
         const auto& orgParams = static_cast<const fully_connected_params&>(params);
 
         float estimated_time =
             orgParams.inputs[0].GetDType() == Datatype::F16 && orgParams.output.Batch().v >= 16 ?
             FORCE_PRIORITY_3 : FORCE_PRIORITY_5;
         
-        return GetCommonKernelsData(params, optParams, DataLayout::fb, { WeightsLayout::io }, estimated_time);
+        for (size_t i = 0; i < autoTuneOptions.size(); i++)
+        {
+            KernelsData kd = GetTunedKernelsDataByIndex(params, optParams, DataLayout::fb, { WeightsLayout::io }, estimated_time, (int)i);
+            if (!kd.empty())
+            {
+                res.emplace_back(kd[0]);
+            }
+        }
+
+        return res;
     }
-}
\ No newline at end of file
+}