Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / reference / include / ngraph / runtime / reference / topk.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include <algorithm>
20 #include <cmath>
21 #include <numeric>
22
23 #include "ngraph/coordinate_transform.hpp"
24 #include "ngraph/op/topk.hpp"
25
26 NGRAPH_SUPPRESS_DEPRECATED_START
27
28 namespace ngraph
29 {
30     namespace runtime
31     {
32         namespace reference
33         {
34             // Had to split out these two functions. They used to be lambda expressions but
35             // MSVC had difficulty compiling. This way is more explicit.
36             template <typename T, typename U>
37             inline bool compare_max(const std::tuple<T, U>& a, const std::tuple<T, U>& b)
38             {
39 // this is intentional to be able to compare floats directly
40 // without using relative or absolute tolerance
41 #if defined(__GNUC__)
42 #pragma GCC diagnostic push
43 #pragma GCC diagnostic ignored "-Wfloat-equal"
44 #endif
45                 if (std::get<0>(a) == std::get<0>(b))
46                 {
47                     return std::get<1>(a) < std::get<1>(b);
48                 }
49 #if defined(__GNUC__)
50 #pragma GCC diagnostic pop
51 #endif
52                 return a > b;
53             }
54
55             template <typename T, typename U>
56             inline bool compare_min(const std::tuple<T, U>& a, const std::tuple<T, U>& b)
57             {
58                 return a < b;
59             }
60
61             template <typename T, typename U>
62             inline bool sort_indices_descending(const std::tuple<T, U>& a,
63                                                 const std::tuple<T, U>& b)
64             {
65                 return std::get<1>(a) < std::get<1>(b);
66             }
67
68             template <typename T, typename U>
69             inline bool sort_indices_ascending(const std::tuple<T, U>& a, const std::tuple<T, U>& b)
70             {
71                 return std::get<1>(a) > std::get<1>(b);
72             }
73
74             template <typename T, typename U>
75             void topk(const T* arg,
76                       U* out_indices,
77                       T* out_values,
78                       const Shape& in_shape,
79                       const Shape& out_shape,
80                       size_t axis,
81                       size_t k,
82                       bool compute_max,
83                       op::TopK::SortType sort = op::TopK::SortType::NONE)
84             {
85                 using namespace std;
86                 // reorder source axis visit order and make "axis" inner most
87                 size_t ndim = static_cast<size_t>(in_shape.size());
88                 Coordinate start_corner(ndim, 0);
89                 Coordinate end_corner(in_shape);
90                 end_corner[axis] = 1;
91                 Strides strides(ndim, 1);
92                 AxisVector axis_order(ndim);
93                 iota(axis_order.begin(), axis_order.end(), 0);
94                 axis_order.erase(axis_order.begin() + axis);
95                 axis_order.push_back(axis);
96                 // Create CoordinateTransforms that visits only the first element along "axis"
97                 CoordinateTransform input_transform(
98                     in_shape, start_corner, end_corner, strides, axis_order);
99                 CoordinateTransform output_transform(
100                     out_shape, start_corner, end_corner, strides, axis_order);
101                 // Create temp vector for sorting.
102                 vector<tuple<T, U>> workspace(in_shape[axis]);
103                 vector<size_t> in_strides = ngraph::row_major_strides(in_shape);
104                 vector<size_t> out_strides = ngraph::row_major_strides(out_shape);
105                 auto in_axis_stride = in_strides[axis];
106                 auto out_axis_stride = out_strides[axis];
107                 for (const Coordinate& coord : input_transform)
108                 {
109                     auto arg_index = input_transform.index(coord);
110                     auto out_index = output_transform.index(coord);
111                     // Fill the temp vector
112                     U i = 0;
113                     for (tuple<T, U>& entry : workspace)
114                     {
115                         get<0>(entry) = arg[arg_index];
116                         get<1>(entry) = i;
117                         arg_index += in_axis_stride;
118                         i++;
119                     }
120                     // Sort the temp vector
121                     if (compute_max)
122                     {
123                         nth_element(workspace.begin(),
124                                     workspace.begin() + k,
125                                     workspace.end(),
126                                     compare_max<T, U>);
127                     }
128                     else
129                     {
130                         nth_element(workspace.begin(),
131                                     workspace.begin() + k,
132                                     workspace.end(),
133                                     compare_min<T, U>);
134                     }
135                     // Write temp vector to output
136                     if (compute_max)
137                     {
138                         switch (sort)
139                         {
140                         case op::TopK::SortType::NONE: break;
141                         case op::TopK::SortType::SORT_INDICES:
142                             std::sort(workspace.begin(),
143                                       workspace.begin() + k,
144                                       sort_indices_descending<T, U>);
145                             break;
146                         case op::TopK::SortType::SORT_VALUES:
147                             std::sort(workspace.begin(), workspace.begin() + k, compare_max<T, U>);
148                             break;
149                         }
150                     }
151                     else
152                     {
153                         switch (sort)
154                         {
155                         case op::TopK::SortType::NONE: break;
156                         case op::TopK::SortType::SORT_INDICES:
157                             std::sort(workspace.begin(),
158                                       workspace.begin() + k,
159                                       sort_indices_ascending<T, U>);
160                             break;
161                         case op::TopK::SortType::SORT_VALUES:
162                             std::sort(workspace.begin(), workspace.begin() + k, compare_min<T, U>);
163                             break;
164                         }
165                     }
166                     for (size_t j = 0; j < k; j++)
167                     {
168                         tuple<T, U> entry = workspace[j];
169                         out_values[out_index] = get<0>(entry);
170                         out_indices[out_index] = get<1>(entry);
171                         out_index += out_axis_stride;
172                     }
173                 }
174             }
175         }
176     }
177 }
178
179 NGRAPH_SUPPRESS_DEPRECATED_END