1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ngraph_functions/low_precision_transformations/move_dequantization_after_function.hpp"
6 #include "low_precision/network_helper.hpp"
8 #include <ngraph/opsets/opset1.hpp>
9 #include "ngraph_functions/subgraph_builders.hpp"
10 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
12 using namespace ngraph::pass::low_precision;
17 std::shared_ptr<ngraph::Function> MoveDequantizationAfterFunction::getOriginal(
18 const ngraph::element::Type precision,
19 const ngraph::Shape& inputShape,
20 const ngraph::builder::subgraph::DequantizationOperations dequantization) {
21 const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
23 const auto deq = makeDequantization(input, dequantization);
24 const auto op = ngraph::opset1::MaxPool(
30 op::RoundingType::FLOOR);
31 const auto targetOp = std::make_shared<op::TypeRelaxed<opset1::MaxPool>>(
33 std::vector<element::Type>{ element::f32, element::f32 },
34 std::vector<element::Type>{});
35 auto& rtInfo = targetOp->get_rt_info();
36 rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("targetOp");
38 return std::make_shared<ngraph::Function>(
39 ngraph::ResultVector{ std::make_shared<ngraph::opset1::Result>(targetOp) },
40 ngraph::ParameterVector{ input },
41 "MoveDequantizationAfterFunction");
44 std::shared_ptr<ngraph::Function> MoveDequantizationAfterFunction::getReference(
45 const ngraph::element::Type precision,
46 const ngraph::Shape& inputShape,
47 const ngraph::builder::subgraph::DequantizationOperations dequantizationBefore,
48 const ngraph::element::Type precisionAfterOperation,
49 const ngraph::builder::subgraph::DequantizationOperations dequantizationAfter) {
50 const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
52 const auto deqBefore = makeDequantization(input, dequantizationBefore);
53 const auto op = ngraph::opset1::MaxPool(
59 op::RoundingType::FLOOR);
60 const auto targetOp = std::make_shared<op::TypeRelaxed<opset1::MaxPool>>(
62 std::vector<element::Type>{ element::f32, element::f32 },
63 std::vector<element::Type>{});
64 ngraph::pass::low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(targetOp, precisionAfterOperation);
65 auto& rtInfo = targetOp->get_rt_info();
66 rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("targetOp");
68 const auto deqAfter = makeDequantization(targetOp, dequantizationAfter);
70 return std::make_shared<ngraph::Function>(
71 ngraph::ResultVector{ std::make_shared<ngraph::opset1::Result>(deqAfter) },
72 ngraph::ParameterVector{ input },
73 "MoveDequantizationAfterFunction");
76 } // namespace subgraph
77 } // namespace builder