Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / shuffle_channels.hpp
1 /*
2 // Copyright (c) 2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19
20 #include "../C/shuffle_channels.h"
21 #include "primitive.hpp"
22
23 namespace  cldnn
24 {
25 /// @addtogroup cpp_api C++ API
26 /// @{
27 /// @addtogroup cpp_topology Network Topology
28 /// @{
29 /// @addtogroup cpp_primitives Primitives
30 /// @{
31
32 /// @brief
33 /// @details
34 struct shuffle_channels : public primitive_base<shuffle_channels, CLDNN_PRIMITIVE_DESC(shuffle_channels)>
35 {
36     CLDNN_DECLARE_PRIMITIVE(shuffle_channels)
37
38     /// @brief Constructs shuffle_channels primitive.
39     /// @param id This primitive id.
40     /// @param input Input dictionary primitive id.
41     /// @param group The number of groups to split the channel dimension.
42     /// @param axis The index of the channel dimension.
43     shuffle_channels(
44             const primitive_id& id,
45             const primitive_id& input,
46             const int32_t group,
47             const int32_t axis = 1,
48             const padding& output_padding = padding()
49     )
50             : primitive_base(id, {input}, output_padding)
51             , group(group)
52             , axis(axis)
53     {
54     }
55
56     /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{shuffle_channels}
57     shuffle_channels(const dto* dto)
58             : primitive_base(dto)
59             , group(dto->group)
60             , axis(dto->axis)
61     {
62     }
63
64     /// @brief The number of groups to split the channel dimension. This number must evenly divide the channel dimension size.
65     int32_t group;
66     /// @brief The index of the channel dimension (default is 1).
67     int32_t axis;
68 protected:
69
70     void update_dto(dto& dto) const override
71     {
72         dto.group = group;
73         dto.axis = axis;
74     }
75 };
76 /// @}
77 /// @}
78 /// @}
79 }