Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / strided_slice.hpp
1 /*
2 // Copyright (c) 2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "../C/strided_slice.h"
20 #include "primitive.hpp"
21
22 namespace  cldnn
23 {
24 /// @addtogroup cpp_api C++ API
25 /// @{
26 /// @addtogroup cpp_topology Network Topology
27 /// @{
28 /// @addtogroup cpp_primitives Primitives
29 /// @{
30
31 /// @brief
32 /// @details
33 struct strided_slice : public primitive_base<strided_slice, CLDNN_PRIMITIVE_DESC(strided_slice)>
34 {
35     CLDNN_DECLARE_PRIMITIVE(strided_slice)
36
37     /// @brief Constructs strided_slice primitive.
38     /// @param id This primitive id.
39     /// @param input Input data primitive id.
40     /// @param begin_id Begin position primitive id.
41     /// @param end_id End position primitive id.
42     /// @param strides_id Step of slicing primitive id.
43     /// @param begin_mask Array of bits, that provide replace begin[i] to max possible range in that dimension.
44     /// @param end_mask Array of bits, that provide replace end[i] to max possible range in that dimension.
45     /// @param new_axis_mask Array of bits, that provide adding a new length 1 dimension at ith position in the output tensor.
46     /// @param shrink_axis_mask Array of bits, that provide shrinks the dimensionality by 1, taking on the value at index begin[i].
47     strided_slice(
48         const primitive_id& id,
49         const primitive_id& input,
50         const primitive_id& begin_id,
51         const primitive_id& end_id,
52         const primitive_id& strides_id,
53         std::vector<uint8_t> begin_mask,
54         std::vector<uint8_t> end_mask,
55         std::vector<uint8_t> new_axis_mask,
56         std::vector<uint8_t> shrink_axis_mask,
57         const padding& output_padding = padding()
58     )
59         : primitive_base(id, {input, begin_id, end_id, strides_id}, output_padding)
60         , begin_mask(begin_mask)
61         , end_mask(end_mask)
62         , new_axis_mask(new_axis_mask)
63         , shrink_axis_mask(shrink_axis_mask)
64     {
65     }
66
67     /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{strided_slice}
68     strided_slice(const dto* dto)
69         : primitive_base(dto)
70         , begin_mask(uint8_t_arr_to_vector(dto->begin_mask))
71         , end_mask(uint8_t_arr_to_vector(dto->end_mask))
72         , new_axis_mask(uint8_t_arr_to_vector(dto->new_axis_mask))
73         , shrink_axis_mask(uint8_t_arr_to_vector(dto->shrink_axis_mask))
74     {
75     }
76
77     /// @param begin_mask Array of bits, that provide replace begin[i] to max possible range in that dimension.
78     std::vector<uint8_t> begin_mask;
79     /// @param end_mask Array of bits, that provide replace end[i] to max possible range in that dimension.
80     std::vector<uint8_t> end_mask;
81     /// @param new_axis_mask Array of bits, that provide adding a new length 1 dimension at ith position in the output tensor.
82     std::vector<uint8_t> new_axis_mask;
83     /// @param shrink_axis_mask Array of bits, that provide shrinks the dimensionality by 1, taking on the value at index begin[i].
84     std::vector<uint8_t> shrink_axis_mask;
85
86 protected:
87
88     void update_dto(dto& dto) const override
89     {
90         dto.begin_mask = uint8_t_vector_to_arr(begin_mask);
91         dto.end_mask = uint8_t_vector_to_arr(end_mask);
92         dto.new_axis_mask = uint8_t_vector_to_arr(new_axis_mask);
93         dto.shrink_axis_mask = uint8_t_vector_to_arr(shrink_axis_mask);
94     }
95 };
96 /// @}
97 /// @}
98 /// @}
99 }