31f8f74a8ea8631207b8b4f8e2005cfa96c3a913
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / reorder.hpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "../C/reorder.h"
20 #include "primitive.hpp"
21 #include "memory.hpp"
22 #include <vector>
23
24 namespace cldnn {
25 /// @addtogroup cpp_api C++ API
26 /// @{
27 /// @addtogroup cpp_topology Network Topology
28 /// @{
29 /// @addtogroup cpp_primitives Primitives
30 /// @{
31
32 /// @brief Changes how data is ordered in memory. Value type is not changed & all information is preserved.
33 /// @details Corresponding values are bitwise equal before/after reorder.
34 /// Also merged with subtraction layer, which can subtract, multiply or divide values based on mean_mode value, while doing reordering.
35 /// NOTE THAT THIS WILL SUBTRACT THE SAME VALUES FROM EACH BATCH.
36 struct reorder : public primitive_base<reorder, CLDNN_PRIMITIVE_DESC(reorder)> {
37     CLDNN_DECLARE_PRIMITIVE(reorder)
38
39     /// @brief Constructs reorder primitive with directly provided mean subtract values.
40     /// @param id This primitive id.
41     /// @param input Input primitive id.
42     /// @param output_layout Requested memory layout.
43     /// @param values_to_subtract Array of mean subtract values.
44     reorder(const primitive_id& id,
45             const primitive_id& input,
46             const layout& output_layout,
47             const std::vector<float>& values_to_subtract = {},
48             const cldnn_reorder_mean_mode mode = cldnn_reorder_mean_mode::mean_subtract)
49         : primitive_base(id, {input}, output_layout.data_padding, optional_data_type {output_layout.data_type}),
50           output_format(output_layout.format),
51           mean(""),
52           subtract_per_feature(values_to_subtract),
53           mean_mode(mode) {}
54
55     /// @brief Constructs reorder primitive which takes mean subtract values from another primitive.
56     /// @param id This primitive id.
57     /// @param input Input primitive id.
58     /// @param output_layout Requested memory layout.
59     /// @param mean Primitive id to get mean subtract values.
60     reorder(const primitive_id& id,
61             const primitive_id& input,
62             const layout& output_layout,
63             primitive_id const& mean,
64             const cldnn_reorder_mean_mode mode = cldnn_reorder_mean_mode::mean_subtract)
65         : primitive_base(id, {input}, output_layout.data_padding, optional_data_type {output_layout.data_type}),
66           output_format(output_layout.format),
67           mean(mean),
68           subtract_per_feature(0),
69           mean_mode(mode) {}
70
71     /// @brief Constructs reorder primitive with directly provided mean subtract values.
72     /// @param id This primitive id.
73     /// @param input Input primitive id.
74     /// @param output_layout Requested memory layout.
75     /// @param values_to_subtract Array of mean subtract values.
76     reorder(const primitive_id& id,
77             const primitive_id& input,
78             format output_format,
79             data_types output_data_type,
80             const std::vector<float>& values_to_subtract = {},
81             const cldnn_reorder_mean_mode mode = cldnn_reorder_mean_mode::mean_subtract,
82             const padding& output_padding = padding())
83         : primitive_base(id, {input}, output_padding, optional_data_type{output_data_type}),
84           output_format(output_format),
85           mean(""),
86           subtract_per_feature(values_to_subtract),
87           mean_mode(mode) {}
88
89     /// @brief Constructs reorder primitive which takes mean subtract values from another primitive.
90     /// @param id This primitive id.
91     /// @param input Input primitive id.
92     /// @param output_layout Requested memory layout.
93     /// @param mean Primitive id to get mean subtract values.
94     reorder(const primitive_id& id,
95             const primitive_id& input,
96             format output_format,
97             data_types output_data_type,
98             primitive_id const& mean,
99             const cldnn_reorder_mean_mode mode = cldnn_reorder_mean_mode::mean_subtract,
100             const padding& output_padding = padding())
101         : primitive_base(id, {input}, output_padding, optional_data_type {output_data_type}),
102           output_format(output_format),
103           mean(mean),
104           subtract_per_feature(0),
105           mean_mode(mode) {}
106
107     /// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{reorder}
108     reorder(const dto* dto)
109         : primitive_base(dto),
110           output_format(dto->output_format),
111           mean(dto->mean_subtract),
112           subtract_per_feature(float_arr_to_vector(dto->subtract_per_feature)),
113           mean_mode(dto->mean_mode) {}
114
115     /// @brief Requested memory format.
116     format output_format;
117     /// @brief Primitive id to get mean subtract values. Ignored if subtract_per_featrue is set.
118     primitive_id mean;
119     /// @brief Array of mean subtract values.
120     std::vector<float> subtract_per_feature;
121     /// @brief Mode of mean execution
122     cldnn_reorder_mean_mode mean_mode;
123
124 protected:
125     std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
126         if (mean.empty())
127             return {};
128         return {mean};
129     }
130
131     void update_dto(dto& dto) const override {
132         dto.output_format = static_cast<cldnn_format_type>(output_format.value);
133         dto.mean_subtract = mean.c_str();
134         dto.subtract_per_feature = float_vector_to_arr(subtract_per_feature);
135         dto.mean_mode = mean_mode;
136     }
137 };
138 /// @}
139 /// @}
140 /// @}
141 }  // namespace cldnn