1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "low_precision/prelu.hpp"
11 #include "low_precision/common/ie_lpt_exception.hpp"
12 #include "low_precision/network_helper.hpp"
16 namespace low_precision {
18 void PReluTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
22 make_op_pattern<opset1::PRelu>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::Constant>() }));
25 bool PReluTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) const {
26 std::shared_ptr<Node> prelu = m.get_match_root();
27 if (!canBeTransformed(context, prelu)) {
31 prelu = separateInStandaloneBranch(prelu);
32 const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(prelu, 0);
33 moveDequantizationAfter(context, prelu, dequantization, false, false);
37 bool PReluTransformation::isPrecisionPreserved(std::shared_ptr<Node> op) const noexcept {
41 bool PReluTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
42 const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op, 0);
43 if (dequantization.empty() || (dequantization.subtract != nullptr)) {
47 const std::shared_ptr<opset1::Constant> constant = as_type_ptr<opset1::Constant>(dequantization.multiply->input_value(1).get_node_shared_ptr());
48 const auto scales = constant->cast_vector<float>();
49 if (std::any_of(scales.begin(), scales.end(), [](const float value) { return value < 0.f; })) {
56 } // namespace low_precision