[LPT] POT support: absent convert fix & element-wise empty dequantization data (...
authorEdward Shogulin <edward.shogulin@intel.com>
Fri, 13 Nov 2020 07:32:59 +0000 (10:32 +0300)
committerGitHub <noreply@github.com>
Fri, 13 Nov 2020 07:32:59 +0000 (10:32 +0300)
inference-engine/src/low_precision_transformations/src/common/eltwise_base_transformation.cpp
inference-engine/src/low_precision_transformations/src/common/network_helper.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/elementwise_with_multi_parent_dequantization_transformation.cpp [new file with mode: 0644]
inference-engine/tests/functional/inference_engine/lp_transformations/max_pool_transformation.cpp
inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp [new file with mode: 0644]
inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/max_pool_function.hpp
inference-engine/tests/ngraph_functions/src/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.cpp [new file with mode: 0644]
inference-engine/tests/ngraph_functions/src/low_precision_transformations/max_pool_function.cpp

index aa4a869..155e4ba 100644 (file)
@@ -69,11 +69,13 @@ bool EltwiseBaseTransformation::canBeTransformed(const TransformationContext& co
         return false;
     }
 
-    if (dequantization1.empty() && !is_type<opset1::Constant>(dequantization1.data.get_node_shared_ptr())) {
+    if ((dequantization1.data.get_node() == nullptr) ||
+        (dequantization1.empty() && !is_type<opset1::Constant>(dequantization1.data.get_node_shared_ptr()))) {
         return false;
     }
 
-    if (dequantization2.empty() && !is_type<opset1::Constant>(dequantization2.data.get_node_shared_ptr())) {
+    if ((dequantization2.data.get_node() == nullptr) ||
+        (dequantization2.empty() && !is_type<opset1::Constant>(dequantization2.data.get_node_shared_ptr()))) {
         return false;
     }
 
index 86159ed..8bd84d5 100644 (file)
@@ -948,7 +948,10 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
 
     auto parent = newOperation;
     if (shouldConvert) {
-        parent = std::make_shared<DequantizationConvert>(parent, dequantization.convert->get_output_element_type(0));
+        const auto convertOutputPrecision = dequantization.convert != nullptr ?
+            dequantization.convert->get_output_element_type(0) :
+            dequantization.multiply->get_output_element_type(0);
+        parent = std::make_shared<DequantizationConvert>(parent, convertOutputPrecision);
         ngraph::copy_runtime_info({ newOperation, parent }, parent);
     }
 
diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/elementwise_with_multi_parent_dequantization_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/elementwise_with_multi_parent_dequantization_transformation.cpp
new file mode 100644 (file)
index 0000000..490b0ca
--- /dev/null
@@ -0,0 +1,161 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "layer_transformation.hpp"
+
+#include <string>
+#include <sstream>
+#include <memory>
+
+#include <gtest/gtest.h>
+
+#include <utility>
+#include <transformations/utils/utils.hpp>
+#include <transformations/init_node_info.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+#include "simple_low_precision_transformer.hpp"
+
+#include <low_precision/add.hpp>
+#include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp"
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
+
+using namespace testing;
+using namespace ngraph::pass;
+using namespace ngraph::builder::subgraph;
+
+class ElementwiseWithMultiParentDequantizationTransformationTestValues {
+public:
+    class Actual {
+    public:
+        ngraph::element::Type precision1;
+        ngraph::builder::subgraph::DequantizationOperations dequantization1;
+        ngraph::element::Type precision2;
+        ngraph::builder::subgraph::DequantizationOperations dequantization2;
+    };
+
+    class Expected {
+    public:
+        ngraph::element::Type precision1;
+        ngraph::builder::subgraph::DequantizationOperations dequantization1;
+        ngraph::element::Type precision2;
+        ngraph::builder::subgraph::DequantizationOperations dequantization2;
+    };
+
+    ngraph::element::Type precision;
+    ngraph::Shape inputShape;
+    ngraph::pass::low_precision::LayerTransformation::Params params;
+    Actual actual;
+    Expected expected;
+};
+
+template <typename T>
+inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
+    os << "{ ";
+    for (size_t i = 0; i < values.size(); ++i) {
+        os << values[i];
+        if (i != (values.size() - 1ul)) {
+            os << ", ";
+        }
+    }
+    os << " }";
+    return os;
+}
+
+class ElementwiseWithMultiParentDequantizationTransformation :
+    public LayerTransformation,
+    public testing::WithParamInterface<ElementwiseWithMultiParentDequantizationTransformationTestValues> {
+public:
+    void SetUp() override {
+        const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = GetParam();
+
+        actualFunction = ElementwiseWithMultiParentDequantizationFunction::get(
+            testValues.precision,
+            testValues.inputShape,
+            testValues.params,
+            testValues.actual.precision1,
+            testValues.actual.dequantization1,
+            testValues.actual.precision2,
+            testValues.actual.dequantization2);
+
+        SimpleLowPrecisionTransformer transform;
+        transform.add<ngraph::pass::low_precision::AddTransformation, ngraph::opset1::Add>(
+            low_precision::LayerTransformation::Params(testValues.params));
+        transform.transform(actualFunction);
+
+        referenceFunction = ElementwiseWithMultiParentDequantizationFunction::get(
+            testValues.precision,
+            testValues.inputShape,
+            testValues.params,
+            testValues.expected.precision1,
+            testValues.expected.dequantization1,
+            testValues.expected.precision2,
+            testValues.expected.dequantization2);
+    }
+
+    static std::string getTestCaseName(testing::TestParamInfo<ElementwiseWithMultiParentDequantizationTransformationTestValues> obj) {
+        const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = obj.param;
+
+        std::ostringstream result;
+        result <<
+            testValues.precision << "_" <<
+            testValues.inputShape << "_" <<
+            testValues.actual.precision1 << "_" <<
+            testValues.actual.dequantization1 << "_" <<
+            testValues.actual.precision2 << "_" <<
+            testValues.actual.dequantization2;
+        return result.str();
+    }
+};
+
+TEST_P(ElementwiseWithMultiParentDequantizationTransformation, CompareFunctions) {
+    actualFunction->validate_nodes_and_infer_types();
+    auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+const std::vector<ElementwiseWithMultiParentDequantizationTransformationTestValues> addTransformationTestValues = {
+    // U8
+    {
+        ngraph::element::f32,
+        ngraph::Shape{1, 4, 16, 16},
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            { {ngraph::element::f32},  { 7.f }, { 10.f }},
+            ngraph::element::u8,
+            {},
+        },
+        {
+            ngraph::element::u8,
+            { {ngraph::element::f32},  { 7.f }, { 10.f }},
+            ngraph::element::u8,
+            {},
+        }
+    },
+    // U8
+    {
+        ngraph::element::f32,
+        ngraph::Shape{1, 4, 16, 16},
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            {},
+            ngraph::element::u8,
+            { {ngraph::element::f32},  { 7.f }, { 10.f }}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            ngraph::element::u8,
+            { {ngraph::element::f32},  { 7.f }, { 10.f }}
+        }
+    }
+};
+
+INSTANTIATE_TEST_CASE_P(
+    LPT,
+    ElementwiseWithMultiParentDequantizationTransformation,
+    ::testing::ValuesIn(addTransformationTestValues),
+    ElementwiseWithMultiParentDequantizationTransformation::getTestCaseName);
index 7629827..e3650e3 100644 (file)
 #include "common_test_utils/ngraph_test_utils.hpp"
 #include "simple_low_precision_transformer.hpp"
 #include "ngraph_functions/low_precision_transformations/max_pool_function.hpp"
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
+
 
 using namespace testing;
 using namespace ngraph::pass;
 
 class MaxPoolTransformationTestValues {
 public:
-    low_precision::LayerTransformation::Params params;
-    std::vector<float> subtractValues;
-    std::vector<float> mutliplyValues;
+    class Actual {
+    public:
+        ngraph::element::Type precisionBeforeDequantization;
+        ngraph::builder::subgraph::DequantizationOperations dequantization1;
+        ngraph::builder::subgraph::DequantizationOperations dequantization2;
+    };
+
+    class Expected {
+    public:
+        ngraph::element::Type precisionBeforeDequantization;
+        ngraph::builder::subgraph::DequantizationOperations dequantization1;
+        ngraph::builder::subgraph::DequantizationOperations dequantization2;
+    };
+
+    ngraph::pass::low_precision::LayerTransformation::Params params;
+    Actual actual;
+    Expected expected;
 };
 
 typedef std::tuple<
-    ngraph::element::Type,
     ngraph::Shape,
     MaxPoolTransformationTestValues> MaxPoolTransformationParams;
 
 class MaxPoolTransformation : public LayerTransformation, public testing::WithParamInterface<MaxPoolTransformationParams> {
 public:
     void SetUp() override {
-        const ngraph::element::Type precision = std::get<0>(GetParam());
-        const ngraph::Shape shape = std::get<1>(GetParam());
-        const MaxPoolTransformationTestValues testValues = std::get<2>(GetParam());
+        const ngraph::Shape shape = std::get<0>(GetParam());
+        const MaxPoolTransformationTestValues testValues = std::get<1>(GetParam());
 
-        actualFunction = ngraph::builder::subgraph::MaxPoolFunction::getOriginal(
-            precision,
+        actualFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
             shape,
-            {
-                testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
-                testValues.subtractValues,
-                testValues.mutliplyValues
-            });
+            testValues.actual.precisionBeforeDequantization,
+            testValues.actual.dequantization1,
+            testValues.actual.dequantization2);
 
         SimpleLowPrecisionTransformer transform;
         transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
         transform.transform(actualFunction);
 
-        referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::getReference(
-            precision,
+        referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
             shape,
-            {
-                testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
-                testValues.subtractValues,
-                testValues.mutliplyValues
-            });
+            testValues.expected.precisionBeforeDequantization,
+            testValues.expected.dequantization1,
+            testValues.expected.dequantization2);
     }
 
     static std::string getTestCaseName(testing::TestParamInfo<MaxPoolTransformationParams> obj) {
-        const ngraph::element::Type precision = std::get<0>(obj.param);
-        const ngraph::Shape shape = std::get<1>(obj.param);
-        const MaxPoolTransformationTestValues testValues = std::get<2>(obj.param);
+        const ngraph::Shape shape = std::get<0>(obj.param);
+        const MaxPoolTransformationTestValues testValues = std::get<1>(obj.param);
 
         std::ostringstream result;
         result <<
-            LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
-            testValues.subtractValues.size() << "_" <<
-            testValues.mutliplyValues.size() << "_";
+            LayerTransformation::getTestCaseNameByParams(testValues.actual.precisionBeforeDequantization, shape, testValues.params) << "_" <<
+            testValues.actual.dequantization1 << "_" <<
+            testValues.actual.dequantization2 << "_" <<
+            testValues.expected.dequantization1 << "_" <<
+            testValues.expected.dequantization2 << "_";
         return result.str();
     }
 };
 
 TEST_P(MaxPoolTransformation, CompareFunctions) {
     actualFunction->validate_nodes_and_infer_types();
-    auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
+    auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
     ASSERT_TRUE(res.first) << res.second;
 }
 
-const std::vector<ngraph::element::Type> precisions = {
-    ngraph::element::f32,
-    // ngraph::element::f16
-};
-
 const std::vector<ngraph::Shape> shapes = {
-    { 1, 32, 72, 48 }
+    { 1, 32, 72, 48 },
+    { 4, 32, 72, 48 }
 };
 
 const std::vector<MaxPoolTransformationTestValues> testValues = {
-    { LayerTransformation::createParamsU8I8(), { 128 }, { 0.02f } },
-    { LayerTransformation::createParamsU8I8(), {}, { 0.02f } },
-    { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), { 128 }, { 0.02f } },
-    { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), {}, { 0.02f } },
-    { LayerTransformation::createParamsI8I8(), { 128 }, { 0.02f } },
+    // Multiply
+    {
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            { {}, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }},
+            {}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            { ngraph::element::f32, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }}
+        }
+    },
+    // Subtract + Multiply
+    {
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            {
+                {},
+                { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
+                { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
+            },
+            {}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            {
+                ngraph::element::f32,
+                { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
+                { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
+            }
+        }
+    },
+    // Convert + Subtract + Multiply
+    {
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            { ngraph::element::f32, { 128 }, { 0.02f }},
+            {}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            { ngraph::element::f32, { 128 }, { 0.02f }}
+        }
+    },
+    // Convert + Subtract + Multiply
+    {
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            { ngraph::element::f32, {}, { 0.02f }},
+            {}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            { ngraph::element::f32, {}, { 0.02f }}
+        }
+    },
+    // Convert + Subtract + Multiply
+    {
+        LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
+        {
+            ngraph::element::u8,
+            { ngraph::element::f32, { 128 }, { 0.02f }},
+            {}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            { ngraph::element::f32, { 128 }, { 0.02f }}
+        }
+    },
+    // Convert + Subtract + Multiply
+    {
+        LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
+        {
+            ngraph::element::u8,
+            { ngraph::element::f32, {}, { 0.02f }},
+            {}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            { ngraph::element::f32, {}, { 0.02f }}
+        }
+    }
 };
 
 INSTANTIATE_TEST_CASE_P(
     LPT,
     MaxPoolTransformation,
     ::testing::Combine(
-        ::testing::ValuesIn(precisions),
         ::testing::ValuesIn(shapes),
         ::testing::ValuesIn(testValues)),
     MaxPoolTransformation::getTestCaseName);
diff --git a/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp b/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp
new file mode 100644 (file)
index 0000000..6f55b2d
--- /dev/null
@@ -0,0 +1,71 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+#include <ngraph/ngraph.hpp>
+
+#include "functional_test_utils/low_precision_transformations/layer_transformation.hpp"
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
+#include "ngraph_functions/subgraph_builders.hpp"
+#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
+
+namespace ngraph {
+namespace builder {
+namespace subgraph {
+
+class AddActualValues {
+public:
+    ngraph::element::Type precision1;
+    std::vector<float> subtractValues1;
+    std::vector<float> mutliplyValues1;
+    ngraph::element::Type precision2;
+    std::vector<float> subtractValues2;
+    std::vector<float> mutliplyValues2;
+};
+
+inline std::ostream& operator<<(std::ostream& out, const AddActualValues& values) {
+    return out <<
+        "_" << values.precision1 <<
+        "_subtract" << values.subtractValues1.size() <<
+        "_mutliply" << values.mutliplyValues1.size() <<
+        "_" << values.precision2 <<
+        "_subtract" << values.subtractValues2.size() <<
+        "_mutliply" << values.mutliplyValues2.size();
+}
+
+class AddExpectedValues {
+public:
+    ngraph::element::Type precision1;
+    std::vector<float> subtractValues1;
+    std::vector<float> mutliplyValues1;
+    ngraph::element::Type precision2;
+    std::vector<float> mutliplyValuesAfter;
+};
+
+inline std::ostream& operator<<(std::ostream& out, const AddExpectedValues& values) {
+    return out <<
+        "_" << values.precision1 <<
+        "_subtract" << values.subtractValues1.size() <<
+        "_mutliply" << values.mutliplyValues1.size() <<
+        "_" << values.precision2 <<
+        "_mutliply" << values.mutliplyValuesAfter.size();
+}
+
+class ElementwiseWithMultiParentDequantizationFunction {
+public:
+    static std::shared_ptr<ngraph::Function> get(
+        const ngraph::element::Type precision,
+        const ngraph::Shape& inputShape,
+        const ngraph::pass::low_precision::LayerTransformation::Params& params,
+        const ngraph::element::Type& precision1,
+        const ngraph::builder::subgraph::DequantizationOperations& dequantization1,
+        const ngraph::element::Type& precision2,
+        const ngraph::builder::subgraph::DequantizationOperations& dequantization2);
+};
+
+}  // namespace subgraph
+}  // namespace builder
+}  // namespace ngraph
index 20c3026..80f61af 100644 (file)
@@ -8,6 +8,7 @@
 #include <ngraph/ngraph.hpp>
 #include "common/fake_quantize_on_data.hpp"
 #include "low_precision/layer_transformation.hpp"
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
 
 namespace ngraph {
 namespace builder {
@@ -15,34 +16,16 @@ namespace subgraph {
 
 class MaxPoolFunction {
 public:
-    class ActualValues {
-    public:
-        ngraph::element::Type lowPrecision;
-        std::vector<float> subtractValues;
-        std::vector<float> mutliplyValues;
-    };
-
-    class ExpectedValues {
-    public:
-        ngraph::element::Type activationPrecision;
-        std::vector<float> subtractValues;
-        std::vector<float> mutliplyValues;
-    };
-
-    static std::shared_ptr<ngraph::Function> getOriginal(
-        const ngraph::element::Type originalFunctionPrecision,
-        const ngraph::Shape& inputShape,
-        const ActualValues& values);
-
     static std::shared_ptr<ngraph::Function> getOriginal(
         const ngraph::element::Type originalFunctionPrecision,
         const ngraph::Shape& inputShape,
         const FakeQuantizeOnData& fakeQuantizeOnData);
 
-    static std::shared_ptr<ngraph::Function> getReference(
-        const ngraph::element::Type originalFunctionPrecision,
+    static std::shared_ptr<ngraph::Function> get(
         const ngraph::Shape& inputShape,
-        const ExpectedValues& values);
+        const ngraph::element::Type precisionBeforeDequantization,
+        const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
+        const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter);
 };
 
 }  // namespace subgraph
diff --git a/inference-engine/tests/ngraph_functions/src/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.cpp b/inference-engine/tests/ngraph_functions/src/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.cpp
new file mode 100644 (file)
index 0000000..2bc5bf0
--- /dev/null
@@ -0,0 +1,60 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp"
+#include "low_precision/network_helper.hpp"
+
+#include <ngraph/opsets/opset1.hpp>
+#include "ngraph_functions/builders.hpp"
+#include "ngraph_functions/subgraph_builders.hpp"
+
+using namespace ngraph::pass::low_precision;
+
+namespace ngraph {
+namespace builder {
+namespace subgraph {
+
+std::shared_ptr<ngraph::Function> ElementwiseWithMultiParentDequantizationFunction::get(
+    const ngraph::element::Type precision,
+    const ngraph::Shape& inputShape,
+    const ngraph::pass::low_precision::LayerTransformation::Params& params,
+    const ngraph::element::Type& precision1,
+    const ngraph::builder::subgraph::DequantizationOperations& dequantization1,
+    const ngraph::element::Type& precision2,
+    const ngraph::builder::subgraph::DequantizationOperations& dequantization2) {
+    const auto input1_1 = std::make_shared<ngraph::opset1::Parameter>(precision1, inputShape);
+    const auto input1_2 = std::make_shared<ngraph::opset1::Parameter>(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }));
+    const std::shared_ptr<ngraph::Node> multiply1 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
+        DequantizationMultiply(
+            ngraph::op::TemporaryReplaceOutputType(input1_1, element::f32).get(),
+            ngraph::op::TemporaryReplaceOutputType(input1_2, element::f32).get()),
+        std::vector<element::Type>{element::f32, element::f32},
+        std::vector<element::Type>{});
+
+    const std::shared_ptr<ngraph::Node> parent1 = dequantization1.empty() ? multiply1 : makeDequantization(multiply1, dequantization1);
+
+    const auto input2_1 = std::make_shared<ngraph::opset1::Parameter>(precision1, inputShape);
+    const auto input2_2 = std::make_shared<ngraph::opset1::Parameter>(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }));
+    const std::shared_ptr<ngraph::Node> multiply2 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
+        DequantizationMultiply(
+            ngraph::op::TemporaryReplaceOutputType(input2_1, element::f32).get(),
+            ngraph::op::TemporaryReplaceOutputType(input2_2, element::f32).get()),
+        std::vector<element::Type>{element::f32, element::f32},
+        std::vector<element::Type>{});
+
+    const std::shared_ptr<ngraph::Node> parent2 = dequantization2.empty() ? multiply2 : makeDequantization(multiply2, dequantization2);
+
+    const auto add = std::make_shared<ngraph::opset1::Add>(parent1, parent2);
+    add->set_friendly_name("output");
+    auto& rtInfo = add->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("add");
+
+    ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(add) };
+    ngraph::ParameterVector parameters = { input1_1, input1_2, input2_1, input2_2 };
+    return std::make_shared<ngraph::Function>(results, parameters, "ElementwiseWithMultiParentDequantization");
+}
+
+}  // namespace subgraph
+}  // namespace builder
+}  // namespace ngraph
index 64eebf8..7296028 100644 (file)
@@ -17,41 +17,6 @@ namespace subgraph {
 std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
     const ngraph::element::Type originalFunctionPrecision,
     const ngraph::Shape& inputShape,
-    const ActualValues& values) {
-    const auto input = std::make_shared<ngraph::opset1::Parameter>(values.lowPrecision, ngraph::Shape(inputShape));
-    std::shared_ptr<ngraph::Node> parent = input;
-
-    const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::opset1::Convert>(parent, originalFunctionPrecision);
-    parent = convert;
-
-    if (!values.subtractValues.empty()) {
-        const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::opset1::Subtract>(
-            parent,
-            std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues));
-        parent = subtract;
-    }
-
-    const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::opset1::Multiply>(
-        parent,
-        std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues));
-    parent = multiply;
-
-    const std::shared_ptr<ngraph::Node> maxPool = std::make_shared<ngraph::opset1::MaxPool>(
-        parent,
-        Strides{ 1, 1 },
-        Shape{ 1, 1 },
-        Shape{ 0, 0 },
-        Shape{ 2, 2 },
-        op::RoundingType::FLOOR);
-    maxPool->set_friendly_name("output");
-
-    ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(maxPool) };
-    return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
-}
-
-std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
-    const ngraph::element::Type originalFunctionPrecision,
-    const ngraph::Shape& inputShape,
     const FakeQuantizeOnData& fakeQuantizeOnData) {
     const auto input = std::make_shared<ngraph::opset1::Parameter>(originalFunctionPrecision, ngraph::Shape(inputShape));
 
@@ -71,13 +36,16 @@ std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
 }
 
-std::shared_ptr<ngraph::Function> MaxPoolFunction::getReference(
-    const ngraph::element::Type originalFunctionPrecision,
+std::shared_ptr<ngraph::Function> MaxPoolFunction::get(
     const ngraph::Shape& inputShape,
-    const ExpectedValues& values) {
-    auto input = std::make_shared<ngraph::opset1::Parameter>(values.activationPrecision, ngraph::Shape(inputShape));
+    const ngraph::element::Type precisionBeforeDequantization,
+    const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
+    const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter) {
+    const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, ngraph::Shape(inputShape));
     std::shared_ptr<ngraph::Node> parent = input;
 
+    parent = dequantizationBefore.empty() ? parent : makeDequantization(parent, dequantizationBefore);
+
     const std::shared_ptr<ngraph::Node> maxPool = std::make_shared<ngraph::opset1::MaxPool>(
         parent,
         Strides{ 1, 1 },
@@ -87,25 +55,16 @@ std::shared_ptr<ngraph::Function> MaxPoolFunction::getReference(
         op::RoundingType::FLOOR);
     parent = maxPool;
 
-    if (parent->get_output_element_type(0) != originalFunctionPrecision) {
-        const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::pass::low_precision::DequantizationConvert>(parent, originalFunctionPrecision);
-        parent = convert;
-    }
+    parent = dequantizationAfter.empty() ? maxPool : makeDequantization(maxPool, dequantizationAfter);
+    maxPool->set_friendly_name("maxPool");
 
-    if (!values.subtractValues.empty()) {
-        const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(
-            parent,
-            std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues));
-        parent = subtract;
-    }
+    const std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(parent);
 
-    const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(
-        parent,
-        std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues));
-    multiply->set_friendly_name("output");
-
-    ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
-    return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
+    const std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
+        ngraph::ResultVector{ result },
+        std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
+        "MaxPoolTransformation");
+    return function;
 }
 
 }  // namespace subgraph