Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / fully_connected / fully_connected_kernel_base.cpp
index 20e6e8d..9b4cbb7 100644 (file)
@@ -47,27 +47,27 @@ namespace kernel_selector
         return jit;
     }
 
-    std::unique_ptr<FullyConnectedKernelBase::DispatchData> FullyConnectedKernelBase::SetDefault(const fully_connected_params& params) const
+    FullyConnectedKernelBase::DispatchData FullyConnectedKernelBase::SetDefault(const fully_connected_params& params, int) const
     {
-        std::unique_ptr<DispatchData> dispatchData = std::unique_ptr<DispatchData>(new DispatchData());
-        dispatchData->fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
+        DispatchData dispatchData;
+        dispatchData.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
 
         // Determine global work sizes.
-        dispatchData->gws0 = params.output.LogicalSize();
-        dispatchData->gws1 = dispatchData->gws2 = 1;
+        dispatchData.gws0 = params.output.LogicalSize();
+        dispatchData.gws1 = dispatchData.gws2 = 1;
 
         // Find largest positive local work size that is divider for global work size.
-        dispatchData->lws0 = std::min(std::max(dispatchData->gws0, static_cast<size_t>(1)), static_cast<size_t>(32));
-        while (dispatchData->gws0 % dispatchData->lws0 != 0)
+        dispatchData.lws0 = std::min(std::max(dispatchData.gws0, static_cast<size_t>(1)), static_cast<size_t>(32));
+        while (dispatchData.gws0 % dispatchData.lws0 != 0)
         {
-            --dispatchData->lws0;
+            --dispatchData.lws0;
         }
-        dispatchData->lws1 = dispatchData->lws2 = 1;
+        dispatchData.lws1 = dispatchData.lws2 = 1;
 
-        return std::move(dispatchData);
+        return dispatchData;
     }
 
-    KernelsData FullyConnectedKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options, DataLayout dl, std::vector<WeightsLayout> wl, float estimated_time) const
+    KernelsData FullyConnectedKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options, DataLayout dl, std::vector<WeightsLayout> wl, float estimated_time, const std::string exeMode, int autoTuneIndex) const
     {
         if (!Validate(params, options) ||
             wl.empty())
@@ -117,15 +117,31 @@ namespace kernel_selector
         
         auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, options);
 
-        const std::unique_ptr<DispatchData> runInfo = SetDefault(newParams);
-        auto cldnn_jit = GetJitConstants(newParams, *runInfo.get());
+        const DispatchData runInfo = SetDefault(newParams, autoTuneIndex);
+        auto cldnn_jit = GetJitConstants(newParams, runInfo);
         std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
 
         auto& kernel = kd.kernels[0];
-        FillCLKernelData(kernel, *runInfo.get(), params.engineInfo, kernelName, jit, entry_point, ROUND_ROBIN, true, !orgParams.bias.empty(), 1, newParams.int8_quantization, newParams.output_calibration);
+        FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point, exeMode, true, !orgParams.bias.empty(), 1, newParams.int8_quantization, newParams.output_calibration);
 
         kd.estimatedTime = estimated_time;
-        kd.autoTuneIndex = -1;
+        kd.autoTuneIndex = autoTuneIndex;
         return{ kd };
     }
+
+    std::string FullyConnectedKernelBase::GetAutoTuneOptions(int autoTuneIndex) const
+    {
+        if ((autoTuneIndex >= 0) && (autoTuneIndex < (int)autoTuneOptions.size()))
+        {
+            return autoTuneOptions[autoTuneIndex];
+        }
+
+        return DEFAULT;
+}
+
+    KernelsData FullyConnectedKernelBase::GetTunedKernelsDataByIndex(const Params& params, const optional_params& options, DataLayout dl, std::vector<WeightsLayout> wl, float estimated_time, const int autoTuneIndex) const
+    {
+        return GetCommonKernelsData(params, options, dl, wl, estimated_time, GetAutoTuneOptions(autoTuneIndex), autoTuneIndex);
+    }
+
 }