1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ngraph_functions/low_precision_transformations/max_pool_function.hpp"
7 #include <ngraph/opsets/opset1.hpp>
8 #include <ngraph_ops/type_relaxed.hpp>
9 #include "ngraph_functions/subgraph_builders.hpp"
10 #include "low_precision/network_helper.hpp"
16 std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
17 const ngraph::element::Type originalFunctionPrecision,
18 const ngraph::Shape& inputShape,
19 const ActualValues& values) {
20 const auto input = std::make_shared<ngraph::opset1::Parameter>(values.lowPrecision, ngraph::Shape(inputShape));
21 std::shared_ptr<ngraph::Node> parent = input;
23 const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::opset1::Convert>(parent, originalFunctionPrecision);
26 if (!values.subtractValues.empty()) {
27 const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::opset1::Subtract>(
29 std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues));
33 const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::opset1::Multiply>(
35 std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues));
38 const std::shared_ptr<ngraph::Node> maxPool = std::make_shared<ngraph::opset1::MaxPool>(
44 op::RoundingType::FLOOR);
45 maxPool->set_friendly_name("output");
47 ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(maxPool) };
48 return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
51 std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
52 const ngraph::element::Type originalFunctionPrecision,
53 const ngraph::Shape& inputShape,
54 const FakeQuantizeOnData& fakeQuantizeOnData) {
55 const auto input = std::make_shared<ngraph::opset1::Parameter>(originalFunctionPrecision, ngraph::Shape(inputShape));
57 const auto fakeQuantize = ngraph::builder::makeFakeQuantize(
58 input, originalFunctionPrecision, fakeQuantizeOnData.quantizationLevel, fakeQuantizeOnData.constantShape,
59 fakeQuantizeOnData.inputLowValues, fakeQuantizeOnData.inputHighValues, fakeQuantizeOnData.outputLowValues, fakeQuantizeOnData.outputHighValues);
61 const std::shared_ptr<ngraph::Node> maxPool = std::make_shared<ngraph::opset1::MaxPool>(
67 op::RoundingType::FLOOR);
69 ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(maxPool) };
70 return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
73 std::shared_ptr<ngraph::Function> MaxPoolFunction::getReference(
74 const ngraph::element::Type originalFunctionPrecision,
75 const ngraph::Shape& inputShape,
76 const ExpectedValues& values) {
77 auto input = std::make_shared<ngraph::opset1::Parameter>(values.activationPrecision, ngraph::Shape(inputShape));
78 std::shared_ptr<ngraph::Node> parent = input;
80 const std::shared_ptr<ngraph::Node> maxPool = std::make_shared<ngraph::opset1::MaxPool>(
86 op::RoundingType::FLOOR);
89 if (parent->get_output_element_type(0) != originalFunctionPrecision) {
90 const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::opset1::Convert>(parent, originalFunctionPrecision);
94 if (!values.subtractValues.empty()) {
95 const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::opset1::Subtract>(
97 std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues));
101 const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::opset1::Multiply>(
103 std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues));
104 multiply->set_friendly_name("output");
106 ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
107 return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
110 } // namespace subgraph
111 } // namespace builder
112 } // namespace ngraph