Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / reorder / reorder_kernel_binary.cpp
index 3e16d8a..7ddc661 100644 (file)
@@ -24,10 +24,13 @@ ParamsKey ReorderKernelBinary::GetSupportedKey() const {
     k.EnableInputDataType(Datatype::F32);
     k.EnableInputDataType(Datatype::BINARY);
     k.EnableOutputDataType(Datatype::BINARY);
+    k.EnableOutputDataType(Datatype::F32);
+    k.EnableOutputDataType(Datatype::F16);
     k.EnableDifferentTypes();
     k.EnableInputLayout(DataLayout::bfyx);
     k.EnableInputLayout(DataLayout::b_fs_yx_32fp);
     k.EnableOutputLayout(DataLayout::b_fs_yx_32fp);
+    k.EnableOutputLayout(DataLayout::bfyx);
     k.EnableTensorOffset();
     k.EnableTensorPitches();
     k.EnableBatching();
@@ -42,13 +45,17 @@ JitConstants ReorderKernelBinary::GetJitConstants(const reorder_params& params)
     const auto& input = newParams.inputs[0];
     jit.AddConstant(MakeJitConstant("ELEMENTS_COUNT", input.LogicalSize()));
     jit.AddConstant(MakeJitConstant("IFM_PACK_SIZE", 32));
-    jit.AddConstant(MakeJitConstant("OUTPUT_PACKED_FEATURES_NUM", CeilDiv(params.output.Feature().v, 32)));
 
     if (input.GetDType() == Datatype::BINARY) {
         jit.AddConstant(MakeJitConstant("BINARY_INPUT", 1));
         jit.AddConstant(MakeJitConstant("INPUT_PACKED_FEATURES_NUM", CeilDiv(input.Feature().v, 16)));
     }
 
+    if (params.output.GetDType() == Datatype::BINARY) {
+        jit.AddConstant(MakeJitConstant("BINARY_OUTPUT", 1));
+        jit.AddConstant(MakeJitConstant("OUTPUT_PACKED_FEATURES_NUM", CeilDiv(params.output.Feature().v, 32)));
+    }
+
     return jit;
 }
 
@@ -76,10 +83,18 @@ KernelsData ReorderKernelBinary::GetKernelsData(const Params& params, const opti
 
     const reorder_params& orgParams = static_cast<const reorder_params&>(params);
 
+    if (orgParams.inputs[0].GetDType() != Datatype::BINARY &&
+        orgParams.output.GetDType() != Datatype::BINARY)
+        return {};
+
     if (orgParams.inputs[0].GetDType() == Datatype::BINARY &&
         orgParams.inputs[0].GetLayout() != DataLayout::b_fs_yx_32fp)
         return {};
 
+    if (orgParams.output.GetDType() == Datatype::BINARY &&
+        orgParams.output.GetLayout() != DataLayout::b_fs_yx_32fp)
+        return {};
+
     auto estimatedTime = FORCE_PRIORITY_6;
 
     return GetCommonKernelsData(orgParams, options, estimatedTime);