1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "ngraph/op/slice.hpp"
19 #include "ngraph/runtime/host_tensor.hpp"
20 #include "ngraph/runtime/reference/slice.hpp"
22 NGRAPH_SUPPRESS_DEPRECATED_START
25 using namespace ngraph;
27 constexpr NodeTypeInfo op::Slice::type_info;
29 op::Slice::Slice(const Output<Node>& arg,
30 const Coordinate& lower_bounds,
31 const Coordinate& upper_bounds,
32 const Strides& strides)
34 , m_lower_bounds(lower_bounds)
35 , m_upper_bounds(upper_bounds)
38 constructor_validate_and_infer_types();
41 op::Slice::Slice(const Output<Node>& arg,
42 const Coordinate& lower_bounds,
43 const Coordinate& upper_bounds)
45 , m_lower_bounds(lower_bounds)
46 , m_upper_bounds(upper_bounds)
47 , m_strides(Strides())
49 constructor_validate_and_infer_types();
52 void op::Slice::validate_and_infer_types()
54 // An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
55 // construct the default value.
56 if (m_strides.size() == 0)
58 m_strides = Strides(m_lower_bounds.size(), 1);
61 NODE_VALIDATION_CHECK(this,
62 m_lower_bounds.size() == m_upper_bounds.size() &&
63 m_lower_bounds.size() == m_strides.size(),
64 "Ranks of lower bounds (",
72 size_t output_rank = m_upper_bounds.size();
74 for (size_t i = 0; i < output_rank; i++)
76 NODE_VALIDATION_CHECK(this,
77 m_lower_bounds[i] <= m_upper_bounds[i],
78 "Lower bound for slice is greater than upper bound at axis ",
86 NODE_VALIDATION_CHECK(this,
88 "Stride for slice is zero at axis ",
95 const PartialShape& input_shape = get_input_partial_shape(0);
96 Dimension input_rank = input_shape.rank();
98 NODE_VALIDATION_CHECK(this,
99 input_rank.is_dynamic() || input_rank.get_length() == output_rank,
100 "Input rank does not match the rank of the lower bounds (",
108 std::vector<Dimension> result_dims(output_rank);
110 for (size_t i = 0; i < output_rank; i++)
112 NODE_VALIDATION_CHECK(this,
113 input_rank.is_dynamic() || input_shape[i].is_dynamic() ||
114 m_upper_bounds[i] <= input_shape[i].get_length(),
115 "Upper bound for slice at axis ",
120 ", argument shape: ",
124 size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
126 result_axis_size / m_strides[i] + ((result_axis_size % m_strides[i] == 0) ? 0 : 1);
127 result_dims[i] = result_axis_size;
130 set_output_type(0, get_input_element_type(0), PartialShape{result_dims});
133 shared_ptr<Node> op::Slice::clone_with_new_inputs(const OutputVector& new_args) const
135 check_new_args_count(this, new_args);
136 return make_shared<Slice>(new_args.at(0), m_lower_bounds, m_upper_bounds, m_strides);
141 bool evaluate_slice(const HostTensorPtr& in,
142 const HostTensorPtr& out,
143 const Coordinate& lower_bounds,
144 const Coordinate& upper_bounds,
145 const Strides& strides)
147 runtime::reference::slice(in->get_data_ptr<const char>(),
148 out->get_data_ptr<char>(),
154 in->get_element_type().size());
160 bool op::Slice::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
162 const auto& data = inputs[0];
163 const auto& output = outputs[0];
165 return evaluate_slice(data, output, m_lower_bounds, m_upper_bounds, m_strides);