Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / reference / include / ngraph / runtime / reference / quantize.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include "ngraph/coordinate_transform.hpp"
20 #include "ngraph/op/quantize.hpp"
21 #include "ngraph/shape_util.hpp"
22
23 NGRAPH_SUPPRESS_DEPRECATED_START
24
25 namespace ngraph
26 {
27     namespace runtime
28     {
29         namespace reference
30         {
31             template <typename REAL, typename QUANT>
32             void quantize(const REAL* input,
33                           const REAL* scale,
34                           const QUANT* zero_point,
35                           QUANT* output,
36                           const Shape& input_shape,
37                           const Shape& scale_zero_point_shape,
38                           const AxisSet& axes,
39                           op::Quantize::RoundMode round_mode)
40             {
41                 CoordinateTransform input_transform(input_shape);
42                 CoordinateTransform scale_zero_point_transform(scale_zero_point_shape);
43
44                 for (const Coordinate& input_coord : input_transform)
45                 {
46                     Coordinate scale_zero_point_coord = project(input_coord, axes);
47
48                     // apply scale
49                     REAL qvalue = input[input_transform.index(input_coord)] /
50                                   scale[scale_zero_point_transform.index(scale_zero_point_coord)];
51
52                     // round
53                     if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY)
54                     {
55                         REAL abs_qvalue = std::fabs(qvalue);
56                         REAL abs_qvalue_toward_inf =
57                             std::floor(abs_qvalue + static_cast<REAL>(0.5));
58                         qvalue = (qvalue < static_cast<REAL>(0.0)) ? -abs_qvalue_toward_inf
59                                                                    : abs_qvalue_toward_inf;
60                     }
61                     else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_ZERO)
62                     {
63                         auto abs_qvalue = std::fabs(qvalue);
64                         auto abs_qvalue_toward_zero =
65                             std::ceil(abs_qvalue - static_cast<REAL>(0.5));
66                         qvalue = (qvalue < static_cast<REAL>(0.0)) ? -abs_qvalue_toward_zero
67                                                                    : abs_qvalue_toward_zero;
68                     }
69                     else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_UPWARD)
70                     {
71                         qvalue = std::floor(qvalue + static_cast<REAL>(0.5));
72                     }
73                     else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_DOWNWARD)
74                     {
75                         qvalue = std::ceil(qvalue - static_cast<REAL>(0.5));
76                     }
77                     else if (round_mode == op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
78                     {
79                         auto up_qvalue = std::floor(qvalue + static_cast<REAL>(0.5));
80                         auto dn_qvalue = std::ceil(qvalue - static_cast<REAL>(0.5));
81                         auto rem = std::fmod(up_qvalue, 2.0);
82                         qvalue = (rem == 0.0) ? up_qvalue : dn_qvalue;
83                     }
84                     else if (round_mode == op::Quantize::RoundMode::ROUND_TOWARD_INFINITY)
85                     {
86                         auto abs_qvalue = std::fabs(qvalue);
87                         auto abs_qvalue_toward_inf = std::ceil(abs_qvalue);
88                         qvalue = (qvalue < static_cast<REAL>(0.0)) ? -abs_qvalue_toward_inf
89                                                                    : abs_qvalue_toward_inf;
90                     }
91                     else if (round_mode == op::Quantize::RoundMode::ROUND_TOWARD_ZERO)
92                     {
93                         auto abs_qvalue = std::fabs(qvalue);
94                         auto abs_qvalue_toward_zero = std::floor(abs_qvalue);
95                         qvalue = (qvalue < static_cast<REAL>(0.0)) ? -abs_qvalue_toward_zero
96                                                                    : abs_qvalue_toward_zero;
97                     }
98                     else if (round_mode == op::Quantize::RoundMode::ROUND_UP)
99                     {
100                         qvalue = std::ceil(qvalue);
101                     }
102                     else if (round_mode == op::Quantize::RoundMode::ROUND_DOWN)
103                     {
104                         qvalue = std::floor(qvalue);
105                     }
106
107                     // apply zero_point
108                     qvalue += zero_point[scale_zero_point_transform.index(scale_zero_point_coord)];
109
110                     // clamp
111                     qvalue = std::max<REAL>(qvalue,
112                                             static_cast<REAL>(std::numeric_limits<QUANT>::min()));
113                     qvalue = std::min<REAL>(qvalue,
114                                             static_cast<REAL>(std::numeric_limits<QUANT>::max()));
115
116                     // cast
117                     output[input_transform.index(input_coord)] = static_cast<QUANT>(qvalue);
118                 }
119             }
120         }
121     }
122 }
123
124 NGRAPH_SUPPRESS_DEPRECATED_END