Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / src / op / quantized_dot.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "quantized_dot.hpp"
18 #include <numeric>
19 #include "ngraph/coordinate_diff.hpp"
20 #include "ngraph/util.hpp"
21 #include "ngraph/validation_util.hpp"
22
23 NGRAPH_SUPPRESS_DEPRECATED_START
24
25 using namespace std;
26 using namespace ngraph;
27
28 constexpr NodeTypeInfo op::QuantizedDot::type_info;
29
30 op::QuantizedDot::QuantizedDot(const Output<Node>& input0,
31                                const Output<Node>& input1,
32                                size_t reduction_axes_count,
33                                const Output<Node>& input0_scale,
34                                const Output<Node>& input0_zero_point,
35                                const Output<Node>& input1_scale,
36                                const Output<Node>& input1_zero_point,
37                                const Output<Node>& output_scale,
38                                const Output<Node>& output_zero_point,
39                                const element::Type& output_type,
40                                const AxisSet& input0_axes,
41                                const AxisSet& input1_axes,
42                                const AxisSet& output_axes)
43     : Op({input0,
44           input1,
45           input0_scale,
46           input0_zero_point,
47           input1_scale,
48           input1_zero_point,
49           output_scale,
50           output_zero_point})
51     , m_reduction_axes_count(reduction_axes_count)
52     , m_output_type(output_type)
53     , m_input0_axes(input0_axes)
54     , m_input1_axes(input1_axes)
55     , m_output_axes(output_axes)
56 {
57     constructor_validate_and_infer_types();
58 }
59
60 void op::QuantizedDot::validate_and_infer_types()
61 {
62     enum
63     {
64         INPUT0,
65         INPUT1,
66         INPUT0_SCALE,
67         INPUT0_ZERO_POINT,
68         INPUT1_SCALE,
69         INPUT1_ZERO_POINT,
70         OUTPUT_SCALE,
71         OUTPUT_ZERO_POINT
72     };
73
74     NODE_VALIDATION_CHECK(
75         this, m_output_type.is_static(), "Output element type must not be dynamic");
76
77     NODE_VALIDATION_CHECK(this,
78                           m_output_type.is_quantized(),
79                           "Output element type (",
80                           m_output_type,
81                           ") must be a quantized type");
82     NODE_VALIDATION_CHECK(this,
83                           get_input_element_type(INPUT0).is_quantized(),
84                           "Input0 element type (",
85                           get_input_element_type(INPUT0),
86                           ") must be a quantized type");
87
88     NODE_VALIDATION_CHECK(this,
89                           get_input_element_type(INPUT1).is_quantized(),
90                           "Input1 element type (",
91                           get_input_element_type(INPUT1),
92                           ") must be a quantized type");
93
94     NODE_VALIDATION_CHECK(this,
95                           get_input_element_type(INPUT0_SCALE).is_real() ||
96                               get_input_element_type(INPUT0_SCALE).is_dynamic() ||
97                               get_input_element_type(INPUT1_SCALE).is_real() ||
98                               get_input_element_type(INPUT1_SCALE).is_dynamic() ||
99                               get_input_element_type(OUTPUT_SCALE).is_real() ||
100                               get_input_element_type(OUTPUT_SCALE).is_dynamic(),
101                           "Scale must be a floating point number");
102
103     NODE_VALIDATION_CHECK(
104         this,
105         get_input_element_type(INPUT0).compatible(get_input_element_type(INPUT0_ZERO_POINT)),
106         "Input0 Zero point element type (",
107         get_input_element_type(INPUT0_ZERO_POINT),
108         ") must match input0 element type (",
109         get_input_element_type(INPUT0),
110         ")");
111
112     NODE_VALIDATION_CHECK(
113         this,
114         get_input_element_type(INPUT1).compatible(get_input_element_type(INPUT1_ZERO_POINT)),
115         "Input1 Zero point element type (",
116         get_input_element_type(INPUT1_ZERO_POINT),
117         ") must match input1 element type (",
118         get_input_element_type(INPUT1),
119         ")");
120
121     // TODO Remove these checks once we support channelwise and vector of scales
122     NODE_VALIDATION_CHECK(this,
123                           get_input_partial_shape(2).compatible(PartialShape{}) &&
124                               get_input_partial_shape(3).compatible(PartialShape{}),
125                           "Input0 scale and input0 zero point shape must be same and 1");
126
127     NODE_VALIDATION_CHECK(this,
128                           get_input_partial_shape(4).compatible(PartialShape{}) &&
129                               get_input_partial_shape(5).compatible(PartialShape{}),
130                           "Input1 scale and input1 zero point shape must be same and 1");
131
132     NODE_VALIDATION_CHECK(this,
133                           get_input_partial_shape(6).compatible(PartialShape{}) &&
134                               get_input_partial_shape(7).compatible(PartialShape{}),
135                           "Output scale and output zero point shape must be same and 1");
136
137     // AxisSet should be empty till we support channel wise quantization
138     NODE_VALIDATION_CHECK(this,
139                           m_input0_axes == AxisSet{} && m_input1_axes == AxisSet{} &&
140                               m_output_axes == AxisSet{},
141                           "Input0, input1 and output AxisSet should be empty");
142
143     const PartialShape& arg0_shape = get_input_partial_shape(0);
144     const PartialShape& arg1_shape = get_input_partial_shape(1);
145
146     PartialShape result_shape;
147
148     if (arg0_shape.rank().is_static() && arg1_shape.rank().is_static())
149     {
150         for (size_t i = 0; i < m_reduction_axes_count; i++)
151         {
152             size_t axis_index_arg0 = arg0_shape.rank().get_length() - m_reduction_axes_count + i;
153             size_t axis_index_arg1 = i;
154
155             NODE_VALIDATION_CHECK(
156                 this,
157                 arg0_shape[axis_index_arg0].compatible(arg1_shape[axis_index_arg1]),
158                 "Paired axes (axis ",
159                 axis_index_arg0,
160                 " from arg0, axis ",
161                 axis_index_arg1,
162                 " from arg1) do not have same length (arg0 shape: ",
163                 arg0_shape,
164                 ", arg1 shape: ",
165                 arg1_shape,
166                 ", reduction axes count: ",
167                 m_reduction_axes_count,
168                 ").");
169         }
170
171         std::vector<Dimension> result_dims(arg0_shape.rank().get_length() +
172                                            arg1_shape.rank().get_length() -
173                                            2 * m_reduction_axes_count);
174
175         size_t i = 0;
176
177         for (size_t j = 0; j < arg0_shape.rank().get_length() - m_reduction_axes_count; j++)
178         {
179             result_dims[i++] = arg0_shape[j];
180         }
181         for (size_t j = m_reduction_axes_count; j < arg1_shape.rank().get_length(); j++)
182         {
183             result_dims[i++] = arg1_shape[j];
184         }
185
186         result_shape = PartialShape(result_dims);
187     }
188     else
189     {
190         result_shape = PartialShape::dynamic();
191     }
192
193     NODE_VALIDATION_CHECK(
194         this,
195         get_output_element_type(0).compatible(get_input_element_type(OUTPUT_ZERO_POINT)),
196         "Output Zero point element type (",
197         get_input_element_type(OUTPUT_ZERO_POINT),
198         ") must match output element type (",
199         get_output_element_type(0),
200         ")");
201
202     set_output_type(0, m_output_type, result_shape);
203 }
204
205 shared_ptr<Node> op::QuantizedDot::clone_with_new_inputs(const OutputVector& new_args) const
206 {
207     check_new_args_count(this, new_args);
208     return shared_ptr<Node>(new QuantizedDot(new_args.at(0),
209                                              new_args.at(1),
210                                              m_reduction_axes_count,
211                                              new_args.at(2),
212                                              new_args.at(3),
213                                              new_args.at(4),
214                                              new_args.at(5),
215                                              new_args.at(6),
216                                              new_args.at(7),
217                                              m_output_type,
218                                              m_input0_axes,
219                                              m_input1_axes,
220                                              m_output_axes));
221 }