[LPT] POT support: absent convert fix & element-wise empty dequantization data (...
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / max_pool_function.cpp
index 64eebf8..7296028 100644 (file)
@@ -17,41 +17,6 @@ namespace subgraph {
 std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
     const ngraph::element::Type originalFunctionPrecision,
     const ngraph::Shape& inputShape,
-    const ActualValues& values) {
-    const auto input = std::make_shared<ngraph::opset1::Parameter>(values.lowPrecision, ngraph::Shape(inputShape));
-    std::shared_ptr<ngraph::Node> parent = input;
-
-    const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::opset1::Convert>(parent, originalFunctionPrecision);
-    parent = convert;
-
-    if (!values.subtractValues.empty()) {
-        const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::opset1::Subtract>(
-            parent,
-            std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues));
-        parent = subtract;
-    }
-
-    const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::opset1::Multiply>(
-        parent,
-        std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues));
-    parent = multiply;
-
-    const std::shared_ptr<ngraph::Node> maxPool = std::make_shared<ngraph::opset1::MaxPool>(
-        parent,
-        Strides{ 1, 1 },
-        Shape{ 1, 1 },
-        Shape{ 0, 0 },
-        Shape{ 2, 2 },
-        op::RoundingType::FLOOR);
-    maxPool->set_friendly_name("output");
-
-    ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(maxPool) };
-    return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
-}
-
-std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
-    const ngraph::element::Type originalFunctionPrecision,
-    const ngraph::Shape& inputShape,
     const FakeQuantizeOnData& fakeQuantizeOnData) {
     const auto input = std::make_shared<ngraph::opset1::Parameter>(originalFunctionPrecision, ngraph::Shape(inputShape));
 
@@ -71,13 +36,16 @@ std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
 }
 
-std::shared_ptr<ngraph::Function> MaxPoolFunction::getReference(
-    const ngraph::element::Type originalFunctionPrecision,
+std::shared_ptr<ngraph::Function> MaxPoolFunction::get(
     const ngraph::Shape& inputShape,
-    const ExpectedValues& values) {
-    auto input = std::make_shared<ngraph::opset1::Parameter>(values.activationPrecision, ngraph::Shape(inputShape));
+    const ngraph::element::Type precisionBeforeDequantization,
+    const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
+    const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter) {
+    const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, ngraph::Shape(inputShape));
     std::shared_ptr<ngraph::Node> parent = input;
 
+    parent = dequantizationBefore.empty() ? parent : makeDequantization(parent, dequantizationBefore);
+
     const std::shared_ptr<ngraph::Node> maxPool = std::make_shared<ngraph::opset1::MaxPool>(
         parent,
         Strides{ 1, 1 },
@@ -87,25 +55,16 @@ std::shared_ptr<ngraph::Function> MaxPoolFunction::getReference(
         op::RoundingType::FLOOR);
     parent = maxPool;
 
-    if (parent->get_output_element_type(0) != originalFunctionPrecision) {
-        const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::pass::low_precision::DequantizationConvert>(parent, originalFunctionPrecision);
-        parent = convert;
-    }
+    parent = dequantizationAfter.empty() ? maxPool : makeDequantization(maxPool, dequantizationAfter);
+    maxPool->set_friendly_name("maxPool");
 
-    if (!values.subtractValues.empty()) {
-        const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(
-            parent,
-            std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues));
-        parent = subtract;
-    }
+    const std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(parent);
 
-    const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(
-        parent,
-        std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues));
-    multiply->set_friendly_name("output");
-
-    ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
-    return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
+    const std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
+        ngraph::ResultVector{ result },
+        std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
+        "MaxPoolTransformation");
+    return function;
 }
 
 }  // namespace subgraph