[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
authorEdward Shogulin <edward.shogulin@intel.com>
Fri, 6 Nov 2020 13:15:27 +0000 (16:15 +0300)
committerGitHub <noreply@github.com>
Fri, 6 Nov 2020 13:15:27 +0000 (16:15 +0300)
* [LPT] Concat transformation generalization

* [LPT] Reshape transformation fix

* [LPT] Legacy callback fix

* [LPT] * added rt_info propagation
      * functional tests: added rt_info
      * functional tests: added MoveDequatnizationAfter tests

Co-authored-by: Vladislav Golubev <vladislav.golubev@intel.com>
36 files changed:
inference-engine/src/low_precision_transformations/include/low_precision/network_helper.hpp
inference-engine/src/low_precision_transformations/src/common/add.cpp
inference-engine/src/low_precision_transformations/src/common/concat.cpp
inference-engine/src/low_precision_transformations/src/common/convolution.cpp
inference-engine/src/low_precision_transformations/src/common/fake_quantize.cpp
inference-engine/src/low_precision_transformations/src/common/fuse_convert.cpp
inference-engine/src/low_precision_transformations/src/common/mat_mul.cpp
inference-engine/src/low_precision_transformations/src/common/mvn.cpp
inference-engine/src/low_precision_transformations/src/common/network_helper.cpp
inference-engine/src/low_precision_transformations/src/common/normalize_l2.cpp
inference-engine/src/low_precision_transformations/src/common/reshape.cpp
inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/concat_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/mat_mul_with_constant_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/max_pool_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/move_dequantization_after_transformation.cpp [new file with mode: 0644]
inference-engine/tests/functional/inference_engine/lp_transformations/normalize_l2_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/relu_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/reshape_transformation.cpp
inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/common/builders.hpp
inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/common/fake_quantize_on_data.hpp
inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/concat_function.hpp
inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/move_dequantization_after_function.hpp [new file with mode: 0644]
inference-engine/tests/ngraph_functions/src/low_precision_transformations/add_function.cpp
inference-engine/tests/ngraph_functions/src/low_precision_transformations/common/builders.cpp
inference-engine/tests/ngraph_functions/src/low_precision_transformations/concat_function.cpp
inference-engine/tests/ngraph_functions/src/low_precision_transformations/convert_mul_or_add_finally_with_dequantization_function.cpp
inference-engine/tests/ngraph_functions/src/low_precision_transformations/convolution_function.cpp
inference-engine/tests/ngraph_functions/src/low_precision_transformations/fake_quantize_function.cpp
inference-engine/tests/ngraph_functions/src/low_precision_transformations/mat_mul_function.cpp
inference-engine/tests/ngraph_functions/src/low_precision_transformations/max_pool_function.cpp
inference-engine/tests/ngraph_functions/src/low_precision_transformations/move_dequantization_after_function.cpp [new file with mode: 0644]
inference-engine/tests/ngraph_functions/src/low_precision_transformations/mul_add_to_scaleshift_or_power_function.cpp
inference-engine/tests/ngraph_functions/src/low_precision_transformations/multiply_function.cpp
inference-engine/tests/ngraph_functions/src/low_precision_transformations/mvn_function.cpp
inference-engine/tests/ngraph_functions/src/low_precision_transformations/normalize_l2_function.cpp

index 306ba73..f27462b 100644 (file)
@@ -160,8 +160,6 @@ public:
     // handles only specific case: Constant -> [dequantization operations] -> [node]
     static void foldDequantization(std::shared_ptr<Node>& node, const size_t branchIndex, const bool inPlace = false);
 
-    static std::shared_ptr<Node> markAsDequantizationOp(std::shared_ptr<Node> op);
-
 private:
     static std::shared_ptr<Node> foldFakeQuantize(const std::shared_ptr<opset1::FakeQuantize>& fq, const bool roundValues, const bool roundValuesWasSet);
 
index 54b2384..ce9d3a8 100644 (file)
@@ -112,7 +112,7 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
         }
 
         newMultiply = NetworkHelper::swapMultiplyAndAdd(add, multiplyBranch.first);
-
+        ngraph::copy_runtime_info({ add, newMultiply }, newMultiply);
         if (is_type<opset1::Add>(newMultiply->get_input_node_shared_ptr(0))) {
             newAddOrSubtract = newMultiply->get_input_node_shared_ptr(0);
 
@@ -186,6 +186,7 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
 
         replace_node(add, newMultiply);
         NetworkHelper::copyInfo(add, newAddOrSubtract);
+        ngraph::copy_runtime_info({ add, newMultiply }, newMultiply);
     }
 
     updateOutput(context, newMultiply, newAddOrSubtract);
index 3aac198..99bf48d 100644 (file)
@@ -261,18 +261,15 @@ void ConcatTransformation::addDequantizationLayers(
 
                         if (layerDequantizations.size() > 1ul) {
                             auto broadcastElementWiseConst = [](
+                                // FakeQuantize constant shape must be broadcastable to the shape on data.
                                 std::shared_ptr<ngraph::opset1::Constant> operation,
                                 const ngraph::Shape targetShape) -> std::shared_ptr<Node> {
-                                auto unsqueeze = ngraph::pass::low_precision::fold<ngraph::opset1::Unsqueeze>(
-                                    operation->shared_from_this(),
-                                    std::make_shared<ngraph::opset1::Constant>(element::i64, ngraph::Shape{ 4 }, std::vector<size_t>{ 0, 1, 2, 3 }));
-
                                 auto targetShapeConst = std::make_shared<ngraph::opset1::Constant>(
                                     element::i64, ngraph::Shape{ targetShape.size() },
                                     targetShape);
 
                                 auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
-                                    unsqueeze,
+                                    operation,
                                     targetShapeConst,
                                     ngraph::op::AutoBroadcastType::NUMPY);
 
@@ -342,6 +339,7 @@ void ConcatTransformation::addDequantizationLayers(
                             std::shared_ptr<ngraph::Node> convert =
                                 convertNodes[0]->clone_with_new_inputs({ destination->get_input_source_output(sourceOutputIdx) });
                             insert_new_node_between(source, destination, convert);
+                            ngraph::copy_runtime_info({ layer, convert }, convert);
                             source = convert;
                         }
 
@@ -354,6 +352,7 @@ void ConcatTransformation::addDequantizationLayers(
                                     subtractNodes[0] :
                                     ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(subtractNodes, 1)));
                             insert_new_node_between(source, destination, subtract);
+                            ngraph::copy_runtime_info({ layer, subtract }, subtract);
                             source = subtract;
                         }
 
@@ -365,6 +364,7 @@ void ConcatTransformation::addDequantizationLayers(
                                     multiplyNodes[0] :
                                     ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(multiplyNodes, 1)));
                             insert_new_node_between(source, destination, multiply);
+                            ngraph::copy_runtime_info({ layer, multiply }, multiply);
                             source = multiply;
                         }
                     }
index 88a2bdd..734dd17 100644 (file)
@@ -234,7 +234,7 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
 
     std::shared_ptr<ngraph::opset1::Multiply> finalDequantization = NetworkHelper::optimizeMultipliesAfter(
         convolution->output(0).get_target_inputs().begin()->get_node()->shared_from_this());
-
+    ngraph::copy_runtime_info({ convolution, finalDequantization }, finalDequantization);
     updateOutput(context, finalDequantization, convolution);
 
     auto onWeights = convolution->get_input_node_shared_ptr(1);
index f0ec5be..e872a2c 100644 (file)
@@ -260,8 +260,9 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
         fakeQuantize->input_value(4) }));
 
     replace_node(fakeQuantize, newFakeQuantize);
-    NetworkHelper::copyInfo(fakeQuantize, newFakeQuantize);
-
+    ngraph::copy_runtime_info({ fakeQuantize, eltwise }, newFakeQuantize);
+    newFakeQuantize->set_friendly_name(fakeQuantize->get_friendly_name());
+    NetworkHelper::cleanRunTimeInfo(newFakeQuantize);
     return newFakeQuantize;
 }
 
index 9195d13..df65e15 100644 (file)
@@ -90,7 +90,8 @@ bool FuseConvertTransformation::transform(TransformationContext& context, ngraph
         }
 
         if (newOp != nullptr) {
-            NetworkHelper::copyInfo(op, newOp);
+            ngraph::copy_runtime_info({ convert, op }, newOp);
+            newOp->set_friendly_name(op->get_friendly_name());
         }
     }
 
index 0f9b29a..5fc9ac4 100644 (file)
@@ -64,6 +64,7 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
         matMul->get_transpose_a(),
         matMul->get_transpose_b());
     NetworkHelper::setOutDataPrecisionForTypeRelaxed(newMatMul, matMul->get_output_element_type(0));
+    NetworkHelper::copyInfo(matMul, newMatMul);
 
     auto transpose = [](const std::shared_ptr<Node>& node) -> std::shared_ptr<Node> {
         const Shape outputShape = node->get_output_shape(0);
@@ -95,6 +96,7 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
                 NetworkHelper::toScalar(as_type_ptr<opset1::Constant>(const1)),
                 const2)));
     replace_node(matMul, newMultiply);
+    ngraph::copy_runtime_info({ newMultiply, matMul }, newMultiply);
 
     updateOutput(context, newMultiply, matMul);
 
index 5998edf..b54abdc 100644 (file)
@@ -115,9 +115,10 @@ bool MVNTransformation::transform(TransformationContext &context, ngraph::patter
                 mvn->get_normalize_variance(),
                 mvn->get_eps()),
         type);
+    NetworkHelper::copyInfo(mvn, newMVN);
 
     auto newMultiply = std::make_shared<DequantizationMultiply>(newMVN, newScalesConst);
-    newMVN->set_friendly_name(mvn->get_friendly_name());
+    ngraph::copy_runtime_info({ mvn, newMultiply }, newMultiply);
 
     replace_node(mvn, newMultiply);
 
index 1acc8d9..6eb5bc1 100644 (file)
@@ -632,7 +632,7 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decompos
             fq->get_levels(),
             fq->get_auto_broadcast()),
         true);
-    newFQ->set_friendly_name(fq->get_friendly_name());
+    NetworkHelper::copyInfo(fq, newFQ);
 
     std::shared_ptr<ngraph::Node> convert2;
     if (updatePrecision) {
@@ -650,10 +650,12 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decompos
 
         convert2 = std::make_shared<DequantizationConvert>(convert, element::f32);
         convert2->set_friendly_name(convert->get_friendly_name() + "/DequantizationConvert");
+        ngraph::copy_runtime_info({ newFQ, convert2 }, convert2);
     } else {
         if (newFQ->get_output_element_type(0) != element::f32) {
             convert2 = std::make_shared<DequantizationConvert>(newFQ, element::f32);
             convert2->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationConvert");
+            ngraph::copy_runtime_info({ newFQ, convert2 }, convert2);
         }
     }
 
@@ -663,12 +665,14 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decompos
         std::make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(convert2 == nullptr ? newFQ : convert2, shift);
     if (sub != nullptr) {
         sub->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationSubtract");
+        ngraph::copy_runtime_info({ newFQ, sub }, sub);
     }
 
     const std::shared_ptr<ngraph::opset1::Multiply> dequantize = std::make_shared<DequantizationMultiply>(
         sub == nullptr ? (convert2 == nullptr ? newFQ : convert2) : sub,
         scale);
     dequantize->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationMultiply");
+    ngraph::copy_runtime_info({ newFQ, dequantize }, dequantize);
 
     replace_node(fq, dequantize);
 
@@ -929,7 +933,7 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
 
     const std::shared_ptr<ngraph::Node> newOperation = operation->clone_with_new_inputs(inputs);
     newOperation->set_friendly_name(operation->get_friendly_name());
-    // copyInfo(operation, newOperation);
+    ngraph::copy_runtime_info(operation, newOperation);
 
     if (updatePrecision) {
         auto op = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(newOperation);
@@ -945,18 +949,22 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
     auto parent = newOperation;
     if (shouldConvert) {
         parent = std::make_shared<DequantizationConvert>(parent, dequantization.convert->get_output_element_type(0));
+        ngraph::copy_runtime_info({ newOperation, parent }, parent);
     }
     if (moveSubtract && (dequantization.subtract != nullptr)) {
         auto subtractConstant = dequantization.subtract->get_input_node_shared_ptr(1);
         parent = std::make_shared<DequantizationSubtract>(parent, subtractConstant);
+        ngraph::copy_runtime_info({ newOperation, parent }, parent);
     }
     if (dequantization.multiply != nullptr) {
         auto multiplyConstant = dequantization.multiply->get_input_node_shared_ptr(1);
         parent = std::make_shared<DequantizationMultiply>(parent, multiplyConstant);
+        ngraph::copy_runtime_info({ newOperation, parent }, parent);
     }
     replace_node(operation, parent);
 
     if ((!moveSubtract) && (dequantization.convert != nullptr) && (dequantization.subtract != nullptr)) {
+        NetworkHelper::cleanRunTimeInfo(dequantization.subtract);
         optimizeSubtract(dequantization.subtract);
     }
 
@@ -1036,13 +1044,6 @@ std::shared_ptr<Node> NetworkHelper::toScalarIfPossible(std::shared_ptr<Node> no
     return NetworkHelper::toScalar(constant);
 }
 
-std::shared_ptr<Node> NetworkHelper::markAsDequantizationOp(std::shared_ptr<Node> op) {
-    auto opCopy = op->clone_with_new_inputs(op->input_values());
-    auto& rtInfo = opCopy->get_rt_info();
-    rtInfo["DEQUANTIZATION"] = std::make_shared<VariantWrapper<DequantizationAttr>>(DequantizationAttr());
-    return opCopy;
-}
-
 }  // namespace low_precision
 }  // namespace pass
 }  // namespace ngraph
index 15999cc..da156c7 100644 (file)
@@ -133,6 +133,7 @@ bool NormalizeL2Transformation::transform(TransformationContext &context, ngraph
         ngraph::op::TemporaryReplaceOutputType(newScalesConst, element::f32).get());
 
     replace_node(normalize, newMultiply);
+    ngraph::copy_runtime_info({ normalize, newMultiply }, newMultiply);
 
     updateOutput(context, newMultiply, normalize);
     return true;
index 58d01b7..ede8fef 100644 (file)
@@ -28,65 +28,109 @@ void ReshapeTransformation::registerMatcherIn(GraphRewrite &pass, Transformation
 void reshapeDequantizationConstant(const std::shared_ptr<opset1::Reshape>& reshape) {
     const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(reshape, 0);
     if (dequantization.multiply->get_input_node_ptr(1)->get_output_shape(0).size() > 1ul) {
+        // Reshape Subtract or Multiply operation Constant.
+        //    1. modify reshape parameters to avoid reshape by spatial dimensions
+        //    2. broadcast element-wise constant if channels are changed
+        //    3. reshape element-wise constant with modified reshape parameters
         auto replaceConstant = [](const std::shared_ptr<opset1::Reshape>& reshape, const std::shared_ptr<Node>& op) {
-            if (reshape->output(0).get_shape().size() == 2ul) {
-                const auto inputShape = reshape->input(0).get_shape();
-
-                Shape shape(inputShape);
-                shape[0] = 1ul;
-
-                const std::shared_ptr<Node> broadcastedConstant = fold<opset1::Broadcast>(
-                    op->get_input_node_shared_ptr(1),
-                    std::make_shared<opset1::Constant>(element::i32, Shape{ shape.size() }, shape));
-
-                const std::shared_ptr<Node> reshapedConstant = fold<opset1::Reshape>(
-                    broadcastedConstant,
-                    reshape->get_input_node_shared_ptr(1),
-                    reshape->get_special_zero());
-
-                replace_node(op->get_input_node_shared_ptr(1), reshapedConstant);
-            } else {
-                // Original Reshape operation is used to update operation Constant.
-                // But original Reshape operation output data shape constant should be changed before reshape.
-
-                // simple broadcast operation Constant shape to shape on activations
-                auto newOperationConstantShape = op->input(1).get_shape();
-                auto const reshapeInputShape = reshape->input(0).get_shape();
-                if ((reshapeInputShape.size() - newOperationConstantShape.size()) == 1ul) {
-                    newOperationConstantShape.insert(newOperationConstantShape.begin(), 1ul);
-                }
-                const std::shared_ptr<opset1::Constant> originalConstant = as_type_ptr<opset1::Constant>(op->get_input_node_shared_ptr(1));
-                const std::shared_ptr<opset1::Constant> newOperationConstant = std::make_shared<opset1::Constant>(
-                    op->input(1).get_element_type(),
-                    newOperationConstantShape,
-                    originalConstant->cast_vector<float>());
-
-                // update Reshape constant
-                const std::vector<int> reshapeConstValues = as_type_ptr<opset1::Constant>(reshape->get_input_node_shared_ptr(1))->cast_vector<int>();
-                std::vector<int> newReshapeConstValues(reshapeConstValues);
-                for (int i = static_cast<int>(newReshapeConstValues.size() - 1); i >= 0; --i) {
-                    if (newOperationConstantShape.size() <= i) {
-                        newReshapeConstValues[i] = 1;
-                    } else if (newOperationConstantShape[i] == 1ul) {
-                        // not used dimension
-                        newReshapeConstValues[i] = 1;
-                    } else {
-                        break;
+            const size_t constantIndex = as_type<ngraph::opset1::Constant>(op->get_input_node_ptr(1)) ? 1 : 0;
+            const Shape constantShape = op->input(constantIndex).get_shape();
+            // reshape for element-wise constant is not required
+            if (constantShape.empty() || (constantShape.size() == 1ul)) {
+                return;
+            }
+
+            // simple broadcast operation Constant shape to shape on activations
+            auto newOperationConstantShape = op->input(1).get_shape();
+            auto const reshapeInputShape = reshape->input(0).get_shape();
+            Shape newOperationConstantBroadcastedShape(reshapeInputShape);
+            newOperationConstantBroadcastedShape[0] = 1ul;
+
+            if ((reshapeInputShape.size() - newOperationConstantShape.size()) == 1ul) {
+                newOperationConstantShape.insert(newOperationConstantShape.begin(), 1ul);
+            }
+            const std::shared_ptr<opset1::Constant> originalConstant = as_type_ptr<opset1::Constant>(op->get_input_node_shared_ptr(1));
+            const std::shared_ptr<opset1::Constant> newOperationConstant = std::make_shared<opset1::Constant>(
+                op->input(1).get_element_type(),
+                newOperationConstantShape,
+                originalConstant->cast_vector<float>());
+
+            // reshape -1 value hanling
+            auto getOverallValue = [](const Shape& shape, const std::vector<int>& reshapeValues, const bool specialZero) -> size_t {
+                size_t overallValue = shape_size(shape);
+                for (size_t i = 0; i < reshapeValues.size(); ++i) {
+                    auto reshapeValue = reshapeValues[i];
+                    if ((reshapeValue == 1ul) || (reshapeValue == -1) || ((reshapeValue == 0ul) && !specialZero)) {
+                        continue;
                     }
-                }
 
-                const std::shared_ptr<opset1::Constant> newReshapedConstant = std::make_shared<opset1::Constant>(
-                    reshape->input(1).get_element_type(),
-                    Shape({ newReshapeConstValues.size() }),
-                    newReshapeConstValues);
+                    if ((reshapeValue == 0ul) && specialZero) {
+                        reshapeValue = shape[i];
+                    }
 
-                const std::shared_ptr<Node> resultConstant = fold<opset1::Reshape>(
-                    newOperationConstant,
-                    newReshapedConstant,
-                    reshape->get_special_zero());
+                    overallValue = overallValue / reshapeValue;
+                }
+                return overallValue;
+            };
+
+            // modify reshape constant for element-wise constant reshape
+            // element-wise constant doesn't have spatial dimensions, as result we should remove spatial dimensions from reshape parameters
+            const std::vector<int> reshapeConstValues = as_type_ptr<opset1::Constant>(reshape->get_input_node_shared_ptr(1))->cast_vector<int>();
+
+            size_t overallValue = 0;
+            for (size_t i = 0; i < reshapeConstValues.size(); ++i) {
+                if (reshapeConstValues[i] == -1) {
+                    overallValue = getOverallValue(
+                        reshapeInputShape,
+                        reshapeConstValues,
+                        as_type_ptr<opset1::Reshape>(reshape)->get_special_zero());
+                    break;
+                }
+            }
 
-                replace_node(op->get_input_node_shared_ptr(1), resultConstant);
+            std::vector<int> newReshapeConstValues(reshapeConstValues);
+            for (int i = static_cast<int>(newReshapeConstValues.size() - 1); i >= 0; --i) {
+                if (newOperationConstantShape.size() <= i) {
+                    // new dimension was added
+                    newReshapeConstValues[i] = 1;
+                } else if (newOperationConstantShape[i] == 1ul) {
+                    // keep the same
+                    newReshapeConstValues[i] = 1;
+                } else if (newReshapeConstValues[i] == -1) {
+                    // modified reshape parameters are different, but value instead '-1' has to be equal as original reshape
+                    newReshapeConstValues[i] = overallValue;
+                }
             }
+
+            const std::shared_ptr<opset1::Constant> newReshapeConstant = std::make_shared<opset1::Constant>(
+                reshape->input(1).get_element_type(),
+                Shape({ newReshapeConstValues.size() }),
+                newReshapeConstValues);
+
+            // if channels are different then broadcast spatial dimensions to reshape channels correctly
+            // limitation which has to be covered by canBeTransformed:
+            //    1. spatial dimensions have to be absent or equal to 1 after reshape
+            //    2. only second dimension can be changed
+
+            const bool shouldBroadcast = (shape_size(newReshapeConstValues) != 1ul) && (reshapeConstValues[1] != 0) &&
+                (((reshapeConstValues[1] != -1) && (constantShape[1] != reshapeConstValues[1])) ||
+                ((reshapeConstValues[1] == -1) && (constantShape[1] != overallValue)));
+
+            const std::shared_ptr<Node> broadcastedConstant = shouldBroadcast ?
+                fold<opset1::Broadcast>(
+                    newOperationConstant,
+                    std::make_shared<opset1::Constant>(
+                        element::i32,
+                        Shape({newOperationConstantBroadcastedShape.size()}),
+                        newOperationConstantBroadcastedShape)) :
+                newOperationConstant;
+
+            const std::shared_ptr<Node> resultConstant = fold<opset1::Reshape>(
+                broadcastedConstant,
+                newReshapeConstant,
+                reshape->get_special_zero());
+
+            replace_node(op->get_input_node_shared_ptr(1), resultConstant);
         };
 
         if (dequantization.subtract != nullptr) {
index 0cc8b2c..2c9f8b2 100644 (file)
@@ -181,7 +181,7 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork, const Config& conf)
     // not legacy actually, but it should be the last transformation in the transformation pipeline
     legacyManager.register_pass<ngraph::pass::UnrollTensorIterator>();
 
-    auto legacyPassConfig = manager.get_pass_config();
+    auto legacyPassConfig = legacyManager.get_pass_config();
     legacyPassConfig->set_callback<ngraph::pass::AddMultiplyFusion>([](const_node_ptr &node) -> bool {
         if (auto mul_op = std::dynamic_pointer_cast<const ngraph::opset1::Multiply>(node)) {
             auto add_op = std::dynamic_pointer_cast<const ngraph::opset1::Add>(mul_op->get_input_node_shared_ptr(0));
index 29af334..42d132d 100644 (file)
@@ -29,8 +29,8 @@ namespace {
 
 class ConcatTransformationActualValues {
 public:
-    ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
-    ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
+    ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize1;
+    ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize2;
 };
 
 inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationActualValues& values) {
@@ -39,8 +39,8 @@ inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationAct
 
 class ConcatTransformationResultValues {
 public:
-    ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
-    ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
+    ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize1;
+    ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize2;
     ngraph::builder::subgraph::DequantizationOperations dequantizationOperations;
 };
 
@@ -86,6 +86,7 @@ public:
             shape,
             testValues.actual.fakeQuantize1,
             testValues.actual.fakeQuantize2);
+
         SimpleLowPrecisionTransformer transform;
         if (testValues.multiChannels) {
             transform.add<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params);
@@ -138,12 +139,40 @@ const std::vector<ConcatTransformationTestValues> testValues = {
         LayerTransformation::createParamsU8I8(),
         false,
         {
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
+            { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
+            { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }
+        },
+        {
+            { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { ngraph::element::f32, {}, { 0.01f } }
+        }
+    },
+    // U8: concat
+    {
+        LayerTransformation::createParamsU8I8(),
+        false,
+        {
+            { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
+            { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} }
+        },
+        {
+            { 256ul, {{1}, {1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { 256ul, {{1}, {1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { ngraph::element::f32, {}, { 0.01f } }
+        }
+    },
+    // U8: concat
+    {
+        LayerTransformation::createParamsU8I8(),
+        false,
+        {
+            { 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
+            { 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}, {0.f}, {2.55f}, {0.f}, {2.55f} }
         },
         {
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { 256ul, {{1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
             { ngraph::element::f32, {}, { 0.01f } }
         }
     },
@@ -152,26 +181,74 @@ const std::vector<ConcatTransformationTestValues> testValues = {
         LayerTransformation::createParamsU8I8(),
         true,
         {
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
-            { 256ul, ngraph::Shape({}), {0.f}, {1.275f}, {0.f}, {1.275f} }
+            { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
+            { 256ul, {}, {0.f}, {1.275f}, {0.f}, {1.275f} }
+        },
+        {
+            { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { 256ul, {}, {0.f}, {1.275f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }
+        }
+    },
+    // U8: concat multi channels
+    {
+        LayerTransformation::createParamsU8I8(),
+        true,
+        {
+            { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {2.55f}, {0.f}, {2.55f} },
+            { 256ul, {{1}, {1}, {1}, {1}}, {0.f}, {1.275f}, {0.f}, {1.275f} }
         },
         {
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
-            { 256ul, ngraph::Shape({}), {0.f}, {1.275f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { 256ul, {{1}, {1}, {}, {}}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { 256ul, {{1}, {1}, {}, {}}, {0.f}, {1.275f}, {0.f}, {255.f}, ngraph::element::u8 },
             { ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }
         }
     },
+    // U8: concat multi channels
+    {
+        LayerTransformation::createParamsU8I8(),
+        true,
+        {
+            {
+                256ul,
+                {{1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}},
+                {0.f, 0.f, 0.f}, {2.55f, 2.55f, 2.55f}, {0.f, 0.f, 0.f}, {2.55f / 1.f, 2.55f / 2.f, 2.55f / 3.f},
+                ngraph::element::f32
+            },
+            {
+                256ul,
+                {{1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}, {1, 3, 1, 1}},
+                {0.f, 0.f, 0.f}, {1.275f, 1.275f, 1.275f}, {0.f, 0.f, 0.f}, {1.275f / 1.f, 1.275f / 2.f, 1.275f / 3.f},
+                ngraph::element::f32
+            }
+        },
+        {
+            {
+                256ul,
+                {{1, 3, 1, 1}, {1, 3, 1, 1}, {}, {}},
+                {0.f, 0.f, 0.f}, {2.55f, 2.55f, 2.55f}, {0.f}, {255.f},
+                ngraph::element::u8
+            },
+            {
+                256ul,
+                {{1, 3, 1, 1}, {1, 3, 1, 1}, {}, {}},
+                {0.f, 0.f, 0.f}, {1.275f, 1.275f, 1.275f}, {0.f}, {255.f},
+                ngraph::element::u8
+            },
+            { ngraph::element::f32, {}, {{ 0.01f / 1.f, 0.01f / 2.f, 0.01f / 3.f, 0.005f / 1.f, 0.005f / 2.f, 0.005f / 3.f }} }
+        }
+    },
     // U8: concat multi channels with subtract
     {
         LayerTransformation::createParamsU8I8(),
         true,
         {
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
-            { 256ul, ngraph::Shape({}), {1.275f}, {2.55f}, {1.275f}, {2.55f} }
+            { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
+            { 256ul, {}, {1.275f}, {2.55f}, {1.275f}, {2.55f} }
         },
         {
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
-            { 256ul, ngraph::Shape({}), {1.275f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { 256ul, {}, {1.275f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
             {
                 ngraph::element::f32,
                 {{ 0.f, 0.f, 0.f, -255.f, -255.f, -255.f }},
@@ -184,12 +261,12 @@ const std::vector<ConcatTransformationTestValues> testValues = {
         LayerTransformation::createParamsI8I8(),
         false,
         {
-            { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
-            { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
+            { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
+            { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
         },
         {
-            { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-128.f}, {127.f}, ngraph::element::i8 },
-            { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-128.f}, {127.f}, ngraph::element::i8 },
+            { 256ul, {}, {-1.28f}, {1.27f}, {-128.f}, {127.f}, ngraph::element::i8 },
+            { 256ul, {}, {-1.28f}, {1.27f}, {-128.f}, {127.f}, ngraph::element::i8 },
             { ngraph::element::f32, {}, { 0.01f } }
         }
     },
@@ -198,12 +275,12 @@ const std::vector<ConcatTransformationTestValues> testValues = {
         LayerTransformation::createParamsU8I8(),
         false,
         {
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
-            { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
+            { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
+            { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
         },
         {
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {85.f}, {255.f}, ngraph::element::u8 },
-            { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {170.f}, ngraph::element::u8 },
+            { 256ul, {}, {0.f}, {2.55f}, {85.f}, {255.f}, ngraph::element::u8 },
+            { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {170.f}, ngraph::element::u8 },
             { ngraph::element::f32, { 85 }, { 0.015f } }
         }
     },
@@ -212,12 +289,12 @@ const std::vector<ConcatTransformationTestValues> testValues = {
         LayerTransformation::createParamsU8I8(),
         true,
         {
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
-            { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
+            { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
+            { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
         },
         {
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
-            { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {255.f}, ngraph::element::u8 },
             { ngraph::element::f32, {{ 0.f, 0.f, 0.f, 128.f, 128.f, 128.f }}, { 0.01f } }
         }
     },
@@ -226,29 +303,29 @@ const std::vector<ConcatTransformationTestValues> testValues = {
         LayerTransformation::createParamsU8I8(),
         false,
         {
-            { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
+            { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
+            { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }
         },
         {
-            { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {170.f}, ngraph::element::u8 },
-            { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {85.f}, {255.f}, ngraph::element::u8 },
+            { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {170.f}, ngraph::element::u8 },
+            { 256ul, {}, {0.f}, {2.55f}, {85.f}, {255.f}, ngraph::element::u8 },
             { ngraph::element::f32, { 85 }, { 0.015f } }
         }
     },
     // real case from ctdet_coco_dlav0_384 model, coverage bad rounding
     {
-            LayerTransformation::createParamsU8I8(),
-            false,
-            {
-                    { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {2.3007815f} },
-                    { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {-3.873046875f}, {3.84375} }
-            },
-            {
-                    { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {128.f}, {204.f}, ngraph::element::u8 },
-                    { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
-                    { ngraph::element::f32, { 128 }, { 0.0302619f } }
-            }
-    },
+        LayerTransformation::createParamsU8I8(),
+        false,
+        {
+            { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {2.3007815f} },
+            { 256ul, {}, {0.f}, {2.55f}, {-3.873046875f}, {3.84375} }
+        },
+        {
+            { 256ul, {}, {-1.28f}, {1.27f}, {128.f}, {204.f}, ngraph::element::u8 },
+            { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8 },
+            { ngraph::element::f32, { 128 }, { 0.0302619f } }
+        }
+    }
 };
 
 const std::vector<ngraph::Shape> shapes = {
index 934adb2..ef3e4af 100644 (file)
@@ -140,11 +140,8 @@ public:
 };
 
 TEST_P(MatMulWithConstantTransformation, CompareFunctions) {
-    InitNodeInfo().run_on_function(actualFunction);
-
     actualFunction->validate_nodes_and_infer_types();
-
-    auto res = compare_functions(referenceFunction, actualFunction, true, true);
+    auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
     ASSERT_TRUE(res.first) << res.second;
 }
 
index 342f138..7629827 100644 (file)
@@ -78,10 +78,8 @@ public:
 };
 
 TEST_P(MaxPoolTransformation, CompareFunctions) {
-    InitNodeInfo().run_on_function(actualFunction);
     actualFunction->validate_nodes_and_infer_types();
-
-    auto res = compare_functions(referenceFunction, actualFunction, true, true);
+    auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
     ASSERT_TRUE(res.first) << res.second;
 }
 
diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/move_dequantization_after_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/move_dequantization_after_transformation.cpp
new file mode 100644 (file)
index 0000000..f528121
--- /dev/null
@@ -0,0 +1,265 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "layer_transformation.hpp"
+
+#include <string>
+#include <sstream>
+#include <memory>
+
+#include <gtest/gtest.h>
+
+#include <utility>
+#include <transformations/utils/utils.hpp>
+#include <transformations/init_node_info.hpp>
+#include <low_precision/network_helper.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+#include "ngraph_functions/low_precision_transformations/move_dequantization_after_function.hpp"
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
+
+using namespace testing;
+using namespace ngraph::pass;
+using namespace ngraph::builder::subgraph;
+
+class MoveDequantizationAfterTransformationParams {
+public:
+    class Actual {
+    public:
+        ngraph::builder::subgraph::DequantizationOperations dequantization;
+    };
+
+    class Expected {
+    public:
+        ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
+        ngraph::element::Type precisionAfterOperation;
+        ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
+    };
+
+    ngraph::element::Type originalPrecision;
+    ngraph::pass::low_precision::LayerTransformation::Params params;
+    bool updatePrecision;
+    bool moveSubtract;
+    Actual actual;
+    Expected expected;
+};
+
+typedef std::tuple<
+    ngraph::Shape,
+    MoveDequantizationAfterTransformationParams> MoveDequantizationAfterTransformationTestValues;
+
+class MoveDequantizationAfterTransformation :
+    public LayerTransformation,
+    public testing::WithParamInterface<MoveDequantizationAfterTransformationTestValues> {
+public:
+    void SetUp() override {
+        const auto inputShape = std::get<0>(GetParam());
+        const auto testValues = std::get<1>(GetParam());
+        actualFunction = ngraph::builder::subgraph::MoveDequantizationAfterFunction::getOriginal(
+            testValues.originalPrecision,
+            inputShape,
+            testValues.actual.dequantization);
+
+        const auto targetNode = actualFunction->get_output_op(0)->get_input_node_shared_ptr(0);
+        const auto dequantization = ngraph::pass::low_precision::NetworkHelper::getDequantization(targetNode);
+        ngraph::pass::low_precision::NetworkHelper::moveDequantizationAfter(
+            targetNode,
+            dequantization,
+            testValues.updatePrecision,
+            testValues.moveSubtract);
+
+        referenceFunction = ngraph::builder::subgraph::MoveDequantizationAfterFunction::getReference(
+            testValues.originalPrecision,
+            inputShape,
+            testValues.expected.dequantizationBefore,
+            testValues.expected.precisionAfterOperation,
+            testValues.expected.dequantizationAfter);
+    }
+
+    static std::string getTestCaseName(testing::TestParamInfo<MoveDequantizationAfterTransformationTestValues> obj) {
+        const auto inputShape = std::get<0>(obj.param);
+        const auto testValues = std::get<1>(obj.param);
+
+        std::ostringstream result;
+        result <<
+            testValues.originalPrecision << "_" <<
+            inputShape << "_" <<
+            testValues.actual.dequantization << "_" <<
+            (testValues.moveSubtract ? "move_subtract_" : "don't_move_subtract_") <<
+            (testValues.updatePrecision ? "updatePrecision" : "don't_update_precision");
+        return result.str();
+    }
+};
+
+TEST_P(MoveDequantizationAfterTransformation, CompareFunctions) {
+    actualFunction->validate_nodes_and_infer_types();
+    auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+const std::vector<ngraph::Shape> inputShapes = {
+    { 1, 3, 16, 16 },
+    { 4, 3, 16, 16 }
+};
+
+const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
+    // U8
+    {
+        ngraph::element::u8,
+        LayerTransformation::createParamsU8I8(),
+        true,
+        true,
+        {
+            { {ngraph::element::f32},  { 7.f }, { 10.f } },
+        },
+        {
+            { {},  {}, {} },
+            ngraph::element::u8,
+            { {ngraph::element::f32},  { 7.f }, { 10.f } },
+        },
+    },
+    // moveSubtract = false
+    {
+        ngraph::element::u8,
+        LayerTransformation::createParamsU8I8(),
+        true,
+        false,
+        {
+            { {ngraph::element::f32},  { 7.f }, { 10.f } },
+        },
+        {
+            { {},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
+            ngraph::element::f32,
+            { {},  {}, { 10.f } },
+        },
+    },
+    // updatePrecision = false
+    {
+        ngraph::element::u8,
+        LayerTransformation::createParamsU8I8(),
+        false,
+        true,
+        {
+            { {ngraph::element::f32},  { 7.f }, { 10.f } },
+        },
+        {
+            { {},  {}, {} },
+            ngraph::element::f32,
+            { {},  { 7.f }, { 10.f } },
+        },
+    },
+    // moveSubtract = false & updatePrecision = false
+    {
+        ngraph::element::u8,
+        LayerTransformation::createParamsU8I8(),
+        false,
+        false,
+        {
+            { {ngraph::element::f32},  { 7.f }, { 10.f } },
+        },
+        {
+            { {},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
+            ngraph::element::f32,
+            { {},  {}, { 10.f } },
+        },
+    },
+    // I8
+    {
+        ngraph::element::i8,
+        LayerTransformation::createParamsI8I8(),
+        true,
+        true,
+        {
+            { {ngraph::element::f32},  { 7.f }, { 10.f } },
+        },
+        {
+            { {},  {}, {} },
+            ngraph::element::i8,
+            { {ngraph::element::f32},  { 7.f }, { 10.f } },
+        },
+    },
+    // moveSubtract = false
+    {
+        ngraph::element::i8,
+        LayerTransformation::createParamsI8I8(),
+        true,
+        false,
+        {
+            { {ngraph::element::f32},  { 7.f }, { 10.f } },
+        },
+        {
+            { {},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
+            ngraph::element::f32,
+            { {},  {}, { 10.f } },
+        },
+    },
+    // updatePrecision = false
+    {
+        ngraph::element::i8,
+        LayerTransformation::createParamsI8I8(),
+        false,
+        true,
+        {
+            { {ngraph::element::f32},  { 7.f }, { 10.f } },
+        },
+        {
+            { {},  {}, {} },
+            ngraph::element::f32,
+            { {},  { 7.f }, { 10.f } },
+        },
+    },
+    // moveSubtract = false & updatePrecision = false
+    {
+        ngraph::element::i8,
+        LayerTransformation::createParamsI8I8(),
+        false,
+        false,
+        {
+            { {ngraph::element::f32},  { 7.f }, { 10.f } },
+        },
+        {
+            { {},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
+            ngraph::element::f32,
+            { {},  {}, { 10.f } },
+        },
+    },
+    // per-channel quantizations with the same values
+    {
+        ngraph::element::u8,
+        LayerTransformation::createParamsU8I8(),
+        false,
+        false,
+        {
+            { {ngraph::element::f32},  { { 7.f, 7.f, 7.f } }, { { 10.f, 10.f, 10.f } } },
+        },
+        {
+            { {},  { { 7.f, 7.f, 7.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
+            ngraph::element::f32,
+            { {},  {}, { { 10.f, 10.f, 10.f } } },
+        },
+    },
+    // per-channel quantizations with the same values
+    {
+        ngraph::element::u8,
+        LayerTransformation::createParamsU8I8(),
+        false,
+        false,
+        {
+            { {ngraph::element::f32},  { { 7.f, 8.f, 9.f } }, { { 10.f, 12.f, 16.f } } },
+        },
+        {
+            { {},  { { 7.f, 8.f, 9.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
+            ngraph::element::f32,
+            { {},  {}, { { 10.f, 12.f, 16.f } } },
+        },
+    },
+};
+
+INSTANTIATE_TEST_CASE_P(
+    LPT,
+    MoveDequantizationAfterTransformation,
+    ::testing::Combine(
+        ::testing::ValuesIn(inputShapes),
+        ::testing::ValuesIn(testValues)),
+    MoveDequantizationAfterTransformation::getTestCaseName);
index 0ca7df8..e9efaef 100644 (file)
@@ -48,7 +48,6 @@ public:
             shape,
             epsMode,
             params.actual);
-
         SimpleLowPrecisionTransformer transform;
         transform.add<low_precision::NormalizeL2Transformation, ngraph::opset1::NormalizeL2>(
             low_precision::LayerTransformation::Params(params.transformationParams));
index ce9d927..0259edd 100644 (file)
@@ -166,7 +166,7 @@ const std::vector<ReluTransformationTestValues> testValues = {
         },
         {
             ngraph::element::u8,
-            {{}, { {128}, ngraph::element::f32 }, {}},
+            {{}, { {128}, ngraph::element::f32, {}, false }, {}},
             ngraph::element::f32,
             {{}, {}, {0.1f}}
         }
@@ -181,7 +181,7 @@ const std::vector<ReluTransformationTestValues> testValues = {
         },
         {
             ngraph::element::i8,
-            {{}, { {127}, ngraph::element::f32 }, {}},
+            {{}, { {127}, ngraph::element::f32, {}, false }, {}},
             ngraph::element::f32,
             {{}, {}, {0.1f}}
         }
index 6e652e4..9e48de1 100644 (file)
@@ -116,6 +116,22 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
             {{ngraph::element::f32}, {}, {0.1f}}
         }
     },
+    // U8: no subtract 3D -> 4D: channels are not affected
+    {
+        ngraph::Shape({ 4, 384, 1024 }),
+        { 4, 384, 16, 64},
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            {{ngraph::element::f32}, {}, {0.1f}}
+        },
+        {
+            ngraph::element::u8,
+            {{}, {}, {}},
+            ngraph::element::u8,
+            {{ngraph::element::f32}, {}, {0.1f}}
+        }
+    },
     // U8: no subtract 3D -> 4D: channels are not affected: no subtract
     {
         ngraph::Shape({ 1, 3, 20 }),
@@ -132,6 +148,22 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
             {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
         }
     },
+    // U8: no subtract 3D -> 4D: channels are not affected: no subtract
+    {
+        ngraph::Shape({ 4, 3, 20 }),
+        { 4, 3, 4, 5},
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}}
+        },
+        {
+            ngraph::element::u8,
+            {{}, {}, {}},
+            ngraph::element::u8,
+            {{ngraph::element::f32}, {}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
+        }
+    },
     // U8: no subtract 3D -> 4D: channels are not affected: with subtract
     {
         ngraph::Shape({ 1, 3, 20 }),
@@ -156,7 +188,31 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
             }
         }
     },
-    // U8: no subtract 4D -> 3D: channels are not affected: no subtract
+    // U8: no subtract 3D -> 4D: channels are not affected: with subtract
+    {
+        ngraph::Shape({ 1, 3, 20 }),
+        { 1, -1, 4, 5},
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            {
+                {ngraph::element::f32},
+                {{32, 64, 128}, ngraph::element::f32, {1, 3, 1}},
+                {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}
+            }
+        },
+        {
+            ngraph::element::u8,
+            {{}, {}, {}},
+            ngraph::element::u8,
+            {
+                {ngraph::element::f32},
+                {{32, 64, 128}, ngraph::element::f32, {1, 3, 1, 1}},
+                {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}
+            }
+        }
+    },
+    // U8: no subtract 4D -> 6D: channels are not affected: no subtract
     {
         ngraph::Shape({ 1, 3, 4, 5 }),
         { 1, 3, 20, 1, 1, 1},
@@ -172,7 +228,7 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
             {}
         }
     },
-    // U8: no subtract 4D -> 3D: channels are not affected: with subtract
+    // U8: no subtract 4D -> 6D: channels are not affected: with subtract
     {
         ngraph::Shape({ 1, 3, 4, 5 }),
         { 1, 3, 20, 1, 1, 1},
@@ -277,7 +333,7 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
             {}
         }
     },
-    // U8: no subtract 4D -> 2D: channels are not affected: no subtract
+    // U8: no subtract 4D -> 6D: channels are not affected: no subtract
     {
         ngraph::Shape({ 1, 3, 1, 1 }),
         { 1, 3, 1, 1, 1, 1 },
@@ -293,12 +349,12 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
             {}
         }
     },
-    // U8: no subtract 2D -> 4D: channels are not affected: per tensor quantization
+    // U8: no subtract 4D -> 2D: channels are not affected: per tensor quantization
     // TODO: story 38439
     {
         ngraph::Shape({ 1, 3, 4, 5 }),
         { 0, -1 },
-            LayerTransformation::createParamsU8I8(),
+        LayerTransformation::createParamsU8I8(),
         {
             ngraph::element::u8,
             {{ngraph::element::f32}, {{128.f}, ngraph::element::f32, {}}, {{0.1f}, ngraph::element::f32, {}}}
@@ -310,7 +366,7 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
             {{ngraph::element::f32}, {{128.f}, ngraph::element::f32, {}}, {{0.1f}, ngraph::element::f32, {}}}
         }
     },
-    // U8: no subtract 2D -> 4D: channels are not affected: per tensor quantization
+    // U8: no subtract 4D -> 2D: channels are not affected: per tensor quantization
     {
         ngraph::Shape({ 1, 3, 2, 2 }),
         { 0, -1 },
@@ -326,10 +382,31 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
             {
                 {ngraph::element::f32},
                 {{0.f, 0.f, 0.f, 0.f, 128.f, 128.f, 128.f, 128.f, 255.f, 255.f, 255.f, 255.f}, ngraph::element::f32, {1, 12}},
-                {{0.1f, 0.1f, 0.1f, 0.1f, 0.2f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f, 0.3f, 0.3f}, ngraph::element::f32, {1, 12}}}
+                {{0.1f, 0.1f, 0.1f, 0.1f, 0.2f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f, 0.3f, 0.3f}, ngraph::element::f32, {1, 12}}
+            }
+        }
+    },
+    // U8: no subtract 4D -> 2D: channels are not affected: per tensor quantization
+    {
+        ngraph::Shape({ 4, 3, 2, 2 }),
+        { 0, -1 },
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1, 1}}}
+        },
+        {
+            ngraph::element::u8,
+            {},
+            ngraph::element::u8,
+            {
+                {ngraph::element::f32},
+                {{0.f, 0.f, 0.f, 0.f, 128.f, 128.f, 128.f, 128.f, 255.f, 255.f, 255.f, 255.f}, ngraph::element::f32, {1, 12}},
+                {{0.1f, 0.1f, 0.1f, 0.1f, 0.2f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f, 0.3f, 0.3f}, ngraph::element::f32, {1, 12}}
+            }
         }
     },
-    // U8: no subtract 2D -> 4D: channels are not affected: per channel quantization: case #1: dequantization operation constant needs broadcast
+    // U8: no subtract 4D -> 2D: channels are not affected: per channel quantization: case #1: dequantization operation constant needs broadcast
     {
         ngraph::Shape({ 1, 3, 1, 1 }),
         { 0, -1 },
@@ -345,7 +422,7 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3}}},
         }
     },
-    // U8: no subtract 2D -> 4D: channels are not affected: per channel quantization: case #2: dequantization operation constant doesn't need broadcast
+    // U8: no subtract 4D -> 2D: channels are not affected: per channel quantization: case #2: dequantization operation constant doesn't need broadcast
     {
         ngraph::Shape({ 1, 3, 1, 1 }),
         { 0, -1 },
@@ -361,7 +438,7 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3}}},
         }
     },
-    // U8: no subtract 2D -> 4D: channels are affected: per tensor quantization: case #1: dequantization operation constant needs broadcast
+    // U8: no subtract 4D -> 3D: channels are affected: per tensor quantization: case #1: dequantization operation constant needs broadcast
     {
         ngraph::Shape({ 1, 3, 4, 5 }),
         { 0, 0, -1 },
@@ -377,7 +454,7 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}},
         }
     },
-    // U8: no subtract 2D -> 4D: channels are affected: per tensor quantization: case #2: dequantization operation constant doesn't need broadcast
+    // U8: no subtract 4D -> 3D: channels are affected: per tensor quantization: case #2: dequantization operation constant doesn't need broadcast
     {
         ngraph::Shape({ 1, 3, 4, 5 }),
         { 0, 0, -1 },
@@ -393,6 +470,70 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
             {{ngraph::element::f32}, {{0.f, 128.f, 255.f}, ngraph::element::f32, {1, 3, 1}}, {{0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 3, 1}}},
         }
     },
+    // U8: no subtract 4D -> 2D
+    {
+        ngraph::Shape({ 1, 2048, 1, 1 }),
+        { 1, -1 },
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {}}}
+        },
+        {
+            ngraph::element::u8,
+            {{}, {}, {}},
+            ngraph::element::u8,
+            {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {}}}
+        }
+    },
+    // U8: no subtract 4D -> 2D
+    {
+        ngraph::Shape({ 2, 2048, 1, 1 }),
+        { 2, -1 },
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1ul}}}
+        },
+        {
+            ngraph::element::u8,
+            {{}, {}, {}},
+            ngraph::element::u8,
+            {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1ul}}}
+        }
+    },
+    // U8: no subtract 4D -> 2D
+    {
+        ngraph::Shape({ 1, 2048, 1, 1 }),
+        { 1, -1 },
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1, 1, 1}}}
+        },
+        {
+            ngraph::element::u8,
+            {{}, {}, {}},
+            ngraph::element::u8,
+            {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1}}}
+        }
+    },
+    // U8: no subtract 4D -> 2D: channels are not affected
+    {
+        ngraph::Shape({ 2, 2048, 1, 1 }),
+        { 2, -1},
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1, 1, 1}}}
+        },
+        {
+            ngraph::element::u8,
+            {{}, {}, {}},
+            ngraph::element::u8,
+            {{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1}}}
+        }
+    }
 };
 
 TEST_P(ReshapeTransformation, CompareFunctions) {
index 4feed74..80b2a29 100644 (file)
@@ -72,6 +72,18 @@ std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantizeTypeRelaxed(
     const ngraph::element::Type precision,
     const FakeQuantizeOnData& fqOnData);
 
+std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantize(
+    const Output<Node>& input,
+    const ngraph::element::Type precision,
+    const FakeQuantizeOnDataWithConstant& fqOnData);
+
+std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantizeTypeRelaxed(
+    const std::shared_ptr<ngraph::Node>& input,
+    const ngraph::element::Type precision,
+    const FakeQuantizeOnDataWithConstant& fqOnData);
+
+std::shared_ptr<Node> addDequantizationAttribute(const std::shared_ptr<Node>& op);
+
 } // namespace subgraph
 } // namespace builder
 } // namespace ngraph
index e8b5fb6..1a69e39 100644 (file)
@@ -58,6 +58,25 @@ inline std::ostream& operator<<(std::ostream& out, const FakeQuantizeOnData& dat
         (data.outputPrecision == ngraph::element::undefined ? "" : data.outputPrecision.get_type_name());
 }
 
+class FakeQuantizeOnDataWithConstant {
+public:
+    size_t quantizationLevel;
+    std::vector<ngraph::Shape> constantShapes;
+    std::vector<float> inputLowValues;
+    std::vector<float> inputHighValues;
+    std::vector<float> outputLowValues;
+    std::vector<float> outputHighValues;
+    ngraph::element::Type outputPrecision;
+};
+
+inline std::ostream& operator<<(std::ostream& out, const FakeQuantizeOnDataWithConstant& data) {
+    return out <<  "_" << data.quantizationLevel <<
+        (data.constantShapes.empty() ? ngraph::Shape{} : data.constantShapes[0]) << "_" <<
+        data.inputLowValues << "_" << data.inputHighValues << "_" <<
+        data.outputLowValues << "_" << data.outputHighValues << "_" <<
+        (data.outputPrecision == ngraph::element::undefined ? "" : data.outputPrecision.get_type_name());
+}
+
 }  // namespace subgraph
 }  // namespace builder
 }  // namespace ngraph
index 72d6f6e..5e178d6 100644 (file)
@@ -23,6 +23,12 @@ public:
         const FakeQuantizeOnData& fakeQuantize1,
         const FakeQuantizeOnData& fakeQuantize2);
 
+    static std::shared_ptr<ngraph::Function> getOriginal(
+        const ngraph::element::Type precision,
+        const ngraph::Shape& inputShape,
+        const FakeQuantizeOnDataWithConstant& fakeQuantize1,
+        const FakeQuantizeOnDataWithConstant& fakeQuantize2);
+
     static std::shared_ptr<ngraph::Function> getOriginalWithNeighbors(
         const ngraph::element::Type precision,
         const ngraph::Shape& inputShape,
@@ -70,6 +76,13 @@ public:
         const FakeQuantizeOnData& fakeQuantize2,
         const DequantizationOperations& dequantizationOperations);
 
+    static std::shared_ptr<ngraph::Function> getReference(
+        const ngraph::element::Type precision,
+        const ngraph::Shape& inputShape,
+        const FakeQuantizeOnDataWithConstant& fakeQuantize1,
+        const FakeQuantizeOnDataWithConstant& fakeQuantize2,
+        const DequantizationOperations& dequantizationOperations);
+
     static std::shared_ptr<ngraph::Function> getReferenceWithNeighbors(
         const ngraph::element::Type precision,
         const ngraph::Shape& inputShape,
diff --git a/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/move_dequantization_after_function.hpp b/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/move_dequantization_after_function.hpp
new file mode 100644 (file)
index 0000000..2296771
--- /dev/null
@@ -0,0 +1,34 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+#include <ngraph/ngraph.hpp>
+
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
+#include "ngraph_functions/subgraph_builders.hpp"
+
+namespace ngraph {
+namespace builder {
+namespace subgraph {
+
+class MoveDequantizationAfterFunction {
+public:
+    static std::shared_ptr<ngraph::Function> getOriginal(
+        const ngraph::element::Type precision,
+        const ngraph::Shape& inputShape,
+        const ngraph::builder::subgraph::DequantizationOperations dequantization);
+
+    static std::shared_ptr<ngraph::Function> getReference(
+        const ngraph::element::Type precision,
+        const ngraph::Shape& inputShape,
+        const ngraph::builder::subgraph::DequantizationOperations dequantizationBefore,
+        const ngraph::element::Type precisionAfterOperation,
+        const ngraph::builder::subgraph::DequantizationOperations dequantizationAfter);
+};
+
+}  // namespace subgraph
+}  // namespace builder
+}  // namespace ngraph
index bd89cef..def311c 100644 (file)
@@ -90,8 +90,9 @@ std::shared_ptr<ngraph::Function> AddFunction::getOriginal(
     const auto dequantizationOp2 = is_type<ngraph::opset1::Constant>(parent) ? parent : makeDequantization(parent, dequantization2);
 
     const auto add = std::make_shared<ngraph::opset1::Add>(dequantizationOp1, dequantizationOp2);
-
     add->set_friendly_name("output");
+    auto& rtInfo = add->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("add");
 
     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(add) };
     ngraph::ParameterVector parameters;
@@ -196,6 +197,8 @@ std::shared_ptr<ngraph::Function> AddFunction::getReference(
             ngraph::op::TemporaryReplaceOutputType(dequantizationOp2, element::f32).get());
 
     NetworkHelper::setOutDataPrecisionForTypeRelaxed(add, precision);
+    auto& rtInfo = add->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("add");
 
     const auto dequantizationOpAfter = makeDequantization(add, dequantizationAfter);
 
index db9dd07..294c4d0 100644 (file)
@@ -25,6 +25,7 @@ std::shared_ptr<Node> makeDequantization(
         std::shared_ptr<ngraph::opset1::Convert> convert = std::make_shared<ngraph::pass::low_precision::DequantizationConvert>(
             data,
             dequantizationOperations.convert.outPrecision);
+        ngraph::copy_runtime_info({ data.get_node_shared_ptr(), convert }, convert);
         parent = convert;
     }
 
@@ -64,6 +65,7 @@ std::shared_ptr<Node> makeDequantization(
         if (!dequantizationOperations.subtract.addDequantizationAttribute) {
             ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(subtract);
         }
+        ngraph::copy_runtime_info({ data.get_node_shared_ptr(), subtract }, subtract);
         parent = subtract;
     }
 
@@ -111,7 +113,7 @@ std::shared_ptr<Node> makeDequantization(
                     ngraph::op::TemporaryReplaceOutputType(constant, element::f32).get(),
                     ngraph::op::TemporaryReplaceOutputType(parent, element::f32).get());
         }
-
+        ngraph::copy_runtime_info({ data.get_node_shared_ptr(), multiply }, multiply);
         parent = multiply;
     }
 
@@ -141,6 +143,52 @@ std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantizeTypeRelaxed(
     return std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::FakeQuantize>>(*fq, fqOnData.outputPrecision);
 }
 
+std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantize(
+    const Output<Node>& input,
+    const ngraph::element::Type precision,
+    const FakeQuantizeOnDataWithConstant& fqOnData) {
+    const auto inputLowNode = ngraph::builder::makeConstant(
+        precision,
+        fqOnData.constantShapes.empty() ? ngraph::Shape{} : fqOnData.constantShapes[0],
+        fqOnData.inputLowValues,
+        fqOnData.inputLowValues.empty());
+
+    const auto inputHighNode = ngraph::builder::makeConstant(
+        precision,
+        fqOnData.constantShapes.empty() ? ngraph::Shape{} : fqOnData.constantShapes[1],
+        fqOnData.inputHighValues,
+        fqOnData.inputHighValues.empty());
+
+    const auto outputLowNode = ngraph::builder::makeConstant(
+        precision,
+        fqOnData.constantShapes.empty() ? ngraph::Shape{} : fqOnData.constantShapes[2],
+        fqOnData.outputLowValues,
+        fqOnData.outputLowValues.empty());
+
+    const auto outputHighNode = ngraph::builder::makeConstant(
+        precision,
+        fqOnData.constantShapes.empty() ? ngraph::Shape{} : fqOnData.constantShapes[3],
+        fqOnData.outputHighValues,
+        fqOnData.outputHighValues.empty());
+
+    auto fq = std::make_shared<ngraph::opset1::FakeQuantize>(input, inputLowNode, inputHighNode, outputLowNode, outputHighNode, fqOnData.quantizationLevel);
+    return fq;
+}
+
+std::shared_ptr<ngraph::opset1::FakeQuantize> makeFakeQuantizeTypeRelaxed(
+    const std::shared_ptr<ngraph::Node>& input,
+    const ngraph::element::Type precision,
+    const FakeQuantizeOnDataWithConstant& fqOnData) {
+    const std::shared_ptr<ngraph::opset1::FakeQuantize> fq = makeFakeQuantize(input, precision, fqOnData);
+    return std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::FakeQuantize>>(*fq, fqOnData.outputPrecision);
+}
+
+std::shared_ptr<Node> addDequantizationAttribute(const std::shared_ptr<Node>& op) {
+    auto& rtInfo = op->get_rt_info();
+    rtInfo["DEQUANTIZATION"] = std::make_shared<VariantWrapper<DequantizationAttr>>(DequantizationAttr());
+    return op;
+}
+
 } // namespace subgraph
 } // namespace builder
 } // namespace ngraph
index d37a65b..ffdc619 100644 (file)
@@ -36,6 +36,38 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginal(
     const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
         ngraph::OutputVector{ fakeQuantize1->output(0), fakeQuantize2->output(0) }, 1);
     concat->set_friendly_name("output");
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
+
+    ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(concat) };
+    std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
+        results,
+        ngraph::ParameterVector{ input1, input2 },
+        "ConcatTransformation");
+
+    return function;
+}
+
+std::shared_ptr<ngraph::Function> ConcatFunction::getOriginal(
+    const ngraph::element::Type precision,
+    const ngraph::Shape& inputShape,
+    const FakeQuantizeOnDataWithConstant& fqOnData1,
+    const FakeQuantizeOnDataWithConstant& fqOnData2) {
+    const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
+    input1->set_friendly_name("input1");
+    const auto fakeQuantize1 = makeFakeQuantize(input1, precision, fqOnData1);
+
+    const std::vector<size_t> inputShape2 = inputShape;
+    const auto input2 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape2));
+    input2->set_friendly_name("input2");
+    const auto fakeQuantize2 = makeFakeQuantize(input2, precision, fqOnData2);
+
+    const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
+        ngraph::OutputVector{ fakeQuantize1->output(0), fakeQuantize2->output(0) }, 1);
+    concat->set_friendly_name("output");
+
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
 
     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(concat) };
     std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
@@ -72,11 +104,17 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithNeighbors(
         1ull);
     concat1->set_friendly_name("concat1");
 
+    auto& rtInfo1 = concat1->get_rt_info();
+    rtInfo1["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat1");
+
     const auto concat2 = std::make_shared<ngraph::opset1::Concat>(
         ngraph::OutputVector { fakeQuantize2->output(0), fakeQuantize3->output(0) },
         1ull);
     concat2->set_friendly_name("concat2");
 
+    auto& rtInfo2 = concat2->get_rt_info();
+    rtInfo2["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat2");
+
     const ngraph::ResultVector results {
         std::make_shared<ngraph::opset1::Result>(concat1),
         std::make_shared<ngraph::opset1::Result>(concat2)
@@ -153,6 +191,8 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithIntermediate(
         ngraph::OutputVector{ fakeQuantize1->output(0), intermediateOp->output(0) }, 1);
     concat->set_friendly_name("concat");
 
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
 
     auto weights = ngraph::opset1::Constant::create(precision, ngraph::Shape{ inputShape[1], inputShape[1], 1, 1 }, { 1 });
     auto convolution = std::make_shared<ngraph::opset1::Convolution>(
@@ -216,6 +256,9 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithSplitedIntermed
         ngraph::OutputVector{ fakeQuantize1->output(0), intermediateOp->output(0) }, splitedAxis);
     concat->set_friendly_name("concat");
 
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
+
     auto weights = ngraph::opset1::Constant::create(precision, ngraph::Shape{ inputShape[1] / numSplit, inputShape[1] / numSplit, 1, 1 }, { 1 });
     auto convolution = std::make_shared<ngraph::opset1::Convolution>(
         intermediateOp->output(1),
@@ -302,6 +345,8 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalSelectionWithInterm
         ngraph::OutputVector{ fakeQuantize1->output(0), intermediateOp->output(0) }, 1);
     concat->set_friendly_name("concat");
 
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
 
     auto weights = ngraph::opset1::Constant::create(precision, ngraph::Shape{ inputShape[1], inputShape[1], 1, 1 }, { 1 });
     auto convolution = std::make_shared<ngraph::opset1::Convolution>(
@@ -343,6 +388,9 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithDifferentPrecis
     const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
         ngraph::OutputVector{ fakeQuantize1->output(0), fakeQuantize2->output(0) }, 1);
 
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
+
     const std::vector<size_t> kernel = { 3, 3 };
     const std::vector<size_t> stride = { 1, 1 };
     const std::vector<size_t> padBegin = { 0, 0 };
@@ -439,6 +487,9 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithIntermediateWit
         ngraph::OutputVector{ fakeQuantize2->output(0), intermediateOp->output(0) }, 1);
     concat->set_friendly_name("concat");
 
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
+
     ngraph::ResultVector results{
         std::make_shared<ngraph::opset1::Result>(concat),
     };
@@ -468,6 +519,58 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReference(
 
     const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Concat>>(
         ngraph::OutputVector{ fakeQuantize1->output(0), fakeQuantize2->output(0) }, 1);
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
+
+    const std::shared_ptr<ngraph::Node> lastDequantization = makeDequantization(concat, dequantizationOperations);
+    lastDequantization->set_friendly_name("output");
+
+    ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(lastDequantization) };
+    std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
+        results,
+        ngraph::ParameterVector{ input1, input2 },
+        "ConcatTransformation");
+
+    if (fqOnData1.outputPrecision != fqOnData2.outputPrecision) {
+        THROW_IE_EXCEPTION << "FakeQuantize expected precisions are different";
+    }
+    const ngraph::element::Type fqOnDataPrecision = fqOnData1.outputPrecision;
+    if (fqOnDataPrecision != ngraph::element::undefined) {
+        if (fakeQuantize1->get_output_element_type(0) != fakeQuantize2->get_output_element_type(0)) {
+            THROW_IE_EXCEPTION << "FakeQuantize operation precisions are different";
+        }
+        const ngraph::element::Type fakeQuantizePrecision = fakeQuantize1->get_output_element_type(0);
+
+        if (fqOnDataPrecision != fakeQuantizePrecision) {
+            ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(fakeQuantize1, fqOnDataPrecision);
+            ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(fakeQuantize2, fqOnDataPrecision);
+            ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(concat, fqOnDataPrecision);
+        }
+    }
+
+    return function;
+}
+
+std::shared_ptr<ngraph::Function> ConcatFunction::getReference(
+    const ngraph::element::Type precision,
+    const ngraph::Shape& inputShape,
+    const FakeQuantizeOnDataWithConstant& fqOnData1,
+    const FakeQuantizeOnDataWithConstant& fqOnData2,
+    const DequantizationOperations& dequantizationOperations) {
+    const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
+    input1->set_friendly_name("input1");
+    const auto fakeQuantize1 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(input1, precision, fqOnData1);
+
+    const std::vector<size_t> inputShape2 = inputShape;
+    const auto input2 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape2));
+    input2->set_friendly_name("input2");
+    const auto fakeQuantize2 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(input2, precision, fqOnData2);
+
+    const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Concat>>(
+        ngraph::OutputVector{ fakeQuantize1->output(0), fakeQuantize2->output(0) }, 1);
+
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
 
     const std::shared_ptr<ngraph::Node> lastDequantization = makeDequantization(concat, dequantizationOperations);
     lastDequantization->set_friendly_name("output");
@@ -526,11 +629,17 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithNeighbors(
         1ull);
     concat1->set_friendly_name("concat1");
 
+    auto& rtInfo1 = concat1->get_rt_info();
+    rtInfo1["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat1");
+
     const auto concat2 = std::make_shared<ngraph::opset1::Concat>(
         ngraph::OutputVector { fakeQuantize2->output(0), fakeQuantize3->output(0) },
         1ull);
     concat2->set_friendly_name("concat2");
 
+    auto& rtInfo2 = concat2->get_rt_info();
+    rtInfo2["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat2");
+
     const std::shared_ptr<ngraph::Node> lastDequantization1 = makeDequantization(concat1, dequantizationOperations1);
     lastDequantization1->set_friendly_name("concat1");
 
@@ -636,6 +745,9 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithIntermediate(
         1);
     concat->set_friendly_name("concat");
 
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
+
     const std::shared_ptr<ngraph::Node> lastDequantization1 = dequantizationOperations1.empty() ?
         concat :
         makeDequantization(concat, dequantizationOperations1);
@@ -741,6 +853,9 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithSplitedInterme
         ngraph::OutputVector{ fakeQuantize1->output(0), intermediateOp->output(0) }, splitedAxis);
     concat->set_friendly_name("concat");
 
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
+
     const std::shared_ptr<ngraph::Node> lastDequantization1 = dequantizationOperations1.empty() ?
         concat :
         makeDequantization(concat, dequantizationOperations1);
@@ -850,6 +965,9 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceSelectionWithInter
         1);
     concat->set_friendly_name("concat");
 
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
+
     const std::shared_ptr<ngraph::Node> lastDequantization1 = dequantizationOperations1.empty() ?
         concat :
         makeDequantization(concat, dequantizationOperations1);
@@ -932,6 +1050,9 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithDifferentPreci
         ngraph::OutputVector{ fakeQuantize1->output(0), fakeQuantize2->output(0) }, 1);
     concat->set_friendly_name("concat");
 
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
+
     const auto lastDequantization1 = makeDequantization(concat->output(0), dequantizationOperations1);
 
     const std::vector<size_t> kernel = { 3, 3 };
@@ -1053,6 +1174,9 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithIntermediateWi
     concat->set_friendly_name("concat");
     ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(concat, precisionAfterOperation);
 
+    auto& rtInfo = concat->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
+
     const auto deqAfter = makeDequantization(concat->output(0), dequantizationAfter);
     deqAfter->set_friendly_name("concat");
 
index f4733ac..03d0c91 100644 (file)
@@ -62,8 +62,7 @@ std::shared_ptr<ngraph::Function> ConvertMulOrAddWithDequantizationFunction::get
     const auto weights = std::make_shared<opset1::Constant>(element::f32, inputShape, multiplyConst);
     const auto bias = std::make_shared<opset1::Constant>(element::f32, inputShape, 0.0);
     std::shared_ptr<Node> scaleShift = std::make_shared<ngraph::op::ScaleShiftIE>(relu, weights, bias);
-
-    scaleShift = low_precision::NetworkHelper::markAsDequantizationOp(scaleShift);
+    addDequantizationAttribute(scaleShift);
 
     scaleShift->set_friendly_name("output");
 
index cbee456..97e7814 100644 (file)
@@ -66,6 +66,8 @@ std::shared_ptr<ngraph::Function> ConvolutionFunction::getOriginal(
         std::vector<element::Type>{ element::f32, element::f32 },
         std::vector<element::Type>{});
     convolution->set_friendly_name("output");
+    auto& rtInfo = convolution->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("convolution");
 
     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(convolution) };
     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "ConvolutionTransformation");
@@ -261,6 +263,8 @@ std::shared_ptr<ngraph::Function> ConvolutionFunction::getReference(
         std::vector<element::Type>{});
 
     ngraph::pass::low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(convolution, precisionAfterOperation);
+    auto& rtInfo = convolution->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("convolution");
 
     const auto deqAfter = makeDequantization(convolution, dequantizationAfter);
     deqAfter->set_friendly_name("output");
index 0fadd17..d868aa9 100644 (file)
@@ -29,6 +29,8 @@ std::shared_ptr<ngraph::Function> FakeQuantizeFunction::getOriginal(
         input, element::f32, fakeQuantizeOnData.quantizationLevel, fakeQuantizeOnData.constantShape,
         fakeQuantizeOnData.inputLowValues, fakeQuantizeOnData.inputHighValues, fakeQuantizeOnData.outputLowValues, fakeQuantizeOnData.outputHighValues);
     fakeQuantize->set_friendly_name("fakeQuantize");
+    auto& rtInfo = fakeQuantize->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("fakeQuantize");
 
     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(fakeQuantize) };
     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "FakeQuantizeFunction");
@@ -55,15 +57,18 @@ std::shared_ptr<ngraph::Function> FakeQuantizeFunction::getReference(
         fakeQuantizeOnData.outputLowValues,
         fakeQuantizeOnData.outputHighValues));
     std::shared_ptr<Node> parent = fakeQuantize;
+    auto& rtInfo = fakeQuantize->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("fakeQuantize");
 
     if (updatePrecisions) {
         const std::shared_ptr<ngraph::opset1::Convert> convert = std::make_shared<DequantizationConvert>(parent, element::f32);
+        ngraph::copy_runtime_info({ fakeQuantize, convert }, convert);
         parent = convert;
-
         ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(fakeQuantize, fakeQuantizeOutputPrecision);
     } else {
         if (fakeQuantize->get_output_element_type(0) != element::f32) {
             const std::shared_ptr<ngraph::opset1::Convert> convert = std::make_shared<DequantizationConvert>(parent, element::f32);
+            ngraph::copy_runtime_info({ fakeQuantize, convert }, convert);
             parent = convert;
         }
     }
@@ -78,6 +83,7 @@ std::shared_ptr<ngraph::Function> FakeQuantizeFunction::getReference(
                 expectedSubtractValues),
             ngraph::op::AutoBroadcastSpec::NUMPY);
     if (subtract != nullptr) {
+        ngraph::copy_runtime_info({ fakeQuantize, subtract }, subtract);
         parent = subtract;
     }
 
@@ -90,10 +96,10 @@ std::shared_ptr<ngraph::Function> FakeQuantizeFunction::getReference(
                 expectedMultiplyValues.size() == 1ul ? ngraph::Shape{ } : ngraph::Shape{ expectedMultiplyValues.size() },
                 expectedMultiplyValues));
     if (multiply != nullptr) {
+        ngraph::copy_runtime_info({ fakeQuantize, multiply }, multiply);
         parent = multiply;
     }
     parent->set_friendly_name("fakeQuantize");
-
     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(parent) };
     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "FakeQuantizeFunction");
 }
index a2a62cf..5bca24e 100644 (file)
@@ -115,6 +115,8 @@ std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
         false,
         false);
     matMul->set_friendly_name("matMul");
+    auto& rtInfo = matMul->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("matMul");
 
     std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
 
@@ -200,6 +202,8 @@ std::shared_ptr<ngraph::Function> MatMulFunction::getReference(
         false,
         false);
     matMul->set_friendly_name("matMul");
+    auto& rtInfo = matMul->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("matMul");
     ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(matMul, precision);
 
     const std::shared_ptr<ngraph::Node> lastDequantizationAfter = makeDequantization(matMul, resultDequantization);
index 42b3ef9..64eebf8 100644 (file)
@@ -6,8 +6,9 @@
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph_ops/type_relaxed.hpp>
-#include "ngraph_functions/subgraph_builders.hpp"
 #include "low_precision/network_helper.hpp"
+#include "ngraph_functions/subgraph_builders.hpp"
+#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
 
 namespace ngraph {
 namespace builder {
@@ -87,18 +88,18 @@ std::shared_ptr<ngraph::Function> MaxPoolFunction::getReference(
     parent = maxPool;
 
     if (parent->get_output_element_type(0) != originalFunctionPrecision) {
-        const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::opset1::Convert>(parent, originalFunctionPrecision);
+        const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::pass::low_precision::DequantizationConvert>(parent, originalFunctionPrecision);
         parent = convert;
     }
 
     if (!values.subtractValues.empty()) {
-        const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::opset1::Subtract>(
+        const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(
             parent,
             std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues));
         parent = subtract;
     }
 
-    const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::opset1::Multiply>(
+    const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(
         parent,
         std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues));
     multiply->set_friendly_name("output");
diff --git a/inference-engine/tests/ngraph_functions/src/low_precision_transformations/move_dequantization_after_function.cpp b/inference-engine/tests/ngraph_functions/src/low_precision_transformations/move_dequantization_after_function.cpp
new file mode 100644 (file)
index 0000000..cd8d8d8
--- /dev/null
@@ -0,0 +1,78 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "ngraph_functions/low_precision_transformations/move_dequantization_after_function.hpp"
+#include "low_precision/network_helper.hpp"
+
+#include <ngraph/opsets/opset1.hpp>
+#include "ngraph_functions/subgraph_builders.hpp"
+#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
+
+using namespace ngraph::pass::low_precision;
+
+namespace ngraph {
+namespace builder {
+namespace subgraph {
+    std::shared_ptr<ngraph::Function> MoveDequantizationAfterFunction::getOriginal(
+        const ngraph::element::Type precision,
+        const ngraph::Shape& inputShape,
+        const ngraph::builder::subgraph::DequantizationOperations dequantization) {
+        const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
+
+        const auto deq = makeDequantization(input, dequantization);
+        const auto op = ngraph::opset1::MaxPool(
+            deq,
+            Strides{ 1, 1 },
+            Shape{ 1, 1 },
+            Shape{ 0, 0 },
+            Shape{ 2, 2 },
+            op::RoundingType::FLOOR);
+        const auto targetOp = std::make_shared<op::TypeRelaxed<opset1::MaxPool>>(
+            op,
+            std::vector<element::Type>{ element::f32, element::f32 },
+            std::vector<element::Type>{});
+        auto& rtInfo = targetOp->get_rt_info();
+        rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("targetOp");
+
+        return std::make_shared<ngraph::Function>(
+            ngraph::ResultVector{ std::make_shared<ngraph::opset1::Result>(targetOp) },
+            ngraph::ParameterVector{ input },
+            "MoveDequantizationAfterFunction");
+    }
+
+    std::shared_ptr<ngraph::Function> MoveDequantizationAfterFunction::getReference(
+        const ngraph::element::Type precision,
+        const ngraph::Shape& inputShape,
+        const ngraph::builder::subgraph::DequantizationOperations dequantizationBefore,
+        const ngraph::element::Type precisionAfterOperation,
+        const ngraph::builder::subgraph::DequantizationOperations dequantizationAfter) {
+        const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
+
+        const auto deqBefore = makeDequantization(input, dequantizationBefore);
+        const auto op = ngraph::opset1::MaxPool(
+            deqBefore,
+            Strides{ 1, 1 },
+            Shape{ 1, 1 },
+            Shape{ 0, 0 },
+            Shape{ 2, 2 },
+            op::RoundingType::FLOOR);
+        const auto targetOp = std::make_shared<op::TypeRelaxed<opset1::MaxPool>>(
+            op,
+            std::vector<element::Type>{ element::f32, element::f32 },
+            std::vector<element::Type>{});
+        ngraph::pass::low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(targetOp, precisionAfterOperation);
+        auto& rtInfo = targetOp->get_rt_info();
+        rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("targetOp");
+
+        const auto deqAfter = makeDequantization(targetOp, dequantizationAfter);
+
+        return std::make_shared<ngraph::Function>(
+            ngraph::ResultVector{ std::make_shared<ngraph::opset1::Result>(deqAfter) },
+            ngraph::ParameterVector{ input },
+            "MoveDequantizationAfterFunction");
+    }
+
+}  // namespace subgraph
+}  // namespace builder
+}  // namespace ngraph
index e30526f..d04c97c 100644 (file)
@@ -13,6 +13,7 @@
 #include <legacy/ngraph_ops/scaleshift.hpp>
 
 #include "ngraph_functions/subgraph_builders.hpp"
+#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
 #include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
 
 namespace ngraph {
@@ -62,7 +63,7 @@ namespace subgraph {
         std::shared_ptr<ngraph::Node> lastNode;
         if (isDequantization) {
             std::shared_ptr<Node> scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(input, weights, biases, precisionAfterOperation);
-            scaleshift = low_precision::NetworkHelper::markAsDequantizationOp(scaleshift);
+            addDequantizationAttribute(scaleshift);
             scaleshift->set_friendly_name("add");
             lastNode = scaleshift;
         } else {
index f0496d8..e7125c0 100644 (file)
@@ -64,7 +64,8 @@ std::shared_ptr<ngraph::Function> MultiplyFunction::get(
         multiplyOriginal,
         std::vector<element::Type>{element::f32, element::f32},
         std::vector<element::Type>{});
-
+    auto& rtInfo = multiply->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("multiply");
     multiply->set_friendly_name("output");
 
     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
index 54a595d..8868756 100644 (file)
@@ -25,6 +25,8 @@ std::shared_ptr<ngraph::Function> MVNFunction::getOriginal(
     const auto dequantizationOp = makeDequantization(input, dequantization);
     const auto mvn = std::make_shared<ngraph::op::MVN>(dequantizationOp, reductionAxes, normalizeVariance);
     mvn->set_friendly_name("output");
+    auto& rtInfo = mvn->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("mvn");
 
     ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(mvn) };
     return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MVNFunction");
@@ -62,6 +64,9 @@ std::shared_ptr<ngraph::Function> MVNFunction::getReference(
     const std::shared_ptr<Node> dequantizationOpBefore = makeDequantization(input, dequantizationBefore);
     const auto mvn = std::make_shared<ngraph::op::TypeRelaxed<ngraph::op::MVN>>(
         op::MVN(dequantizationOpBefore, reductionAxes, normalizeVariance), precisionAfterOperation);
+    auto& rtInfo = mvn->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("mvn");
+
     const std::shared_ptr<Node> dequantizationOpAfter = makeDequantization(mvn, dequantizationAfter);
     dequantizationOpAfter->set_friendly_name("output");
 
index 89255db..582deb6 100644 (file)
@@ -88,6 +88,8 @@ std::shared_ptr<ngraph::Function> NormalizeL2Function::getOriginal(
     const auto axesNode = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{ actualValues.axes.size() }, actualValues.axes);
     const auto normalizeL2 = std::make_shared<ngraph::opset1::NormalizeL2>(parent, axesNode, 1e-6, epsMode);
     normalizeL2->set_friendly_name("output");
+    auto& rtInfo = normalizeL2->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("normalizeL2");
 
     ngraph::ResultVector results = { std::make_shared<ngraph::opset1::Result>(normalizeL2) };
     const auto function = std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "NormalizeL2Transformation");
@@ -123,6 +125,8 @@ std::shared_ptr<ngraph::Function> NormalizeL2Function::getReference(
         ngraph::op::TemporaryReplaceOutputType(axesNode, element::f32).get(),
         1e-6,
         epsMode);
+    auto& rtInfo = normalizeL2->get_rt_info();
+    rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("normalizeL2");
     std::shared_ptr<ngraph::Node> output = normalizeL2;
 
     if (!expectedValues.mutliplyValues.empty()) {
@@ -131,6 +135,7 @@ std::shared_ptr<ngraph::Function> NormalizeL2Function::getReference(
             ngraph::op::TemporaryReplaceOutputType(output, element::f32).get(),
             ngraph::op::TemporaryReplaceOutputType(std::make_shared<ngraph::opset1::Constant>(
                 precision, Shape({ 1, expectedValues.mutliplyValues.size(), 1, 1 }), expectedValues.mutliplyValues), element::f32).get());
+        multiply->get_rt_info()["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("normalizeL2");
         output = multiply;
     }
     output->set_friendly_name("output");