Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / eltwise / eltwise_kernel_fs_bs_yx_bsv4_fsv32.cpp
index 571a013..e644505 100644 (file)
@@ -1,5 +1,5 @@
 /*
-// Copyright (c) 2018 Intel Corporation
+// Copyright (c) 2018-2019 Intel Corporation
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
 */
 
 #include "eltwise_kernel_fs_bs_yx_bsv4_fsv32.h"
-#include "kernel_selector_utils.h" 
+#include "kernel_selector_utils.h"
 
 namespace kernel_selector {
 
@@ -31,6 +31,7 @@ namespace kernel_selector {
         k.EnableBatching();
         k.EnableInt8Quantization();
         k.EnableOutputCalibration();
+        k.EnableEltwiseStride();
         return k;
     }
 
@@ -46,6 +47,7 @@ namespace kernel_selector {
         kd.lws1 = 1;
         kd.lws2 = 8;
 
+        kd.effiency = FORCE_PRIORITY_3;
         return kd;
     }
 
@@ -100,6 +102,12 @@ namespace kernel_selector {
             }
 
             inputs_decls += const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
+
+            if (!params.stride.empty())
+            {
+                jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_X", params.stride[i].x));
+                jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_Y", params.stride[i].y));
+            }
         }
 
         jit.AddConstant(MakeJitConstant("INPUTS_DECLS", inputs_decls));
@@ -177,17 +185,67 @@ namespace kernel_selector {
 
             switch (ew.mode)
             {
-            case EltwiseMode::ADD:      op += input0_str + " + " + input1_str; break;
-            case EltwiseMode::SUB:      op += input0_str + " - " + input1_str; break;
-            case EltwiseMode::MUL:      op += input0_str + " * " + input1_str; break;
-            case EltwiseMode::DIV:      op += input0_str + " / " + input1_str; break;
-            case EltwiseMode::MODULU:   op += cast_type + "fmod(" + input0_str + ", " + input1_str + ")"; break;
-            case EltwiseMode::MIN:      op += cast_type + "fmin(" + input0_str + ", " + input1_str + ")"; break;
-            case EltwiseMode::MAX:      op += cast_type + "fmax(" + input0_str + ", " + input1_str + ")"; break;
-            case EltwiseMode::POW:      op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
-            case EltwiseMode::SQRT:     op += cast_type + "sqrt(" + input0_str + ")"; break;
-            case EltwiseMode::RSQRT:    op += cast_type + "1/sqrt(" + input0_str + ")"; break;
-            case EltwiseMode::ASSIGN:   op += input0_str; break;
+            case EltwiseMode::ADD:          op += input0_str + " + " + input1_str; break;
+            case EltwiseMode::SUB:          op += input0_str + " - " + input1_str; break;
+            case EltwiseMode::MUL:          op += input0_str + " * " + input1_str; break;
+            case EltwiseMode::DIV:          op += input0_str + " / " + input1_str; break;
+            case EltwiseMode::MODULU:
+            case EltwiseMode::MIN:
+            case EltwiseMode::MAX:
+            {
+                auto mode = (ew.mode == EltwiseMode::MODULU ? "mod" : (ew.mode == EltwiseMode::MIN ? "min" : "max"));
+                auto input_0_type = params.inputs[0].GetDType();
+                auto input_1_type = params.inputs[1].GetDType();
+
+                // input_0 == int
+                if (input_0_type == kernel_selector::Datatype::INT8 ||
+                    input_0_type == kernel_selector::Datatype::INT32 ||
+                    input_0_type == kernel_selector::Datatype::INT64)
+                {
+                    // input_0 == int && input_1 == int
+                    if (input_1_type == kernel_selector::Datatype::INT8 ||
+                        input_1_type == kernel_selector::Datatype::INT32 ||
+                        input_1_type == kernel_selector::Datatype::INT64)
+                    {
+                        if (ew.mode == EltwiseMode::MODULU)
+                            op += input0_str + " % " + input1_str;
+                        else
+                            op += cast_type + mode + "(" + input0_str + ", " + input1_str + ")";
+                    }
+                    // input_0 == int && input_1 != int
+                    else
+                    {
+                        op += cast_type + "f" + mode + "(convert_float(" + input0_str + "), " + input1_str + ")";
+                    }
+                }
+                // input_0 != int && input_1 == int
+                else if (input_1_type == kernel_selector::Datatype::INT8 ||
+                    input_1_type == kernel_selector::Datatype::INT32 ||
+                    input_1_type == kernel_selector::Datatype::INT64)
+                {
+                    op += cast_type + "f" + mode + "(" + input0_str + ", convert_float(" + input1_str + "))";
+                }
+                // input_0 != int && input_1 != int
+                else
+                {
+                    op += cast_type + "f" + mode + "(" + input0_str + ", " + input1_str + ")";
+                }
+            } break;
+            case EltwiseMode::POW:          op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
+            case EltwiseMode::SQRT:         op += cast_type + "sqrt(" + input0_str + ")"; break;
+            case EltwiseMode::RSQRT:        op += cast_type + "1/sqrt(" + input0_str + ")"; break;
+            case EltwiseMode::SQUARED_DIFF: op += cast_type + "((" + input0_str + " - " + input1_str + ")"
+                                                            " * (" + input0_str + " - " + input1_str + "))"; break;
+            case EltwiseMode::EQ:           op += cast_type + "(" + input0_str + " == " + input1_str + ")"; break;
+            case EltwiseMode::NE:           op += cast_type + "(" + input0_str + " != " + input1_str + ")"; break;
+            case EltwiseMode::LT:           op += cast_type + "(" + input0_str + " < " + input1_str + ")"; break;
+            case EltwiseMode::LE:           op += cast_type + "(" + input0_str + " <= " + input1_str + ")"; break;
+            case EltwiseMode::GT:           op += cast_type + "(" + input0_str + " > " + input1_str + ")"; break;
+            case EltwiseMode::GE:           op += cast_type + "(" + input0_str + " >= " + input1_str + ")"; break;
+            case EltwiseMode::LOGIC_AND:    op += cast_type + "(" + input0_str + " && " + input1_str + ")"; break;
+            case EltwiseMode::LOGIC_OR:     op += cast_type + "(" + input0_str + " || " + input1_str + ")"; break;
+            case EltwiseMode::LOGIC_XOR:    op += cast_type + "(!" + input0_str + " != !" + input1_str + ")"; break;
+            case EltwiseMode::ASSIGN:       op += input0_str; break;
             default:
                 break;
             }
@@ -211,6 +269,11 @@ namespace kernel_selector {
             jit.Merge(GetTensorFriendlyWorkGroupsJit(params.inputs[0]));
         }
 
+        if (!params.stride.empty())
+        {
+            jit.AddConstant(MakeJitConstant("INPUT_STRIDED", 1));
+        }
+
         ///////////////
         return jit;
     }
@@ -219,4 +282,4 @@ namespace kernel_selector {
     {
         return GetCommonKernelsData(params, options);
     }
-}
\ No newline at end of file
+}