4feed74d98ac9ca9962591e759fc08f5dae3e378
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / include / ngraph_functions / low_precision_transformations / common / builders.hpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <memory>
8 #include <ngraph/ngraph.hpp>
9 #include "ngraph_ops/type_relaxed.hpp"
10
11 #include "low_precision/network_helper.hpp"
12 #include "low_precision/common/dequantization_op.hpp"
13
14 #include "ngraph_functions/low_precision_transformations/common/add.hpp"
15 #include "ngraph_functions/low_precision_transformations/common/fake_quantize_on_data.hpp"
16 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
17
18 namespace ngraph {
19 namespace builder {
20 namespace subgraph {
21
22 template <typename Operation, typename OperationDesc>
23 std::shared_ptr<Node> makeElementwise(const std::shared_ptr<ngraph::Node> data, const OperationDesc& description) {
24     std::vector<size_t> shape;
25     if (description.constantShapeIsDefined) {
26         shape = description.constantShape;
27     } else {
28         if (description.values.size() == 1ul) {
29             shape = std::vector<size_t>({});
30         } else {
31             shape = std::vector<size_t>(data->get_output_shape(0).size(), 1ul);
32             shape[shape.size() >= 2 ? 1ul : 0] = description.values.size();
33         }
34     }
35
36     const auto operationConst = std::make_shared<ngraph::opset1::Constant>(
37         description.outPrecision,
38         shape,
39         description.values);
40
41     std::shared_ptr<Operation> operation;
42     if ((description.outPrecision == element::undefined) || (description.outPrecision == data->get_output_element_type(0))) {
43         operation = std::make_shared<Operation>(data, operationConst);
44     } else {
45         operation = std::make_shared<op::TypeRelaxed<Operation>>(
46             std::vector<element::Type>{element::f32, element::f32}, std::vector<element::Type>{},
47             ngraph::op::TemporaryReplaceOutputType(data, element::f32).get(),
48             ngraph::op::TemporaryReplaceOutputType(operationConst, element::f32).get());
49         ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(operation, description.outPrecision);
50     }
51
52     if (is_type<ngraph::opset1::Subtract>(operation) || is_type<ngraph::opset1::Add>(operation)) {
53         replace_node(
54             operationConst,
55             ngraph::pass::low_precision::fold<ngraph::opset1::Convert>(operationConst, data->get_output_element_type(0)));
56     }
57
58     return operation;
59 }
60
61 std::shared_ptr<Node> makeDequantization(
62     const Output<Node>& data,
63     const DequantizationOperations& dequantizationOperations);
64
65 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantize(
66     const Output<Node>& input,
67     const ngraph::element::Type precision,
68     const FakeQuantizeOnData& fqOnData);
69
70 std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantizeTypeRelaxed(
71     const std::shared_ptr<ngraph::Node>& input,
72     const ngraph::element::Type precision,
73     const FakeQuantizeOnData& fqOnData);
74
75 } // namespace subgraph
76 } // namespace builder
77 } // namespace ngraph