1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ngraph_functions/low_precision_transformations/convert_mul_or_add_finally_with_dequantization_function.hpp"
9 #include <ngraph/ngraph.hpp>
12 #include <ngraph/opsets/opset1.hpp>
13 #include "ngraph_functions/subgraph_builders.hpp"
14 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
15 #include "low_precision/network_helper.hpp"
16 #include <legacy/ngraph_ops/scaleshift.hpp>
17 #include "low_precision/common/dequantization_op.hpp"
23 using namespace ngraph::pass;
25 std::shared_ptr<ngraph::Function> ConvertMulOrAddWithDequantizationFunction::getOriginal(
26 const ngraph::Shape& inputShape,
27 const ngraph::element::Type inputPrecision,
28 const std::vector<float>& multiplyConst) {
29 const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
30 const auto reluOriginal = ngraph::opset1::Relu(
31 ngraph::op::TemporaryReplaceOutputType(input, element::f32).get());
33 std::shared_ptr<ngraph::opset1::Relu> relu = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Relu>>(
35 std::vector<element::Type>{ element::f32, element::f32 },
36 std::vector<element::Type>{});
39 const auto multiply = std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(relu,
40 std::make_shared<opset1::Constant>(element::f32, inputShape, multiplyConst));
42 multiply->set_friendly_name("output");
44 ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
45 return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input },
46 "ConvertMulOrAddTransformationWithDequantization");
49 std::shared_ptr<ngraph::Function> ConvertMulOrAddWithDequantizationFunction::getReference(
50 const ngraph::Shape& inputShape,
51 const ngraph::element::Type inputPrecision,
52 const std::vector<float>& multiplyConst) {
53 const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
54 const auto reluOriginal = ngraph::opset1::Relu(
55 ngraph::op::TemporaryReplaceOutputType(input, element::f32).get());
57 std::shared_ptr<ngraph::opset1::Relu> relu = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Relu>>(
59 std::vector<element::Type>{ element::f32, element::f32 },
60 std::vector<element::Type>{});
62 const auto weights = std::make_shared<opset1::Constant>(element::f32, inputShape, multiplyConst);
63 const auto bias = std::make_shared<opset1::Constant>(element::f32, inputShape, 0.0);
64 std::shared_ptr<Node> scaleShift = std::make_shared<ngraph::op::ScaleShiftIE>(relu, weights, bias);
66 scaleShift = low_precision::NetworkHelper::markAsDequantizationOp(scaleShift);
68 scaleShift->set_friendly_name("output");
70 ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(scaleShift) };
71 return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "ConvertMulOrAddTransformationWithDequantization");
74 } // namespace subgraph
75 } // namespace builder