[LPT] POT support: absent convert fix & element-wise empty dequantization data (...
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / lp_transformations / max_pool_transformation.cpp
index 7629827..e3650e3 100644 (file)
 #include "common_test_utils/ngraph_test_utils.hpp"
 #include "simple_low_precision_transformer.hpp"
 #include "ngraph_functions/low_precision_transformations/max_pool_function.hpp"
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
+
 
 using namespace testing;
 using namespace ngraph::pass;
 
 class MaxPoolTransformationTestValues {
 public:
-    low_precision::LayerTransformation::Params params;
-    std::vector<float> subtractValues;
-    std::vector<float> mutliplyValues;
+    class Actual {
+    public:
+        ngraph::element::Type precisionBeforeDequantization;
+        ngraph::builder::subgraph::DequantizationOperations dequantization1;
+        ngraph::builder::subgraph::DequantizationOperations dequantization2;
+    };
+
+    class Expected {
+    public:
+        ngraph::element::Type precisionBeforeDequantization;
+        ngraph::builder::subgraph::DequantizationOperations dequantization1;
+        ngraph::builder::subgraph::DequantizationOperations dequantization2;
+    };
+
+    ngraph::pass::low_precision::LayerTransformation::Params params;
+    Actual actual;
+    Expected expected;
 };
 
 typedef std::tuple<
-    ngraph::element::Type,
     ngraph::Shape,
     MaxPoolTransformationTestValues> MaxPoolTransformationParams;
 
 class MaxPoolTransformation : public LayerTransformation, public testing::WithParamInterface<MaxPoolTransformationParams> {
 public:
     void SetUp() override {
-        const ngraph::element::Type precision = std::get<0>(GetParam());
-        const ngraph::Shape shape = std::get<1>(GetParam());
-        const MaxPoolTransformationTestValues testValues = std::get<2>(GetParam());
+        const ngraph::Shape shape = std::get<0>(GetParam());
+        const MaxPoolTransformationTestValues testValues = std::get<1>(GetParam());
 
-        actualFunction = ngraph::builder::subgraph::MaxPoolFunction::getOriginal(
-            precision,
+        actualFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
             shape,
-            {
-                testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
-                testValues.subtractValues,
-                testValues.mutliplyValues
-            });
+            testValues.actual.precisionBeforeDequantization,
+            testValues.actual.dequantization1,
+            testValues.actual.dequantization2);
 
         SimpleLowPrecisionTransformer transform;
         transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
         transform.transform(actualFunction);
 
-        referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::getReference(
-            precision,
+        referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
             shape,
-            {
-                testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
-                testValues.subtractValues,
-                testValues.mutliplyValues
-            });
+            testValues.expected.precisionBeforeDequantization,
+            testValues.expected.dequantization1,
+            testValues.expected.dequantization2);
     }
 
     static std::string getTestCaseName(testing::TestParamInfo<MaxPoolTransformationParams> obj) {
-        const ngraph::element::Type precision = std::get<0>(obj.param);
-        const ngraph::Shape shape = std::get<1>(obj.param);
-        const MaxPoolTransformationTestValues testValues = std::get<2>(obj.param);
+        const ngraph::Shape shape = std::get<0>(obj.param);
+        const MaxPoolTransformationTestValues testValues = std::get<1>(obj.param);
 
         std::ostringstream result;
         result <<
-            LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
-            testValues.subtractValues.size() << "_" <<
-            testValues.mutliplyValues.size() << "_";
+            LayerTransformation::getTestCaseNameByParams(testValues.actual.precisionBeforeDequantization, shape, testValues.params) << "_" <<
+            testValues.actual.dequantization1 << "_" <<
+            testValues.actual.dequantization2 << "_" <<
+            testValues.expected.dequantization1 << "_" <<
+            testValues.expected.dequantization2 << "_";
         return result.str();
     }
 };
 
 TEST_P(MaxPoolTransformation, CompareFunctions) {
     actualFunction->validate_nodes_and_infer_types();
-    auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
+    auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
     ASSERT_TRUE(res.first) << res.second;
 }
 
-const std::vector<ngraph::element::Type> precisions = {
-    ngraph::element::f32,
-    // ngraph::element::f16
-};
-
 const std::vector<ngraph::Shape> shapes = {
-    { 1, 32, 72, 48 }
+    { 1, 32, 72, 48 },
+    { 4, 32, 72, 48 }
 };
 
 const std::vector<MaxPoolTransformationTestValues> testValues = {
-    { LayerTransformation::createParamsU8I8(), { 128 }, { 0.02f } },
-    { LayerTransformation::createParamsU8I8(), {}, { 0.02f } },
-    { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), { 128 }, { 0.02f } },
-    { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), {}, { 0.02f } },
-    { LayerTransformation::createParamsI8I8(), { 128 }, { 0.02f } },
+    // Multiply
+    {
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            { {}, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }},
+            {}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            { ngraph::element::f32, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }}
+        }
+    },
+    // Subtract + Multiply
+    {
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            {
+                {},
+                { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
+                { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
+            },
+            {}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            {
+                ngraph::element::f32,
+                { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
+                { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
+            }
+        }
+    },
+    // Convert + Subtract + Multiply
+    {
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            { ngraph::element::f32, { 128 }, { 0.02f }},
+            {}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            { ngraph::element::f32, { 128 }, { 0.02f }}
+        }
+    },
+    // Convert + Subtract + Multiply
+    {
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            { ngraph::element::f32, {}, { 0.02f }},
+            {}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            { ngraph::element::f32, {}, { 0.02f }}
+        }
+    },
+    // Convert + Subtract + Multiply
+    {
+        LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
+        {
+            ngraph::element::u8,
+            { ngraph::element::f32, { 128 }, { 0.02f }},
+            {}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            { ngraph::element::f32, { 128 }, { 0.02f }}
+        }
+    },
+    // Convert + Subtract + Multiply
+    {
+        LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
+        {
+            ngraph::element::u8,
+            { ngraph::element::f32, {}, { 0.02f }},
+            {}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            { ngraph::element::f32, {}, { 0.02f }}
+        }
+    }
 };
 
 INSTANTIATE_TEST_CASE_P(
     LPT,
     MaxPoolTransformation,
     ::testing::Combine(
-        ::testing::ValuesIn(precisions),
         ::testing::ValuesIn(shapes),
         ::testing::ValuesIn(testValues)),
     MaxPoolTransformation::getTestCaseName);