2 // Copyright (c) 2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "strided_slice_inst.h"
18 #include "primitive_type_base.h"
19 #include "error_handler.h"
20 #include "json_object.h"
21 #include "data_inst.h"
25 primitive_type_id strided_slice_type_id()
27 static primitive_type_base<strided_slice> instance;
31 layout strided_slice_inst::calc_output_layout(strided_slice_node const& node) {
32 const size_t numberOfDims = 4;
33 auto desc = node.get_primitive();
34 auto input_layout = node.input(0).get_output_layout();
35 auto input_format = input_layout.format;
37 auto completeStridedSliceParams = [&](std::vector<int32_t>& param) {
38 for (size_t i = param.size(); i < numberOfDims; ++i)
42 auto completeStridedSliceMasks = [&](std::vector<uint8_t>& mask) {
43 for (size_t i = mask.size(); i < numberOfDims; ++i)
47 auto maskStridedSliceParams = [&](std::vector<int32_t>& param, const std::vector<uint8_t>& mask) {
48 for (size_t i = 0; i < param.size(); ++i)
50 param[i] = input_layout.size.sizes(format::bfyx)[i];
53 // Getting data from constant inputs. There are 3 args: Begin, End, Stride
54 std::vector<std::vector<int32_t>> stridedSliceArgs;
55 for (size_t i = 1; i < node.get_dependencies().size(); ++i) {
56 auto& input = node.get_dependency(i).as<data>();
57 auto& mem = input.get_attached_memory();
58 int32_t* data = static_cast<int32_t*>(mem.lock());
59 std::vector<int32_t> vData = std::vector<int32_t>(data, data + input.get_output_layout().count());
60 completeStridedSliceParams(vData);
61 stridedSliceArgs.push_back(vData);
65 std::vector<uint8_t> beginMask(desc->begin_mask);
66 completeStridedSliceMasks(beginMask);
67 std::vector<uint8_t> endMask(desc->end_mask);
68 completeStridedSliceMasks(endMask);
70 auto& begin = stridedSliceArgs[0];
71 auto& end = stridedSliceArgs[1];
72 const auto& strides = stridedSliceArgs[2];
73 std::vector<int32_t> outputDimsSizes;
75 // If the ith bit of begin_mask is set, begin[i] is ignored and the fullest possible range in that dimension is used instead.
76 maskStridedSliceParams(begin, beginMask);
77 // end_mask works analogously
78 maskStridedSliceParams(end, endMask);
80 auto isShiftPossible = [] (std::vector<int32_t>& dims) -> bool {
81 if (dims[dims.size() - 1] == 1)
87 // If the new_axis_mask is set, then begin, end, and stride are ignored
88 if (std::find(desc->new_axis_mask.begin(), desc->new_axis_mask.end(), 1) == desc->new_axis_mask.end()) {
89 for (size_t i = 0; i < numberOfDims; ++i) {
90 int32_t outputDimSize = (end[i] - begin[i]) / strides[i];
91 if ((end[i] - begin[i]) % strides[i] != 0)
93 outputDimsSizes.push_back(outputDimSize);
96 outputDimsSizes = input_layout.size.sizes(format::bfyx);
97 for (size_t i = 0; i < desc->new_axis_mask.size(); ++i)
98 if (desc->new_axis_mask[desc->new_axis_mask.size() - i - 1] == 1)
99 if (isShiftPossible(outputDimsSizes)) {
100 for (size_t j = outputDimsSizes.size() - 1; j > i; --j)
101 outputDimsSizes[j] = outputDimsSizes[j - 1];
102 outputDimsSizes[i] = 1;
106 return layout{input_layout.data_type, input_format, tensor(outputDimsSizes[0], outputDimsSizes[1], outputDimsSizes[3], outputDimsSizes[2])};
109 std::string strided_slice_inst::to_string(strided_slice_node const& node)
111 auto desc = node.get_primitive();
112 auto node_info = node.desc_to_json();
113 auto& input = node.input();
115 std::stringstream primitive_description;
117 json_composite strided_slice_info;
118 strided_slice_info.add("input id", input.id());
119 strided_slice_info.add("begin_param id", node.get_dependency(1).id());
120 strided_slice_info.add("end_param id", node.get_dependency(2).id());
121 strided_slice_info.add("stride_param id", node.get_dependency(3).id());
122 strided_slice_info.add("begin mask", node.get_primitive()->begin_mask);
123 strided_slice_info.add("end mask", node.get_primitive()->end_mask);
124 strided_slice_info.add("new axis mask", node.get_primitive()->new_axis_mask);
125 strided_slice_info.add("shrink axis mask", node.get_primitive()->shrink_axis_mask);
126 strided_slice_info.add("begin_param shape", node.get_dependency(1).get_output_layout().size.to_string());
127 strided_slice_info.add("end_param shape", node.get_dependency(2).get_output_layout().size.to_string());
128 strided_slice_info.add("stride_param shape", node.get_dependency(3).get_output_layout().size.to_string());
130 node_info->add("strided_slice info", strided_slice_info);
131 node_info->dump(primitive_description);
133 return primitive_description.str();
136 strided_slice_inst::typed_primitive_inst(network_impl& network, strided_slice_node const& node)
137 : parent(network, node)