1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "transformations/low_precision/normalize_l2.hpp"
12 #include "ngraph/type/element_type.hpp"
13 #include "ngraph/type/element_type_traits.hpp"
14 #include "transformations/low_precision/network_helper.hpp"
16 using namespace ngraph;
17 using namespace ngraph::pass;
18 using namespace ngraph::pass::low_precision;
20 namespace normalize_l2 {
23 std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Constant& originalConst) {
24 std::vector<T> source = originalConst.cast_vector<T>();
26 std::vector<T> newData(source.size());
27 for (size_t i = 0; i < source.size(); ++i) {
28 newData[i] = source[i] < 0 ? -1 : 1;
31 const ngraph::element::Type type = originalConst.get_output_element_type(0);
32 return ngraph::op::Constant::create(type, originalConst.get_shape(), newData);
35 } // namespace normalize_l2
37 bool NormalizeL2Transformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
38 if (!LayerTransformation::canBeTransformed(context, operation)) {
42 if (!canSubtractBeHandled(operation)) {
46 const std::shared_ptr<Node> multiply = operation->get_input_node_shared_ptr(0);
47 auto scalesConst = as_type_ptr<ngraph::opset1::Constant>(multiply->get_input_node_shared_ptr(1));
48 if (scalesConst == nullptr) {
49 scalesConst = as_type_ptr<ngraph::opset1::Constant>(multiply->get_input_node_shared_ptr(0));
51 if (scalesConst == nullptr) {
55 // TODO: Expand transformation for all cases of axes values
56 const auto axes = as_type_ptr<opset1::Constant>(operation->get_input_node_shared_ptr(1));
57 const std::vector<int64_t> axesAcrossSpatial = { 1 };
58 const std::vector<int64_t> axesByChannels = { 1, 2, 3 };
60 std::vector<int64_t> axesValues = axes->cast_vector<int64_t>();
61 if (!(axesValues == axesAcrossSpatial || axesValues == axesByChannels)) {
65 const ngraph::Shape outputShape = scalesConst->get_output_shape(0);
66 const size_t size = ngraph::shape_size(outputShape);
67 const size_t channels = operation->get_output_shape(0)[1];
69 if (size != channels && size != 1) {
73 if (!NetworkHelper::isScalarLike(scalesConst)) {
80 void NormalizeL2Transformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
84 make_op_pattern<ngraph::opset1::NormalizeL2>({
85 make_op_label<ngraph::opset1::Multiply>(),
86 make_op_label<ngraph::opset1::Constant>()
90 bool NormalizeL2Transformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
91 std::shared_ptr<Node> operation = m.get_match_root();
92 if (!canBeTransformed(context, operation)) {
96 auto normalize = as_type_ptr<opset1::NormalizeL2>(separateInStandaloneBranch(operation));
98 const auto axes = as_type_ptr<opset1::Constant>(normalize->get_input_node_shared_ptr(1));
99 FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(normalize);
100 auto scalesConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1));
101 if (scalesConst == nullptr) {
102 scalesConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(0));
105 std::shared_ptr<opset1::Constant> newScalesConst;
106 const auto type = scalesConst->get_output_element_type(0);
108 case ngraph::element::Type_t::f16: {
109 newScalesConst = normalize_l2::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
112 case ngraph::element::Type_t::f32: {
113 newScalesConst = normalize_l2::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
117 THROW_TRANSFORMATION_EXCEPTION << "unexpected element type " << type;
121 auto newNormalize = std::make_shared<op::TypeRelaxed<opset1::NormalizeL2>>(
122 std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{element::f32},
123 ngraph::op::TemporaryReplaceOutputType(dequantization.subtract == nullptr ? dequantization.data : dequantization.subtract, element::f32).get(),
124 ngraph::op::TemporaryReplaceOutputType(axes->clone_with_new_inputs({}), element::f32).get(),
125 normalize->get_eps(),
126 normalize->get_eps_mode());
127 NetworkHelper::copyInfo(normalize, newNormalize);
129 auto newMultiply = std::make_shared<op::TypeRelaxed<DequantizationMultiply>>(
130 std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{element::f32},
131 ngraph::op::TemporaryReplaceOutputType(newNormalize, element::f32).get(),
132 ngraph::op::TemporaryReplaceOutputType(newScalesConst, element::f32).get());
134 replace_node(normalize, newMultiply);
136 updateOutput(context, newMultiply, normalize);
140 bool NormalizeL2Transformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {