Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / fully_connected / fully_connected_kernel_fb_io_block.cpp
index b32c8a5..01a7061 100644 (file)
@@ -15,7 +15,6 @@
 */
 
 #include "fully_connected_kernel_fb_io_block.h"
-#include "kernel_selector_utils.h"
 
 namespace kernel_selector 
 {
@@ -35,9 +34,10 @@ namespace kernel_selector
         return k;
     }
 
-    std::unique_ptr<FullyConnected_fb_io_block::FullyConnectedKernelBase::DispatchData> FullyConnected_fb_io_block::SetDefault(const fully_connected_params& arg) const
+
+    FullyConnected_fb_io_block::DispatchData FullyConnected_fb_io_block::SetDefault(const fully_connected_params& arg, int ) const
     {
-        auto kd = std::unique_ptr<DispatchData>(new DispatchData(*FullyConnectedKernelBase::SetDefault(arg)));
+        auto kd = FullyConnectedKernelBase::SetDefault(arg);
         const auto& output = arg.output;
         
         auto batch_size = output.Batch().v;
@@ -55,38 +55,37 @@ namespace kernel_selector
         // for at least one input data set from batch.
         auto rg_count = CeilDiv(response_size, units_per_sg_read);
 
-        kd->lws0 = sub_group_size;
+        kd.lws0 = sub_group_size;
         // Number of work items needed to process all response groups.
-        kd->gws0 = rg_count * sub_group_size;
-        kd->lws1 = 1;
-        kd->gws1 = batch_size / units_per_sg_read;
-
-        kd->unit_byte_size    = unit_byte_size;
-        kd->chunk_type        = chunk_type;
-        kd->chunk_byte_size   = chunk_byte_size;
-        kd->units_per_chunk   = units_per_chunk;
-        kd->bytes_per_sg_read = sub_group_size * chunk_byte_size;
-        kd->units_per_sg_read = units_per_sg_read;
-        kd->rg_count          = (uint32_t)rg_count;
-        kd->last_rg_size      = response_size % units_per_sg_read;
-        return std::move(kd);
+        kd.gws0 = rg_count * sub_group_size;
+        kd.lws1 = 1;
+        kd.gws1 = batch_size / units_per_sg_read;
+
+        kd.unit_byte_size    = unit_byte_size;
+        kd.chunk_type        = chunk_type;
+        kd.chunk_byte_size   = chunk_byte_size;
+        kd.units_per_chunk   = units_per_chunk;
+        kd.bytes_per_sg_read = sub_group_size * chunk_byte_size;
+        kd.units_per_sg_read = units_per_sg_read;
+        kd.rg_count          = (uint32_t)rg_count;
+        kd.last_rg_size      = response_size % units_per_sg_read;
+        return kd;
     }
 
     JitConstants FullyConnected_fb_io_block::GetJitConstants(const fully_connected_params& params, const FullyConnectedKernelBase::DispatchData& run_info) const
     {
-        auto &d = static_cast<const DispatchData&>(run_info);
         auto cldnn_jit = FullyConnectedKernelBase::GetJitConstants(params, run_info);
         cldnn_jit.AddConstants({
-            MakeJitConstant("SUB_GROUP_SIZE",        d.lws0),
-            MakeJitConstant("WORK_ITEMS_PER_BATCH",  d.gws1),
-            MakeJitConstant("UNIT_BYTE_SIZE",        d.unit_byte_size),
-            MakeJitConstant("CHUNK_TYPE",            d.chunk_type),
-            MakeJitConstant("CHUNK_BYTE_SIZE",       d.chunk_byte_size),
-            MakeJitConstant("UNITS_PER_CHUNK",       d.units_per_chunk),
-            MakeJitConstant("BYTES_PER_SG_READ",     d.bytes_per_sg_read),
-            MakeJitConstant("UNITS_PER_SG_READ",     d.units_per_sg_read),
-            MakeJitConstant("RG_COUNT",              d.rg_count),
-            MakeJitConstant("LAST_RG_SIZE",          d.last_rg_size),
+            MakeJitConstant("SUB_GROUP_SIZE",        run_info.lws0),
+            MakeJitConstant("WORK_ITEMS_PER_BATCH",  run_info.gws1),
+            MakeJitConstant("UNIT_BYTE_SIZE",        run_info.unit_byte_size),
+            MakeJitConstant("CHUNK_TYPE",            run_info.chunk_type),
+            MakeJitConstant("CHUNK_BYTE_SIZE",       run_info.chunk_byte_size),
+            MakeJitConstant("UNITS_PER_CHUNK",       run_info.units_per_chunk),
+            MakeJitConstant("BYTES_PER_SG_READ",     run_info.bytes_per_sg_read),
+            MakeJitConstant("UNITS_PER_SG_READ",     run_info.units_per_sg_read),
+            MakeJitConstant("RG_COUNT",              run_info.rg_count),
+            MakeJitConstant("LAST_RG_SIZE",          run_info.last_rg_size),
         });
         return cldnn_jit;
     }
@@ -144,6 +143,18 @@ namespace kernel_selector
         //       (fb == fyxb flatten fyx, not yxfb flatten yxf).
         //       the order of the add operation cause some numeric changes. in order to avoid them right now we use yxfb/oiyx instead.
         // return GetCommonKernelsData(params, optParams, DataLayout::fb, WeightsLayout::io, estimated_time);
-        return GetCommonKernelsData(params, optParams, DataLayout::yxfb, { WeightsLayout::yxio }, estimated_time);
-    }
+        //return GetCommonKernelsData(params, optParams, DataLayout::yxfb, { WeightsLayout::yxio }, estimated_time);
+
+        KernelsData res = {};
+        for (size_t i = 0; i < autoTuneOptions.size(); i++)
+        {
+            KernelsData kd = GetTunedKernelsDataByIndex(params, optParams, DataLayout::yxfb, { WeightsLayout::yxio }, estimated_time, (int)i);
+            if (!kd.empty())
+            {
+                res.emplace_back(kd[0]);
+            }
+       }
+
+        return res;
+       }
 }