Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / contract.cpp
1 // Copyright (c) 2019 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15
16 #include "contract_inst.h"
17
18 #include "error_handler.h"
19 #include "json_object.h"
20 #include "primitive_type_base.h"
21
22
23 namespace cldnn
24 {
25     primitive_type_id contract_type_id()
26     {
27         static primitive_type_base<contract> instance;
28         return &instance;
29     }
30
31     layout contract_inst::calc_output_layout(contract_node const& node)
32     {
33         auto input_layout = node.input().get_output_layout();
34         const auto& input_sizes = input_layout.size;
35         auto desc = node.get_primitive();
36         auto reduction_axes = desc->reduction_axes;
37
38         std::vector<tensor::value_type> input_dims = { input_sizes.batch[0], input_sizes.feature[0],
39             input_sizes.spatial[1], input_sizes.spatial[0] };
40         std::vector<tensor::value_type> output_sizes(4, 0);
41         int cur_dim = 3;
42         for (int i = 3; i >= 0; --i)
43         {
44             while (std::find(reduction_axes.begin(), reduction_axes.end(), cur_dim) != reduction_axes.end() && cur_dim >= 0)
45                 --cur_dim;
46             output_sizes.at(i) = cur_dim >= 0 ? input_dims.at(cur_dim--) : 1;
47         }
48
49         return { input_layout.data_type, input_layout.format, cldnn::tensor(output_sizes[0], output_sizes[1], output_sizes[3], output_sizes[2]) };
50     }
51
52     std::string contract_inst::to_string(contract_node const& node)
53     {
54         auto desc = node.get_primitive();
55         auto node_info = node.desc_to_json();
56         const auto& reduction_axes = desc->reduction_axes;
57         auto& input = node.input();
58
59         std::stringstream primitive_description;
60         std::stringstream ss_reduction_axes;
61
62         for (size_t i = 0; i < reduction_axes.size(); ++i)
63         {
64             ss_reduction_axes << reduction_axes.at(i);
65             i != (reduction_axes.size() - 1) ? ss_reduction_axes << ", " : ss_reduction_axes << "";
66         }
67
68         std::string str_mode;
69         switch (desc->mode)
70         {
71         case contract_mode::sum:
72             str_mode = "sum";
73             break;
74         case contract_mode::prod:
75             str_mode = "product";
76             break;
77         case contract_mode::all:
78             str_mode = "all";
79             break;
80         case contract_mode::any:
81             str_mode = "any";
82             break;
83         case contract_mode::max:
84             str_mode = "max";
85             break;
86         default:
87             str_mode = "not supported mode";
88             break;
89         }
90
91         json_composite contract_info;
92         contract_info.add("input id", input.id());
93         contract_info.add("mode", str_mode);
94         contract_info.add("reduction axes", ss_reduction_axes.str());
95
96         node_info->add("contract info", contract_info);
97         node_info->dump(primitive_description);
98
99         return primitive_description.str();
100     }
101
102     contract_inst::typed_primitive_inst(network_impl& network, contract_node const& node)
103         : parent(network, node)
104     {
105         std::set<uint16_t> existing;
106         const auto& reduction_axes = node.get_primitive()->reduction_axes;
107         size_t reduction_axes_size = reduction_axes.size();
108
109         if (reduction_axes.empty())
110         {
111             CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: reduction_axes should not be empty.");
112         }
113         if (reduction_axes_size > 4)
114         {
115             CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: reduction_axes size should be less or equal 4.");
116         }
117         for (size_t i = 0; i < reduction_axes_size; ++i)
118         {
119             if (reduction_axes.at(i) >= 4)
120             {
121                 CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: reduction_axes index should be within reduction_axes range.");
122             }
123             if (existing.find(reduction_axes.at(i)) != existing.end())
124             {
125                 CLDNN_ERROR_MESSAGE(node.id(), "Incorrect parameters configuration: Duplicate axes numbers was found in reduction_axes.");
126             }
127             existing.insert(reduction_axes.at(i));
128         }
129     }
130 }