Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / fully_connected / fully_connected_kernel_bf_io_gemm.cpp
index 61d5edc..8b762f9 100644 (file)
@@ -15,7 +15,6 @@
 */
 
 #include "fully_connected_kernel_bf_io_gemm.h"
-#include "kernel_selector_utils.h"
 
 namespace kernel_selector {
 
@@ -38,9 +37,9 @@ namespace kernel_selector {
         return k;
     }
 
-    std::unique_ptr<FullyConnected_bf_io_GEMM::Parent::DispatchData> FullyConnected_bf_io_GEMM::SetDefault(const fully_connected_params& params) const
+    FullyConnected_bf_io_GEMM::DispatchData FullyConnected_bf_io_GEMM::SetDefault(const fully_connected_params& params, int autoTuneIndex) const
     {
-        auto runInfo = Parent::SetDefault(params);
+        auto runInfo = Parent::SetDefault(params, autoTuneIndex);
 
         const uint32_t localWorkSizeX = 64;
         const uint32_t globalWorkSizeX = localWorkSizeX;
@@ -48,17 +47,17 @@ namespace kernel_selector {
         std::vector<size_t> global = { globalWorkSizeX, params.output.Feature().v, params.output.Batch().v };
         std::vector<size_t> local = { localWorkSizeX, 1, 1 };
 
-        runInfo->gws0 = global[0];
-        runInfo->gws1 = global[1];
-        runInfo->gws2 = 1;
+        runInfo.gws0 = global[0];
+        runInfo.gws1 = global[1];
+        runInfo.gws2 = 1;
 
-        runInfo->lws0 = local[0];
-        runInfo->lws1 = local[1];
-        runInfo->lws2 = 1;
+        runInfo.lws0 = local[0];
+        runInfo.lws1 = local[1];
+        runInfo.lws2 = 1;
 
-        runInfo->effiency = FORCE_PRIORITY_6;
+        runInfo.effiency = FORCE_PRIORITY_6;
 
-        return std::move(runInfo);
+        return runInfo;
     }
 
     JitConstants FullyConnected_bf_io_GEMM::GetJitConstants(const fully_connected_params& params, const DispatchData& kd) const
@@ -89,6 +88,16 @@ namespace kernel_selector {
 
     KernelsData FullyConnected_bf_io_GEMM::GetKernelsData(const Params& params, const optional_params& options) const
     {
-        return GetCommonKernelsData(params, options, DataLayout::bf, { WeightsLayout::oiyx }, FORCE_PRIORITY_6);
+        KernelsData res = {};
+        for (size_t i = 0; i < autoTuneOptions.size(); i++)
+        {
+            KernelsData kd = GetTunedKernelsDataByIndex(params, options, DataLayout::bf, { WeightsLayout::oiyx }, FORCE_PRIORITY_6, (int)i);
+            if (!kd.empty())
+            {
+                res.emplace_back(kd[0]);
+            }
+        }
+
+        return res;
     }
 }