[LPT] integration: issue #42391 & issue #43001 (#3201)
authorEdward Shogulin <edward.shogulin@intel.com>
Mon, 23 Nov 2020 14:08:41 +0000 (17:08 +0300)
committerGitHub <noreply@github.com>
Mon, 23 Nov 2020 14:08:41 +0000 (17:08 +0300)
* [LPT] NetworkHelper::roundWithTolerance: removed tolerance & rename to round
[LPT] NetworkHelper::round functional tests
[LPT] ieFuncTests: updated some test-cases

* [LPT] Subtract is not used

* [LPT] AddTransformation: zero handling

* [LPT] AddTransformation test

27 files changed:
inference-engine/src/low_precision_transformations/include/low_precision/add.hpp
inference-engine/src/low_precision_transformations/include/low_precision/common/fake_quantize_dequantization.hpp
inference-engine/src/low_precision_transformations/include/low_precision/network_helper.hpp
inference-engine/src/low_precision_transformations/src/add.cpp
inference-engine/src/low_precision_transformations/src/clamp.cpp
inference-engine/src/low_precision_transformations/src/fake_quantize_dequantization.cpp
inference-engine/src/low_precision_transformations/src/group_convolution.cpp
inference-engine/src/low_precision_transformations/src/layer_transformation.cpp
inference-engine/src/low_precision_transformations/src/mvn.cpp
inference-engine/src/low_precision_transformations/src/network_helper.cpp
inference-engine/src/low_precision_transformations/src/normalize_l2.cpp
inference-engine/src/low_precision_transformations/src/prelu.cpp
inference-engine/src/low_precision_transformations/src/relu.cpp
inference-engine/src/low_precision_transformations/src/subtract.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/add_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/clamp_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/convolution_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/group_convolution_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/mat_mul_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/move_dequantization_after_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/mvn_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/normalize_l2_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/prelu_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/relu_transformation.cpp
inference-engine/tests/functional/inference_engine/lp_transformations/round_transformation.cpp [new file with mode: 0644]
inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/round_function.hpp [new file with mode: 0644]
inference-engine/tests/ngraph_functions/src/low_precision_transformations/round_function.cpp [new file with mode: 0644]

index 187b8e3..bae35aa 100644 (file)
@@ -17,6 +17,7 @@ public:
     ~AddTransformation() override {}
     void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
     bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
+    bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
 };
 
 } // namespace low_precision
index 87a4dab..2bfc6ee 100644 (file)
@@ -26,6 +26,7 @@ public:
         std::shared_ptr<ngraph::opset1::Multiply> multiply);
 
     bool empty() const;
+    bool multiplyHasZero() const;
     bool isShared() const;
     bool isLowPrecision() const;
     static bool checkElementwise(const std::shared_ptr<ngraph::Node>& elementwise);
index f27462b..91df6bb 100644 (file)
@@ -81,7 +81,7 @@ public:
     // Optimizes the series of multiplies after a given output port
     static std::shared_ptr<ngraph::opset1::Multiply> optimizeMultipliesAfter(std::shared_ptr<Node> multiply);
 
-    static std::shared_ptr<opset1::Constant> roundWithTolerance(std::shared_ptr<Node> node, element::Type target_type, float tolerance = 0.1);
+    static std::shared_ptr<opset1::Constant> round(std::shared_ptr<Node> node, element::Type target_type);
 
     static std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> decomposeFakeQuantize(
         std::shared_ptr<opset1::FakeQuantize> fq,
index ce9d3a8..daddb07 100644 (file)
@@ -199,6 +199,20 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
     return true;
 }
 
+bool AddTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
+    const FakeQuantizeDequantization dequantization1 = pass::low_precision::NetworkHelper::getDequantization(layer, 0ul);
+    if (dequantization1.multiplyHasZero()) {
+        return false;
+    }
+
+    const FakeQuantizeDequantization dequantization2 = pass::low_precision::NetworkHelper::getDequantization(layer, 1ul);
+    if (dequantization2.multiplyHasZero()) {
+        return false;
+    }
+
+    return EltwiseBaseTransformation::canBeTransformed(context, layer);
+}
+
 } // namespace low_precision
 } // namespace pass
 } // namespace ngraph
index 16d7220..c93eec4 100644 (file)
@@ -42,7 +42,8 @@ bool ClampTransformation::transform(TransformationContext& context, ngraph::patt
     const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(clamp);
 
     const bool moveSubtract = subWithTheSameValues(dequantization.subtract);
-    if (!moveSubtract && !canSubtractBeHandled(clamp, dequantization)) {
+    // issue #43136
+    if (!moveSubtract && (dequantization.subtract != nullptr)) {
         return false;
     }
     const auto newClamp = as_type_ptr<opset1::Clamp>(moveDequantizationAfter(context, clamp, dequantization, false, moveSubtract));
index b8ebdbc..9244cd1 100644 (file)
@@ -30,6 +30,23 @@ bool FakeQuantizeDequantization::empty() const {
     return (convert == nullptr) && (subtract == nullptr) && (multiply == nullptr);
 }
 
+bool FakeQuantizeDequantization::multiplyHasZero() const {
+    if (multiply == nullptr) {
+        return false;
+    }
+
+    std::shared_ptr<opset1::Constant> multiplyConstant = as_type_ptr<opset1::Constant>(multiply->get_input_node_shared_ptr(1));
+    if (multiplyConstant == nullptr) {
+        multiplyConstant = as_type_ptr<opset1::Constant>(multiply->get_input_node_shared_ptr(0));
+    }
+    if (multiplyConstant == nullptr) {
+        return false;
+    }
+
+    auto const values = multiplyConstant->cast_vector<float>();
+    return std::any_of(values.begin(), values.end(), [](const float value) { return value == 0.f; });
+}
+
 bool FakeQuantizeDequantization::isShared() const {
     if ((convert != nullptr) && (convert->get_output_target_inputs(0).size() > 1ul)) {
         return true;
index 20b67fb..d563785 100644 (file)
@@ -33,6 +33,7 @@ bool GroupConvolutionTransformation::isQuantized(std::shared_ptr<Node> layer) co
 
 bool GroupConvolutionTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
     auto convolution = m.get_match_root();
+
     if (!GroupConvolutionTransformation::canBeTransformed(context, convolution)) {
         return false;
     }
index e3b5622..dbe2198 100644 (file)
@@ -138,9 +138,7 @@ bool LayerTransformation::canSubtractBeHandled(const std::shared_ptr<Node>& op,
         return false;
     }
 
-    std::shared_ptr<Node> zeroPoint = dequantization.subtract->input_value(1).get_node_shared_ptr();
-    auto convertedZeroPoint = NetworkHelper::roundWithTolerance(zeroPoint, operationType);
-    return convertedZeroPoint->output(0).get_element_type() == operationType;
+    return true;
 }
 
 #ifdef LPT_PRINT_DEQUANTIZATION_INFO
index b54abdc..b4540d2 100644 (file)
@@ -41,7 +41,7 @@ bool MVNTransformation::canBeTransformed(const TransformationContext& context, s
         return false;
     }
 
-    if (!canSubtractBeHandled(operation)) {
+    if (NetworkHelper::getDequantization(operation).subtract != nullptr) {
         return false;
     }
 
index 8bd84d5..e2b84a2 100644 (file)
@@ -321,52 +321,15 @@ std::shared_ptr<ngraph::opset1::Multiply> NetworkHelper::optimizeMultipliesAfter
     return nullptr;
 }
 
-std::shared_ptr<opset1::Constant> NetworkHelper::roundWithTolerance(std::shared_ptr<Node> node, element::Type target_type, float tolerance) {
-    auto constant = as_type_ptr<opset1::Constant>(node);
+std::shared_ptr<opset1::Constant> NetworkHelper::round(std::shared_ptr<Node> node, element::Type target_type) {
+    const auto constant = as_type_ptr<opset1::Constant>(node);
     assert(constant);
-    auto values = constant->cast_vector<float>();
-
-    auto castedConstant = as_type_ptr<opset1::Constant>(fold<opset1::Convert>(constant, target_type));
-    auto castedValues = castedConstant->cast_vector<float>();
-
-    // TODO: implement with constant folding when ReduceAnd constant folding is ready
-    if (std::equal(values.begin(), values.end(), castedValues.begin(), [tolerance](float a, float b) { return fabs(a - b) < tolerance; })) {
-        return castedConstant;
-    }
 
-    auto round = [](
-        const std::shared_ptr<opset1::Constant>& constant,
-        element::Type target_type,
-        float tolerance,
-        std::vector<float>& values,
-        float increaseValue) -> std::shared_ptr<opset1::Constant> {
-        const auto castedConstant = as_type_ptr<opset1::Constant>(fold<opset1::Convert>(
-            fold<opset1::Add>(constant, std::make_shared<opset1::Constant>(constant->get_output_element_type(0), Shape{ 1 }, increaseValue)),
-            target_type));
-        const auto castedValues = castedConstant->cast_vector<float>();
-        if (std::equal(values.begin(), values.end(), castedValues.begin(), [tolerance](float a, float b) { return fabs(a - b) < tolerance; })) {
-            return castedConstant;
-        }
-
-        return nullptr;
-    };
+    const auto castedConstant = as_type_ptr<ngraph::opset1::Constant>(fold<op::v0::Convert>(
+        fold<ngraph::op::v5::Round>(constant->output(0), ngraph::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO),
+        target_type));
 
-    castedConstant = round(constant, target_type, tolerance, values, 0.5f);
-    if (castedConstant != nullptr) {
-        return castedConstant;
-    }
-
-    castedConstant = round(constant, target_type, tolerance, values, -0.5f);
-    if (castedConstant != nullptr) {
-        return castedConstant;
-    }
-
-    castedConstant = round(constant, target_type, tolerance, values, 1.f);
-    if (castedConstant != nullptr) {
-        return castedConstant;
-    }
-
-    return constant;
+    return castedConstant;
 }
 
 std::shared_ptr<Node> NetworkHelper::fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq) {
@@ -889,16 +852,13 @@ std::shared_ptr<Node> NetworkHelper::optimizeSubtract(std::shared_ptr<opset1::Su
 
     auto data = convertOnSubtract->input_value(0);
     auto shift = subtract->input_value(1).get_node_shared_ptr();
-    auto roundedShift = NetworkHelper::roundWithTolerance(shift, convertInputType);
-
-    std::shared_ptr<Node> replacement;
-    if (roundedShift->get_element_type() == convertInputType) {
-        // Propagate convertInputType down
-        replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, roundedShift);
-        NetworkHelper::copyInfo(subtract, replacement);
-        NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType);
-        replace_node(subtract, replacement);
-    }
+    auto roundedShift = NetworkHelper::round(shift, convertInputType);
+
+    // Propagate convertInputType down
+    const auto replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(data, roundedShift);
+    NetworkHelper::copyInfo(subtract, replacement);
+    NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, convertOutputType);
+    replace_node(subtract, replacement);
 
     // We lose the tail conversion here; not needed if the next node is a TypeRelaxed
     // TODO: check cases when Convert should be preserved
@@ -992,7 +952,8 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
 
     if ((!moveSubtract) && (dequantization.convert != nullptr) && (dequantization.subtract != nullptr)) {
         NetworkHelper::cleanRunTimeInfo(dequantization.subtract);
-        optimizeSubtract(dequantization.subtract);
+        // issue #43088
+        // NetworkHelper::optimizeElementwise(dequantization.subtract);
     }
 
     return InsertDequantizationResult(newOperation, parent);
index da156c7..f969600 100644 (file)
@@ -40,7 +40,7 @@ bool NormalizeL2Transformation::canBeTransformed(const TransformationContext& co
         return false;
     }
 
-    if (!canSubtractBeHandled(operation)) {
+    if (NetworkHelper::getDequantization(operation).subtract != nullptr) {
         return false;
     }
 
index 4e40523..4907206 100644 (file)
@@ -40,7 +40,7 @@ bool PReluTransformation::isPrecisionPreserved(std::shared_ptr<Node> op) const n
 
 bool PReluTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
     const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op, 0);
-    if (dequantization.empty()) {
+    if (dequantization.empty() || (dequantization.subtract != nullptr)) {
         return false;
     }
 
index 37eb5dd..c4dd184 100644 (file)
@@ -48,11 +48,7 @@ bool ReluTransformation::canBeTransformed(const TransformationContext& context,
     }
 
     const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op, 0);
-    if (dequantization.empty()) {
-        return false;
-    }
-
-    if (!canSubtractBeHandled(op, dequantization)) {
+    if (dequantization.empty() || (dequantization.subtract != nullptr)) {
         return false;
     }
 
index 730cb3e..10b49dc 100644 (file)
@@ -72,12 +72,13 @@ bool SubtractTransformation::transform(TransformationContext& context, ngraph::p
     }
 
     if (dequantization.convert != nullptr) {
-        std::shared_ptr<Node> newSubtract = NetworkHelper::optimizeSubtract(subtract);
-        newSubtract->set_output_type(0, originalPrecision, newSubtract->get_output_partial_shape(0));
+        // issue #43088
+        // std::shared_ptr<Node> newSubtract = NetworkHelper::optimizeElementwise(subtract);
+        subtract->set_output_type(0, originalPrecision, subtract->get_output_partial_shape(0));
 
-        replace_node(newSubtract, std::make_shared<op::TypeRelaxed<opset1::Subtract>>(
-            newSubtract->get_input_node_shared_ptr(0),
-            newSubtract->get_input_node_shared_ptr(1)));
+        replace_node(subtract, std::make_shared<op::TypeRelaxed<opset1::Subtract>>(
+            subtract->get_input_node_shared_ptr(0),
+            subtract->get_input_node_shared_ptr(1)));
     }
     return true;
 }
index b66d0f1..a105f08 100644 (file)
@@ -147,6 +147,54 @@ TEST_P(AddTransformation, CompareFunctions) {
 }
 
 const std::vector<AddTransformationTestValues> addTransformationTestValues = {
+    // Multiply with zero on the first branch
+    {
+        ngraph::element::f32,
+        ngraph::Shape{1, 4, 16, 16},
+        false,
+        -1,
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::f32,
+            { },
+            ngraph::element::u8,
+            { {ngraph::element::f32},  { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
+            { }
+        },
+        {
+            ngraph::element::f32,
+            { },
+            ngraph::element::u8,
+            { {ngraph::element::f32},  { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
+            { },
+            { }
+        },
+        ""
+    },
+    // Multiply with zero on the second branch
+    {
+        ngraph::element::f32,
+        ngraph::Shape{1, 4, 16, 16},
+        false,
+        -1,
+        LayerTransformation::createParamsU8I8(),
+        {
+            ngraph::element::u8,
+            { {ngraph::element::f32},  { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
+            ngraph::element::f32,
+            { },
+            { }
+        },
+        {
+            ngraph::element::u8,
+            { {ngraph::element::f32},  { 7.f }, { {1.f, 0.f, 2.f, 3.f} }},
+            ngraph::element::f32,
+            { },
+            { },
+            { }
+        },
+        ""
+    },
     // U8
     {
         ngraph::element::f32,
index 6cb622c..22e1609 100644 (file)
@@ -331,9 +331,13 @@ const std::vector<ClampTransformationTestValues> testValues = {
         },
         {
             ngraph::element::u8,
-            {{}, {{ 128.f, 0.f, 128.f }, ngraph::element::f32}, {}},
+            {
+                {ngraph::element::f32},
+                {{ 128.f, 0.f, 128.f }},
+                {{ 3.f, 3.f, 3.f }}
+            },
             ngraph::element::f32,
-            {{}, {}, {{3.f, 3.f, 3.f}}}
+            {{}, {}, {}}
         }
     },
     // U8 without asymmetric quantization
index e090447..8e1f042 100644 (file)
@@ -154,7 +154,7 @@ const std::vector<ConvolutionTransformationTestValues> testValues = {
         // ActualValues
         {
             ngraph::element::f32,
-            {{ngraph::element::f32}, { 128.f }, { 0.02f }},
+            {{}, { 128.f }, { 0.02f }},
             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
         },
@@ -214,7 +214,7 @@ const std::vector<ConvolutionTransformationTestValues> testValues = {
         // ActualValues
         {
             ngraph::element::f32,
-            {{ngraph::element::f32}, {}, { 0.02f }},
+            {{}, {}, { 0.02f }},
             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
         },
index 3101d68..79f05d2 100644 (file)
@@ -165,7 +165,7 @@ const std::vector<GroupConvolutionTestValues> testValues = {
         // ActualValues
         {
             ngraph::element::f32,
-            {{ngraph::element::f32}, { 128.f }, { 0.02f }},
+            {{}, { 128.f }, { 0.02f }},
             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
         },
@@ -329,7 +329,7 @@ const std::vector<GroupConvolutionTestValues> testValues = {
         // ActualValues
         {
             ngraph::element::f32,
-            {{ngraph::element::f32}, { 128.f }, { 0.02f }},
+            {{}, { 128.f }, { 0.02f }},
             op::Constant::create(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{ 2.f }),
             { 255ul, Shape({ 1, 1, 1, 1 }), { 0.f }, { 254.f }, { -1.27f }, { 1.27f } }
         },
index 5722628..865ded0 100644 (file)
@@ -218,12 +218,12 @@ std::vector<MatMullTransformationTestValues> testValues = {
         },
         {
             ngraph::element::u8,
-            { ngraph::element::f32, { 127.5f }, { 0.02f } },
+            { {}, {{128.f}, ngraph::element::f32, ngraph::Shape{ }, false}, {} },
             ngraph::element::i8,
-            { ngraph::element::f32, {}, { 0.03f } },
+            { },
             ngraph::element::f32,
             ngraph::element::f32,
-            {},
+            { {}, {}, { 0.0006f } },
         }
     },
     // U8 + FP32
index f528121..1589ba3 100644 (file)
@@ -129,7 +129,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
             { {ngraph::element::f32},  { 7.f }, { 10.f } },
         },
         {
-            { {},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
+            { {ngraph::element::f32},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
             ngraph::element::f32,
             { {},  {}, { 10.f } },
         },
@@ -159,7 +159,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
             { {ngraph::element::f32},  { 7.f }, { 10.f } },
         },
         {
-            { {},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
+            { {ngraph::element::f32},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
             ngraph::element::f32,
             { {},  {}, { 10.f } },
         },
@@ -189,7 +189,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
             { {ngraph::element::f32},  { 7.f }, { 10.f } },
         },
         {
-            { {},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
+            { {ngraph::element::f32},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
             ngraph::element::f32,
             { {},  {}, { 10.f } },
         },
@@ -219,7 +219,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
             { {ngraph::element::f32},  { 7.f }, { 10.f } },
         },
         {
-            { {},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
+            { {ngraph::element::f32},  { { 7.f }, ngraph::element::f32, {}, false }, {} },
             ngraph::element::f32,
             { {},  {}, { 10.f } },
         },
@@ -234,12 +234,12 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
             { {ngraph::element::f32},  { { 7.f, 7.f, 7.f } }, { { 10.f, 10.f, 10.f } } },
         },
         {
-            { {},  { { 7.f, 7.f, 7.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
+            { {ngraph::element::f32},  { { 7.f, 7.f, 7.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
             ngraph::element::f32,
             { {},  {}, { { 10.f, 10.f, 10.f } } },
         },
     },
-    // per-channel quantizations with the same values
+    // per-channel quantizations with different values
     {
         ngraph::element::u8,
         LayerTransformation::createParamsU8I8(),
@@ -249,7 +249,7 @@ const std::vector<MoveDequantizationAfterTransformationParams> testValues = {
             { {ngraph::element::f32},  { { 7.f, 8.f, 9.f } }, { { 10.f, 12.f, 16.f } } },
         },
         {
-            { {},  { { 7.f, 8.f, 9.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
+            { {ngraph::element::f32},  { { 7.f, 8.f, 9.f }, ngraph::element::f32, { 1, 3, 1, 1 }, false }, {} },
             ngraph::element::f32,
             { {},  {}, { { 10.f, 12.f, 16.f } } },
         },
index 0e30a84..f69ab33 100644 (file)
@@ -91,6 +91,7 @@ public:
 
         std::ostringstream result;
         result <<
+            toString(testValues.params) << "_" <<
             testValues.inputShape << "_" <<
             testValues.reductionAxes << "_" <<
             testValues.normalizeVariance << "_" <<
@@ -145,9 +146,9 @@ const std::vector<MVNTransformationTestValues> testValues = {
         },
         {
             ngraph::element::u8,
-            {{ngraph::element::f32}, {127.f}, {}},
+            {{ngraph::element::f32}, {127.f}, {0.45f}},
             ngraph::element::f32,
-            {{}, {}, {1.f}}
+            {{}, {}, {}}
         }
     },
     {
@@ -163,7 +164,7 @@ const std::vector<MVNTransformationTestValues> testValues = {
             ngraph::element::u8,
             {{ngraph::element::f32}, {12.5f}, {0.45f}},
             ngraph::element::f32,
-            {}
+            {{}, {}, {}}
         }
     },
     {
index e9efaef..3a00f9b 100644 (file)
@@ -53,7 +53,7 @@ public:
             low_precision::LayerTransformation::Params(params.transformationParams));
         transform.transform(actualFunction);
 
-        referenceFunction = (!params.transformationParams.supportAsymmetricQuantization) && (!params.expected.subtractValues.empty()) ?
+        referenceFunction = !params.expected.subtractValues.empty() ?
             ngraph::builder::subgraph::NormalizeL2Function::getOriginal(
                 precision,
                 shape,
index f636ad8..daa0204 100644 (file)
@@ -137,9 +137,9 @@ const std::vector<PReluTransformationTestValues> testValues = {
         },
         {
             ngraph::element::u8,
-            {{}, { {128}, ngraph::element::f32 }, {}},
+            {{ngraph::element::f32}, { 128 }, {0.1f}},
             ngraph::element::f32,
-            {{}, {}, {0.1f}}
+            {{}, {}, {}}
         }
     },
     // I8: with positive subtract value
@@ -152,24 +152,9 @@ const std::vector<PReluTransformationTestValues> testValues = {
         },
         {
             ngraph::element::i8,
-            {{}, { {127}, ngraph::element::f32 }, {}},
-            ngraph::element::f32,
-            {{}, {}, {0.1f}}
-        }
-    },
-    // U8: with negative subtract value: Convert is still here
-    {
-        ngraph::Shape({ 1, 3, 16, 16 }),
-        LayerTransformation::createParamsU8I8(),
-        {
-            ngraph::element::u8,
-            {{ngraph::element::f32}, { -128 }, {0.1f}}
-        },
-        {
-            ngraph::element::u8,
-            {{ngraph::element::f32}, { {-128}, ngraph::element::f32 }, {}},
+            {{ngraph::element::f32}, { 127 }, {0.1f}},
             ngraph::element::f32,
-            {{}, {}, {0.1f}}
+            {{}, {}, {}}
         }
     },
 };
index 0259edd..60c60a2 100644 (file)
@@ -73,6 +73,7 @@ public:
 
         std::ostringstream result;
         result <<
+            toString(testValues.params) << "_" <<
             testValues.shape << "_" <<
             testValues.actual.precisionBeforeDequantization << "_" <<
             testValues.actual.dequantization << "_" <<
@@ -166,9 +167,9 @@ const std::vector<ReluTransformationTestValues> testValues = {
         },
         {
             ngraph::element::u8,
-            {{}, { {128}, ngraph::element::f32, {}, false }, {}},
+            {{ngraph::element::f32}, { 128 }, {0.1f}},
             ngraph::element::f32,
-            {{}, {}, {0.1f}}
+            {{}, {}, {}}
         }
     },
     // I8: with subtract value
@@ -181,9 +182,9 @@ const std::vector<ReluTransformationTestValues> testValues = {
         },
         {
             ngraph::element::i8,
-            {{}, { {127}, ngraph::element::f32, {}, false }, {}},
+            {{ngraph::element::f32}, { 127 }, {0.1f}},
             ngraph::element::f32,
-            {{}, {}, {0.1f}}
+            {{}, {}, {}}
         }
     },
     // I8: with subtract value
diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/round_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/round_transformation.cpp
new file mode 100644 (file)
index 0000000..eafa072
--- /dev/null
@@ -0,0 +1,111 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "layer_transformation.hpp"
+
+#include <string>
+#include <sstream>
+#include <gtest/gtest.h>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+#include "ngraph_functions/low_precision_transformations/round_function.hpp"
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
+#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
+#include "low_precision/network_helper.hpp"
+
+
+namespace {
+using namespace testing;
+using namespace ngraph;
+using namespace ngraph::pass;
+
+class RoundTestValues {
+public:
+    ngraph::element::Type inputPrecision;
+    ngraph::Shape inputShape;
+    ngraph::builder::subgraph::DequantizationOperations actualDequantization;
+    ngraph::builder::subgraph::DequantizationOperations referenceDequantization;
+};
+
+
+
+class RoundTransformation : public LayerTransformation, public testing::WithParamInterface<RoundTestValues> {
+public:
+    void SetUp() override {
+        const auto testValues = this->GetParam();
+
+        actualFunction = ngraph::builder::subgraph::RoundWithToleranceFunction::getOriginal(
+            testValues.inputPrecision,
+            testValues.inputShape,
+            testValues.actualDequantization);
+        const auto lastNode = actualFunction->get_output_op(0)->get_input_node_shared_ptr(0);
+        const auto dequantization = ngraph::pass::low_precision::NetworkHelper::getDequantization(lastNode);
+        const auto subtractConstant = dequantization.subtract->get_input_node_shared_ptr(1);
+        const auto roundedConst = ngraph::pass::low_precision::NetworkHelper::round(
+            subtractConstant,
+            testValues.inputPrecision);
+
+        if (roundedConst->get_element_type() == testValues.inputPrecision) {
+            const auto replacement = std::make_shared<op::TypeRelaxed<opset1::Subtract>>(dequantization.data, roundedConst);
+            ngraph::pass::low_precision::NetworkHelper::copyInfo(dequantization.subtract, replacement);
+            ngraph::pass::low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(replacement, dequantization.convert->get_element_type());
+            replace_node(dequantization.subtract, replacement);
+        }
+
+        referenceFunction = ngraph::builder::subgraph::RoundWithToleranceFunction::getReference(
+            testValues.inputPrecision,
+            testValues.inputShape,
+            testValues.referenceDequantization);
+    }
+
+    static std::string getTestCaseName(testing::TestParamInfo<RoundTestValues> obj) {
+        const auto testValues = obj.param;
+
+        std::ostringstream result;
+        result << testValues.inputPrecision << "_"
+               << testValues.actualDequantization << "_"
+               << testValues.referenceDequantization;
+        return result.str();
+    }
+};
+
+std::vector<RoundTestValues> testValues = {
+    {
+        ngraph::element::u8,
+        ngraph::Shape{ 1, 3, 16, 16 },
+        { { ngraph::element::f32 }, { 125.5f }, { 0.1f } },
+        { {}, { { 126.f }, ngraph::element::f32 }, { 0.1f } }
+    },
+    {
+        ngraph::element::u8,
+        ngraph::Shape{ 1, 3, 16, 16 },
+        { { ngraph::element::f32 }, { { 128.3f, 64.5f, 31.7f } }, { { 0.1f, 0.1f, 0.1f } } },
+        { {}, { { 128.f, 65.f, 32.f }, ngraph::element::f32 }, { { 0.1f, 0.1f, 0.1f } } }
+    },
+    {
+        ngraph::element::i8,
+        ngraph::Shape{ 1, 3, 16, 16 },
+        { { ngraph::element::f32 }, { 126.6f }, { 0.1f } },
+        { {}, { { 127.f }, ngraph::element::f32 }, { 0.1f } }
+    },
+    {
+        ngraph::element::i8,
+        ngraph::Shape{ 1, 3, 16, 16 },
+        { { ngraph::element::f32 }, { { 126.5f, 32.25f, -127.5f } }, { { 0.1f, 0.1f, 0.1f } } },
+        { {}, { { 127.f, 32.f, -128.f }, ngraph::element::f32 }, { { 0.1f, 0.1f, 0.1f } } }
+    },
+};
+
+TEST_P(RoundTransformation, CompareFunctions) {
+    actualFunction->validate_nodes_and_infer_types();
+    auto res = compare_functions(referenceFunction, actualFunction, true, true);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+INSTANTIATE_TEST_CASE_P(
+    LPT,
+    RoundTransformation,
+    ::testing::ValuesIn(testValues),
+    RoundTransformation::getTestCaseName);
+} // namespace
diff --git a/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/round_function.hpp b/inference-engine/tests/ngraph_functions/include/ngraph_functions/low_precision_transformations/round_function.hpp
new file mode 100644 (file)
index 0000000..d854c33
--- /dev/null
@@ -0,0 +1,32 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+#include <ngraph/ngraph.hpp>
+
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
+#include "ngraph_functions/subgraph_builders.hpp"
+
+namespace ngraph {
+namespace builder {
+namespace subgraph {
+
+class RoundWithToleranceFunction {
+public:
+    static std::shared_ptr<ngraph::Function> getOriginal(
+        const ngraph::element::Type precision,
+        const ngraph::Shape& inputShape,
+        const ngraph::builder::subgraph::DequantizationOperations dequantization);
+
+    static std::shared_ptr<ngraph::Function> getReference(
+        const ngraph::element::Type precision,
+        const ngraph::Shape& inputShape,
+        const ngraph::builder::subgraph::DequantizationOperations dequantization);
+};
+
+}  // namespace subgraph
+}  // namespace builder
+}  // namespace ngraph
diff --git a/inference-engine/tests/ngraph_functions/src/low_precision_transformations/round_function.cpp b/inference-engine/tests/ngraph_functions/src/low_precision_transformations/round_function.cpp
new file mode 100644 (file)
index 0000000..cdc4b3e
--- /dev/null
@@ -0,0 +1,56 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "ngraph_functions/low_precision_transformations/round_function.hpp"
+
+#include <ngraph/opsets/opset1.hpp>
+#include "ngraph_functions/subgraph_builders.hpp"
+#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
+
+using namespace ngraph::pass::low_precision;
+
+namespace ngraph {
+namespace builder {
+namespace subgraph {
+    std::shared_ptr<ngraph::Function> RoundWithToleranceFunction::getOriginal(
+        const ngraph::element::Type precision,
+        const ngraph::Shape& inputShape,
+        const ngraph::builder::subgraph::DequantizationOperations dequantization) {
+        const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
+        input->set_friendly_name("input");
+
+        const auto deq = makeDequantization(input, dequantization);
+        deq->set_friendly_name("output");
+
+        const auto result = std::make_shared<ngraph::opset1::Result>(deq);
+        result->set_friendly_name("result");
+
+        return std::make_shared<ngraph::Function>(
+            ngraph::ResultVector{ result },
+            ngraph::ParameterVector{ input },
+            "RoundWithToleranceFunction");
+    }
+
+    std::shared_ptr<ngraph::Function> RoundWithToleranceFunction::getReference(
+        const ngraph::element::Type precision,
+        const ngraph::Shape& inputShape,
+        const ngraph::builder::subgraph::DequantizationOperations dequantization) {
+        const auto input = std::make_shared<ngraph::op::v0::Parameter>(precision, inputShape);
+        input->set_friendly_name("input");
+
+        const auto deq = makeDequantization(input, dequantization);
+        deq->set_friendly_name("output");
+
+        const auto result = std::make_shared<ngraph::opset1::Result>(deq);
+        result->set_friendly_name("result");
+
+        return std::make_shared<ngraph::Function>(
+            ngraph::ResultVector{ result },
+            ngraph::ParameterVector{ input },
+            "RoundWithToleranceFunction");
+    }
+
+}  // namespace subgraph
+}  // namespace builder
+}  // namespace ngraph