Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / src / op / quantize.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "ngraph/op/quantize.hpp"
18 #include "ngraph/shape_util.hpp"
19
20 NGRAPH_SUPPRESS_DEPRECATED_START
21
22 using namespace std;
23 using namespace ngraph;
24
25 constexpr NodeTypeInfo op::Quantize::type_info;
26
27 op::Quantize::Quantize(const Output<Node>& input,
28                        const Output<Node>& scale,
29                        const Output<Node>& zero_point,
30                        const element::Type& type,
31                        const AxisSet& axes,
32                        RoundMode round_mode)
33
34     : Op({input, scale, zero_point})
35     , m_type(type)
36     , m_axes(axes)
37     , m_round_mode(round_mode)
38 {
39     constructor_validate_and_infer_types();
40 }
41
42 void op::Quantize::validate_and_infer_types()
43 {
44     enum
45     {
46         INPUT,
47         SCALE,
48         ZERO_POINT
49     };
50
51     NODE_VALIDATION_CHECK(this, m_type.is_static(), "Output element type must not be dynamic");
52
53     NODE_VALIDATION_CHECK(
54         this, m_type.is_quantized(), "Output element type (", m_type, ") must be a quantized type");
55
56     element::Type unquantized_type;
57
58     NODE_VALIDATION_CHECK(this,
59                           element::Type::merge(unquantized_type,
60                                                get_input_element_type(INPUT),
61                                                get_input_element_type(SCALE)),
62                           "Scale element type (",
63                           get_input_element_type(SCALE),
64                           ") must match input element type (",
65                           get_input_element_type(INPUT),
66                           ")");
67
68     NODE_VALIDATION_CHECK(this,
69                           unquantized_type.is_dynamic() || unquantized_type.is_real(),
70                           "Scale / input element type (",
71                           unquantized_type,
72                           ") must be a floating point number");
73
74     element::Type quantized_type;
75
76     NODE_VALIDATION_CHECK(
77         this,
78         element::Type::merge(quantized_type, get_input_element_type(ZERO_POINT), m_type),
79         "Zero point element type (",
80         get_input_element_type(ZERO_POINT),
81         ") must match output element type (",
82         m_type,
83         ")");
84
85     PartialShape input_shape = get_input_partial_shape(0);
86     Dimension input_rank = input_shape.rank();
87
88     for (auto axis : m_axes)
89     {
90         NODE_VALIDATION_CHECK(this,
91                               input_rank.is_dynamic() || axis < input_rank.get_length(),
92                               "Quantization axis (",
93                               axis,
94                               ") must be less than input shape rank (",
95                               input_rank,
96                               ")");
97     }
98
99     PartialShape scale_zero_point_shape = get_input_partial_shape(SCALE);
100
101     NODE_VALIDATION_CHECK(
102         this,
103         PartialShape::merge_into(scale_zero_point_shape, get_input_partial_shape(ZERO_POINT)),
104         "Scale shape (",
105         get_input_partial_shape(SCALE),
106         ") and zero point shape (",
107         get_input_partial_shape(ZERO_POINT),
108         ") must match");
109
110     NODE_VALIDATION_CHECK(this,
111                           scale_zero_point_shape.rank().compatible(m_axes.size()),
112                           "Scale / zero point rank (",
113                           scale_zero_point_shape.rank(),
114                           ") does not match the number of ",
115                           "quantization axes (",
116                           m_axes.size(),
117                           ")");
118
119     set_output_size(1);
120
121     if (input_shape.rank().is_static() && scale_zero_point_shape.rank().is_static())
122     {
123         size_t i = 0;
124
125         vector<Dimension> injected_scale_zero_point_dims;
126
127         for (size_t j = 0; j < input_shape.rank().get_length(); j++)
128         {
129             if (m_axes.count(j) != 0)
130             {
131                 injected_scale_zero_point_dims.push_back(scale_zero_point_shape[i++]);
132             }
133             else
134             {
135                 injected_scale_zero_point_dims.push_back(Dimension::dynamic());
136             }
137         }
138
139         PartialShape result_shape = input_shape;
140         NODE_VALIDATION_CHECK(
141             this,
142             PartialShape::merge_into(result_shape, PartialShape{injected_scale_zero_point_dims}),
143             "Scale / zero point shape (",
144             scale_zero_point_shape,
145             ") must match input shape (",
146             input_shape,
147             ") at the quantization axes (",
148             m_axes,
149             ")");
150         set_output_type(0, quantized_type, result_shape);
151     }
152     else
153     {
154         set_output_type(0, quantized_type, PartialShape::dynamic());
155     }
156 }
157
158 shared_ptr<Node> op::Quantize::clone_with_new_inputs(const OutputVector& new_args) const
159 {
160     check_new_args_count(this, new_args);
161     return make_shared<Quantize>(
162         new_args.at(0), new_args.at(1), new_args.at(2), m_type, m_axes, m_round_mode);
163 }