layout broadcast_inst::calc_output_layout(broadcast_node const& node)
{
+ assert((bool)node.get_primitive()->output_data_type == false
+ && "Output data type forcing is not supported for broadcast_node!");
auto input_layout = node.input().get_output_layout();
auto desc = node.get_primitive();
- auto&& new_size = tensor::max(desc->broadcast_sizes, input_layout.size);
- return {input_layout.data_type, input_layout.format, new_size};
+ return {input_layout.data_type, input_layout.format, desc->broadcast_sizes};
}
std::string broadcast_inst::to_string(broadcast_node const& node)
{
- auto desc = node.get_primitive();
+ auto desc = node.get_primitive();
+ auto node_info = node.desc_to_json();
+ const auto& broadcast_sizes = desc->broadcast_sizes;
+ const auto& broadcast_axes = desc->broadcast_axes;
+ auto& input = node.input();
- const auto& broadcast_sizes = desc->broadcast_sizes;
+ std::stringstream primitive_description;
+ std::stringstream ss_broadcast_axes;
+
+ for (size_t i = 0; i < broadcast_axes.size(); ++i)
+ {
+ ss_broadcast_axes << broadcast_axes.at(i);
+ i != (broadcast_axes.size() - 1) ? ss_broadcast_axes << ", " : ss_broadcast_axes << "";
+ }
- auto node_info = node.desc_to_json();
-
json_composite broadcast_info;
- broadcast_info.add("broadcast sizes", broadcast_sizes.to_string());
+ broadcast_info.add("input id", input.id());
+ broadcast_info.add("broadcast_sizes", broadcast_sizes.to_string());
+ broadcast_info.add("broadcast axes", ss_broadcast_axes.str());
node_info->add("broadcast info", broadcast_info);
-
- std::stringstream primitive_description;
node_info->dump(primitive_description);
+
return primitive_description.str();
}
{
auto input_layout = node.input().get_output_layout();
- const auto input_format = input_layout.format;
const auto& input_sizes = input_layout.size;
-
- auto bc_sizes = argument.broadcast_sizes;
-
- CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Input format", input_format.value, "supported broadcast primitive input formats",
- format::bfyx, format::yxfb, format::byxf);
-
-
- // Check if sizes of broadcast are in proper range.
- CLDNN_ERROR_TENSOR_SIZES_LESS_THAN(node.id(), "Broadcast sizes", bc_sizes, "0 value", {1, 1, 1, 1},
- "Invalid broadcast size: non-positive value");
-
- bc_sizes = tensor::max(bc_sizes, input_sizes);
-
- // Check if sizes of broadcast are compatible with sizes of input.
- CLDNN_ERROR_TENSOR_SIZES_NOT_DIVIDABLE(node.id(), "Broadcast sizes", bc_sizes, "input sizes", input_sizes,
+ const auto& output_sizes = argument.broadcast_sizes;
+
+ std::vector<tensor::value_type> input_dims = {input_sizes.batch[0], input_sizes.feature[0],
+ input_sizes.spatial[1], input_sizes.spatial[0]};
+ std::vector<tensor::value_type> reordered_input_dims(4, 0);
+ std::set<uint16_t> existing;
+
+ const auto& broadcast_axes = node.get_primitive()->broadcast_axes;
+ size_t broadcast_axes_size = broadcast_axes.size();
+ size_t index = 0;
+ size_t input_index = broadcast_axes_size;
+
+ if (broadcast_axes_size > 4)
+ {
+ CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: broadcast_axes size should be less or equal 4.");
+ }
+ for (size_t i = 0; i < broadcast_axes_size; ++i)
+ {
+ if (broadcast_axes.at(i) >= 4)
+ {
+ CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: broadcast_axes index should be within broadcast_sizes range.");
+ }
+ if (existing.find(broadcast_axes.at(i)) != existing.end())
+ {
+ CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: Duplicate axes numbers was found in broadcast_axes.");
+ }
+ existing.insert(broadcast_axes.at(i));
+ }
+ for (size_t i = 0; i < input_index; ++i)
+ {
+ CLDNN_ERROR_NOT_EQUAL(node.id(), "Input size on dimension number " + std::to_string(i), input_dims.at(i), "", 1, "Must be equal 1.");
+ }
+ //bfyx format
+ for (size_t i = 0; i < 4; ++i)
+ {
+ if (std::find(broadcast_axes.begin(), broadcast_axes.end(), i) != broadcast_axes.end())
+ {
+ reordered_input_dims.at(i) = input_dims.at(index);
+ ++index;
+ }
+ else
+ {
+ reordered_input_dims.at(i) = input_dims.at(input_index);
+ ++input_index;
+ }
+ }
+ tensor input_sizes_to_compare(reordered_input_dims.at(0), reordered_input_dims.at(1), reordered_input_dims.at(3), reordered_input_dims.at(2));
+
+ CLDNN_ERROR_TENSOR_SIZES_NOT_DIVIDABLE(node.id(), "Broadcast sizes", output_sizes, "input sizes", input_sizes_to_compare,
"Invalid broadcast size: not dividable by input size");
}
-}
\ No newline at end of file
+}