Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / src / op / slice.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "ngraph/op/slice.hpp"
18
19 #include "ngraph/runtime/host_tensor.hpp"
20 #include "ngraph/runtime/reference/slice.hpp"
21
22 NGRAPH_SUPPRESS_DEPRECATED_START
23
24 using namespace std;
25 using namespace ngraph;
26
27 constexpr NodeTypeInfo op::Slice::type_info;
28
29 op::Slice::Slice(const Output<Node>& arg,
30                  const Coordinate& lower_bounds,
31                  const Coordinate& upper_bounds,
32                  const Strides& strides)
33     : Op({arg})
34     , m_lower_bounds(lower_bounds)
35     , m_upper_bounds(upper_bounds)
36     , m_strides(strides)
37 {
38     constructor_validate_and_infer_types();
39 }
40
41 op::Slice::Slice(const Output<Node>& arg,
42                  const Coordinate& lower_bounds,
43                  const Coordinate& upper_bounds)
44     : Op({arg})
45     , m_lower_bounds(lower_bounds)
46     , m_upper_bounds(upper_bounds)
47     , m_strides(Strides())
48 {
49     constructor_validate_and_infer_types();
50 }
51
52 void op::Slice::validate_and_infer_types()
53 {
54     // An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
55     // construct the default value.
56     if (m_strides.size() == 0)
57     {
58         m_strides = Strides(m_lower_bounds.size(), 1);
59     }
60
61     NODE_VALIDATION_CHECK(this,
62                           m_lower_bounds.size() == m_upper_bounds.size() &&
63                               m_lower_bounds.size() == m_strides.size(),
64                           "Ranks of lower bounds (",
65                           m_lower_bounds,
66                           "), upper bounds (",
67                           m_upper_bounds,
68                           ") and strides (",
69                           m_strides,
70                           ") do not match.");
71
72     size_t output_rank = m_upper_bounds.size();
73
74     for (size_t i = 0; i < output_rank; i++)
75     {
76         NODE_VALIDATION_CHECK(this,
77                               m_lower_bounds[i] <= m_upper_bounds[i],
78                               "Lower bound for slice is greater than upper bound at axis ",
79                               i,
80                               " (lower bounds: ",
81                               m_lower_bounds,
82                               ", upper bounds: ",
83                               m_upper_bounds,
84                               ").");
85
86         NODE_VALIDATION_CHECK(this,
87                               m_strides[i] != 0,
88                               "Stride for slice is zero at axis ",
89                               i,
90                               " (strides: ",
91                               m_strides,
92                               ").");
93     }
94
95     const PartialShape& input_shape = get_input_partial_shape(0);
96     Dimension input_rank = input_shape.rank();
97
98     NODE_VALIDATION_CHECK(this,
99                           input_rank.is_dynamic() || input_rank.get_length() == output_rank,
100                           "Input rank does not match the rank of the lower bounds (",
101                           m_lower_bounds,
102                           "), upper bounds (",
103                           m_upper_bounds,
104                           "), and strides (",
105                           m_strides,
106                           ").");
107
108     std::vector<Dimension> result_dims(output_rank);
109
110     for (size_t i = 0; i < output_rank; i++)
111     {
112         NODE_VALIDATION_CHECK(this,
113                               input_rank.is_dynamic() || input_shape[i].is_dynamic() ||
114                                   m_upper_bounds[i] <= input_shape[i].get_length(),
115                               "Upper bound for slice at axis ",
116                               i,
117                               " is out of range ",
118                               "(upper bounds: ",
119                               m_upper_bounds,
120                               ", argument shape: ",
121                               input_shape,
122                               ").");
123
124         size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
125         result_axis_size =
126             result_axis_size / m_strides[i] + ((result_axis_size % m_strides[i] == 0) ? 0 : 1);
127         result_dims[i] = result_axis_size;
128     }
129
130     set_output_type(0, get_input_element_type(0), PartialShape{result_dims});
131 }
132
133 shared_ptr<Node> op::Slice::clone_with_new_inputs(const OutputVector& new_args) const
134 {
135     check_new_args_count(this, new_args);
136     return make_shared<Slice>(new_args.at(0), m_lower_bounds, m_upper_bounds, m_strides);
137 }
138
139 namespace
140 {
141     bool evaluate_slice(const HostTensorPtr& in,
142                         const HostTensorPtr& out,
143                         const Coordinate& lower_bounds,
144                         const Coordinate& upper_bounds,
145                         const Strides& strides)
146     {
147         runtime::reference::slice(in->get_data_ptr<const char>(),
148                                   out->get_data_ptr<char>(),
149                                   in->get_shape(),
150                                   lower_bounds,
151                                   upper_bounds,
152                                   strides,
153                                   out->get_shape(),
154                                   in->get_element_type().size());
155
156         return true;
157     }
158 }
159
160 bool op::Slice::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
161 {
162     const auto& data = inputs[0];
163     const auto& output = outputs[0];
164
165     return evaluate_slice(data, output, m_lower_bounds, m_upper_bounds, m_strides);
166 }