Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / include / ngraph / op / topk.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include <memory>
20
21 #include "ngraph/axis_set.hpp"
22 #include "ngraph/op/constant.hpp"
23 #include "ngraph/op/op.hpp"
24
25 namespace ngraph
26 {
27     namespace op
28     {
29         namespace v0
30         {
31             // \brief Computes indices of top k maximum/minimum index along a specified axis for a
32             //        given tensor
33             class NGRAPH_DEPRECATED(
34                 "This operation is deprecated and will be removed soon. "
35                 "Use v1::TopK instead of it.") NGRAPH_API TopK : public Op
36             {
37                 NGRAPH_SUPPRESS_DEPRECATED_START
38             public:
39                 using SortType = TopKSortType;
40
41                 static constexpr NodeTypeInfo type_info{"TopK", 0};
42                 const NodeTypeInfo& get_type_info() const override { return type_info; }
43                 /// \brief Constructs a TopK operation
44                 TopK() = default;
45                 /// \brief Constructs a TopK operation.
46                 ///
47                 /// \param arg The input tensor
48                 /// \param top_k_axis The axis along which to compute top k indices
49                 /// \param index_element_type produce indices. Currently, only int64 or int32 are
50                 ///                           supported
51                 /// \param k Number of top indices to compute. Compute all indices if k = 0
52                 /// \param compute_max Compute top k max or top k min?
53                 /// \param sort SortType for sorting results, default - SORT_VALUES
54                 TopK(const Output<Node>& arg,
55                      size_t top_k_axis,
56                      const element::Type& index_element_type,
57                      size_t k = 0,
58                      bool compute_max = true,
59                      SortType sort = SortType::SORT_VALUES);
60                 /// \brief Constructs a TopK operation.
61                 ///
62                 /// \param arg The input tensor
63                 /// \param k Number of top indices to compute. Compute all indices if k = 0
64                 /// \param top_k_axis The axis along which to compute top k indices
65                 /// \param index_element_type produce indices. Currently, only int64 or int32 are
66                 ///                           supported
67                 /// \param compute_max Compute top k max or top k min?
68                 /// \param sort SortType for sorting results, default - SORT_VALUES
69                 TopK(const Output<Node>& arg,
70                      const Output<Node>& k,
71                      size_t top_k_axis,
72                      const element::Type& index_element_type,
73                      bool compute_max = true,
74                      SortType sort = SortType::SORT_VALUES);
75
76                 /// \brief Constructs a TopK operation.
77                 ///
78                 /// \param arg The input tensor
79                 /// \param k Number of top indices to compute. Compute all indices if k = 0
80                 /// \param top_k_axis The axis along which to compute top k indices
81                 /// \param index_element_type produce indices. Currently, only int64 or int32 are
82                 /// supported
83                 /// \param compute_max Compute top k max or top k min?
84                 /// \param sort SortType for sorting results, default - NONE
85                 TopK(const Output<Node>& arg,
86                      const Output<Node>& k,
87                      const Output<Node>& top_k_axis,
88                      const element::Type& index_element_type,
89                      bool compute_max = true,
90                      SortType sort = SortType::NONE);
91
92                 void validate_and_infer_types() override;
93
94                 virtual std::shared_ptr<Node>
95                     clone_with_new_inputs(const OutputVector& new_args) const override;
96
97                 size_t get_k() const;
98                 void set_k(size_t k);
99
100                 size_t get_top_k_axis() const;
101                 Dimension get_top_k_axis_dynamic() const;
102                 void set_top_k_axis(size_t k);
103
104                 element::Type get_index_element_type() const { return m_index_element_type; }
105                 bool get_compute_max() const { return m_compute_max; }
106                 SortType get_sort() const { return m_sort; }
107                 size_t get_default_output_index() const override { return no_default_index(); }
108                 bool evaluate(const HostTensorVector& outputs,
109                               const HostTensorVector& inputs) const override;
110
111             protected:
112                 element::Type m_index_element_type;
113                 bool m_compute_max{false};
114                 SortType m_sort{SortType::NONE};
115                 Shape compute_output_shape(const Shape input_shape,
116                                            const int64_t k,
117                                            const size_t axis) const;
118                 NGRAPH_SUPPRESS_DEPRECATED_END
119             };
120         } // namespace v0
121
122         namespace v1
123         {
124             /// \brief Computes indices and values of the k maximum/minimum values
125             ///        for each slice along specified axis.
126             class NGRAPH_API TopK : public Op
127             {
128             public:
129                 using SortType = TopKSortType;
130                 using Mode = TopKMode;
131
132                 static constexpr NodeTypeInfo type_info{"TopK", 1};
133                 const NodeTypeInfo& get_type_info() const override { return type_info; }
134                 /// \brief Constructs a TopK operation
135                 TopK() = default;
136                 /// \brief Constructs a TopK operation with two outputs: values and indices.
137                 ///        By default the indices output is described by i32 data type.
138                 ///
139                 /// \param data The input tensor
140                 /// \param k Specifies how many maximum/minimum elements should be computed
141                 ///          (note: scalar input tensor)
142                 /// \param axis The axis along which to compute top k indices
143                 /// \param mode Specifies which operation (min or max) is used to select
144                 ///             the biggest element of two.
145                 /// \param sort Specifies order of output elements and/or indices
146                 ///             Accepted values: none, index, value
147                 /// \param index_element_type Specyfies type of produced indices
148                 TopK(const Output<Node>& data,
149                      const Output<Node>& k,
150                      const int64_t axis,
151                      const std::string& mode,
152                      const std::string& sort,
153                      const element::Type& index_element_type = element::i32);
154
155                 TopK(const Output<Node>& data,
156                      const Output<Node>& k,
157                      const int64_t axis,
158                      const Mode mode,
159                      const SortType sort,
160                      const element::Type& index_element_type = element::i32);
161
162                 bool visit_attributes(AttributeVisitor& visitor) override;
163                 void validate_and_infer_types() override;
164
165                 virtual std::shared_ptr<Node>
166                     clone_with_new_inputs(const OutputVector& new_args) const override;
167
168                 virtual size_t get_version() const override { return 1; }
169                 /// \brief Returns axis value after normalization
170                 /// \note If input rank required to normalization is dynamic, the exception is
171                 /// thrown
172                 uint64_t get_axis() const;
173                 /// \brief Returns axis value before normalization
174                 int64_t get_provided_axis() const { return m_axis; }
175                 void set_axis(const int64_t axis);
176                 Mode get_mode() const { return m_mode; }
177                 void set_mode(const Mode mode) { m_mode = mode; }
178                 SortType get_sort_type() const { return m_sort; }
179                 void set_sort_type(const SortType sort) { m_sort = sort; }
180                 element::Type get_index_element_type() const { return m_index_element_type; }
181                 void set_index_element_type(const element::Type& index_element_type)
182                 {
183                     m_index_element_type = index_element_type;
184                 }
185                 /// \brief Returns the value of K, if available
186                 ///
187                 /// \note If the second input to this op is a constant, the value is retrieved
188                 ///       and returned. If the input is not constant(dynamic) this method returns 0
189                 size_t get_k() const;
190                 void set_k(size_t k);
191                 size_t get_default_output_index() const override { return no_default_index(); }
192                 bool evaluate(const HostTensorVector& outputs,
193                               const HostTensorVector& inputs) const override;
194
195             protected:
196                 int64_t m_axis;
197                 uint64_t m_normalized_axis;
198                 Mode m_mode;
199                 SortType m_sort;
200                 element::Type m_index_element_type{element::i32};
201
202                 virtual size_t read_k_from_constant_node(const std::shared_ptr<Node>& node,
203                                                          const element::Type& k_element_type) const;
204
205                 template <typename T>
206                 size_t validate_and_get_k(const std::shared_ptr<op::Constant>& k_constant) const;
207                 Shape compute_output_shape(const std::string& node_description,
208                                            const PartialShape input_partial_shape,
209                                            const int64_t k) const;
210                 void set_axis(const Rank input_rank, const int64_t axis);
211             };
212         } // namespace v1
213
214         namespace v3
215         {
216             /// \brief Computes indices and values of the k maximum/minimum values
217             ///        for each slice along specified axis.
218             class NGRAPH_API TopK : public v1::TopK
219             {
220             public:
221                 static constexpr NodeTypeInfo type_info{"TopK", 3};
222                 const NodeTypeInfo& get_type_info() const override { return type_info; }
223                 /// \brief Constructs a TopK operation
224                 TopK() = default;
225                 /// \brief Constructs a TopK operation with two outputs: values and indices.
226                 ///        By default the indices output is described by i32 data type.
227                 ///
228                 /// \param data The input tensor
229                 /// \param k Specifies how many maximum/minimum elements should be computed
230                 ///          (note: scalar input tensor)
231                 /// \param axis The axis along which to compute top k indices
232                 /// \param mode Specifies which operation (min or max) is used to select
233                 ///             the biggest element of two.
234                 /// \param sort Specifies order of output elements and/or indices
235                 ///             Accepted values: none, index, value
236                 /// \param index_element_type Specyfies type of produced indices
237                 TopK(const Output<Node>& data,
238                      const Output<Node>& k,
239                      const int64_t axis,
240                      const std::string& mode,
241                      const std::string& sort,
242                      const element::Type& index_element_type = element::i32);
243
244                 TopK(const Output<Node>& data,
245                      const Output<Node>& k,
246                      const int64_t axis,
247                      const Mode mode,
248                      const SortType sort,
249                      const element::Type& index_element_type = element::i32);
250                 bool visit_attributes(AttributeVisitor& visitor) override;
251                 void validate_and_infer_types() override;
252                 virtual std::shared_ptr<Node>
253                     clone_with_new_inputs(const OutputVector& new_args) const override;
254
255                 bool evaluate(const HostTensorVector& outputs,
256                               const HostTensorVector& inputs) const override;
257
258             protected:
259                 virtual size_t
260                     read_k_from_constant_node(const std::shared_ptr<Node>& node,
261                                               const element::Type& k_element_type) const override;
262             };
263         } // namespace v3
264
265         NGRAPH_SUPPRESS_DEPRECATED_START
266         using v0::TopK;
267         NGRAPH_SUPPRESS_DEPRECATED_END
268     } // op
269 } // ngraph