Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / src / op / dequantize.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "ngraph/op/dequantize.hpp"
18 #include "ngraph/shape_util.hpp"
19
20 NGRAPH_SUPPRESS_DEPRECATED_START
21
22 using namespace std;
23 using namespace ngraph;
24
25 constexpr NodeTypeInfo op::Dequantize::type_info;
26
27 op::Dequantize::Dequantize(const Output<Node>& input,
28                            const Output<Node>& scale,
29                            const Output<Node>& zero_point,
30                            const element::Type& type,
31                            const AxisSet& axes)
32
33     : Op({input, scale, zero_point})
34     , m_type(type)
35     , m_axes(axes)
36 {
37     constructor_validate_and_infer_types();
38 }
39
40 void op::Dequantize::validate_and_infer_types()
41 {
42     enum
43     {
44         INPUT,
45         SCALE,
46         ZERO_POINT
47     };
48
49     NODE_VALIDATION_CHECK(this, m_type.is_static(), "Output element type must not be dynamic");
50
51     NODE_VALIDATION_CHECK(
52         this, m_type.is_real(), "Output element type (", m_type, ") must be a floating point type");
53
54     element::Type quantized_type;
55
56     NODE_VALIDATION_CHECK(this,
57                           element::Type::merge(quantized_type,
58                                                get_input_element_type(INPUT),
59                                                get_input_element_type(ZERO_POINT)),
60                           "Zero point element type (",
61                           get_input_element_type(ZERO_POINT),
62                           ") must match input element type (",
63                           get_input_element_type(INPUT),
64                           ")");
65
66     NODE_VALIDATION_CHECK(this,
67                           quantized_type.is_dynamic() || quantized_type.is_quantized(),
68                           "Zero point / input element type (",
69                           quantized_type,
70                           ") must be a quantized type");
71
72     element::Type unquantized_type;
73
74     NODE_VALIDATION_CHECK(
75         this,
76         element::Type::merge(unquantized_type, get_input_element_type(SCALE), m_type),
77         "Scale element type (",
78         get_input_element_type(SCALE),
79         ") must match output element type (",
80         m_type,
81         ")");
82
83     PartialShape input_shape = get_input_partial_shape(0);
84     Dimension input_rank = input_shape.rank();
85
86     for (auto axis : m_axes)
87     {
88         NODE_VALIDATION_CHECK(this,
89                               input_rank.is_dynamic() || axis < input_rank.get_length(),
90                               "Quantization axis (",
91                               axis,
92                               ") must be less than input shape rank (",
93                               input_rank,
94                               ")");
95     }
96
97     PartialShape scale_zero_point_shape = get_input_partial_shape(SCALE);
98
99     NODE_VALIDATION_CHECK(
100         this,
101         PartialShape::merge_into(scale_zero_point_shape, get_input_partial_shape(ZERO_POINT)),
102         "Scale shape (",
103         get_input_partial_shape(SCALE),
104         ") and zero point shape (",
105         get_input_partial_shape(ZERO_POINT),
106         ") must match");
107
108     NODE_VALIDATION_CHECK(this,
109                           scale_zero_point_shape.rank().compatible(m_axes.size()),
110                           "Scale / zero point rank (",
111                           scale_zero_point_shape.rank(),
112                           ") does not match the number of ",
113                           "quantization axes (",
114                           m_axes.size(),
115                           ")");
116
117     set_output_size(1);
118
119     if (input_shape.rank().is_static() && scale_zero_point_shape.rank().is_static())
120     {
121         size_t i = 0;
122
123         vector<Dimension> injected_scale_zero_point_dims;
124
125         for (size_t j = 0; j < input_shape.rank().get_length(); j++)
126         {
127             if (m_axes.count(j) != 0)
128             {
129                 injected_scale_zero_point_dims.push_back(scale_zero_point_shape[i++]);
130             }
131             else
132             {
133                 injected_scale_zero_point_dims.push_back(Dimension::dynamic());
134             }
135         }
136
137         PartialShape result_shape = input_shape;
138         NODE_VALIDATION_CHECK(
139             this,
140             PartialShape::merge_into(result_shape, PartialShape{injected_scale_zero_point_dims}),
141             "Scale / zero point shape (",
142             scale_zero_point_shape,
143             ") must match input shape (",
144             input_shape,
145             ") at the quantization axes (",
146             m_axes,
147             ")");
148         set_output_type(0, unquantized_type, result_shape);
149     }
150     else
151     {
152         set_output_type(0, unquantized_type, PartialShape::dynamic());
153     }
154 }
155
156 shared_ptr<Node> op::Dequantize::clone_with_new_inputs(const OutputVector& new_args) const
157 {
158     check_new_args_count(this, new_args);
159     return make_shared<Dequantize>(new_args.at(0), new_args.at(1), new_args.at(2), m_type, m_axes);
160 }