1 // Copyright (c) 2019 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "contract_inst.h"
18 #include "error_handler.h"
19 #include "json_object.h"
20 #include "primitive_type_base.h"
25 primitive_type_id contract_type_id()
27 static primitive_type_base<contract> instance;
31 layout contract_inst::calc_output_layout(contract_node const& node)
33 auto input_layout = node.input().get_output_layout();
34 const auto& input_sizes = input_layout.size;
35 auto desc = node.get_primitive();
36 auto reduction_axes = desc->reduction_axes;
38 std::vector<tensor::value_type> input_dims = { input_sizes.batch[0], input_sizes.feature[0],
39 input_sizes.spatial[1], input_sizes.spatial[0] };
40 std::vector<tensor::value_type> output_sizes(4, 0);
42 for (int i = 3; i >= 0; --i)
44 while (std::find(reduction_axes.begin(), reduction_axes.end(), cur_dim) != reduction_axes.end() && cur_dim >= 0)
46 output_sizes.at(i) = cur_dim >= 0 ? input_dims.at(cur_dim--) : 1;
49 return { input_layout.data_type, input_layout.format, cldnn::tensor(output_sizes[0], output_sizes[1], output_sizes[3], output_sizes[2]) };
52 std::string contract_inst::to_string(contract_node const& node)
54 auto desc = node.get_primitive();
55 auto node_info = node.desc_to_json();
56 const auto& reduction_axes = desc->reduction_axes;
57 auto& input = node.input();
59 std::stringstream primitive_description;
60 std::stringstream ss_reduction_axes;
62 for (size_t i = 0; i < reduction_axes.size(); ++i)
64 ss_reduction_axes << reduction_axes.at(i);
65 i != (reduction_axes.size() - 1) ? ss_reduction_axes << ", " : ss_reduction_axes << "";
71 case contract_mode::sum:
74 case contract_mode::prod:
77 case contract_mode::all:
80 case contract_mode::any:
83 case contract_mode::max:
87 str_mode = "not supported mode";
91 json_composite contract_info;
92 contract_info.add("input id", input.id());
93 contract_info.add("mode", str_mode);
94 contract_info.add("reduction axes", ss_reduction_axes.str());
96 node_info->add("contract info", contract_info);
97 node_info->dump(primitive_description);
99 return primitive_description.str();
102 contract_inst::typed_primitive_inst(network_impl& network, contract_node const& node)
103 : parent(network, node)
105 std::set<uint16_t> existing;
106 const auto& reduction_axes = node.get_primitive()->reduction_axes;
107 size_t reduction_axes_size = reduction_axes.size();
109 if (reduction_axes.empty())
111 CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: reduction_axes should not be empty.");
113 if (reduction_axes_size > 4)
115 CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: reduction_axes size should be less or equal 4.");
117 for (size_t i = 0; i < reduction_axes_size; ++i)
119 if (reduction_axes.at(i) >= 4)
121 CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: reduction_axes index should be within reduction_axes range.");
123 if (existing.find(reduction_axes.at(i)) != existing.end())
125 CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: Duplicate axes numbers was found in reduction_axes.");
127 existing.insert(reduction_axes.at(i));