Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / src / op / gather.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "ngraph/op/gather.hpp"
18 #include "itt.hpp"
19 #include "ngraph/op/constant.hpp"
20 #include "ngraph/runtime/host_tensor.hpp"
21 #include "ngraph/runtime/reference/gather.hpp"
22 #include "ngraph/shape.hpp"
23
24 #include <limits>
25
26 NGRAPH_SUPPRESS_DEPRECATED_START
27
28 using namespace std;
29 using namespace ngraph;
30
31 static const int PARAMS = 0;
32 static const int INDICES = 1;
33 static const int AXIS = 2;
34
35 constexpr NodeTypeInfo op::v0::Gather::type_info;
36
37 op::v0::Gather::Gather(const Output<Node>& params, const Output<Node>& indices, size_t axis)
38     : Op({params, indices})
39     , m_axis(axis)
40 {
41     constructor_validate_and_infer_types();
42 }
43
44 shared_ptr<Node> op::v0::Gather::clone_with_new_inputs(const OutputVector& new_args) const
45 {
46     check_new_args_count(this, new_args);
47     return make_shared<v0::Gather>(new_args.at(PARAMS), new_args.at(INDICES), m_axis);
48 }
49
50 void op::v0::Gather::validate_and_infer_types()
51 {
52     element::Type result_et = get_input_element_type(PARAMS);
53     element::Type indices_et = get_input_element_type(INDICES);
54
55     const PartialShape& params_shape = get_input_partial_shape(PARAMS);
56     const PartialShape& indices_shape = get_input_partial_shape(INDICES);
57
58     NODE_VALIDATION_CHECK(this,
59                           indices_et == element::i32 || indices_et == element::i64,
60                           "Indices element type must be i64 or i32");
61
62     // params rank must be at least (axis + 1)
63     // indices value must be in range [0, params.shape[axis]).
64     // output rank is rank(params) + rank(indices) - 1
65     NODE_VALIDATION_CHECK(this,
66                           params_shape.rank().is_dynamic() ||
67                               params_shape.rank().get_length() > static_cast<size_t>(m_axis),
68                           "params rank is expected to be at least axis + 1");
69
70     PartialShape result_shape;
71     if (params_shape.rank().is_static() && indices_shape.rank().is_static())
72     {
73         std::vector<Dimension> result_dims(params_shape.rank().get_length() +
74                                            indices_shape.rank().get_length() - 1);
75         size_t i = 0;
76         for (; i < static_cast<size_t>(m_axis); i++)
77         {
78             result_dims[i] = params_shape[i];
79         }
80         for (size_t j = 0; j < indices_shape.rank().get_length(); i++, j++)
81         {
82             result_dims[i] = indices_shape[j];
83         }
84         for (size_t j = static_cast<size_t>(m_axis) + 1; j < params_shape.rank().get_length();
85              i++, j++)
86         {
87             result_dims[i] = params_shape[j];
88         }
89
90         result_shape = PartialShape(result_dims);
91     }
92     else
93     {
94         result_shape = PartialShape::dynamic();
95     }
96
97     set_output_type(0, result_et, result_shape);
98 }
99
100 constexpr NodeTypeInfo op::v1::Gather::type_info;
101 const int64_t op::v1::Gather::AXIS_NOT_SET_VALUE;
102
103 op::v1::Gather::Gather(const Output<Node>& params,
104                        const Output<Node>& indices,
105                        const Output<Node>& axes)
106     : Op({params, indices, axes})
107 {
108     constructor_validate_and_infer_types();
109 }
110
111 bool ngraph::op::v1::Gather::visit_attributes(AttributeVisitor& visitor)
112 {
113     return true;
114 }
115
116 void op::v1::Gather::validate_and_infer_types()
117 {
118     const auto& input_rank = get_input_partial_shape(PARAMS).rank();
119     const auto& axis_shape = get_input_partial_shape(AXIS);
120     const auto& axis_rank = axis_shape.rank();
121
122     if (axis_rank.is_static() && axis_shape.is_static())
123     {
124         const auto axis_is_scalar = axis_rank.get_length() == 0;
125         const auto axis_has_one_elem =
126             axis_rank.get_length() == 1 && axis_shape[0].get_length() == 1;
127         NODE_VALIDATION_CHECK(this,
128                               axis_is_scalar || axis_has_one_elem,
129                               "Axes input must be scalar or have 1 element (shape: ",
130                               axis_shape,
131                               ").");
132     }
133
134     int64_t axis = get_axis();
135     if (input_rank.is_static() && axis != AXIS_NOT_SET_VALUE)
136     {
137         NODE_VALIDATION_CHECK(this,
138                               axis < input_rank.get_length(),
139                               "The axis must => 0 and <= input_rank (axis: ",
140                               axis,
141                               ").");
142     }
143
144     element::Type result_et = get_input_element_type(PARAMS);
145     element::Type indices_et = get_input_element_type(INDICES);
146
147     const PartialShape& params_shape = get_input_partial_shape(PARAMS);
148     const PartialShape& indices_shape = get_input_partial_shape(INDICES);
149
150     PartialShape result_shape;
151     if (params_shape.rank().is_static() && indices_shape.rank().is_static() &&
152         axis != AXIS_NOT_SET_VALUE)
153     {
154         std::vector<Dimension> result_dims(params_shape.rank().get_length() +
155                                            indices_shape.rank().get_length() - 1);
156         uint64_t i = 0;
157         for (; i < axis; i++)
158         {
159             result_dims[i] = params_shape[i];
160         }
161         for (uint64_t j = 0; j < indices_shape.rank().get_length(); i++, j++)
162         {
163             result_dims[i] = indices_shape[j];
164         }
165         for (uint64_t j = axis + 1; j < params_shape.rank().get_length(); i++, j++)
166         {
167             result_dims[i] = params_shape[j];
168         }
169
170         result_shape = PartialShape(result_dims);
171     }
172     else
173     {
174         result_shape = PartialShape::dynamic();
175     }
176
177     set_output_type(0, result_et, result_shape);
178 }
179
180 int64_t op::v1::Gather::get_axis() const
181 {
182     int64_t axis = AXIS_NOT_SET_VALUE;
183     auto axes_input_node = input_value(AXIS).get_node_shared_ptr();
184     if (auto const_op = as_type_ptr<op::Constant>(axes_input_node))
185     {
186         axis = const_op->cast_vector<int64_t>()[0];
187     }
188     if (axis < 0)
189     {
190         const auto& input_rank = get_input_partial_shape(PARAMS).rank();
191         if (input_rank.is_static())
192         {
193             axis += input_rank.get_length();
194         }
195     }
196     return axis;
197 }
198
199 shared_ptr<Node> op::v1::Gather::clone_with_new_inputs(const OutputVector& new_args) const
200 {
201     check_new_args_count(this, new_args);
202     return make_shared<v1::Gather>(new_args.at(PARAMS), new_args.at(INDICES), new_args.at(AXIS));
203 }
204
205 namespace
206 {
207     template <element::Type_t ET>
208     bool evaluate(const HostTensorPtr& arg0,
209                   const HostTensorPtr& arg1,
210                   const HostTensorPtr& out,
211                   size_t axis)
212     {
213         using T = typename element_type_traits<ET>::value_type;
214         Shape params_shape = arg0->get_shape();
215         Shape indices_shape = arg1->get_shape();
216         Shape out_shape(params_shape.size() + indices_shape.size() - 1);
217         uint64_t i = 0;
218         for (; i < axis; i++)
219         {
220             out_shape[i] = params_shape[i];
221         }
222         for (uint64_t j = 0; j < indices_shape.size(); i++, j++)
223         {
224             out_shape[i] = indices_shape[j];
225         }
226         for (uint64_t j = axis + 1; j < params_shape.size(); i++, j++)
227         {
228             out_shape[i] = params_shape[j];
229         }
230
231         out->set_shape(out_shape);
232
233         if (arg1->get_element_type() == element::i64)
234         {
235             runtime::reference::gather<T, int64_t>(arg0->get_data_ptr<ET>(),
236                                                    arg1->get_data_ptr<int64_t>(),
237                                                    out->get_data_ptr<ET>(),
238                                                    arg0->get_shape(),
239                                                    arg1->get_shape(),
240                                                    out->get_shape(),
241                                                    axis);
242         }
243         else if (arg1->get_element_type() == element::i32)
244         {
245             runtime::reference::gather<T, int32_t>(arg0->get_data_ptr<ET>(),
246                                                    arg1->get_data_ptr<int32_t>(),
247                                                    out->get_data_ptr<ET>(),
248                                                    arg0->get_shape(),
249                                                    arg1->get_shape(),
250                                                    out->get_shape(),
251                                                    axis);
252         }
253         else
254         {
255             throw ngraph_error("Unexpected type");
256         }
257
258         return true;
259     }
260
261     bool evaluate_gather(const HostTensorPtr& arg0,
262                          const HostTensorPtr& arg1,
263                          const HostTensorPtr& out,
264                          size_t axis)
265     {
266         bool rc = true;
267
268         switch (out->get_element_type())
269         {
270             TYPE_CASE(i32)(arg0, arg1, out, axis);
271             break;
272             TYPE_CASE(i64)(arg0, arg1, out, axis);
273             break;
274             TYPE_CASE(u32)(arg0, arg1, out, axis);
275             break;
276             TYPE_CASE(u64)(arg0, arg1, out, axis);
277             break;
278             TYPE_CASE(f16)(arg0, arg1, out, axis);
279             break;
280             TYPE_CASE(f32)(arg0, arg1, out, axis);
281             break;
282             TYPE_CASE(boolean)(arg0, arg1, out, axis);
283             break;
284         default: rc = false; break;
285         }
286         return rc;
287     }
288 }
289
290 bool op::v0::Gather::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
291 {
292     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v0::Gather::evaluate");
293     return evaluate_gather(inputs[0], inputs[1], outputs[0], get_axis());
294 }
295
296 bool op::v1::Gather::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
297 {
298     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v1::Gather::evaluate");
299     int64_t axis = 0;
300     switch (inputs[2]->get_element_type())
301     {
302     case element::Type_t::i8: axis = inputs[2]->get_data_ptr<element::Type_t::i8>()[0]; break;
303     case element::Type_t::i16: axis = inputs[2]->get_data_ptr<element::Type_t::i16>()[0]; break;
304     case element::Type_t::i32: axis = inputs[2]->get_data_ptr<element::Type_t::i32>()[0]; break;
305     case element::Type_t::i64: axis = inputs[2]->get_data_ptr<element::Type_t::i64>()[0]; break;
306     case element::Type_t::u8: axis = inputs[2]->get_data_ptr<element::Type_t::u8>()[0]; break;
307     case element::Type_t::u16: axis = inputs[2]->get_data_ptr<element::Type_t::u16>()[0]; break;
308     case element::Type_t::u32: axis = inputs[2]->get_data_ptr<element::Type_t::u32>()[0]; break;
309     case element::Type_t::u64: axis = inputs[2]->get_data_ptr<element::Type_t::u64>()[0]; break;
310     default: throw ngraph_error("axis element type is not integral data type");
311     }
312
313     if (axis < 0)
314     {
315         const auto& input_rank = get_input_partial_shape(PARAMS).rank();
316         if (input_rank.is_static())
317         {
318             axis += input_rank.get_length();
319         }
320     }
321     return evaluate_gather(inputs[0], inputs[1], outputs[0], axis);
322 }