Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / test / type_prop / dequantize.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "gtest/gtest.h"
18 #include "ngraph/ngraph.hpp"
19 #include "util/type_prop.hpp"
20
21 NGRAPH_SUPPRESS_DEPRECATED_START
22
23 using namespace std;
24 using namespace ngraph;
25
26 TEST(type_prop, dequantize_f32_from_i8_nchw_per_channel_ok)
27 {
28     Shape batch_shape{64, 3, 480, 640};
29     Shape scale_shape{3};
30     Shape zero_point_shape{3};
31     element::Type unquantized_type = element::f32;
32     element::Type quantized_type = element::i8;
33     element::Type batch_type = quantized_type;
34     element::Type scale_type = unquantized_type;
35     element::Type zero_point_type = quantized_type;
36     AxisSet axes{1};
37
38     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
39     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
40     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
41     auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
42
43     ASSERT_EQ(quant->get_output_element_type(0), unquantized_type);
44     ASSERT_EQ(quant->get_output_shape(0), batch_shape);
45 }
46
47 TEST(type_prop, dequantize_f32_from_i8_nchw_per_image_ok)
48 {
49     Shape batch_shape{64, 3, 480, 640};
50     Shape scale_shape{64};
51     Shape zero_point_shape{64};
52     element::Type unquantized_type = element::f32;
53     element::Type quantized_type = element::i8;
54     element::Type batch_type = quantized_type;
55     element::Type scale_type = unquantized_type;
56     element::Type zero_point_type = quantized_type;
57     AxisSet axes{0};
58
59     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
60     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
61     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
62     auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
63
64     ASSERT_EQ(quant->get_output_element_type(0), unquantized_type);
65     ASSERT_EQ(quant->get_output_shape(0), batch_shape);
66 }
67
68 TEST(type_prop, dequantize_f32_from_i8_nchw_per_row_ok)
69 {
70     Shape batch_shape{64, 3, 480, 640};
71     Shape scale_shape{480};
72     Shape zero_point_shape{480};
73     element::Type unquantized_type = element::f32;
74     element::Type quantized_type = element::i8;
75     element::Type batch_type = quantized_type;
76     element::Type scale_type = unquantized_type;
77     element::Type zero_point_type = quantized_type;
78     AxisSet axes{2};
79
80     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
81     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
82     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
83     auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
84
85     ASSERT_EQ(quant->get_output_element_type(0), unquantized_type);
86     ASSERT_EQ(quant->get_output_shape(0), batch_shape);
87 }
88
89 TEST(type_prop, dequantize_f32_from_i8_nchw_per_image_channel_ok)
90 {
91     Shape batch_shape{64, 3, 480, 640};
92     Shape scale_shape{64, 3};
93     Shape zero_point_shape{64, 3};
94     element::Type unquantized_type = element::f32;
95     element::Type quantized_type = element::i8;
96     element::Type batch_type = quantized_type;
97     element::Type scale_type = unquantized_type;
98     element::Type zero_point_type = quantized_type;
99     AxisSet axes{0, 1};
100
101     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
102     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
103     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
104     auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
105
106     ASSERT_EQ(quant->get_output_element_type(0), unquantized_type);
107     ASSERT_EQ(quant->get_output_shape(0), batch_shape);
108 }
109
110 TEST(type_prop, dequantize_f32_from_i8_nchw_whole_batch_ok)
111 {
112     Shape batch_shape{64, 3, 480, 640};
113     Shape scale_shape{};
114     Shape zero_point_shape{};
115     element::Type unquantized_type = element::f32;
116     element::Type quantized_type = element::i8;
117     element::Type batch_type = quantized_type;
118     element::Type scale_type = unquantized_type;
119     element::Type zero_point_type = quantized_type;
120     AxisSet axes{};
121
122     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
123     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
124     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
125     auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
126
127     ASSERT_EQ(quant->get_output_element_type(0), unquantized_type);
128     ASSERT_EQ(quant->get_output_shape(0), batch_shape);
129 }
130
131 TEST(type_prop, dequantize_f64_from_i8_ok)
132 {
133     Shape batch_shape{64, 3, 480, 640};
134     Shape scale_shape{};
135     Shape zero_point_shape{};
136     element::Type unquantized_type = element::f64;
137     element::Type quantized_type = element::i8;
138     element::Type batch_type = quantized_type;
139     element::Type scale_type = unquantized_type;
140     element::Type zero_point_type = quantized_type;
141     AxisSet axes{};
142
143     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
144     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
145     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
146     auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
147
148     ASSERT_EQ(quant->get_output_element_type(0), unquantized_type);
149     ASSERT_EQ(quant->get_output_shape(0), batch_shape);
150 }
151
152 TEST(type_prop, dequantize_f64_to_u8_ok)
153 {
154     Shape batch_shape{64, 3, 480, 640};
155     Shape scale_shape{};
156     Shape zero_point_shape{};
157     element::Type unquantized_type = element::f64;
158     element::Type quantized_type = element::u8;
159     element::Type batch_type = quantized_type;
160     element::Type scale_type = unquantized_type;
161     element::Type zero_point_type = quantized_type;
162     AxisSet axes{};
163
164     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
165     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
166     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
167     auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
168
169     ASSERT_EQ(quant->get_output_element_type(0), unquantized_type);
170     ASSERT_EQ(quant->get_output_shape(0), batch_shape);
171 }
172
173 TEST(type_prop, dequantize_i8_from_u8_fails)
174 {
175     Shape batch_shape{64, 3, 480, 640};
176     Shape scale_shape{};
177     Shape zero_point_shape{};
178     element::Type unquantized_type = element::i8;
179     element::Type quantized_type = element::u8;
180     element::Type batch_type = quantized_type;
181     element::Type scale_type = unquantized_type;
182     element::Type zero_point_type = quantized_type;
183     AxisSet axes{};
184
185     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
186     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
187     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
188
189     try
190     {
191         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
192         FAIL() << "Attempt to dequantize to non-floating point type not detected";
193     }
194     catch (const NodeValidationFailure& error)
195     {
196         EXPECT_HAS_SUBSTRING(error.what(),
197                              "Output element type (i8) must be a floating point type");
198     }
199     catch (...)
200     {
201         FAIL() << "Deduced type check failed for unexpected reason";
202     }
203 }
204
205 TEST(type_prop, dequantize_f32_from_f32_fails)
206 {
207     Shape batch_shape{64, 3, 480, 640};
208     Shape scale_shape{};
209     Shape zero_point_shape{};
210     element::Type unquantized_type = element::f32;
211     element::Type quantized_type = element::f32;
212     element::Type batch_type = quantized_type;
213     element::Type scale_type = unquantized_type;
214     element::Type zero_point_type = quantized_type;
215     AxisSet axes{};
216
217     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
218     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
219     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
220
221     try
222     {
223         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
224         FAIL() << "Attempt to dequantize from non-quantized type not detected";
225     }
226     catch (const NodeValidationFailure& error)
227     {
228         EXPECT_HAS_SUBSTRING(error.what(),
229                              "Zero point / input element type (f32) must be a quantized type");
230     }
231     catch (...)
232     {
233         FAIL() << "Deduced type check failed for unexpected reason";
234     }
235 }
236
237 TEST(type_prop, dequantize_batch_zero_point_type_mismatch_fails)
238 {
239     Shape batch_shape{64, 3, 480, 640};
240     Shape scale_shape{};
241     Shape zero_point_shape{};
242     element::Type unquantized_type = element::f32;
243     element::Type quantized_type = element::i8;
244     element::Type batch_type = quantized_type;
245     element::Type scale_type = unquantized_type;
246     element::Type zero_point_type = element::u8;
247     AxisSet axes{};
248
249     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
250     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
251     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
252
253     try
254     {
255         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
256         FAIL() << "Mismatch of batch and zero point element types not detected";
257     }
258     catch (const NodeValidationFailure& error)
259     {
260         EXPECT_HAS_SUBSTRING(error.what(),
261                              "Zero point element type (u8) must match input element type (i8)");
262     }
263     catch (...)
264     {
265         FAIL() << "Deduced type check failed for unexpected reason";
266     }
267 }
268
269 TEST(type_prop, dequantize_scale_type_mismatch_fails)
270 {
271     Shape batch_shape{64, 3, 480, 640};
272     Shape scale_shape{};
273     Shape zero_point_shape{};
274     element::Type unquantized_type = element::f32;
275     element::Type quantized_type = element::i8;
276     element::Type batch_type = quantized_type;
277     element::Type scale_type = element::f64;
278     element::Type zero_point_type = quantized_type;
279     AxisSet axes{};
280
281     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
282     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
283     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
284
285     try
286     {
287         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
288         FAIL() << "Mismatch of scale element type with scale argument not detected";
289     }
290     catch (const NodeValidationFailure& error)
291     {
292         EXPECT_HAS_SUBSTRING(error.what(),
293                              "Scale element type (f64) must match output element type (f32)");
294     }
295     catch (...)
296     {
297         FAIL() << "Deduced type check failed for unexpected reason";
298     }
299 }
300
301 TEST(type_prop, dequantize_oob_axis_fails)
302 {
303     Shape batch_shape{64, 3, 480, 640};
304     Shape scale_shape{320};
305     Shape zero_point_shape{320};
306     element::Type unquantized_type = element::f32;
307     element::Type quantized_type = element::i8;
308     element::Type batch_type = quantized_type;
309     element::Type scale_type = unquantized_type;
310     element::Type zero_point_type = quantized_type;
311     AxisSet axes{3, 4};
312
313     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
314     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
315     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
316
317     try
318     {
319         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
320         FAIL() << "Out-of-bounds quantization axis not detected";
321     }
322     catch (const NodeValidationFailure& error)
323     {
324         EXPECT_HAS_SUBSTRING(error.what(),
325                              "Quantization axis (4) must be less than input shape rank (4)");
326     }
327     catch (...)
328     {
329         FAIL() << "Deduced type check failed for unexpected reason";
330     }
331 }
332
333 TEST(type_prop, dequantize_scale_shape_mismatch_same_rank_fails)
334 {
335     Shape batch_shape{64, 3, 480, 640};
336     Shape scale_shape{64, 4};
337     Shape zero_point_shape{64, 3};
338     element::Type unquantized_type = element::f32;
339     element::Type quantized_type = element::i8;
340     element::Type batch_type = quantized_type;
341     element::Type scale_type = unquantized_type;
342     element::Type zero_point_type = quantized_type;
343     AxisSet axes{0, 1};
344
345     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
346     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
347     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
348
349     try
350     {
351         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
352         FAIL() << "Mismatch of scale argument shape with required shape not detected";
353     }
354     catch (const NodeValidationFailure& error)
355     {
356         EXPECT_HAS_SUBSTRING(error.what(),
357                              "Scale shape ({64,4}) and zero point shape ({64,3}) must match");
358     }
359     catch (...)
360     {
361         FAIL() << "Deduced type check failed for unexpected reason";
362     }
363 }
364
365 TEST(type_prop, dequantize_scale_shape_mismatch_different_rank_fails)
366 {
367     Shape batch_shape{64, 3, 480, 640};
368     Shape scale_shape{64, 3, 2};
369     Shape zero_point_shape{64, 3};
370     element::Type unquantized_type = element::f32;
371     element::Type quantized_type = element::i8;
372     element::Type batch_type = quantized_type;
373     element::Type scale_type = unquantized_type;
374     element::Type zero_point_type = quantized_type;
375     AxisSet axes{0, 1};
376
377     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
378     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
379     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
380
381     try
382     {
383         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
384         FAIL() << "Mismatch of scale argument shape with required shape not detected";
385     }
386     catch (const NodeValidationFailure& error)
387     {
388         EXPECT_HAS_SUBSTRING(error.what(),
389                              "Scale shape ({64,3,2}) and zero point shape ({64,3}) must match");
390     }
391     catch (...)
392     {
393         FAIL() << "Deduced type check failed for unexpected reason";
394     }
395 }
396
397 TEST(type_prop, dequantize_zero_point_shape_mismatch_same_rank_fails)
398 {
399     Shape batch_shape{64, 3, 480, 640};
400     Shape scale_shape{64, 3};
401     Shape zero_point_shape{64, 4};
402     element::Type unquantized_type = element::f32;
403     element::Type quantized_type = element::i8;
404     element::Type batch_type = quantized_type;
405     element::Type scale_type = unquantized_type;
406     element::Type zero_point_type = quantized_type;
407     AxisSet axes{0, 1};
408
409     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
410     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
411     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
412
413     try
414     {
415         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
416         FAIL() << "Mismatch of zero point argument shape with required shape not detected";
417     }
418     catch (const NodeValidationFailure& error)
419     {
420         EXPECT_HAS_SUBSTRING(error.what(),
421                              "Scale shape ({64,3}) and zero point shape ({64,4}) must match");
422     }
423     catch (...)
424     {
425         FAIL() << "Deduced type check failed for unexpected reason";
426     }
427 }
428
429 TEST(type_prop, dequantize_zero_point_shape_mismatch_different_rank_fails)
430 {
431     Shape batch_shape{64, 3, 480, 640};
432     Shape scale_shape{64, 3};
433     Shape zero_point_shape{64, 3, 2};
434     element::Type unquantized_type = element::f32;
435     element::Type quantized_type = element::i8;
436     element::Type batch_type = quantized_type;
437     element::Type scale_type = unquantized_type;
438     element::Type zero_point_type = quantized_type;
439     AxisSet axes{0, 1};
440
441     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
442     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
443     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
444
445     try
446     {
447         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
448         FAIL() << "Mismatch of zero point argument shape with required shape not detected";
449     }
450     catch (const NodeValidationFailure& error)
451     {
452         EXPECT_HAS_SUBSTRING(error.what(),
453                              "Scale shape ({64,3}) and zero point shape ({64,3,2}) must match");
454     }
455     catch (...)
456     {
457         FAIL() << "Deduced type check failed for unexpected reason";
458     }
459 }
460
461 TEST(type_prop, dequantize_partial_all_rank_dynamic_ok)
462 {
463     PartialShape batch_shape{PartialShape::dynamic()};
464     PartialShape scale_shape{PartialShape::dynamic()};
465     PartialShape zero_point_shape{PartialShape::dynamic()};
466     element::Type unquantized_type = element::f32;
467     element::Type quantized_type = element::i8;
468     element::Type batch_type = quantized_type;
469     element::Type scale_type = unquantized_type;
470     element::Type zero_point_type = quantized_type;
471     AxisSet axes{0, 1, 2000};
472
473     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
474     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
475     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
476     auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
477
478     ASSERT_EQ(quant->get_output_element_type(0), unquantized_type);
479     ASSERT_TRUE(quant->get_output_partial_shape(0).rank().is_dynamic());
480 }
481
482 TEST(type_prop,
483      dequantize_partial_input_rank_dynamic_scale_rank_static_dynamic_zero_point_rank_dynamic_ok)
484 {
485     PartialShape batch_shape{PartialShape::dynamic()};
486     PartialShape scale_shape{64, Dimension::dynamic(), 96};
487     PartialShape zero_point_shape{PartialShape::dynamic()};
488     element::Type unquantized_type = element::f32;
489     element::Type quantized_type = element::i8;
490     element::Type batch_type = quantized_type;
491     element::Type scale_type = unquantized_type;
492     element::Type zero_point_type = quantized_type;
493     AxisSet axes{0, 1, 2000};
494
495     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
496     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
497     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
498     auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
499
500     ASSERT_EQ(quant->get_output_element_type(0), unquantized_type);
501     ASSERT_TRUE(quant->get_output_partial_shape(0).rank().is_dynamic());
502 }
503
504 TEST(
505     type_prop,
506     dequantize_partial_input_rank_dynamic_scale_rank_static_dynamic_zero_point_rank_dynamic_axis_count_inconsistent)
507 {
508     PartialShape batch_shape{PartialShape::dynamic()};
509     PartialShape scale_shape{64, Dimension::dynamic(), 96};
510     PartialShape zero_point_shape{PartialShape::dynamic()};
511     element::Type unquantized_type = element::f32;
512     element::Type quantized_type = element::i8;
513     element::Type batch_type = quantized_type;
514     element::Type scale_type = unquantized_type;
515     element::Type zero_point_type = quantized_type;
516     AxisSet axes{0, 1};
517
518     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
519     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
520     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
521
522     try
523     {
524         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
525         FAIL() << "Mismatch of scale / zero point rank with axis count not detected";
526     }
527     catch (const NodeValidationFailure& error)
528     {
529         EXPECT_HAS_SUBSTRING(
530             error.what(),
531             "Scale / zero point rank (3) does not match the number of quantization axes (2)");
532     }
533     catch (...)
534     {
535         FAIL() << "Deduced type check failed for unexpected reason";
536     }
537 }
538
539 TEST(
540     type_prop,
541     dequantize_partial_input_rank_dynamic_scale_rank_static_dynamic_zero_point_rank_static_dynamic_ok)
542 {
543     PartialShape batch_shape{PartialShape::dynamic()};
544     PartialShape scale_shape{64, Dimension::dynamic(), 96, Dimension::dynamic()};
545     PartialShape zero_point_shape{64, 22, Dimension::dynamic(), Dimension::dynamic()};
546     element::Type unquantized_type = element::f32;
547     element::Type quantized_type = element::i8;
548     element::Type batch_type = quantized_type;
549     element::Type scale_type = unquantized_type;
550     element::Type zero_point_type = quantized_type;
551     AxisSet axes{0, 1, 5, 88};
552
553     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
554     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
555     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
556     auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
557
558     ASSERT_EQ(quant->get_output_element_type(0), unquantized_type);
559     ASSERT_TRUE(quant->get_output_partial_shape(0).rank().is_dynamic());
560 }
561
562 TEST(
563     type_prop,
564     dequantize_partial_input_rank_dynamic_scale_rank_static_dynamic_zero_point_rank_static_dynamic_ranks_inconsistent)
565 {
566     PartialShape batch_shape{PartialShape::dynamic()};
567     PartialShape scale_shape{64, Dimension::dynamic(), 96, Dimension::dynamic()};
568     PartialShape zero_point_shape{64, 22, Dimension::dynamic(), Dimension::dynamic(), 3};
569     element::Type unquantized_type = element::f32;
570     element::Type quantized_type = element::i8;
571     element::Type batch_type = quantized_type;
572     element::Type scale_type = unquantized_type;
573     element::Type zero_point_type = quantized_type;
574     AxisSet axes{0, 1, 5, 88};
575
576     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
577     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
578     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
579
580     try
581     {
582         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
583         FAIL() << "Inconsistent scale / zero point ranks not detected";
584     }
585     catch (const NodeValidationFailure& error)
586     {
587         EXPECT_HAS_SUBSTRING(
588             error.what(),
589             "Scale shape ({64,?,96,?}) and zero point shape ({64,22,?,?,3}) must match");
590     }
591     catch (...)
592     {
593         FAIL() << "Deduced type check failed for unexpected reason";
594     }
595 }
596
597 TEST(
598     type_prop,
599     dequantize_partial_input_rank_dynamic_scale_rank_static_dynamic_zero_point_rank_static_dynamic_dims_inconsistent)
600 {
601     PartialShape batch_shape{PartialShape::dynamic()};
602     PartialShape scale_shape{64, Dimension::dynamic(), 96, Dimension::dynamic()};
603     PartialShape zero_point_shape{65, 22, Dimension::dynamic(), Dimension::dynamic()};
604     element::Type unquantized_type = element::f32;
605     element::Type quantized_type = element::i8;
606     element::Type batch_type = quantized_type;
607     element::Type scale_type = unquantized_type;
608     element::Type zero_point_type = quantized_type;
609     AxisSet axes{0, 1, 5, 88};
610
611     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
612     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
613     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
614
615     try
616     {
617         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
618         FAIL() << "Inconsistent scale / zero point dims not detected";
619     }
620     catch (const NodeValidationFailure& error)
621     {
622         EXPECT_HAS_SUBSTRING(
623             error.what(),
624             "Scale shape ({64,?,96,?}) and zero point shape ({65,22,?,?}) must match");
625     }
626     catch (...)
627     {
628         FAIL() << "Deduced type check failed for unexpected reason";
629     }
630 }
631
632 TEST(
633     type_prop,
634     dequantize_partial_input_static_rank_dynamic_scale_rank_static_dynamic_zero_point_rank_static_dynamic_ok)
635 {
636     PartialShape batch_shape{2, 4, 6, Dimension::dynamic(), 10, Dimension::dynamic()};
637     PartialShape scale_shape{4, Dimension::dynamic(), Dimension::dynamic()};
638     PartialShape zero_point_shape{Dimension::dynamic(), 8, Dimension::dynamic()};
639     element::Type unquantized_type = element::f32;
640     element::Type quantized_type = element::i8;
641     element::Type batch_type = quantized_type;
642     element::Type scale_type = unquantized_type;
643     element::Type zero_point_type = quantized_type;
644     AxisSet axes{1, 3, 5};
645
646     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
647     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
648     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
649     auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
650
651     ASSERT_EQ(quant->get_output_element_type(0), unquantized_type);
652     ASSERT_TRUE(quant->get_output_partial_shape(0).same_scheme(
653         PartialShape{2, 4, 6, 8, 10, Dimension::dynamic()}));
654 }
655
656 TEST(
657     type_prop,
658     dequantize_partial_input_static_rank_dynamic_scale_rank_static_dynamic_zero_point_rank_static_dynamic_axis_oob)
659 {
660     PartialShape batch_shape{2, 4, 6, Dimension::dynamic(), 10, Dimension::dynamic()};
661     PartialShape scale_shape{4, Dimension::dynamic(), Dimension::dynamic()};
662     PartialShape zero_point_shape{Dimension::dynamic(), 8, Dimension::dynamic()};
663     element::Type unquantized_type = element::f32;
664     element::Type quantized_type = element::i8;
665     element::Type batch_type = quantized_type;
666     element::Type scale_type = unquantized_type;
667     element::Type zero_point_type = quantized_type;
668     AxisSet axes{1, 3, 6};
669
670     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
671     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
672     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
673
674     try
675     {
676         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
677         FAIL() << "Out-of-bound quantization axis not detected";
678     }
679     catch (const NodeValidationFailure& error)
680     {
681         EXPECT_HAS_SUBSTRING(error.what(),
682                              "Quantization axis (6) must be less than input shape rank (6)");
683     }
684     catch (...)
685     {
686         FAIL() << "Deduced type check failed for unexpected reason";
687     }
688 }
689
690 TEST(
691     type_prop,
692     dequantize_partial_input_static_rank_dynamic_scale_rank_static_dynamic_zero_point_rank_static_dynamic_dims_inconsistent)
693 {
694     PartialShape batch_shape{2, 5, 6, Dimension::dynamic(), 10, Dimension::dynamic()};
695     PartialShape scale_shape{4, Dimension::dynamic(), Dimension::dynamic()};
696     PartialShape zero_point_shape{Dimension::dynamic(), 8, Dimension::dynamic()};
697     element::Type unquantized_type = element::f32;
698     element::Type quantized_type = element::i8;
699     element::Type batch_type = quantized_type;
700     element::Type scale_type = unquantized_type;
701     element::Type zero_point_type = quantized_type;
702     AxisSet axes{1, 3, 5};
703
704     auto batch = make_shared<op::Parameter>(batch_type, batch_shape);
705     auto scale = make_shared<op::Parameter>(scale_type, scale_shape);
706     auto zero_point = make_shared<op::Parameter>(zero_point_type, zero_point_shape);
707
708     try
709     {
710         auto quant = make_shared<op::Dequantize>(batch, scale, zero_point, unquantized_type, axes);
711         FAIL() << "Inconsistent dimensions not detected";
712     }
713     catch (const NodeValidationFailure& error)
714     {
715         EXPECT_HAS_SUBSTRING(
716             error.what(),
717             "Scale / zero point shape ({4,8,?}) must match input shape ({2,5,6,?,10,?}) "
718             "at the quantization axes (AxisSet{1, 3, 5})");
719     }
720     catch (...)
721     {
722         FAIL() << "Deduced type check failed for unexpected reason";
723     }
724 }