Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / shuffle_channels.cpp
1 /*
2 // Copyright (c) 2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "shuffle_channels_inst.h"
18
19 #include "primitive_type_base.h"
20 #include "error_handler.h"
21 #include "json_object.h"
22
23 namespace cldnn
24 {
25 primitive_type_id shuffle_channels_type_id()
26 {
27     static primitive_type_base<shuffle_channels> instance;
28     return &instance;
29 }
30
31 layout shuffle_channels_inst::calc_output_layout(shuffle_channels_node const& node)
32 {
33     auto desc = node.get_primitive();
34
35     auto input_layout = node.input(0).get_output_layout();
36     auto input_format = input_layout.format;
37
38     const int32_t number_of_dims = 4;
39     const int32_t group = desc->group;
40     int32_t axis = desc->axis;
41
42     if (axis < 0)
43         axis += number_of_dims;
44
45     if (axis < 0 || axis >= number_of_dims)
46         CLDNN_ERROR_MESSAGE(node.id(), "Incorrect axis value! Actual axis is" + std::to_string(group));
47
48     if (group < 1)
49         CLDNN_ERROR_MESSAGE(node.id(), "Invalid group size value (should equal at least one). Actual block size is" +
50                                        std::to_string(group));
51
52     if (input_layout.size.sizes(format::bfyx)[axis] % group != 0)
53         CLDNN_ERROR_MESSAGE(node.id(), "Group parameter must evenly divide the channel dimension. Actual group size is " +
54                                        std::to_string(group));
55
56     return layout{input_layout.data_type, input_format, input_layout.size};
57 }
58
59 std::string shuffle_channels_inst::to_string(shuffle_channels_node const& node)
60 {
61     auto desc = node.get_primitive();
62     auto node_info = node.desc_to_json();
63     auto& input = node.input();
64
65     std::stringstream primitive_description;
66
67     json_composite shuffle_channels_info;
68     shuffle_channels_info.add("input id", input.id());
69     shuffle_channels_info.add("groups number", desc->group);
70     shuffle_channels_info.add("axis", desc->axis);
71
72     node_info->add("shuffle_channels info", shuffle_channels_info);
73     node_info->dump(primitive_description);
74
75     return primitive_description.str();
76 }
77
78 shuffle_channels_inst::typed_primitive_inst(network_impl& network, shuffle_channels_node const& node)
79 : parent(network, node)
80 {
81 }
82
83 }