Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / eltwise.cpp
index 1ee22cc..2a6835b 100644 (file)
@@ -1,5 +1,5 @@
 /*
-// Copyright (c) 2016 Intel Corporation
+// Copyright (c) 2016-2019 Intel Corporation
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -30,22 +30,53 @@ primitive_type_id eltwise_type_id()
 
 layout eltwise_inst::calc_output_layout(eltwise_node const& node)
 {
+    assert((bool)node.get_primitive()->output_data_type == false
+           && "Output data type forcing is not supported for eltwise_inst_node!");
+
     auto input_node_layout = node.input().get_non_padded_output_layout();
+
+    auto size = input_node_layout.size;
+    for (size_t i = 1; i < node.inputs_count(); i++)
+    {
+        size = tensor::max(size, node.input(i).get_non_padded_output_layout().size);
+    }
+    auto output_layout = layout(input_node_layout.data_type, input_node_layout.format, size);
+    auto mode = node.get_primitive()->mode;
     //list of operations supported for integer types
     if (input_node_layout.data_type == data_types::i8 ||
         input_node_layout.data_type == data_types::i32 ||
         input_node_layout.data_type == data_types::i64)
     {
-        auto mode = node.get_primitive()->mode;
-        std::vector<eltwise_mode> eltwise_int_modes = { eltwise_mode::sum, eltwise_mode::sub, eltwise_mode::prod, eltwise_mode::div };
+        std::vector<eltwise_mode> eltwise_int_modes = { eltwise_mode::sum, eltwise_mode::sub, eltwise_mode::prod, eltwise_mode::div, eltwise_mode::min, eltwise_mode::max, eltwise_mode::mod,
+                                                        eltwise_mode::eq, eltwise_mode::ne, eltwise_mode::lt, eltwise_mode::le, eltwise_mode::gt, eltwise_mode::ge,
+                                                        eltwise_mode::logic_and, eltwise_mode::logic_or, eltwise_mode::logic_xor };
         if (std::find(eltwise_int_modes.begin(), eltwise_int_modes.end(), mode) == eltwise_int_modes.end())
             CLDNN_ERROR_MESSAGE(node.id(), "Requested eltwise mode is not supported for integer types.");
     }
 
-    return input_node_layout;
+    // Logic and comparison operations should return i8 for any inputs
+    std::vector<eltwise_mode> eltwise_bool_modes = { eltwise_mode::eq, eltwise_mode::ne, eltwise_mode::lt, eltwise_mode::le,
+                                                     eltwise_mode::gt, eltwise_mode::ge,
+                                                     eltwise_mode::logic_and, eltwise_mode::logic_or, eltwise_mode::logic_xor };
+    if (std::find(eltwise_bool_modes.begin(), eltwise_bool_modes.end(), mode) != eltwise_bool_modes.end())
+    {
+        output_layout.data_type = data_types::i8;
+        if (node.get_primitive()->with_activation)
+            CLDNN_ERROR_MESSAGE(node.id(), "Activations are not supported for logical operations.");
+    }
+
+    auto eltw = std::static_pointer_cast<const eltwise>((node.get_primitive()));
+    if (!eltw->stride.empty())
+    {
+        // we can safely use only first stride, since we're using first input, and input / stride should give exact same value for every input
+        input_node_layout.size.spatial[0] /= eltw->stride[0].spatial[0];
+        input_node_layout.size.spatial[1] /= eltw->stride[0].spatial[1];
+        return input_node_layout;
+    }
+    return output_layout;
 }
 
-static inline std::string stringify_vector(std::vector<float> v)
+static inline std::string stringify_vector(const std::vector<float>& v)
 {
     std::stringstream s;
 
@@ -90,13 +121,43 @@ std::string eltwise_inst::to_string(eltwise_node const& node)
             break;
     case eltwise_mode::min:
             str_mode = "min";
-         break;
+            break;
     case eltwise_mode::pow:
             str_mode = "pow";
             break;
+    case eltwise_mode::squared_diff:
+            str_mode = "squared_diff";
+            break;
     case eltwise_mode::mod:
             str_mode = "mod";
             break;
+    case eltwise_mode::eq:
+            str_mode = "equal";
+            break;
+    case eltwise_mode::ne:
+            str_mode = "not equal";
+            break;
+    case eltwise_mode::lt:
+            str_mode = "less";
+            break;
+    case eltwise_mode::le:
+            str_mode = "less-or-equal";
+            break;
+    case eltwise_mode::gt:
+            str_mode = "greater";
+            break;
+    case eltwise_mode::ge:
+            str_mode = "greater-or-equal";
+            break;
+    case eltwise_mode::logic_and:
+            str_mode = "and";
+            break;
+    case eltwise_mode::logic_or:
+            str_mode = "or";
+            break;
+    case eltwise_mode::logic_xor:
+            str_mode = "xor";
+            break;
     default:
             str_mode = "not supported mode";
             break;
@@ -126,21 +187,78 @@ std::string eltwise_inst::to_string(eltwise_node const& node)
 eltwise_inst::typed_primitive_inst(network_impl& network, eltwise_node const& node)
     :parent(network, node)
 {
-    auto input_layout = node.input().get_output_layout();
-    auto batch_size = input_layout.size.batch[0];
-    auto feature_size = input_layout.size.feature[0];
+    check_inputs_count(node);
+    // check for stride
+    auto prim = node.get_primitive();
+    if (!prim->stride.empty())
+    {
+        // number of strides must match number of inputs
+        CLDNN_ERROR_NOT_EQUAL(node.id(), "Eltwise inputs count", node.inputs_count(), "Eltwise strides count", prim->stride.size(), "");
 
-    auto input_batch_size = input_layout.size.batch[0];
-    auto input_feature_size = input_layout.size.feature[0];
+        const auto out_x = node.get_output_layout().size.spatial[0];
+        const auto out_y = node.get_output_layout().size.spatial[1];
+        // check if strides are correctly set. I.e INPUT_SIZE_X / STRIDE_X = OUTPUT_SIZE_X, same for Y dimension
+        for (size_t i = 0; i < node.inputs_count(); i++)
+        {
+            const auto& in_layout = node.input(i).get_output_layout();
+            auto stride = prim->stride[i];
 
-    if (batch_size != 1)
+            const auto in_x_div_stride_x = in_layout.size.spatial[0] / stride.spatial[0];
+            if(in_x_div_stride_x != out_x)
+                CLDNN_ERROR_NOT_EQUAL(node.id(), "Eltwise input_x / stride_x", in_x_div_stride_x, "Eltwise output_x", out_x, "");
+
+            const auto in_y_div_stride_y = in_layout.size.spatial[1] / stride.spatial[1];
+            if(in_y_div_stride_y != out_y)
+                CLDNN_ERROR_NOT_EQUAL(node.id(), "Eltwise inputyx / stride_y", in_y_div_stride_y, "Eltwise output_y", out_y, "");
+        }
+    }
+    else
     {
-        CLDNN_ERROR_NOT_EQUAL(node.id(), "Eltwise batch size", batch_size, "input batch size", input_batch_size, "");
+        std::vector<int32_t> input0_size = node.input().get_output_layout().size.raw.vector();
+        for (size_t i = 1; i < node.inputs_count(); i++)
+        {
+            std::vector<int32_t> input_size = node.input(i).get_output_layout().size.raw.vector();
+            for (size_t d = 0; d < input0_size.size(); d++)
+            {
+                bool sizes_equal = input0_size[d] == input_size[d];
+                bool broadcast = (input0_size[d] == 1 || input_size[d] == 1) && (input0_size[d] != 1 || input_size[d] != 1);
+                CLDNN_ERROR_BOOL(node.id(), "Sizes equal or broadcast is possible", !(sizes_equal || broadcast), "Invalid input shapes");
+            }
+        }
     }
+}
 
-    if (feature_size != 1)
+void eltwise_inst::check_inputs_count(eltwise_node const &node)
+{
+    const size_t inputs_number = node.get_primitive()->input.size();
+    const eltwise_mode mode = node.get_primitive()->mode;
+
+    switch (mode)
     {
-        CLDNN_ERROR_NOT_EQUAL(node.id(), "Eltwise feature size", feature_size, "input feature size", input_feature_size, "");
+        case eltwise_mode::sum:
+        case eltwise_mode::sub:
+        case eltwise_mode::div:
+        case eltwise_mode::prod:
+        case eltwise_mode::max:
+        case eltwise_mode::min:
+        case eltwise_mode::mod:
+        case eltwise_mode::logic_and:
+        case eltwise_mode::logic_or:
+        case eltwise_mode::logic_xor:
+            if (inputs_number < 2)
+                CLDNN_ERROR_MESSAGE(node.id(), "Invalid eltwise inputs number (should be equal at least to 2). Actual: " + std::to_string(inputs_number));
+            break;
+        case eltwise_mode::eq:
+        case eltwise_mode::ne:
+        case eltwise_mode::lt:
+        case eltwise_mode::le:
+        case eltwise_mode::gt:
+        case eltwise_mode::ge:
+        case eltwise_mode::squared_diff:
+        case eltwise_mode::pow:
+            if (inputs_number != 2)
+                CLDNN_ERROR_MESSAGE(node.id(), "Invalid eltwise inputs number (should be equal to 2). Actual: " + std::to_string(inputs_number));
+            break;
     }
 }
 }