Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / strided_slice.cpp
1 /*
2 // Copyright (c) 2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "strided_slice_inst.h"
18 #include "primitive_type_base.h"
19 #include "error_handler.h"
20 #include "json_object.h"
21 #include "data_inst.h"
22
23 namespace cldnn
24 {
25 primitive_type_id strided_slice_type_id()
26 {
27     static primitive_type_base<strided_slice> instance;
28     return &instance;
29 }
30
31 layout strided_slice_inst::calc_output_layout(strided_slice_node const& node) {
32     const size_t numberOfDims = 4;
33     auto desc = node.get_primitive();
34     auto input_layout = node.input(0).get_output_layout();
35     auto input_format = input_layout.format;
36
37     auto completeStridedSliceParams = [&](std::vector<int32_t>& param) {
38         for (size_t i = param.size(); i < numberOfDims; ++i)
39             param.push_back(1);
40     };
41
42     auto completeStridedSliceMasks = [&](std::vector<uint8_t>& mask) {
43         for (size_t i = mask.size(); i < numberOfDims; ++i)
44             mask.push_back(0);
45     };
46
47     auto maskStridedSliceParams = [&](std::vector<int32_t>& param, const std::vector<uint8_t>& mask) {
48         for (size_t i = 0; i < param.size(); ++i)
49             if (mask[i])
50                 param[i] = input_layout.size.sizes(format::bfyx)[i];
51     };
52
53     // Getting data from constant inputs. There are 3 args: Begin, End, Stride
54     std::vector<std::vector<int32_t>> stridedSliceArgs;
55     for (size_t i = 1; i < node.get_dependencies().size(); ++i) {
56         auto& input = node.get_dependency(i).as<data>();
57         auto& mem = input.get_attached_memory();
58         int32_t* data = static_cast<int32_t*>(mem.lock());
59         std::vector<int32_t> vData = std::vector<int32_t>(data, data + input.get_output_layout().count());
60         completeStridedSliceParams(vData);
61         stridedSliceArgs.push_back(vData);
62         mem.unlock();
63     }
64
65     std::vector<uint8_t> beginMask(desc->begin_mask);
66     completeStridedSliceMasks(beginMask);
67     std::vector<uint8_t> endMask(desc->end_mask);
68     completeStridedSliceMasks(endMask);
69
70     auto& begin = stridedSliceArgs[0];
71     auto& end = stridedSliceArgs[1];
72     const auto& strides = stridedSliceArgs[2];
73     std::vector<int32_t> outputDimsSizes;
74
75     // If the ith bit of begin_mask is set, begin[i] is ignored and the fullest possible range in that dimension is used instead.
76     maskStridedSliceParams(begin, beginMask);
77     // end_mask works analogously
78     maskStridedSliceParams(end, endMask);
79
80     auto isShiftPossible = [] (std::vector<int32_t>& dims) -> bool {
81         if (dims[dims.size() - 1] == 1)
82             return true;
83         else
84             return false;
85     };
86
87     // If the new_axis_mask is set, then begin, end, and stride are ignored
88     if (std::find(desc->new_axis_mask.begin(), desc->new_axis_mask.end(), 1) == desc->new_axis_mask.end()) {
89         for (size_t i = 0; i < numberOfDims; ++i) {
90             int32_t outputDimSize = (end[i] - begin[i]) / strides[i];
91             if ((end[i] - begin[i]) % strides[i] != 0)
92                 outputDimSize++;
93             outputDimsSizes.push_back(outputDimSize);
94         }
95     } else {
96         outputDimsSizes = input_layout.size.sizes(format::bfyx);
97         for (size_t i = 0; i < desc->new_axis_mask.size(); ++i)
98             if (desc->new_axis_mask[desc->new_axis_mask.size() - i - 1] == 1)
99                 if (isShiftPossible(outputDimsSizes)) {
100                     for (size_t j = outputDimsSizes.size() - 1; j > i; --j)
101                         outputDimsSizes[j] = outputDimsSizes[j - 1];
102                     outputDimsSizes[i] = 1;
103                 }
104     }
105
106     return layout{input_layout.data_type, input_format, tensor(outputDimsSizes[0], outputDimsSizes[1], outputDimsSizes[3], outputDimsSizes[2])};
107 }
108
109 std::string strided_slice_inst::to_string(strided_slice_node const& node)
110 {
111     auto desc = node.get_primitive();
112     auto node_info = node.desc_to_json();
113     auto& input = node.input();
114
115     std::stringstream primitive_description;
116
117     json_composite strided_slice_info;
118     strided_slice_info.add("input id", input.id());
119     strided_slice_info.add("begin_param id", node.get_dependency(1).id());
120     strided_slice_info.add("end_param id", node.get_dependency(2).id());
121     strided_slice_info.add("stride_param id", node.get_dependency(3).id());
122     strided_slice_info.add("begin mask", node.get_primitive()->begin_mask);
123     strided_slice_info.add("end mask", node.get_primitive()->end_mask);
124     strided_slice_info.add("new axis mask", node.get_primitive()->new_axis_mask);
125     strided_slice_info.add("shrink axis mask", node.get_primitive()->shrink_axis_mask);
126     strided_slice_info.add("begin_param shape", node.get_dependency(1).get_output_layout().size.to_string());
127     strided_slice_info.add("end_param shape", node.get_dependency(2).get_output_layout().size.to_string());
128     strided_slice_info.add("stride_param shape", node.get_dependency(3).get_output_layout().size.to_string());
129
130     node_info->add("strided_slice info", strided_slice_info);
131     node_info->dump(primitive_description);
132
133     return primitive_description.str();
134 }
135
136 strided_slice_inst::typed_primitive_inst(network_impl& network, strided_slice_node const& node)
137     : parent(network, node)
138 {
139 }
140
141 }