From 4a362bddc531b907c1d7fa3cd6e65740806c28b6 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Fri, 13 Nov 2020 10:32:59 +0300 Subject: [PATCH] [LPT] POT support: absent convert fix & element-wise empty dequantization data (#3067) --- .../src/common/eltwise_base_transformation.cpp | 6 +- .../src/common/network_helper.cpp | 5 +- ..._multi_parent_dequantization_transformation.cpp | 161 +++++++++++++++++++ .../lp_transformations/max_pool_transformation.cpp | 171 ++++++++++++++++----- ...e_with_multi_parent_dequantization_function.hpp | 71 +++++++++ .../max_pool_function.hpp | 27 +--- ...e_with_multi_parent_dequantization_function.cpp | 60 ++++++++ .../max_pool_function.cpp | 71 ++------- 8 files changed, 451 insertions(+), 121 deletions(-) create mode 100644 inference-engine/tests/functional/inference_engine/lp_transformations/elementwise_with_multi_parent_dequantization_transformation.cpp create mode 100644 inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp create mode 100644 inference-engine/tests/ngraph_functions/src/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.cpp diff --git a/inference-engine/src/low_precision_transformations/src/common/eltwise_base_transformation.cpp b/inference-engine/src/low_precision_transformations/src/common/eltwise_base_transformation.cpp index aa4a869..155e4ba 100644 --- a/inference-engine/src/low_precision_transformations/src/common/eltwise_base_transformation.cpp +++ b/inference-engine/src/low_precision_transformations/src/common/eltwise_base_transformation.cpp @@ -69,11 +69,13 @@ bool EltwiseBaseTransformation::canBeTransformed(const TransformationContext& co return false; } - if (dequantization1.empty() && !is_type(dequantization1.data.get_node_shared_ptr())) { + if ((dequantization1.data.get_node() == nullptr) || + (dequantization1.empty() && !is_type(dequantization1.data.get_node_shared_ptr()))) { return false; } - if (dequantization2.empty() && !is_type(dequantization2.data.get_node_shared_ptr())) { + if ((dequantization2.data.get_node() == nullptr) || + (dequantization2.empty() && !is_type(dequantization2.data.get_node_shared_ptr()))) { return false; } diff --git a/inference-engine/src/low_precision_transformations/src/common/network_helper.cpp b/inference-engine/src/low_precision_transformations/src/common/network_helper.cpp index 86159ed..8bd84d5 100644 --- a/inference-engine/src/low_precision_transformations/src/common/network_helper.cpp +++ b/inference-engine/src/low_precision_transformations/src/common/network_helper.cpp @@ -948,7 +948,10 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter auto parent = newOperation; if (shouldConvert) { - parent = std::make_shared(parent, dequantization.convert->get_output_element_type(0)); + const auto convertOutputPrecision = dequantization.convert != nullptr ? + dequantization.convert->get_output_element_type(0) : + dequantization.multiply->get_output_element_type(0); + parent = std::make_shared(parent, convertOutputPrecision); ngraph::copy_runtime_info({ newOperation, parent }, parent); } diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/elementwise_with_multi_parent_dequantization_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/elementwise_with_multi_parent_dequantization_transformation.cpp new file mode 100644 index 0000000..490b0ca --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/elementwise_with_multi_parent_dequantization_transformation.cpp @@ -0,0 +1,161 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "layer_transformation.hpp" + +#include +#include +#include + +#include + +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" +#include "simple_low_precision_transformer.hpp" + +#include +#include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp" +#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp" + +using namespace testing; +using namespace ngraph::pass; +using namespace ngraph::builder::subgraph; + +class ElementwiseWithMultiParentDequantizationTransformationTestValues { +public: + class Actual { + public: + ngraph::element::Type precision1; + ngraph::builder::subgraph::DequantizationOperations dequantization1; + ngraph::element::Type precision2; + ngraph::builder::subgraph::DequantizationOperations dequantization2; + }; + + class Expected { + public: + ngraph::element::Type precision1; + ngraph::builder::subgraph::DequantizationOperations dequantization1; + ngraph::element::Type precision2; + ngraph::builder::subgraph::DequantizationOperations dequantization2; + }; + + ngraph::element::Type precision; + ngraph::Shape inputShape; + ngraph::pass::low_precision::LayerTransformation::Params params; + Actual actual; + Expected expected; +}; + +template +inline std::ostream& operator<<(std::ostream& os, const std::vector& values) { + os << "{ "; + for (size_t i = 0; i < values.size(); ++i) { + os << values[i]; + if (i != (values.size() - 1ul)) { + os << ", "; + } + } + os << " }"; + return os; +} + +class ElementwiseWithMultiParentDequantizationTransformation : + public LayerTransformation, + public testing::WithParamInterface { +public: + void SetUp() override { + const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = GetParam(); + + actualFunction = ElementwiseWithMultiParentDequantizationFunction::get( + testValues.precision, + testValues.inputShape, + testValues.params, + testValues.actual.precision1, + testValues.actual.dequantization1, + testValues.actual.precision2, + testValues.actual.dequantization2); + + SimpleLowPrecisionTransformer transform; + transform.add( + low_precision::LayerTransformation::Params(testValues.params)); + transform.transform(actualFunction); + + referenceFunction = ElementwiseWithMultiParentDequantizationFunction::get( + testValues.precision, + testValues.inputShape, + testValues.params, + testValues.expected.precision1, + testValues.expected.dequantization1, + testValues.expected.precision2, + testValues.expected.dequantization2); + } + + static std::string getTestCaseName(testing::TestParamInfo obj) { + const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = obj.param; + + std::ostringstream result; + result << + testValues.precision << "_" << + testValues.inputShape << "_" << + testValues.actual.precision1 << "_" << + testValues.actual.dequantization1 << "_" << + testValues.actual.precision2 << "_" << + testValues.actual.dequantization2; + return result.str(); + } +}; + +TEST_P(ElementwiseWithMultiParentDequantizationTransformation, CompareFunctions) { + actualFunction->validate_nodes_and_infer_types(); + auto res = compare_functions(referenceFunction, actualFunction, true, true, true); + ASSERT_TRUE(res.first) << res.second; +} + +const std::vector addTransformationTestValues = { + // U8 + { + ngraph::element::f32, + ngraph::Shape{1, 4, 16, 16}, + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + { {ngraph::element::f32}, { 7.f }, { 10.f }}, + ngraph::element::u8, + {}, + }, + { + ngraph::element::u8, + { {ngraph::element::f32}, { 7.f }, { 10.f }}, + ngraph::element::u8, + {}, + } + }, + // U8 + { + ngraph::element::f32, + ngraph::Shape{1, 4, 16, 16}, + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + {}, + ngraph::element::u8, + { {ngraph::element::f32}, { 7.f }, { 10.f }} + }, + { + ngraph::element::u8, + {}, + ngraph::element::u8, + { {ngraph::element::f32}, { 7.f }, { 10.f }} + } + } +}; + +INSTANTIATE_TEST_CASE_P( + LPT, + ElementwiseWithMultiParentDequantizationTransformation, + ::testing::ValuesIn(addTransformationTestValues), + ElementwiseWithMultiParentDequantizationTransformation::getTestCaseName); diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/max_pool_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/max_pool_transformation.cpp index 7629827..e3650e3 100644 --- a/inference-engine/tests/functional/inference_engine/lp_transformations/max_pool_transformation.cpp +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/max_pool_transformation.cpp @@ -17,94 +17,185 @@ #include "common_test_utils/ngraph_test_utils.hpp" #include "simple_low_precision_transformer.hpp" #include "ngraph_functions/low_precision_transformations/max_pool_function.hpp" +#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp" + using namespace testing; using namespace ngraph::pass; class MaxPoolTransformationTestValues { public: - low_precision::LayerTransformation::Params params; - std::vector subtractValues; - std::vector mutliplyValues; + class Actual { + public: + ngraph::element::Type precisionBeforeDequantization; + ngraph::builder::subgraph::DequantizationOperations dequantization1; + ngraph::builder::subgraph::DequantizationOperations dequantization2; + }; + + class Expected { + public: + ngraph::element::Type precisionBeforeDequantization; + ngraph::builder::subgraph::DequantizationOperations dequantization1; + ngraph::builder::subgraph::DequantizationOperations dequantization2; + }; + + ngraph::pass::low_precision::LayerTransformation::Params params; + Actual actual; + Expected expected; }; typedef std::tuple< - ngraph::element::Type, ngraph::Shape, MaxPoolTransformationTestValues> MaxPoolTransformationParams; class MaxPoolTransformation : public LayerTransformation, public testing::WithParamInterface { public: void SetUp() override { - const ngraph::element::Type precision = std::get<0>(GetParam()); - const ngraph::Shape shape = std::get<1>(GetParam()); - const MaxPoolTransformationTestValues testValues = std::get<2>(GetParam()); + const ngraph::Shape shape = std::get<0>(GetParam()); + const MaxPoolTransformationTestValues testValues = std::get<1>(GetParam()); - actualFunction = ngraph::builder::subgraph::MaxPoolFunction::getOriginal( - precision, + actualFunction = ngraph::builder::subgraph::MaxPoolFunction::get( shape, - { - testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision, - testValues.subtractValues, - testValues.mutliplyValues - }); + testValues.actual.precisionBeforeDequantization, + testValues.actual.dequantization1, + testValues.actual.dequantization2); SimpleLowPrecisionTransformer transform; transform.add(testValues.params); transform.transform(actualFunction); - referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::getReference( - precision, + referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::get( shape, - { - testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision, - testValues.subtractValues, - testValues.mutliplyValues - }); + testValues.expected.precisionBeforeDequantization, + testValues.expected.dequantization1, + testValues.expected.dequantization2); } static std::string getTestCaseName(testing::TestParamInfo obj) { - const ngraph::element::Type precision = std::get<0>(obj.param); - const ngraph::Shape shape = std::get<1>(obj.param); - const MaxPoolTransformationTestValues testValues = std::get<2>(obj.param); + const ngraph::Shape shape = std::get<0>(obj.param); + const MaxPoolTransformationTestValues testValues = std::get<1>(obj.param); std::ostringstream result; result << - LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" << - testValues.subtractValues.size() << "_" << - testValues.mutliplyValues.size() << "_"; + LayerTransformation::getTestCaseNameByParams(testValues.actual.precisionBeforeDequantization, shape, testValues.params) << "_" << + testValues.actual.dequantization1 << "_" << + testValues.actual.dequantization2 << "_" << + testValues.expected.dequantization1 << "_" << + testValues.expected.dequantization2 << "_"; return result.str(); } }; TEST_P(MaxPoolTransformation, CompareFunctions) { actualFunction->validate_nodes_and_infer_types(); - auto res = compare_functions(referenceFunction, actualFunction, true, true, true); + auto res = compare_functions(referenceFunction, actualFunction, true, false, true); ASSERT_TRUE(res.first) << res.second; } -const std::vector precisions = { - ngraph::element::f32, - // ngraph::element::f16 -}; - const std::vector shapes = { - { 1, 32, 72, 48 } + { 1, 32, 72, 48 }, + { 4, 32, 72, 48 } }; const std::vector testValues = { - { LayerTransformation::createParamsU8I8(), { 128 }, { 0.02f } }, - { LayerTransformation::createParamsU8I8(), {}, { 0.02f } }, - { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), { 128 }, { 0.02f } }, - { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), {}, { 0.02f } }, - { LayerTransformation::createParamsI8I8(), { 128 }, { 0.02f } }, + // Multiply + { + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + { {}, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }}, + {} + }, + { + ngraph::element::u8, + {}, + { ngraph::element::f32, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }} + } + }, + // Subtract + Multiply + { + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + { + {}, + { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }, + { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 } + }, + {} + }, + { + ngraph::element::u8, + {}, + { + ngraph::element::f32, + { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }, + { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 } + } + } + }, + // Convert + Subtract + Multiply + { + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + { ngraph::element::f32, { 128 }, { 0.02f }}, + {} + }, + { + ngraph::element::u8, + {}, + { ngraph::element::f32, { 128 }, { 0.02f }} + } + }, + // Convert + Subtract + Multiply + { + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + { ngraph::element::f32, {}, { 0.02f }}, + {} + }, + { + ngraph::element::u8, + {}, + { ngraph::element::f32, {}, { 0.02f }} + } + }, + // Convert + Subtract + Multiply + { + LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), + { + ngraph::element::u8, + { ngraph::element::f32, { 128 }, { 0.02f }}, + {} + }, + { + ngraph::element::u8, + {}, + { ngraph::element::f32, { 128 }, { 0.02f }} + } + }, + // Convert + Subtract + Multiply + { + LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), + { + ngraph::element::u8, + { ngraph::element::f32, {}, { 0.02f }}, + {} + }, + { + ngraph::element::u8, + {}, + { ngraph::element::f32, {}, { 0.02f }} + } + } }; INSTANTIATE_TEST_CASE_P( LPT, MaxPoolTransformation, ::testing::Combine( - ::testing::ValuesIn(precisions), ::testing::ValuesIn(shapes), ::testing::ValuesIn(testValues)), MaxPoolTransformation::getTestCaseName); diff --git a/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp b/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp new file mode 100644 index 0000000..6f55b2d --- /dev/null +++ b/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp @@ -0,0 +1,71 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "functional_test_utils/low_precision_transformations/layer_transformation.hpp" +#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp" +#include "ngraph_functions/subgraph_builders.hpp" +#include "ngraph_functions/low_precision_transformations/common/builders.hpp" + +namespace ngraph { +namespace builder { +namespace subgraph { + +class AddActualValues { +public: + ngraph::element::Type precision1; + std::vector subtractValues1; + std::vector mutliplyValues1; + ngraph::element::Type precision2; + std::vector subtractValues2; + std::vector mutliplyValues2; +}; + +inline std::ostream& operator<<(std::ostream& out, const AddActualValues& values) { + return out << + "_" << values.precision1 << + "_subtract" << values.subtractValues1.size() << + "_mutliply" << values.mutliplyValues1.size() << + "_" << values.precision2 << + "_subtract" << values.subtractValues2.size() << + "_mutliply" << values.mutliplyValues2.size(); +} + +class AddExpectedValues { +public: + ngraph::element::Type precision1; + std::vector subtractValues1; + std::vector mutliplyValues1; + ngraph::element::Type precision2; + std::vector mutliplyValuesAfter; +}; + +inline std::ostream& operator<<(std::ostream& out, const AddExpectedValues& values) { + return out << + "_" << values.precision1 << + "_subtract" << values.subtractValues1.size() << + "_mutliply" << values.mutliplyValues1.size() << + "_" << values.precision2 << + "_mutliply" << values.mutliplyValuesAfter.size(); +} + +class ElementwiseWithMultiParentDequantizationFunction { +public: + static std::shared_ptr get( + const ngraph::element::Type precision, + const ngraph::Shape& inputShape, + const ngraph::pass::low_precision::LayerTransformation::Params& params, + const ngraph::element::Type& precision1, + const ngraph::builder::subgraph::DequantizationOperations& dequantization1, + const ngraph::element::Type& precision2, + const ngraph::builder::subgraph::DequantizationOperations& dequantization2); +}; + +} // namespace subgraph +} // namespace builder +} // namespace ngraph diff --git a/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/max_pool_function.hpp b/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/max_pool_function.hpp index 20c3026..80f61af 100644 --- a/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/max_pool_function.hpp +++ b/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/max_pool_function.hpp @@ -8,6 +8,7 @@ #include #include "common/fake_quantize_on_data.hpp" #include "low_precision/layer_transformation.hpp" +#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp" namespace ngraph { namespace builder { @@ -15,34 +16,16 @@ namespace subgraph { class MaxPoolFunction { public: - class ActualValues { - public: - ngraph::element::Type lowPrecision; - std::vector subtractValues; - std::vector mutliplyValues; - }; - - class ExpectedValues { - public: - ngraph::element::Type activationPrecision; - std::vector subtractValues; - std::vector mutliplyValues; - }; - - static std::shared_ptr getOriginal( - const ngraph::element::Type originalFunctionPrecision, - const ngraph::Shape& inputShape, - const ActualValues& values); - static std::shared_ptr getOriginal( const ngraph::element::Type originalFunctionPrecision, const ngraph::Shape& inputShape, const FakeQuantizeOnData& fakeQuantizeOnData); - static std::shared_ptr getReference( - const ngraph::element::Type originalFunctionPrecision, + static std::shared_ptr get( const ngraph::Shape& inputShape, - const ExpectedValues& values); + const ngraph::element::Type precisionBeforeDequantization, + const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore, + const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter); }; } // namespace subgraph diff --git a/inference-engine/tests/ngraph_functions/src/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.cpp b/inference-engine/tests/ngraph_functions/src/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.cpp new file mode 100644 index 0000000..2bc5bf0 --- /dev/null +++ b/inference-engine/tests/ngraph_functions/src/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.cpp @@ -0,0 +1,60 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp" +#include "low_precision/network_helper.hpp" + +#include +#include "ngraph_functions/builders.hpp" +#include "ngraph_functions/subgraph_builders.hpp" + +using namespace ngraph::pass::low_precision; + +namespace ngraph { +namespace builder { +namespace subgraph { + +std::shared_ptr ElementwiseWithMultiParentDequantizationFunction::get( + const ngraph::element::Type precision, + const ngraph::Shape& inputShape, + const ngraph::pass::low_precision::LayerTransformation::Params& params, + const ngraph::element::Type& precision1, + const ngraph::builder::subgraph::DequantizationOperations& dequantization1, + const ngraph::element::Type& precision2, + const ngraph::builder::subgraph::DequantizationOperations& dequantization2) { + const auto input1_1 = std::make_shared(precision1, inputShape); + const auto input1_2 = std::make_shared(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 })); + const std::shared_ptr multiply1 = std::make_shared>( + DequantizationMultiply( + ngraph::op::TemporaryReplaceOutputType(input1_1, element::f32).get(), + ngraph::op::TemporaryReplaceOutputType(input1_2, element::f32).get()), + std::vector{element::f32, element::f32}, + std::vector{}); + + const std::shared_ptr parent1 = dequantization1.empty() ? multiply1 : makeDequantization(multiply1, dequantization1); + + const auto input2_1 = std::make_shared(precision1, inputShape); + const auto input2_2 = std::make_shared(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 })); + const std::shared_ptr multiply2 = std::make_shared>( + DequantizationMultiply( + ngraph::op::TemporaryReplaceOutputType(input2_1, element::f32).get(), + ngraph::op::TemporaryReplaceOutputType(input2_2, element::f32).get()), + std::vector{element::f32, element::f32}, + std::vector{}); + + const std::shared_ptr parent2 = dequantization2.empty() ? multiply2 : makeDequantization(multiply2, dequantization2); + + const auto add = std::make_shared(parent1, parent2); + add->set_friendly_name("output"); + auto& rtInfo = add->get_rt_info(); + rtInfo["Variant::std::string"] = std::make_shared>("add"); + + ngraph::ResultVector results{ std::make_shared(add) }; + ngraph::ParameterVector parameters = { input1_1, input1_2, input2_1, input2_2 }; + return std::make_shared(results, parameters, "ElementwiseWithMultiParentDequantization"); +} + +} // namespace subgraph +} // namespace builder +} // namespace ngraph diff --git a/inference-engine/tests/ngraph_functions/src/low_precision_transformations/max_pool_function.cpp b/inference-engine/tests/ngraph_functions/src/low_precision_transformations/max_pool_function.cpp index 64eebf8..7296028 100644 --- a/inference-engine/tests/ngraph_functions/src/low_precision_transformations/max_pool_function.cpp +++ b/inference-engine/tests/ngraph_functions/src/low_precision_transformations/max_pool_function.cpp @@ -17,41 +17,6 @@ namespace subgraph { std::shared_ptr MaxPoolFunction::getOriginal( const ngraph::element::Type originalFunctionPrecision, const ngraph::Shape& inputShape, - const ActualValues& values) { - const auto input = std::make_shared(values.lowPrecision, ngraph::Shape(inputShape)); - std::shared_ptr parent = input; - - const std::shared_ptr convert = std::make_shared(parent, originalFunctionPrecision); - parent = convert; - - if (!values.subtractValues.empty()) { - const std::shared_ptr subtract = std::make_shared( - parent, - std::make_shared(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues)); - parent = subtract; - } - - const std::shared_ptr multiply = std::make_shared( - parent, - std::make_shared(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues)); - parent = multiply; - - const std::shared_ptr maxPool = std::make_shared( - parent, - Strides{ 1, 1 }, - Shape{ 1, 1 }, - Shape{ 0, 0 }, - Shape{ 2, 2 }, - op::RoundingType::FLOOR); - maxPool->set_friendly_name("output"); - - ngraph::ResultVector results{ std::make_shared(maxPool) }; - return std::make_shared(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation"); -} - -std::shared_ptr MaxPoolFunction::getOriginal( - const ngraph::element::Type originalFunctionPrecision, - const ngraph::Shape& inputShape, const FakeQuantizeOnData& fakeQuantizeOnData) { const auto input = std::make_shared(originalFunctionPrecision, ngraph::Shape(inputShape)); @@ -71,13 +36,16 @@ std::shared_ptr MaxPoolFunction::getOriginal( return std::make_shared(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation"); } -std::shared_ptr MaxPoolFunction::getReference( - const ngraph::element::Type originalFunctionPrecision, +std::shared_ptr MaxPoolFunction::get( const ngraph::Shape& inputShape, - const ExpectedValues& values) { - auto input = std::make_shared(values.activationPrecision, ngraph::Shape(inputShape)); + const ngraph::element::Type precisionBeforeDequantization, + const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore, + const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter) { + const auto input = std::make_shared(precisionBeforeDequantization, ngraph::Shape(inputShape)); std::shared_ptr parent = input; + parent = dequantizationBefore.empty() ? parent : makeDequantization(parent, dequantizationBefore); + const std::shared_ptr maxPool = std::make_shared( parent, Strides{ 1, 1 }, @@ -87,25 +55,16 @@ std::shared_ptr MaxPoolFunction::getReference( op::RoundingType::FLOOR); parent = maxPool; - if (parent->get_output_element_type(0) != originalFunctionPrecision) { - const std::shared_ptr convert = std::make_shared(parent, originalFunctionPrecision); - parent = convert; - } + parent = dequantizationAfter.empty() ? maxPool : makeDequantization(maxPool, dequantizationAfter); + maxPool->set_friendly_name("maxPool"); - if (!values.subtractValues.empty()) { - const std::shared_ptr subtract = std::make_shared( - parent, - std::make_shared(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues)); - parent = subtract; - } + const std::shared_ptr result = std::make_shared(parent); - const std::shared_ptr multiply = std::make_shared( - parent, - std::make_shared(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues)); - multiply->set_friendly_name("output"); - - ngraph::ResultVector results{ std::make_shared(multiply) }; - return std::make_shared(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation"); + const std::shared_ptr function = std::make_shared( + ngraph::ResultVector{ result }, + std::vector> { input }, + "MaxPoolTransformation"); + return function; } } // namespace subgraph -- 2.7.4