Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / broadcast.cpp
index 4113e53..d7f8738 100644 (file)
@@ -30,28 +30,39 @@ primitive_type_id broadcast_type_id()
 
 layout broadcast_inst::calc_output_layout(broadcast_node const& node)
 {
+    assert((bool)node.get_primitive()->output_data_type == false
+           && "Output data type forcing is not supported for broadcast_node!");
     auto input_layout = node.input().get_output_layout();
     auto desc         = node.get_primitive();
 
-    auto&& new_size = tensor::max(desc->broadcast_sizes, input_layout.size);
-    return {input_layout.data_type, input_layout.format, new_size};
+    return {input_layout.data_type, input_layout.format, desc->broadcast_sizes};
 }
 
 std::string broadcast_inst::to_string(broadcast_node const& node)
 {
-    auto desc = node.get_primitive();
+    auto desc                  = node.get_primitive();
+    auto node_info             = node.desc_to_json();
+    const auto& broadcast_sizes   = desc->broadcast_sizes;
+    const auto& broadcast_axes = desc->broadcast_axes;
+    auto& input                = node.input();
 
-    const auto& broadcast_sizes     = desc->broadcast_sizes;
+    std::stringstream primitive_description;
+    std::stringstream ss_broadcast_axes;
+
+    for (size_t i = 0; i < broadcast_axes.size(); ++i)
+    {
+        ss_broadcast_axes << broadcast_axes.at(i);
+        i != (broadcast_axes.size() - 1) ? ss_broadcast_axes << ", " : ss_broadcast_axes << "";
+    }
 
-    auto node_info  = node.desc_to_json();
-   
     json_composite broadcast_info;
-    broadcast_info.add("broadcast sizes", broadcast_sizes.to_string());
+    broadcast_info.add("input id", input.id());
+    broadcast_info.add("broadcast_sizes", broadcast_sizes.to_string());
+    broadcast_info.add("broadcast axes", ss_broadcast_axes.str());
 
     node_info->add("broadcast info", broadcast_info);
-
-    std::stringstream primitive_description;
     node_info->dump(primitive_description);
+
     return primitive_description.str();
 }
 
@@ -60,23 +71,56 @@ broadcast_inst::typed_primitive_inst(network_impl& network, broadcast_node const
 {
     auto input_layout = node.input().get_output_layout();
 
-    const auto input_format = input_layout.format;
     const auto& input_sizes = input_layout.size;
-
-    auto bc_sizes = argument.broadcast_sizes;
-
-    CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Input format", input_format.value, "supported broadcast primitive input formats",
-                                  format::bfyx, format::yxfb, format::byxf);
-
-
-    // Check if sizes of broadcast are in proper range.
-    CLDNN_ERROR_TENSOR_SIZES_LESS_THAN(node.id(), "Broadcast sizes", bc_sizes, "0 value", {1, 1, 1, 1},
-                                       "Invalid broadcast size: non-positive value");
-
-    bc_sizes = tensor::max(bc_sizes, input_sizes);
-
-    // Check if sizes of broadcast are compatible with sizes of input.
-    CLDNN_ERROR_TENSOR_SIZES_NOT_DIVIDABLE(node.id(), "Broadcast sizes", bc_sizes, "input sizes", input_sizes,
+    const auto& output_sizes = argument.broadcast_sizes;
+
+    std::vector<tensor::value_type> input_dims = {input_sizes.batch[0], input_sizes.feature[0],
+                                                  input_sizes.spatial[1], input_sizes.spatial[0]};
+    std::vector<tensor::value_type> reordered_input_dims(4, 0);
+    std::set<uint16_t> existing;
+
+    const auto& broadcast_axes = node.get_primitive()->broadcast_axes;
+    size_t broadcast_axes_size = broadcast_axes.size();
+    size_t index = 0;
+    size_t input_index = broadcast_axes_size;
+
+    if (broadcast_axes_size > 4)
+    {
+        CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: broadcast_axes size should be less or equal 4.");
+    }
+    for (size_t i = 0; i < broadcast_axes_size; ++i)
+    {
+        if (broadcast_axes.at(i) >= 4)
+        {
+            CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: broadcast_axes index should be within broadcast_sizes range.");
+        }
+        if (existing.find(broadcast_axes.at(i)) != existing.end())
+        {
+            CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: Duplicate axes numbers was found in broadcast_axes.");
+        }
+        existing.insert(broadcast_axes.at(i));
+    }
+    for (size_t i = 0; i < input_index; ++i)
+    {
+        CLDNN_ERROR_NOT_EQUAL(node.id(), "Input size on dimension number " + std::to_string(i), input_dims.at(i), "", 1, "Must be equal 1.");
+    }
+    //bfyx format
+    for (size_t i = 0; i < 4; ++i)
+    {
+        if (std::find(broadcast_axes.begin(), broadcast_axes.end(), i) != broadcast_axes.end())
+        {
+            reordered_input_dims.at(i) = input_dims.at(index);
+            ++index;
+        }
+        else
+        {
+            reordered_input_dims.at(i) = input_dims.at(input_index);
+            ++input_index;
+        }
+    }
+    tensor input_sizes_to_compare(reordered_input_dims.at(0), reordered_input_dims.at(1), reordered_input_dims.at(3), reordered_input_dims.at(2));
+
+    CLDNN_ERROR_TENSOR_SIZES_NOT_DIVIDABLE(node.id(), "Broadcast sizes", output_sizes, "input sizes", input_sizes_to_compare,
                                            "Invalid broadcast size: not dividable by input size");
 }
-}
\ No newline at end of file
+}