return false;
}
- if (dequantization1.empty() && !is_type<opset1::Constant>(dequantization1.data.get_node_shared_ptr())) {
+ if ((dequantization1.data.get_node() == nullptr) ||
+ (dequantization1.empty() && !is_type<opset1::Constant>(dequantization1.data.get_node_shared_ptr()))) {
return false;
}
- if (dequantization2.empty() && !is_type<opset1::Constant>(dequantization2.data.get_node_shared_ptr())) {
+ if ((dequantization2.data.get_node() == nullptr) ||
+ (dequantization2.empty() && !is_type<opset1::Constant>(dequantization2.data.get_node_shared_ptr()))) {
return false;
}
auto parent = newOperation;
if (shouldConvert) {
- parent = std::make_shared<DequantizationConvert>(parent, dequantization.convert->get_output_element_type(0));
+ const auto convertOutputPrecision = dequantization.convert != nullptr ?
+ dequantization.convert->get_output_element_type(0) :
+ dequantization.multiply->get_output_element_type(0);
+ parent = std::make_shared<DequantizationConvert>(parent, convertOutputPrecision);
ngraph::copy_runtime_info({ newOperation, parent }, parent);
}
--- /dev/null
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "layer_transformation.hpp"
+
+#include <string>
+#include <sstream>
+#include <memory>
+
+#include <gtest/gtest.h>
+
+#include <utility>
+#include <transformations/utils/utils.hpp>
+#include <transformations/init_node_info.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+#include "simple_low_precision_transformer.hpp"
+
+#include <low_precision/add.hpp>
+#include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp"
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
+
+using namespace testing;
+using namespace ngraph::pass;
+using namespace ngraph::builder::subgraph;
+
+class ElementwiseWithMultiParentDequantizationTransformationTestValues {
+public:
+ class Actual {
+ public:
+ ngraph::element::Type precision1;
+ ngraph::builder::subgraph::DequantizationOperations dequantization1;
+ ngraph::element::Type precision2;
+ ngraph::builder::subgraph::DequantizationOperations dequantization2;
+ };
+
+ class Expected {
+ public:
+ ngraph::element::Type precision1;
+ ngraph::builder::subgraph::DequantizationOperations dequantization1;
+ ngraph::element::Type precision2;
+ ngraph::builder::subgraph::DequantizationOperations dequantization2;
+ };
+
+ ngraph::element::Type precision;
+ ngraph::Shape inputShape;
+ ngraph::pass::low_precision::LayerTransformation::Params params;
+ Actual actual;
+ Expected expected;
+};
+
+template <typename T>
+inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
+ os << "{ ";
+ for (size_t i = 0; i < values.size(); ++i) {
+ os << values[i];
+ if (i != (values.size() - 1ul)) {
+ os << ", ";
+ }
+ }
+ os << " }";
+ return os;
+}
+
+class ElementwiseWithMultiParentDequantizationTransformation :
+ public LayerTransformation,
+ public testing::WithParamInterface<ElementwiseWithMultiParentDequantizationTransformationTestValues> {
+public:
+ void SetUp() override {
+ const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = GetParam();
+
+ actualFunction = ElementwiseWithMultiParentDequantizationFunction::get(
+ testValues.precision,
+ testValues.inputShape,
+ testValues.params,
+ testValues.actual.precision1,
+ testValues.actual.dequantization1,
+ testValues.actual.precision2,
+ testValues.actual.dequantization2);
+
+ SimpleLowPrecisionTransformer transform;
+ transform.add<ngraph::pass::low_precision::AddTransformation, ngraph::opset1::Add>(
+ low_precision::LayerTransformation::Params(testValues.params));
+ transform.transform(actualFunction);
+
+ referenceFunction = ElementwiseWithMultiParentDequantizationFunction::get(
+ testValues.precision,
+ testValues.inputShape,
+ testValues.params,
+ testValues.expected.precision1,
+ testValues.expected.dequantization1,
+ testValues.expected.precision2,
+ testValues.expected.dequantization2);
+ }
+
+ static std::string getTestCaseName(testing::TestParamInfo<ElementwiseWithMultiParentDequantizationTransformationTestValues> obj) {
+ const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = obj.param;
+
+ std::ostringstream result;
+ result <<
+ testValues.precision << "_" <<
+ testValues.inputShape << "_" <<
+ testValues.actual.precision1 << "_" <<
+ testValues.actual.dequantization1 << "_" <<
+ testValues.actual.precision2 << "_" <<
+ testValues.actual.dequantization2;
+ return result.str();
+ }
+};
+
+TEST_P(ElementwiseWithMultiParentDequantizationTransformation, CompareFunctions) {
+ actualFunction->validate_nodes_and_infer_types();
+ auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+const std::vector<ElementwiseWithMultiParentDequantizationTransformationTestValues> addTransformationTestValues = {
+ // U8
+ {
+ ngraph::element::f32,
+ ngraph::Shape{1, 4, 16, 16},
+ LayerTransformation::createParamsU8I8(),
+ {
+ ngraph::element::u8,
+ { {ngraph::element::f32}, { 7.f }, { 10.f }},
+ ngraph::element::u8,
+ {},
+ },
+ {
+ ngraph::element::u8,
+ { {ngraph::element::f32}, { 7.f }, { 10.f }},
+ ngraph::element::u8,
+ {},
+ }
+ },
+ // U8
+ {
+ ngraph::element::f32,
+ ngraph::Shape{1, 4, 16, 16},
+ LayerTransformation::createParamsU8I8(),
+ {
+ ngraph::element::u8,
+ {},
+ ngraph::element::u8,
+ { {ngraph::element::f32}, { 7.f }, { 10.f }}
+ },
+ {
+ ngraph::element::u8,
+ {},
+ ngraph::element::u8,
+ { {ngraph::element::f32}, { 7.f }, { 10.f }}
+ }
+ }
+};
+
+INSTANTIATE_TEST_CASE_P(
+ LPT,
+ ElementwiseWithMultiParentDequantizationTransformation,
+ ::testing::ValuesIn(addTransformationTestValues),
+ ElementwiseWithMultiParentDequantizationTransformation::getTestCaseName);
#include "common_test_utils/ngraph_test_utils.hpp"
#include "simple_low_precision_transformer.hpp"
#include "ngraph_functions/low_precision_transformations/max_pool_function.hpp"
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
+
using namespace testing;
using namespace ngraph::pass;
class MaxPoolTransformationTestValues {
public:
- low_precision::LayerTransformation::Params params;
- std::vector<float> subtractValues;
- std::vector<float> mutliplyValues;
+ class Actual {
+ public:
+ ngraph::element::Type precisionBeforeDequantization;
+ ngraph::builder::subgraph::DequantizationOperations dequantization1;
+ ngraph::builder::subgraph::DequantizationOperations dequantization2;
+ };
+
+ class Expected {
+ public:
+ ngraph::element::Type precisionBeforeDequantization;
+ ngraph::builder::subgraph::DequantizationOperations dequantization1;
+ ngraph::builder::subgraph::DequantizationOperations dequantization2;
+ };
+
+ ngraph::pass::low_precision::LayerTransformation::Params params;
+ Actual actual;
+ Expected expected;
};
typedef std::tuple<
- ngraph::element::Type,
ngraph::Shape,
MaxPoolTransformationTestValues> MaxPoolTransformationParams;
class MaxPoolTransformation : public LayerTransformation, public testing::WithParamInterface<MaxPoolTransformationParams> {
public:
void SetUp() override {
- const ngraph::element::Type precision = std::get<0>(GetParam());
- const ngraph::Shape shape = std::get<1>(GetParam());
- const MaxPoolTransformationTestValues testValues = std::get<2>(GetParam());
+ const ngraph::Shape shape = std::get<0>(GetParam());
+ const MaxPoolTransformationTestValues testValues = std::get<1>(GetParam());
- actualFunction = ngraph::builder::subgraph::MaxPoolFunction::getOriginal(
- precision,
+ actualFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
shape,
- {
- testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
- testValues.subtractValues,
- testValues.mutliplyValues
- });
+ testValues.actual.precisionBeforeDequantization,
+ testValues.actual.dequantization1,
+ testValues.actual.dequantization2);
SimpleLowPrecisionTransformer transform;
transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
transform.transform(actualFunction);
- referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::getReference(
- precision,
+ referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
shape,
- {
- testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
- testValues.subtractValues,
- testValues.mutliplyValues
- });
+ testValues.expected.precisionBeforeDequantization,
+ testValues.expected.dequantization1,
+ testValues.expected.dequantization2);
}
static std::string getTestCaseName(testing::TestParamInfo<MaxPoolTransformationParams> obj) {
- const ngraph::element::Type precision = std::get<0>(obj.param);
- const ngraph::Shape shape = std::get<1>(obj.param);
- const MaxPoolTransformationTestValues testValues = std::get<2>(obj.param);
+ const ngraph::Shape shape = std::get<0>(obj.param);
+ const MaxPoolTransformationTestValues testValues = std::get<1>(obj.param);
std::ostringstream result;
result <<
- LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
- testValues.subtractValues.size() << "_" <<
- testValues.mutliplyValues.size() << "_";
+ LayerTransformation::getTestCaseNameByParams(testValues.actual.precisionBeforeDequantization, shape, testValues.params) << "_" <<
+ testValues.actual.dequantization1 << "_" <<
+ testValues.actual.dequantization2 << "_" <<
+ testValues.expected.dequantization1 << "_" <<
+ testValues.expected.dequantization2 << "_";
return result.str();
}
};
TEST_P(MaxPoolTransformation, CompareFunctions) {
actualFunction->validate_nodes_and_infer_types();
- auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
+ auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
ASSERT_TRUE(res.first) << res.second;
}
-const std::vector<ngraph::element::Type> precisions = {
- ngraph::element::f32,
- // ngraph::element::f16
-};
-
const std::vector<ngraph::Shape> shapes = {
- { 1, 32, 72, 48 }
+ { 1, 32, 72, 48 },
+ { 4, 32, 72, 48 }
};
const std::vector<MaxPoolTransformationTestValues> testValues = {
- { LayerTransformation::createParamsU8I8(), { 128 }, { 0.02f } },
- { LayerTransformation::createParamsU8I8(), {}, { 0.02f } },
- { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), { 128 }, { 0.02f } },
- { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), {}, { 0.02f } },
- { LayerTransformation::createParamsI8I8(), { 128 }, { 0.02f } },
+ // Multiply
+ {
+ LayerTransformation::createParamsU8I8(),
+ {
+ ngraph::element::u8,
+ { {}, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }},
+ {}
+ },
+ {
+ ngraph::element::u8,
+ {},
+ { ngraph::element::f32, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }}
+ }
+ },
+ // Subtract + Multiply
+ {
+ LayerTransformation::createParamsU8I8(),
+ {
+ ngraph::element::u8,
+ {
+ {},
+ { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
+ { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
+ },
+ {}
+ },
+ {
+ ngraph::element::u8,
+ {},
+ {
+ ngraph::element::f32,
+ { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
+ { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
+ }
+ }
+ },
+ // Convert + Subtract + Multiply
+ {
+ LayerTransformation::createParamsU8I8(),
+ {
+ ngraph::element::u8,
+ { ngraph::element::f32, { 128 }, { 0.02f }},
+ {}
+ },
+ {
+ ngraph::element::u8,
+ {},
+ { ngraph::element::f32, { 128 }, { 0.02f }}
+ }
+ },
+ // Convert + Subtract + Multiply
+ {
+ LayerTransformation::createParamsU8I8(),
+ {
+ ngraph::element::u8,
+ { ngraph::element::f32, {}, { 0.02f }},
+ {}
+ },
+ {
+ ngraph::element::u8,
+ {},
+ { ngraph::element::f32, {}, { 0.02f }}
+ }
+ },
+ // Convert + Subtract + Multiply
+ {
+ LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
+ {
+ ngraph::element::u8,
+ { ngraph::element::f32, { 128 }, { 0.02f }},
+ {}
+ },
+ {
+ ngraph::element::u8,
+ {},
+ { ngraph::element::f32, { 128 }, { 0.02f }}
+ }
+ },
+ // Convert + Subtract + Multiply
+ {
+ LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
+ {
+ ngraph::element::u8,
+ { ngraph::element::f32, {}, { 0.02f }},
+ {}
+ },
+ {
+ ngraph::element::u8,
+ {},
+ { ngraph::element::f32, {}, { 0.02f }}
+ }
+ }
};
INSTANTIATE_TEST_CASE_P(
LPT,
MaxPoolTransformation,
::testing::Combine(
- ::testing::ValuesIn(precisions),
::testing::ValuesIn(shapes),
::testing::ValuesIn(testValues)),
MaxPoolTransformation::getTestCaseName);
--- /dev/null
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+#include <ngraph/ngraph.hpp>
+
+#include "functional_test_utils/low_precision_transformations/layer_transformation.hpp"
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
+#include "ngraph_functions/subgraph_builders.hpp"
+#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
+
+namespace ngraph {
+namespace builder {
+namespace subgraph {
+
+class AddActualValues {
+public:
+ ngraph::element::Type precision1;
+ std::vector<float> subtractValues1;
+ std::vector<float> mutliplyValues1;
+ ngraph::element::Type precision2;
+ std::vector<float> subtractValues2;
+ std::vector<float> mutliplyValues2;
+};
+
+inline std::ostream& operator<<(std::ostream& out, const AddActualValues& values) {
+ return out <<
+ "_" << values.precision1 <<
+ "_subtract" << values.subtractValues1.size() <<
+ "_mutliply" << values.mutliplyValues1.size() <<
+ "_" << values.precision2 <<
+ "_subtract" << values.subtractValues2.size() <<
+ "_mutliply" << values.mutliplyValues2.size();
+}
+
+class AddExpectedValues {
+public:
+ ngraph::element::Type precision1;
+ std::vector<float> subtractValues1;
+ std::vector<float> mutliplyValues1;
+ ngraph::element::Type precision2;
+ std::vector<float> mutliplyValuesAfter;
+};
+
+inline std::ostream& operator<<(std::ostream& out, const AddExpectedValues& values) {
+ return out <<
+ "_" << values.precision1 <<
+ "_subtract" << values.subtractValues1.size() <<
+ "_mutliply" << values.mutliplyValues1.size() <<
+ "_" << values.precision2 <<
+ "_mutliply" << values.mutliplyValuesAfter.size();
+}
+
+class ElementwiseWithMultiParentDequantizationFunction {
+public:
+ static std::shared_ptr<ngraph::Function> get(
+ const ngraph::element::Type precision,
+ const ngraph::Shape& inputShape,
+ const ngraph::pass::low_precision::LayerTransformation::Params& params,
+ const ngraph::element::Type& precision1,
+ const ngraph::builder::subgraph::DequantizationOperations& dequantization1,
+ const ngraph::element::Type& precision2,
+ const ngraph::builder::subgraph::DequantizationOperations& dequantization2);
+};
+
+} // namespace subgraph
+} // namespace builder
+} // namespace ngraph
#include <ngraph/ngraph.hpp>
#include "common/fake_quantize_on_data.hpp"
#include "low_precision/layer_transformation.hpp"
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
namespace ngraph {
namespace builder {
class MaxPoolFunction {
public:
- class ActualValues {
- public:
- ngraph::element::Type lowPrecision;
- std::vector<float> subtractValues;
- std::vector<float> mutliplyValues;
- };
-
- class ExpectedValues {
- public:
- ngraph::element::Type activationPrecision;
- std::vector<float> subtractValues;
- std::vector<float> mutliplyValues;
- };
-
- static std::shared_ptr<ngraph::Function> getOriginal(
- const ngraph::element::Type originalFunctionPrecision,
- const ngraph::Shape& inputShape,
- const ActualValues& values);
-
static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::element::Type originalFunctionPrecision,
const ngraph::Shape& inputShape,
const FakeQuantizeOnData& fakeQuantizeOnData);
- static std::shared_ptr<ngraph::Function> getReference(
- const ngraph::element::Type originalFunctionPrecision,
+ static std::shared_ptr<ngraph::Function> get(
const ngraph::Shape& inputShape,
- const ExpectedValues& values);
+ const ngraph::element::Type precisionBeforeDequantization,
+ const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
+ const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter);
};
} // namespace subgraph
--- /dev/null
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp"
+#include "low_precision/network_helper.hpp"
+
+#include <ngraph/opsets/opset1.hpp>
+#include "ngraph_functions/builders.hpp"
+#include "ngraph_functions/subgraph_builders.hpp"
+
+using namespace ngraph::pass::low_precision;
+
+namespace ngraph {
+namespace builder {
+namespace subgraph {
+
+std::shared_ptr<ngraph::Function> ElementwiseWithMultiParentDequantizationFunction::get(
+ const ngraph::element::Type precision,
+ const ngraph::Shape& inputShape,
+ const ngraph::pass::low_precision::LayerTransformation::Params& params,
+ const ngraph::element::Type& precision1,
+ const ngraph::builder::subgraph::DequantizationOperations& dequantization1,
+ const ngraph::element::Type& precision2,
+ const ngraph::builder::subgraph::DequantizationOperations& dequantization2) {
+ const auto input1_1 = std::make_shared<ngraph::opset1::Parameter>(precision1, inputShape);
+ const auto input1_2 = std::make_shared<ngraph::opset1::Parameter>(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }));
+ const std::shared_ptr<ngraph::Node> multiply1 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
+ DequantizationMultiply(
+ ngraph::op::TemporaryReplaceOutputType(input1_1, element::f32).get(),
+ ngraph::op::TemporaryReplaceOutputType(input1_2, element::f32).get()),
+ std::vector<element::Type>{element::f32, element::f32},
+ std::vector<element::Type>{});
+
+ const std::shared_ptr<ngraph::Node> parent1 = dequantization1.empty() ? multiply1 : makeDequantization(multiply1, dequantization1);
+
+ const auto input2_1 = std::make_shared<ngraph::opset1::Parameter>(precision1, inputShape);
+ const auto input2_2 = std::make_shared<ngraph::opset1::Parameter>(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }));
+ const std::shared_ptr<ngraph::Node> multiply2 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
+ DequantizationMultiply(
+ ngraph::op::TemporaryReplaceOutputType(input2_1, element::f32).get(),
+ ngraph::op::TemporaryReplaceOutputType(input2_2, element::f32).get()),
+ std::vector<element::Type>{element::f32, element::f32},
+ std::vector<element::Type>{});
+
+ const std::shared_ptr<ngraph::Node> parent2 = dequantization2.empty() ? multiply2 : makeDequantization(multiply2, dequantization2);
+
+ const auto add = std::make_shared<ngraph::opset1::Add>(parent1, parent2);
+ add->set_friendly_name("output");
+ auto& rtInfo = add->get_rt_info();
+ rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("add");
+
+ ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(add) };
+ ngraph::ParameterVector parameters = { input1_1, input1_2, input2_1, input2_2 };
+ return std::make_shared<ngraph::Function>(results, parameters, "ElementwiseWithMultiParentDequantization");
+}
+
+} // namespace subgraph
+} // namespace builder
+} // namespace ngraph
std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
const ngraph::element::Type originalFunctionPrecision,
const ngraph::Shape& inputShape,
- const ActualValues& values) {
- const auto input = std::make_shared<ngraph::opset1::Parameter>(values.lowPrecision, ngraph::Shape(inputShape));
- std::shared_ptr<ngraph::Node> parent = input;
-
- const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::opset1::Convert>(parent, originalFunctionPrecision);
- parent = convert;
-
- if (!values.subtractValues.empty()) {
- const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::opset1::Subtract>(
- parent,
- std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues));
- parent = subtract;
- }
-
- const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::opset1::Multiply>(
- parent,
- std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues));
- parent = multiply;
-
- const std::shared_ptr<ngraph::Node> maxPool = std::make_shared<ngraph::opset1::MaxPool>(
- parent,
- Strides{ 1, 1 },
- Shape{ 1, 1 },
- Shape{ 0, 0 },
- Shape{ 2, 2 },
- op::RoundingType::FLOOR);
- maxPool->set_friendly_name("output");
-
- ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(maxPool) };
- return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
-}
-
-std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
- const ngraph::element::Type originalFunctionPrecision,
- const ngraph::Shape& inputShape,
const FakeQuantizeOnData& fakeQuantizeOnData) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(originalFunctionPrecision, ngraph::Shape(inputShape));
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
}
-std::shared_ptr<ngraph::Function> MaxPoolFunction::getReference(
- const ngraph::element::Type originalFunctionPrecision,
+std::shared_ptr<ngraph::Function> MaxPoolFunction::get(
const ngraph::Shape& inputShape,
- const ExpectedValues& values) {
- auto input = std::make_shared<ngraph::opset1::Parameter>(values.activationPrecision, ngraph::Shape(inputShape));
+ const ngraph::element::Type precisionBeforeDequantization,
+ const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
+ const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter) {
+ const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, ngraph::Shape(inputShape));
std::shared_ptr<ngraph::Node> parent = input;
+ parent = dequantizationBefore.empty() ? parent : makeDequantization(parent, dequantizationBefore);
+
const std::shared_ptr<ngraph::Node> maxPool = std::make_shared<ngraph::opset1::MaxPool>(
parent,
Strides{ 1, 1 },
op::RoundingType::FLOOR);
parent = maxPool;
- if (parent->get_output_element_type(0) != originalFunctionPrecision) {
- const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::pass::low_precision::DequantizationConvert>(parent, originalFunctionPrecision);
- parent = convert;
- }
+ parent = dequantizationAfter.empty() ? maxPool : makeDequantization(maxPool, dequantizationAfter);
+ maxPool->set_friendly_name("maxPool");
- if (!values.subtractValues.empty()) {
- const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(
- parent,
- std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues));
- parent = subtract;
- }
+ const std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(parent);
- const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(
- parent,
- std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues));
- multiply->set_friendly_name("output");
-
- ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
- return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
+ const std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
+ ngraph::ResultVector{ result },
+ std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
+ "MaxPoolTransformation");
+ return function;
}
} // namespace subgraph