[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / move_dequantization_after_function.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ngraph_functions/low_precision_transformations/move_dequantization_after_function.hpp"
6 #include "low_precision/network_helper.hpp"
7
8 #include <ngraph/opsets/opset1.hpp>
9 #include "ngraph_functions/subgraph_builders.hpp"
10 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
11
12 using namespace ngraph::pass::low_precision;
13
14 namespace ngraph {
15 namespace builder {
16 namespace subgraph {
17     std::shared_ptr<ngraph::Function> MoveDequantizationAfterFunction::getOriginal(
18         const ngraph::element::Type precision,
19         const ngraph::Shape& inputShape,
20         const ngraph::builder::subgraph::DequantizationOperations dequantization) {
21         const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
22
23         const auto deq = makeDequantization(input, dequantization);
24         const auto op = ngraph::opset1::MaxPool(
25             deq,
26             Strides{ 1, 1 },
27             Shape{ 1, 1 },
28             Shape{ 0, 0 },
29             Shape{ 2, 2 },
30             op::RoundingType::FLOOR);
31         const auto targetOp = std::make_shared<op::TypeRelaxed<opset1::MaxPool>>(
32             op,
33             std::vector<element::Type>{ element::f32, element::f32 },
34             std::vector<element::Type>{});
35         auto& rtInfo = targetOp->get_rt_info();
36         rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("targetOp");
37
38         return std::make_shared<ngraph::Function>(
39             ngraph::ResultVector{ std::make_shared<ngraph::opset1::Result>(targetOp) },
40             ngraph::ParameterVector{ input },
41             "MoveDequantizationAfterFunction");
42     }
43
44     std::shared_ptr<ngraph::Function> MoveDequantizationAfterFunction::getReference(
45         const ngraph::element::Type precision,
46         const ngraph::Shape& inputShape,
47         const ngraph::builder::subgraph::DequantizationOperations dequantizationBefore,
48         const ngraph::element::Type precisionAfterOperation,
49         const ngraph::builder::subgraph::DequantizationOperations dequantizationAfter) {
50         const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
51
52         const auto deqBefore = makeDequantization(input, dequantizationBefore);
53         const auto op = ngraph::opset1::MaxPool(
54             deqBefore,
55             Strides{ 1, 1 },
56             Shape{ 1, 1 },
57             Shape{ 0, 0 },
58             Shape{ 2, 2 },
59             op::RoundingType::FLOOR);
60         const auto targetOp = std::make_shared<op::TypeRelaxed<opset1::MaxPool>>(
61             op,
62             std::vector<element::Type>{ element::f32, element::f32 },
63             std::vector<element::Type>{});
64         ngraph::pass::low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(targetOp, precisionAfterOperation);
65         auto& rtInfo = targetOp->get_rt_info();
66         rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("targetOp");
67
68         const auto deqAfter = makeDequantization(targetOp, dequantizationAfter);
69
70         return std::make_shared<ngraph::Function>(
71             ngraph::ResultVector{ std::make_shared<ngraph::opset1::Result>(deqAfter) },
72             ngraph::ParameterVector{ input },
73             "MoveDequantizationAfterFunction");
74     }
75
76 }  // namespace subgraph
77 }  // namespace builder
78 }  // namespace ngraph