Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / eltwise / eltwise_kernel_base.cpp
index 5feac0c..85cedc3 100644 (file)
@@ -1,5 +1,5 @@
 /*
-// Copyright (c) 2016 Intel Corporation
+// Copyright (c) 2016-2019 Intel Corporation
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
 */
 
 #include "eltwise_kernel_base.h"
-#include "kernel_selector_utils.h" 
+#include "kernel_selector_utils.h"
 
 namespace kernel_selector
 {
@@ -31,6 +31,16 @@ namespace kernel_selector
         case EltwiseMode::MAX:
         case EltwiseMode::POW:
         case EltwiseMode::MODULU:
+        case EltwiseMode::EQ:
+        case EltwiseMode::NE:
+        case EltwiseMode::LT:
+        case EltwiseMode::LE:
+        case EltwiseMode::GT:
+        case EltwiseMode::GE:
+        case EltwiseMode::LOGIC_AND:
+        case EltwiseMode::LOGIC_OR:
+        case EltwiseMode::LOGIC_XOR:
+        case EltwiseMode::SQUARED_DIFF:
             return 2;
         case EltwiseMode::SQRT:
         case EltwiseMode::RSQRT:
@@ -54,6 +64,16 @@ namespace kernel_selector
             k.EnableOutputCalibration();
         }
 
+        if (!stride.empty())
+        {
+            k.EnableEltwiseStride();
+        }
+
+        if (broadcast)
+        {
+            k.EnableEltwiseBroadcast();
+        }
+
         return k;
     }
 
@@ -109,6 +129,7 @@ namespace kernel_selector
         jit.AddConstants({
             MakeJitConstant("ELTWISE_LAYOUT_BASED", params.layoutBased),
             MakeJitConstant("QUANTIZATION_TERM",    params.int8_quantization),
+            MakeJitConstant("ELTWISE_BROADCAST",    params.broadcast),
         });
 
         if (params.int8_quantization)
@@ -140,6 +161,11 @@ namespace kernel_selector
             }
 
             inputs_decls += const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
+            if (!params.stride.empty())
+            {
+                jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_X", params.stride[i].x));
+                jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_Y", params.stride[i].y));
+            }
             if (useVload8)
             {
                 vload_decls += "\\\n\tconst " + toCLType(params.inputs[i].GetDType()) + "8 in" + std::to_string(i);
@@ -196,7 +222,7 @@ namespace kernel_selector
                 }
             }
 
-            std::string input0_str, input1_str, cast_type, op;
+            std::string input0_str, input1_str, cast_type, output_cast, op;
 
             if (useVload8)
             {
@@ -214,6 +240,11 @@ namespace kernel_selector
                 op = "const UNIT_TYPE tmp" + op_num_str + " = ";
             }
 
+            if (params.output.GetDType() == Datatype::INT8 && !params.int8_quantization) {
+                output_cast = "(char)";
+                cast_type = "(" + toCLType(params.inputs[op_num].GetDType()) + ")";
+            }
+
             input0_str = cast_type + "INPUT_" + op_num_str + "_0";
             input1_str = cast_type + "INPUT_" + op_num_str + "_1";
 
@@ -238,17 +269,67 @@ namespace kernel_selector
 
             switch (ew.mode)
             {
-            case EltwiseMode::ADD:      op += input0_str + " + " + input1_str; break;
-            case EltwiseMode::SUB:      op += input0_str + " - " + input1_str; break;
-            case EltwiseMode::MUL:      op += input0_str + " * " + input1_str; break;
-            case EltwiseMode::DIV:      op += input0_str + " / " + input1_str; break;
-            case EltwiseMode::MODULU:   op += cast_type + "fmod(" + input0_str + ", " + input1_str + ")"; break;
-            case EltwiseMode::MIN:      op += cast_type + "fmin(" + input0_str + ", " + input1_str + ")"; break;
-            case EltwiseMode::MAX:      op += cast_type + "fmax(" + input0_str + ", " + input1_str + ")"; break;
-            case EltwiseMode::POW:      op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
-            case EltwiseMode::SQRT:     op += cast_type + "sqrt(" + input0_str + ")"; break;
-            case EltwiseMode::RSQRT:    op += cast_type + "1/sqrt(" + input0_str + ")"; break;
-            case EltwiseMode::ASSIGN:   op += input0_str; break;
+            case EltwiseMode::ADD:         op += input0_str + " + " + input1_str; break;
+            case EltwiseMode::SUB:         op += input0_str + " - " + input1_str; break;
+            case EltwiseMode::MUL:         op += input0_str + " * " + input1_str; break;
+            case EltwiseMode::DIV:         op += input0_str + " / " + input1_str; break;
+            case EltwiseMode::MODULU:
+            case EltwiseMode::MIN:
+            case EltwiseMode::MAX:
+            {
+                auto mode = (ew.mode == EltwiseMode::MODULU ? "mod" : (ew.mode == EltwiseMode::MIN ? "min" : "max" ));
+                auto input_0_type = params.inputs[0].GetDType();
+                auto input_1_type = params.inputs[1].GetDType();
+
+                // input_0 == int
+                if (input_0_type == kernel_selector::Datatype::INT8 ||
+                    input_0_type == kernel_selector::Datatype::INT32 ||
+                    input_0_type == kernel_selector::Datatype::INT64)
+                {
+                    // input_0 == int && input_1 == int
+                    if (input_1_type == kernel_selector::Datatype::INT8 ||
+                        input_1_type == kernel_selector::Datatype::INT32 ||
+                        input_1_type == kernel_selector::Datatype::INT64)
+                    {
+                        if (ew.mode == EltwiseMode::MODULU)
+                            op += input0_str + " % " + input1_str;
+                        else
+                            op += cast_type + mode + "(" + input0_str + ", " + input1_str + ")";
+                    }
+                    // input_0 == int && input_1 != int
+                    else
+                    {
+                        op += cast_type + "f" + mode + "(convert_float(" + input0_str + "), " + input1_str + ")";
+                    }
+                }
+                // input_0 != int && input_1 == int
+                else if ( input_1_type == kernel_selector::Datatype::INT8 ||
+                          input_1_type == kernel_selector::Datatype::INT32 ||
+                          input_1_type == kernel_selector::Datatype::INT64)
+                {
+                    op += cast_type + "f" + mode + "(" + input0_str + ", convert_float(" + input1_str + "))";
+                }
+                // input_0 != int && input_1 != int
+                else
+                {
+                    op += cast_type + "f" + mode + "(" + input0_str + ", " + input1_str + ")";
+                }
+            } break;
+            case EltwiseMode::POW:          op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
+            case EltwiseMode::SQRT:         op += cast_type + "sqrt(" + input0_str + ")"; break;
+            case EltwiseMode::RSQRT:        op += cast_type + "1/sqrt(" + input0_str + ")"; break;
+            case EltwiseMode::SQUARED_DIFF: op += cast_type + "((" + input0_str + " - " + input1_str + ")"
+                                                            " * (" + input0_str + " - " + input1_str + "))"; break;
+            case EltwiseMode::EQ:           op += output_cast + "(" + input0_str + " == " + input1_str + ")"; break;
+            case EltwiseMode::NE:           op += output_cast + "(" + input0_str + " != " + input1_str + ")"; break;
+            case EltwiseMode::LT:           op += output_cast + "(" + input0_str + " < " + input1_str + ")"; break;
+            case EltwiseMode::LE:           op += output_cast + "(" + input0_str + " <= " + input1_str + ")"; break;
+            case EltwiseMode::GT:           op += output_cast + "(" + input0_str + " > " + input1_str + ")"; break;
+            case EltwiseMode::GE:           op += output_cast + "(" + input0_str + " >= " + input1_str + ")"; break;
+            case EltwiseMode::LOGIC_AND:    op += output_cast + "(" + input0_str + " && " + input1_str + ")"; break;
+            case EltwiseMode::LOGIC_OR:     op += output_cast + "(" + input0_str + " || " + input1_str + ")"; break;
+            case EltwiseMode::LOGIC_XOR:    op += output_cast + "(!" + input0_str + " != !" + input1_str + ")"; break;
+            case EltwiseMode::ASSIGN:       op += input0_str; break;
             default:
                 break;
             }
@@ -259,7 +340,7 @@ namespace kernel_selector
         }
 
         for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
-            do_eltwise += "\\\n\tinput" + std::to_string(updateInputs[update_input_idx].inputId) + 
+            do_eltwise += "\\\n\tinput" + std::to_string(updateInputs[update_input_idx].inputId) +
             "[GET_INDEX(INPUT, " + std::to_string(updateInputs[update_input_idx].inputId) +
             ")] = tmp" + std::to_string(updateInputs[update_input_idx].tmpId) + ";";
 
@@ -267,9 +348,14 @@ namespace kernel_selector
 
         jit.AddConstant(MakeJitConstant("DO_ELTWISE", do_eltwise));
 
-        if (params.layoutBased || params.int8_quantization)
+        if (params.layoutBased || params.int8_quantization || params.broadcast)
+        {
+            jit.Merge(GetTensorFriendlyWorkGroupsJit(params.output));
+        }
+
+        if (!params.stride.empty())
         {
-            jit.Merge(GetTensorFriendlyWorkGroupsJit(params.inputs[0]));
+            jit.AddConstant(MakeJitConstant("INPUT_STRIDED", 1));
         }
 
         return jit;
@@ -284,12 +370,17 @@ namespace kernel_selector
     {
         DispatchData kd;
 
-        if (params.layoutBased || params.int8_quantization)
+        if (params.layoutBased || params.int8_quantization || params.broadcast)
         {
-            auto global = GetTensorFriendlyWorkGroups(params.inputs[0]);
+            auto global = GetTensorFriendlyWorkGroups(params.output);
             kd.gws0 = global[0];
             kd.gws1 = global[1];
             kd.gws2 = global[2];
+            if (!params.stride.empty())
+            {
+                kd.gws0 /= params.stride[0].x;
+                kd.gws0 /= params.stride[0].y;
+            }
         }
         else if (CheckInputsOutputNoPitchSameDims(params))
         {
@@ -346,7 +437,7 @@ namespace kernel_selector
         kernel.workGroups.global = { runInfo.gws0, runInfo.gws1, runInfo.gws2 };
         kernel.workGroups.local = { runInfo.lws0, runInfo.lws1, runInfo.lws2 };
 
-        kernel.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo, ROUND_ROBIN);
+        kernel.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo, DEFAULT);
         kernel.arguments = GetArgsDesc((uint32_t)newParams.inputs.size(), false, false, newParams.int8_quantization, newParams.output_calibration);
 
         kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE;