Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / fully_connected / fully_connected_kernel_bs_f_bsv8_af8.cpp
index 6b8fbfa..234a941 100644 (file)
@@ -15,7 +15,6 @@
 */
 
 #include "fully_connected_kernel_bs_f_bsv8_af8.h"
-#include "kernel_selector_utils.h"
 
 namespace kernel_selector 
 {
@@ -38,32 +37,32 @@ namespace kernel_selector
         return k;
     }
 
-    std::unique_ptr<FullyConnected_bs_f_bsv8_af8::DispatchData> FullyConnected_bs_f_bsv8_af8::SetDefault(const fully_connected_params& arg) const
+    FullyConnected_bs_f_bsv8_af8::DispatchData FullyConnected_bs_f_bsv8_af8::SetDefault(const fully_connected_params& arg, int ) const
     {
         auto kd = FullyConnectedBlockKernelBase::SetDefault(arg);
 
         size_t groups_per_batches = GetLocalGroupsSize(arg);
-        kd->gws0 = Align(arg.output.LogicalSize() / (GetNeuronsPerWorkItem(arg) * GetBatchesPerWorkItem(arg) * groups_per_batches), 8);
-        kd->gws1 = groups_per_batches;
-        kd->lws0 = 8;
-        kd->lws1 = 1;
+        kd.gws0 = Align(arg.output.LogicalSize() / (GetNeuronsPerWorkItem(arg) * GetBatchesPerWorkItem(arg) * groups_per_batches), 8);
+        kd.gws1 = groups_per_batches;
+        kd.lws0 = 8;
+        kd.lws1 = 1;
 
-        return std::move(kd);
+        return kd;
     }
     
     static bool check_input_layout(const DataTensor& t)
     {
         bool b16_layout = false;
         b16_layout |= t.GetLayout() == DataLayout::bs_f_bsv8__af8;
-        b16_layout |= DataTensor::Channelndex(t.GetLayout(), Tensor::DataChannelName::BATCH) == 0 && (t.Batch().v == 8); // TODO - check f alignment to 8
+        b16_layout |= DataTensor::Channelndex(t.GetLayout(), Tensor::DataChannelName::BATCH) == 0 && (t.Batch().v == 8);
         return b16_layout;
     }
 
     static bool check_output_layout(const DataTensor& t)
     {
         bool b16_layout = false;
-        b16_layout |= (t.GetLayout() == DataLayout::fb);
-        b16_layout |= (t.GetLayout() == DataLayout::bs_f_bsv8__af8) && (t.Batch().v == 8);
+        b16_layout |= (t.GetLayout() == DataLayout::fb) && (t.Batch().v == 8);
+        b16_layout |= (t.GetLayout() == DataLayout::bs_f_bsv8__af8);
         return b16_layout;
     }
 
@@ -85,11 +84,14 @@ namespace kernel_selector
         const bool bProperBatch =
             params.inputs[0].Batch().v >= 8 &&
             params.inputs[0].Batch().v % 8 == 0;
+        const bool bProperFeature =
+            params.inputs[0].Feature().v >= 8 &&
+            params.inputs[0].Feature().v % 8 == 0;
         const bool bProperInput = check_input_layout(params.inputs[0]);
         const bool bProperOutput = check_output_layout(params.output);
         const bool bSupportedLayout = optParams.allowInputReordering || bProperInput;
 
-        if (!bProperBatch || !bSupportedLayout || !bProperOutput)
+        if (!bProperBatch || !bProperFeature || !bSupportedLayout || !bProperOutput)
         {
             return false;
         }
@@ -99,6 +101,16 @@ namespace kernel_selector
 
     KernelsData FullyConnected_bs_f_bsv8_af8::GetKernelsData(const Params& params, const optional_params& optParams) const
     {
-        return GetCommonKernelsData(params, optParams, DataLayout::bs_f_bsv8__af8, { WeightsLayout::os_i_osv8__ai8 }, FORCE_PRIORITY_4);
+        KernelsData res = {};
+        for (size_t i = 0; i < autoTuneOptions.size(); i++)
+        {
+            KernelsData kd = GetTunedKernelsDataByIndex(params, optParams, DataLayout::bs_f_bsv8__af8, { WeightsLayout::os_i_osv8__ai8 }, FORCE_PRIORITY_4, (int)i);
+            if (!kd.empty())
+            {
+                res.emplace_back(kd[0]);
+            }
+        }
+
+        return res;
     }
-}
\ No newline at end of file
+}