1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "ngraph/op/quantize.hpp"
18 #include "ngraph/shape_util.hpp"
20 NGRAPH_SUPPRESS_DEPRECATED_START
23 using namespace ngraph;
25 constexpr NodeTypeInfo op::Quantize::type_info;
27 op::Quantize::Quantize(const Output<Node>& input,
28 const Output<Node>& scale,
29 const Output<Node>& zero_point,
30 const element::Type& type,
34 : Op({input, scale, zero_point})
37 , m_round_mode(round_mode)
39 constructor_validate_and_infer_types();
42 void op::Quantize::validate_and_infer_types()
51 NODE_VALIDATION_CHECK(this, m_type.is_static(), "Output element type must not be dynamic");
53 NODE_VALIDATION_CHECK(
54 this, m_type.is_quantized(), "Output element type (", m_type, ") must be a quantized type");
56 element::Type unquantized_type;
58 NODE_VALIDATION_CHECK(this,
59 element::Type::merge(unquantized_type,
60 get_input_element_type(INPUT),
61 get_input_element_type(SCALE)),
62 "Scale element type (",
63 get_input_element_type(SCALE),
64 ") must match input element type (",
65 get_input_element_type(INPUT),
68 NODE_VALIDATION_CHECK(this,
69 unquantized_type.is_dynamic() || unquantized_type.is_real(),
70 "Scale / input element type (",
72 ") must be a floating point number");
74 element::Type quantized_type;
76 NODE_VALIDATION_CHECK(
78 element::Type::merge(quantized_type, get_input_element_type(ZERO_POINT), m_type),
79 "Zero point element type (",
80 get_input_element_type(ZERO_POINT),
81 ") must match output element type (",
85 PartialShape input_shape = get_input_partial_shape(0);
86 Dimension input_rank = input_shape.rank();
88 for (auto axis : m_axes)
90 NODE_VALIDATION_CHECK(this,
91 input_rank.is_dynamic() || axis < input_rank.get_length(),
92 "Quantization axis (",
94 ") must be less than input shape rank (",
99 PartialShape scale_zero_point_shape = get_input_partial_shape(SCALE);
101 NODE_VALIDATION_CHECK(
103 PartialShape::merge_into(scale_zero_point_shape, get_input_partial_shape(ZERO_POINT)),
105 get_input_partial_shape(SCALE),
106 ") and zero point shape (",
107 get_input_partial_shape(ZERO_POINT),
110 NODE_VALIDATION_CHECK(this,
111 scale_zero_point_shape.rank().compatible(m_axes.size()),
112 "Scale / zero point rank (",
113 scale_zero_point_shape.rank(),
114 ") does not match the number of ",
115 "quantization axes (",
121 if (input_shape.rank().is_static() && scale_zero_point_shape.rank().is_static())
125 vector<Dimension> injected_scale_zero_point_dims;
127 for (size_t j = 0; j < input_shape.rank().get_length(); j++)
129 if (m_axes.count(j) != 0)
131 injected_scale_zero_point_dims.push_back(scale_zero_point_shape[i++]);
135 injected_scale_zero_point_dims.push_back(Dimension::dynamic());
139 PartialShape result_shape = input_shape;
140 NODE_VALIDATION_CHECK(
142 PartialShape::merge_into(result_shape, PartialShape{injected_scale_zero_point_dims}),
143 "Scale / zero point shape (",
144 scale_zero_point_shape,
145 ") must match input shape (",
147 ") at the quantization axes (",
150 set_output_type(0, quantized_type, result_shape);
154 set_output_type(0, quantized_type, PartialShape::dynamic());
158 shared_ptr<Node> op::Quantize::clone_with_new_inputs(const OutputVector& new_args) const
160 check_new_args_count(this, new_args);
161 return make_shared<Quantize>(
162 new_args.at(0), new_args.at(1), new_args.at(2), m_type, m_axes, m_round_mode);