1 // Copyright (c) 2018 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "broadcast_inst.h"
18 #include "error_handler.h"
19 #include "json_object.h"
20 #include "primitive_type_base.h"
25 primitive_type_id broadcast_type_id()
27 static primitive_type_base<broadcast> instance;
31 layout broadcast_inst::calc_output_layout(broadcast_node const& node)
33 assert((bool)node.get_primitive()->output_data_type == false
34 && "Output data type forcing is not supported for broadcast_node!");
35 auto input_layout = node.input().get_output_layout();
36 auto desc = node.get_primitive();
38 return {input_layout.data_type, input_layout.format, desc->broadcast_sizes};
41 std::string broadcast_inst::to_string(broadcast_node const& node)
43 auto desc = node.get_primitive();
44 auto node_info = node.desc_to_json();
45 const auto& broadcast_sizes = desc->broadcast_sizes;
46 const auto& broadcast_axes = desc->broadcast_axes;
47 auto& input = node.input();
49 std::stringstream primitive_description;
50 std::stringstream ss_broadcast_axes;
52 for (size_t i = 0; i < broadcast_axes.size(); ++i)
54 ss_broadcast_axes << broadcast_axes.at(i);
55 i != (broadcast_axes.size() - 1) ? ss_broadcast_axes << ", " : ss_broadcast_axes << "";
58 json_composite broadcast_info;
59 broadcast_info.add("input id", input.id());
60 broadcast_info.add("broadcast_sizes", broadcast_sizes.to_string());
61 broadcast_info.add("broadcast axes", ss_broadcast_axes.str());
63 node_info->add("broadcast info", broadcast_info);
64 node_info->dump(primitive_description);
66 return primitive_description.str();
69 broadcast_inst::typed_primitive_inst(network_impl& network, broadcast_node const& node)
70 : parent(network, node)
72 auto input_layout = node.input().get_output_layout();
74 const auto& input_sizes = input_layout.size;
75 const auto& output_sizes = argument.broadcast_sizes;
77 std::vector<tensor::value_type> input_dims = {input_sizes.batch[0], input_sizes.feature[0],
78 input_sizes.spatial[1], input_sizes.spatial[0]};
79 std::vector<tensor::value_type> reordered_input_dims(4, 0);
80 std::set<uint16_t> existing;
82 const auto& broadcast_axes = node.get_primitive()->broadcast_axes;
83 size_t broadcast_axes_size = broadcast_axes.size();
85 size_t input_index = broadcast_axes_size;
87 if (broadcast_axes_size > 4)
89 CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: broadcast_axes size should be less or equal 4.");
91 for (size_t i = 0; i < broadcast_axes_size; ++i)
93 if (broadcast_axes.at(i) >= 4)
95 CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: broadcast_axes index should be within broadcast_sizes range.");
97 if (existing.find(broadcast_axes.at(i)) != existing.end())
99 CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: Duplicate axes numbers was found in broadcast_axes.");
101 existing.insert(broadcast_axes.at(i));
103 for (size_t i = 0; i < input_index; ++i)
105 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input size on dimension number " + std::to_string(i), input_dims.at(i), "", 1, "Must be equal 1.");
108 for (size_t i = 0; i < 4; ++i)
110 if (std::find(broadcast_axes.begin(), broadcast_axes.end(), i) != broadcast_axes.end())
112 reordered_input_dims.at(i) = input_dims.at(index);
117 reordered_input_dims.at(i) = input_dims.at(input_index);
121 tensor input_sizes_to_compare(reordered_input_dims.at(0), reordered_input_dims.at(1), reordered_input_dims.at(3), reordered_input_dims.at(2));
123 CLDNN_ERROR_TENSOR_SIZES_NOT_DIVIDABLE(node.id(), "Broadcast sizes", output_sizes, "input sizes", input_sizes_to_compare,
124 "Invalid broadcast size: not dividable by input size");