1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "low_precision/normalize_l2.hpp"
12 #include "ngraph/type/element_type.hpp"
13 #include "ngraph/type/element_type_traits.hpp"
14 #include "low_precision/network_helper.hpp"
15 #include "low_precision/common/dequantization_op.hpp"
17 using namespace ngraph;
18 using namespace ngraph::pass;
19 using namespace ngraph::pass::low_precision;
21 namespace normalize_l2 {
24 std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Constant& originalConst) {
25 std::vector<T> source = originalConst.cast_vector<T>();
27 std::vector<T> newData(source.size());
28 for (size_t i = 0; i < source.size(); ++i) {
29 newData[i] = source[i] < 0 ? T{-1} : T{1};
32 const ngraph::element::Type type = originalConst.get_output_element_type(0);
33 return ngraph::op::Constant::create(type, originalConst.get_shape(), newData);
36 } // namespace normalize_l2
38 bool NormalizeL2Transformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
39 if (!LayerTransformation::canBeTransformed(context, operation)) {
43 if (!canSubtractBeHandled(operation)) {
47 const std::shared_ptr<Node> multiply = operation->get_input_node_shared_ptr(0);
48 auto scalesConst = as_type_ptr<ngraph::opset1::Constant>(multiply->get_input_node_shared_ptr(1));
49 if (scalesConst == nullptr) {
50 scalesConst = as_type_ptr<ngraph::opset1::Constant>(multiply->get_input_node_shared_ptr(0));
52 if (scalesConst == nullptr) {
56 // TODO: Expand transformation for all cases of axes values
57 const auto axes = as_type_ptr<opset1::Constant>(operation->get_input_node_shared_ptr(1));
58 const std::vector<int64_t> axesAcrossSpatial = { 1 };
59 const std::vector<int64_t> axesByChannels = { 1, 2, 3 };
61 std::vector<int64_t> axesValues = axes->cast_vector<int64_t>();
62 if (!(axesValues == axesAcrossSpatial || axesValues == axesByChannels)) {
66 const ngraph::Shape outputShape = scalesConst->get_output_shape(0);
67 const size_t size = ngraph::shape_size(outputShape);
68 const size_t channels = operation->get_output_shape(0)[1];
70 if (size != channels && size != 1) {
74 if (!NetworkHelper::isScalarLike(scalesConst)) {
81 void NormalizeL2Transformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
85 make_op_pattern<ngraph::opset1::NormalizeL2>({
86 make_op_label<ngraph::opset1::Multiply>(),
87 make_op_label<ngraph::opset1::Constant>()
91 bool NormalizeL2Transformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
92 std::shared_ptr<Node> operation = m.get_match_root();
93 if (!canBeTransformed(context, operation)) {
97 auto normalize = as_type_ptr<opset1::NormalizeL2>(separateInStandaloneBranch(operation));
99 const auto axes = as_type_ptr<opset1::Constant>(normalize->get_input_node_shared_ptr(1));
100 FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(normalize);
101 auto scalesConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1));
102 if (scalesConst == nullptr) {
103 scalesConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(0));
106 std::shared_ptr<opset1::Constant> newScalesConst;
107 const auto type = scalesConst->get_output_element_type(0);
109 case ngraph::element::Type_t::f16: {
110 newScalesConst = normalize_l2::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
113 case ngraph::element::Type_t::f32: {
114 newScalesConst = normalize_l2::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
118 THROW_TRANSFORMATION_EXCEPTION << "unexpected element type " << type;
122 auto newNormalize = std::make_shared<op::TypeRelaxed<opset1::NormalizeL2>>(
123 std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{element::f32},
124 ngraph::op::TemporaryReplaceOutputType(dequantization.subtract == nullptr ? dequantization.data : dequantization.subtract, element::f32).get(),
125 ngraph::op::TemporaryReplaceOutputType(axes->clone_with_new_inputs({}), element::f32).get(),
126 normalize->get_eps(),
127 normalize->get_eps_mode());
128 NetworkHelper::copyInfo(normalize, newNormalize);
130 auto newMultiply = std::make_shared<op::TypeRelaxed<DequantizationMultiply>>(
131 std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{element::f32},
132 ngraph::op::TemporaryReplaceOutputType(newNormalize, element::f32).get(),
133 ngraph::op::TemporaryReplaceOutputType(newScalesConst, element::f32).get());
135 replace_node(normalize, newMultiply);
137 updateOutput(context, newMultiply, normalize);
141 bool NormalizeL2Transformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {