ConvertPrecision for element::boolean (#1772)
authorGleb Kazantaev <gleb.kazantaev@intel.com>
Tue, 18 Aug 2020 08:10:24 +0000 (11:10 +0300)
committerGitHub <noreply@github.com>
Tue, 18 Aug 2020 08:10:24 +0000 (11:10 +0300)
* Added bool to u8 conversion

* Added opset1::ShapeOf handler

* Added ReduceLogicalAnd/Or support in ConvertPrecision pass

* Moved static map inside function; Updated callbacks

* Removed header

* Fixed tyle relaxed for cases when the same output consumes by multiple inputs in the same operation; added tests; fixed input types setting for already created type relaxed operations

32 files changed:
inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp
inference-engine/src/transformations/include/ngraph_ops/type_relaxed.hpp
inference-engine/src/transformations/include/transformations/convert_precision.hpp
inference-engine/src/transformations/src/transformations/convert_precision.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_precision.cpp
inference-engine/tests/functional/inference_engine/transformations/type_relaxed_tests.cpp
ngraph/core/include/ngraph/op/equal.hpp
ngraph/core/include/ngraph/op/greater.hpp
ngraph/core/include/ngraph/op/greater_eq.hpp
ngraph/core/include/ngraph/op/less.hpp
ngraph/core/include/ngraph/op/less_eq.hpp
ngraph/core/include/ngraph/op/not.hpp
ngraph/core/include/ngraph/op/not_equal.hpp
ngraph/core/include/ngraph/op/or.hpp
ngraph/core/include/ngraph/op/reduce_logical_and.hpp
ngraph/core/include/ngraph/op/reduce_logical_or.hpp
ngraph/core/include/ngraph/op/select.hpp
ngraph/core/include/ngraph/op/shape_of.hpp
ngraph/core/include/ngraph/op/xor.hpp
ngraph/core/src/op/equal.cpp
ngraph/core/src/op/greater.cpp
ngraph/core/src/op/greater_eq.cpp
ngraph/core/src/op/less.cpp
ngraph/core/src/op/less_eq.cpp
ngraph/core/src/op/not.cpp
ngraph/core/src/op/not_equal.cpp
ngraph/core/src/op/or.cpp
ngraph/core/src/op/reduce_logical_and.cpp
ngraph/core/src/op/reduce_logical_or.cpp
ngraph/core/src/op/select.cpp
ngraph/core/src/op/shape_of.cpp
ngraph/core/src/op/xor.cpp

index 95f3747..b9bb493 100644 (file)
@@ -99,6 +99,7 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork) {
             {ngraph::element::u16, ngraph::element::i32},
             {ngraph::element::u32, ngraph::element::i32},
             {ngraph::element::f16, ngraph::element::f32},
+            {ngraph::element::boolean, ngraph::element::u8},
     };
 
     for (auto & precision : convert_precision_list) {
@@ -118,10 +119,6 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork) {
 
     clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
 
-    // WA: ngraph::pass:ConvertPrecision doesn't support BOOL to U8 conversion
-    // so we temporary have to call CNNNetwork ConvertPrecision transformation
-    NetPass::ConvertPrecision(*clonedNetwork, Precision::BOOL, Precision::U8);
-
     // WA: after conversion to CNNNetwork user precision can redefine input/output precisions
     // so we need to apply additional precision conversion but only for inputs and outputs
     for (auto & precision : convert_precision_list) {
index c8dff13..d0cf548 100644 (file)
@@ -169,10 +169,14 @@ private:
 
 template <typename BaseOp>
 void TypeRelaxed<BaseOp>::validate_and_infer_types() {
-    // Remember all input data types and reset them to m_output_data_type.
+    // Remember all input data types
     element::TypeVector old_input_types;
     for (size_t i = 0; i < BaseOp::get_input_size(); ++i) {
         old_input_types.push_back(BaseOp::get_input_element_type(i));
+    }
+
+    // Reset input data types to m_output_data_type.
+    for (size_t i = 0; i < BaseOp::get_input_size(); ++i) {
         auto origin_input_type = get_origin_input_type(i);
         if (origin_input_type != element::undefined) {
             BaseOp::get_input_tensor(i).set_tensor_type(origin_input_type, BaseOp::get_input_partial_shape(i));
index 706710c..8d3c9d5 100644 (file)
@@ -36,11 +36,15 @@ class TRANSFORMATIONS_API ConvertPrecision;
  *     u64 -> i32
  *     i64 -> i32
  *     f16 -> f32
+ *    bool -> u8
+ *    bool -> i32
+ *
  * For all operations from opset1-opset4 this conversions can be applied without adding Conversion operations.
  * That is possible because all operations that produces "FROM" type can produce "TO" type. And for this operations
  * we have created special fuse_type_into_<type> functoin (can be found in cpp file) that performs type fusion
  * into operation.
- * List of operations that are supported by this transformations:
+ *
+ * List of operations that are supported by this transformations for i64 -> i32 conversion:
  *     opset4::Parameter
  *     opset4::Convert
  *     opset4::ShapeOf
@@ -49,6 +53,20 @@ class TRANSFORMATIONS_API ConvertPrecision;
  *     opset4::TopK
  *     opset4::NonZero
  *     opset4::Bucketize
+ *
+ * List of operations that are supported by this transformations for bool -> u8 conversion:
+ *     LogicalAnd
+ *     LogicalNot
+ *     LogicalOr
+ *     LogicalXor
+ *     ReduceLogicalAnd
+ *     ReduceLogicalOr
+ *     Equal
+ *     NotEqual
+ *     Greater
+ *     GreaterEqual
+ *     Less
+ *     LessEqual
  */
 
 class ngraph::pass::ConvertPrecision : public ngraph::pass::FunctionPass {
index ffd9956..250da61 100644 (file)
 #include <ngraph/opsets/opset4.hpp>
 #include <ngraph/opsets/opset3.hpp>
 #include <ngraph/opsets/opset1.hpp>
+#include <ngraph_ops/type_relaxed.hpp>
 
 using namespace ngraph;
 
 bool fuse_type_to_constant(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, const std::vector<ngraph::Input<ngraph::Node>> & consumers);
 bool fuse_type_to_shapeof(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
+bool fuse_type_to_shapeof_v0(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
 bool fuse_type_to_parameter(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
 bool fuse_type_to_convert(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
 bool fuse_type_to_nms3(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
@@ -24,7 +26,54 @@ bool fuse_type_to_nonzero(std::shared_ptr<ngraph::Node> & node, ngraph::element:
 bool fuse_type_to_bucketize(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
 bool fuse_type_to_generic_ie(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
 
-static std::map<ngraph::NodeTypeInfo, std::function<bool(std::shared_ptr<Node>&, element::Type, size_t idx)>> type_to_fuse {
+bool extend_select_type(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
+
+template <typename T>
+bool fuse_type_to_binary_comparision(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
+    if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
+        type_relaxed->set_overridden_output_type(to);
+        return true;
+    } else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
+        auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<T>>(*casted, element::TypeVector{}, element::TypeVector{to});
+        replace_node(node, relaxed_op);
+        return true;
+    }
+    return false;
+}
+
+template <typename T>
+bool fuse_type_to_logical(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
+    if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
+        type_relaxed->set_overridden_output_type(to);
+        type_relaxed->set_origin_input_type(element::boolean, 0);
+        type_relaxed->set_origin_input_type(element::boolean, 1);
+        return true;
+    } else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
+        auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<T>>(*casted,
+                element::TypeVector{element::boolean, element::boolean}, element::TypeVector{to});
+        replace_node(node, relaxed_op);
+        return true;
+    }
+    return false;
+}
+
+template <class T>
+bool fuse_type_to_reduce_logical(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
+    if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
+        type_relaxed->set_overridden_output_type(to);
+        type_relaxed->set_origin_input_type(element::boolean, 0);
+        return true;
+    } else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
+        auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<T>>(*casted,
+                element::TypeVector{element::boolean}, element::TypeVector{to});
+        replace_node(node, relaxed_op);
+        return true;
+    }
+    return false;
+}
+
+bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Function> f) {
+    static std::map<ngraph::NodeTypeInfo, std::function<bool(std::shared_ptr<Node>&, element::Type, size_t idx)>> type_to_fuse {
         {opset4::Parameter::type_info, fuse_type_to_parameter},
         {opset4::Convert::type_info, fuse_type_to_convert},
         {opset4::ShapeOf::type_info, fuse_type_to_shapeof},
@@ -34,9 +83,25 @@ static std::map<ngraph::NodeTypeInfo, std::function<bool(std::shared_ptr<Node>&,
         {opset4::NonZero::type_info, fuse_type_to_nonzero},
         {opset4::Bucketize::type_info, fuse_type_to_bucketize},
         {NodeTypeInfo("GenericIE", 1), fuse_type_to_generic_ie},
-};
+        {opset4::Equal::type_info, fuse_type_to_binary_comparision<opset4::Equal>},
+        {opset4::NotEqual::type_info, fuse_type_to_binary_comparision<opset4::NotEqual>},
+        {opset4::Greater::type_info, fuse_type_to_binary_comparision<opset4::Greater>},
+        {opset4::GreaterEqual::type_info, fuse_type_to_binary_comparision<opset4::GreaterEqual>},
+        {opset4::Less::type_info, fuse_type_to_binary_comparision<opset4::Less>},
+        {opset4::LessEqual::type_info, fuse_type_to_binary_comparision<opset4::LessEqual>},
+        {opset4::LogicalAnd::type_info, fuse_type_to_logical<opset4::LogicalAnd>},
+        {opset4::LogicalOr::type_info, fuse_type_to_logical<opset4::LogicalOr>},
+        {opset4::LogicalXor::type_info, fuse_type_to_logical<opset4::LogicalXor>},
+        {opset4::LogicalNot::type_info, fuse_type_to_logical<opset4::LogicalNot>},
+        {opset4::ReduceLogicalAnd::type_info, fuse_type_to_reduce_logical<opset4::ReduceLogicalAnd>},
+        {opset4::ReduceLogicalOr::type_info, fuse_type_to_reduce_logical<opset4::ReduceLogicalOr>},
+        {opset1::ShapeOf::type_info, fuse_type_to_shapeof_v0}
+    };
+
+    static std::map<ngraph::NodeTypeInfo, std::function<bool(std::shared_ptr<Node>&, element::Type, size_t idx)>> type_to_extend {
+            {opset4::Select::type_info, extend_select_type},
+    };
 
-bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Function> f) {
     // As Constant operations can be shared between multiple nGraph Functions so before
     // changing precision we need to understand which Constant consumers belongs
     // to the current nGraph Function
@@ -45,10 +110,6 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Fun
     std::function<void(const std::shared_ptr<Function> &)> register_constants =
             [&const_to_internal_output, &register_constants](const std::shared_ptr<Function> & f) {
         for (auto & node : f->get_ordered_ops()) {
-            // Recursively run for TensorIterator body function
-            if (auto ti = std::dynamic_pointer_cast<opset4::TensorIterator>(node)) {
-                register_constants(ti->get_body()->to_function());
-            }
             for (auto & input : node->inputs()) {
                 if (auto const_node = std::dynamic_pointer_cast<opset4::Constant>(input.get_source_output().get_node_shared_ptr())) {
                     const_to_internal_output[const_node].emplace_back(input);
@@ -57,39 +118,42 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Fun
         }
     };
 
-    register_constants(f);
-
-    auto convert_node_precision = [this, &const_to_internal_output](std::shared_ptr<Node> & node) {
-        // As input type could changed we need to propagate output type calculation manually
-        node->validate_and_infer_types();
-
+    auto convert_node_output_precision = [this, &const_to_internal_output](std::shared_ptr<Node> & node) {
         for (auto output : node->outputs()) {
             if (output.get_element_type() == m_from) {
                 // Handle case with Constants as they can have consumers from other nGraph Function object
                 if (ngraph::op::is_constant(node) && const_to_internal_output.count(node)) {
                     fuse_type_to_constant(node, m_to, const_to_internal_output.at(node));
-                    continue;
+                    break;
                 }
 
                 // If node type in map and convert can be fused into node we skip Convert creation
                 if (type_to_fuse.count(node->get_type_info()) &&
                     type_to_fuse.at(node->get_type_info())(node, m_to, output.get_index())) {
-                    node->validate_and_infer_types();
-                    continue;
+                    break;
                 }
+            }
+        }
+    };
 
-                // Create Convert operation and reconnect consumers
-                auto consumers = output.get_target_inputs();
-                auto convert = std::make_shared<opset4::Convert>(output, m_to);
-                for (auto & input : consumers) {
-                    input.replace_source_output(convert);
+    auto convert_node_input_precision = [this](std::shared_ptr<Node> & node) {
+        for (auto input : node->inputs()) {
+            if (input.get_element_type() == m_from) {
+                // For some operations we need to extend their input types to support new type
+                if (type_to_extend.count(node->get_type_info()) &&
+                    type_to_extend.at(node->get_type_info())(node, m_to, input.get_index())) {
+                    break;
                 }
             }
         }
     };
 
     std::function<void(const std::shared_ptr<Function> &)> convert_function_precision =
-            [this, &const_to_internal_output, &convert_node_precision, &convert_function_precision](const std::shared_ptr<Function> & f) {
+            [this, &const_to_internal_output,
+                   &register_constants,
+                   &convert_node_output_precision,
+                   &convert_node_input_precision,
+                   &convert_function_precision] (const std::shared_ptr<Function> & f) {
         // Iterate over all nodes in topological order and then iterate over node outputs.
         // If output type mismatch given type we try to fuse type into this operation
         // otherwise we insert Convert operation.
@@ -98,11 +162,18 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Fun
             if (auto ti = std::dynamic_pointer_cast<opset4::TensorIterator>(node)) {
                 convert_function_precision(ti->get_body()->to_function());
             }
-            convert_node_precision(node);
+            convert_node_input_precision(node);
+        }
+        // Register internal constants only after fixing input type that could lead to nodes replacement
+        register_constants(f);
+
+        for (auto &node : f->get_ordered_ops()) {
+            convert_node_output_precision(node);
         }
     };
 
     convert_function_precision(f);
+    f->validate_nodes_and_infer_types();
 
     // TODO: we need to split NopElimination pass to separate MatcherPasses and call Convert elimination here
     for (auto &node : f->get_ordered_ops()) {
@@ -128,6 +199,7 @@ bool fuse_type_to_shapeof(std::shared_ptr<Node> & node, element::Type to, size_t
 bool fuse_type_to_parameter(std::shared_ptr<Node> & node, element::Type to, size_t idx) {
     if (auto param = as_type_ptr<opset4::Parameter>(node)) {
         param->set_element_type(to);
+        param->validate_and_infer_types();
         return true;
     }
     return false;
@@ -192,6 +264,33 @@ bool fuse_type_to_generic_ie(std::shared_ptr<ngraph::Node> & node, ngraph::eleme
     return true;
 }
 
+bool fuse_type_to_shapeof_v0(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
+    if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
+        type_relaxed->set_overridden_output_type(to);
+        return true;
+    } else if (auto casted = std::dynamic_pointer_cast<opset1::ShapeOf>(node)) {
+        auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<opset1::ShapeOf>>(*casted,
+                element::TypeVector{}, element::TypeVector{to});
+        replace_node(node, relaxed_op);
+        return true;
+    }
+    return false;
+}
+
+bool extend_select_type(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
+    if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
+        type_relaxed->set_origin_input_type(element::boolean, 0);
+        return true;
+    } else if (auto casted = std::dynamic_pointer_cast<opset4::Select>(node)) {
+        auto relaxed_op = std::make_shared<op::TypeRelaxed<opset4::Select>>(*casted,
+                element::TypeVector{element::boolean},
+                element::TypeVector{});
+        replace_node(node, relaxed_op);
+        return true;
+    }
+    return false;
+}
+
 template <element::Type_t PREC_FROM, element::Type_t PREC_TO>
 std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::Constant> & constant) {
     using src_type = typename element_type_traits<PREC_FROM>::value_type;
index 45d3aa3..418f460 100644 (file)
@@ -16,6 +16,7 @@
 #include <transformations/convert_precision.hpp>
 #include <transformations/utils/utils.hpp>
 #include <ngraph/pass/manager.hpp>
+#include <ngraph_ops/type_relaxed.hpp>
 
 #include "common_test_utils/ngraph_test_utils.hpp"
 
@@ -289,6 +290,255 @@ TEST(TransformationTests, ConvertPrecision_TIBody) {
     }
 }
 
+TEST(TransformationTests, ConvertPrecision_Equal) {
+    std::shared_ptr<Function> f(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+        auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+        auto node = std::make_shared<ngraph::opset4::Equal>(input1, input2);
+
+        f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+        pass::Manager manager;
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
+        manager.run_passes(f);
+    }
+
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+    ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_NotEqual) {
+    std::shared_ptr<Function> f(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+        auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+        auto node = std::make_shared<ngraph::opset4::NotEqual>(input1, input2);
+
+        f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+        pass::Manager manager;
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
+        manager.run_passes(f);
+    }
+
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+    ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_Greater) {
+    std::shared_ptr<Function> f(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+        auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+        auto node = std::make_shared<ngraph::opset4::Greater>(input1, input2);
+
+        f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+        pass::Manager manager;
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
+        manager.run_passes(f);
+    }
+
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+    ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_GreaterEqual) {
+    std::shared_ptr<Function> f(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+        auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+        auto node = std::make_shared<ngraph::opset4::GreaterEqual>(input1, input2);
+
+        f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+        pass::Manager manager;
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
+        manager.run_passes(f);
+    }
+
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+    ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_Less) {
+    std::shared_ptr<Function> f(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+        auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+        auto node = std::make_shared<ngraph::opset4::Less>(input1, input2);
+
+        f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+        pass::Manager manager;
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
+        manager.run_passes(f);
+    }
+
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+    ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_LessEqual) {
+    std::shared_ptr<Function> f(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+        auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+        auto node = std::make_shared<ngraph::opset4::LessEqual>(input1, input2);
+
+        f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+        pass::Manager manager;
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
+        manager.run_passes(f);
+    }
+
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+    ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_LogicalAnd) {
+    std::shared_ptr<Function> f(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+        auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+        auto node = std::make_shared<ngraph::opset4::LogicalAnd>(input1, input2);
+
+        f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+        pass::Manager manager;
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+        manager.run_passes(f);
+    }
+
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+    ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_LogicalOr) {
+    std::shared_ptr<Function> f(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+        auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+        auto node = std::make_shared<ngraph::opset4::LogicalOr>(input1, input2);
+
+        f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+        pass::Manager manager;
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+        manager.run_passes(f);
+    }
+
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+    ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_LogicalXor) {
+    std::shared_ptr<Function> f(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+        auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+        auto node = std::make_shared<ngraph::opset4::LogicalXor>(input1, input2);
+
+        f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+        pass::Manager manager;
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+        manager.run_passes(f);
+    }
+
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+    ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_LogicalNot) {
+    std::shared_ptr<Function> f(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+        auto node = std::make_shared<ngraph::opset4::LogicalNot>(input1);
+
+        f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1});
+
+        pass::Manager manager;
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+        manager.run_passes(f);
+    }
+
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+    ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_Select) {
+    std::shared_ptr<Function> f(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+        auto node = std::make_shared<ngraph::opset4::LogicalNot>(input1);
+        auto select = std::make_shared<ngraph::opset4::Select>(node, input1, input1);
+
+        f = std::make_shared<Function>(OutputVector{select}, ParameterVector{input1});
+
+        pass::Manager manager;
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+        manager.run_passes(f);
+    }
+
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+    ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_TypeRelaxedWithSelect) {
+    std::shared_ptr<Function> f(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+        auto node = std::make_shared<ngraph::opset4::LogicalNot>(input1);
+        auto select = std::make_shared<ngraph::opset4::Select>(node, input1, input1);
+
+        f = std::make_shared<Function>(OutputVector{select}, ParameterVector{input1});
+
+        pass::Manager manager;
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::i32);
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i32, ngraph::element::i64);
+        manager.run_passes(f);
+    }
+
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+    ASSERT_FALSE(has_type<ngraph::element::Type_t::i32>(f));
+    ASSERT_TRUE(has_type<ngraph::element::Type_t::i64>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_TypeRelaxed) {
+    std::shared_ptr<Function> f(nullptr);
+    {
+        auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+        auto select = std::make_shared<ngraph::opset4::Select>(input1, input1, input1);
+        auto type_relaxed = std::make_shared<op::TypeRelaxed<opset4::Select>>(*select, element::TypeVector{}, element::TypeVector{element::i64});
+
+        f = std::make_shared<Function>(OutputVector{type_relaxed}, ParameterVector{input1});
+
+        pass::Manager manager;
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::i32);
+        manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i32, ngraph::element::i64);
+        manager.run_passes(f);
+
+        ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+        ASSERT_FALSE(has_type<ngraph::element::Type_t::i32>(f));
+        ASSERT_TRUE(has_type<ngraph::element::Type_t::i64>(f));
+    }
+}
+
 TEST(TransformationTests, ConvertPrecision_Variables) {
     std::shared_ptr<ngraph::Function> f(nullptr);
     {
index c93a2cc..e985a96 100644 (file)
@@ -271,3 +271,28 @@ TEST_F(TypeRelaxedTests, setGetTypes) {
 
     ASSERT_EQ(4, ngraph->get_ops().size());
 }
+
+TEST_F(TypeRelaxedTests, OneOutputMultipleInputPorts) {
+    std::shared_ptr<ngraph::Function> f;
+    {
+        auto param1 = make_shared<ngraph::opset1::Parameter>(element::boolean, ngraph::Shape{1, 3, 22, 22});
+        auto op = ngraph::opset1::Select(param1, param1, param1);
+        auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Select>>(
+                op, TypeVector{}, TypeVector{element::i64});
+
+        f = make_shared<ngraph::Function>(ngraph::OutputVector{relaxed_op}, ngraph::ParameterVector{param1});
+
+        // Prepare relaxed op for input change
+        relaxed_op->set_origin_input_type(element::boolean, 0);
+
+        // Change Parameter element type
+        param1->set_element_type(element::i64);
+        param1->validate_and_infer_types();
+        ASSERT_EQ(param1->output(0).get_element_type(), element::i64);
+
+        // Check that after restoring original precisions inside validate_and_infer_types
+        // function we do not corrupt original types
+        relaxed_op->validate_and_infer_types();
+        ASSERT_EQ(param1->output(0).get_element_type(), element::i64);
+    }
+}
\ No newline at end of file
index 6c0f827..d962ae4 100644 (file)
@@ -90,8 +90,7 @@ namespace ngraph
             class NGRAPH_API Equal : public util::BinaryElementwiseComparison
             {
             public:
-                static constexpr NodeTypeInfo type_info{"Equal", 1};
-                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                NGRAPH_RTTI_DECLARATION;
                 /// \brief Constructs an equal operation.
                 Equal()
                     : util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
index ac84a18..c81eebd 100644 (file)
@@ -57,8 +57,7 @@ namespace ngraph
             class NGRAPH_API Greater : public util::BinaryElementwiseComparison
             {
             public:
-                static constexpr NodeTypeInfo type_info{"Greater", 1};
-                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                NGRAPH_RTTI_DECLARATION;
                 /// \brief Constructs a greater-than operation.
                 Greater()
                     : util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
index 7eb3b05..6c52dcc 100644 (file)
@@ -57,8 +57,7 @@ namespace ngraph
             class NGRAPH_API GreaterEqual : public util::BinaryElementwiseComparison
             {
             public:
-                static constexpr NodeTypeInfo type_info{"GreaterEqual", 1};
-                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                NGRAPH_RTTI_DECLARATION;
                 /// \brief Constructs a greater-than-or-equal operation.
                 GreaterEqual()
                     : util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
index 88c5eff..6105c5a 100644 (file)
@@ -57,8 +57,7 @@ namespace ngraph
             class NGRAPH_API Less : public util::BinaryElementwiseComparison
             {
             public:
-                static constexpr NodeTypeInfo type_info{"Less", 1};
-                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                NGRAPH_RTTI_DECLARATION;
                 /// \brief Constructs a less-than operation.
                 Less()
                     : util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
index a63748e..94afc11 100644 (file)
@@ -28,8 +28,7 @@ namespace ngraph
             class NGRAPH_API LessEqual : public util::BinaryElementwiseComparison
             {
             public:
-                static constexpr NodeTypeInfo type_info{"LessEqual", 1};
-                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                NGRAPH_RTTI_DECLARATION;
                 /// \brief Constructs a less-than-or-equal operation.
                 LessEqual()
                     : util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
index 5c905c4..63f3283 100644 (file)
@@ -28,8 +28,7 @@ namespace ngraph
             class NGRAPH_API LogicalNot : public Op
             {
             public:
-                static constexpr NodeTypeInfo type_info{"LogicalNot", 1};
-                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                NGRAPH_RTTI_DECLARATION;
                 /// \brief Constructs a logical negation operation.
                 LogicalNot() = default;
                 /// \brief Constructs a logical negation operation.
index 03f29be..eab9120 100644 (file)
@@ -58,8 +58,7 @@ namespace ngraph
             class NGRAPH_API NotEqual : public util::BinaryElementwiseComparison
             {
             public:
-                static constexpr NodeTypeInfo type_info{"NotEqual", 1};
-                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                NGRAPH_RTTI_DECLARATION;
                 /// \brief Constructs a not-equal operation.
                 NotEqual()
                     : util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
index 4ed7f15..4f3abac 100644 (file)
@@ -31,8 +31,7 @@ namespace ngraph
             class NGRAPH_API LogicalOr : public util::BinaryElementwiseLogical
             {
             public:
-                static constexpr NodeTypeInfo type_info{"LogicalOr", 1};
-                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                NGRAPH_RTTI_DECLARATION;
                 LogicalOr() = default;
                 /// \brief Constructs a logical-or operation.
                 ///
index 3be104a..d853795 100644 (file)
@@ -31,8 +31,7 @@ namespace ngraph
             class NGRAPH_API ReduceLogicalAnd : public util::LogicalReductionKeepDims
             {
             public:
-                static constexpr NodeTypeInfo type_info{"ReduceLogicalAnd", 1};
-                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                NGRAPH_RTTI_DECLARATION;
                 ReduceLogicalAnd() = default;
                 /// \brief Constructs a ReduceLogicalAnd node.
                 ///
index 77ebf58..ee41e4c 100644 (file)
@@ -31,8 +31,7 @@ namespace ngraph
             class NGRAPH_API ReduceLogicalOr : public util::LogicalReductionKeepDims
             {
             public:
-                static constexpr NodeTypeInfo type_info{"ReduceLogicalOr", 1};
-                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                NGRAPH_RTTI_DECLARATION;
                 ReduceLogicalOr() = default;
                 /// \brief Constructs a ReduceLogicalOr node.
                 ///
index d5daf70..1faf15b 100644 (file)
@@ -86,8 +86,7 @@ namespace ngraph
             class NGRAPH_API Select : public Op
             {
             public:
-                static constexpr NodeTypeInfo type_info{"Select", 1};
-                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                NGRAPH_RTTI_DECLARATION;
                 /// \brief Constructs a selection operation.
                 Select()
                     : m_auto_broadcast(AutoBroadcastSpec(AutoBroadcastType::NUMPY))
index c005ca8..38aa6d3 100644 (file)
@@ -70,8 +70,7 @@ namespace ngraph
             class NGRAPH_API ShapeOf : public Op
             {
             public:
-                static constexpr NodeTypeInfo type_info{"ShapeOf", 0};
-                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                NGRAPH_RTTI_DECLARATION;
                 ShapeOf() = default;
                 /// \brief Constructs a shape-of operation.
                 ShapeOf(const Output<Node>& arg);
index cff16e1..5c69d69 100644 (file)
@@ -31,8 +31,7 @@ namespace ngraph
             class NGRAPH_API LogicalXor : public util::BinaryElementwiseLogical
             {
             public:
-                static constexpr NodeTypeInfo type_info{"LogicalXor", 1};
-                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                NGRAPH_RTTI_DECLARATION;
                 LogicalXor() = default;
                 /// \brief Constructs a logical-xor operation.
                 ///
index 0b5ae6d..a6c3c88 100644 (file)
@@ -94,7 +94,7 @@ bool op::v0::Equal::evaluate(const HostTensorVector& outputs, const HostTensorVe
 
 //------------------------------- v1 -------------------------------------------
 
-constexpr NodeTypeInfo op::v1::Equal::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::Equal, "Equal", 1);
 
 op::v1::Equal::Equal(const Output<Node>& arg0,
                      const Output<Node>& arg1,
index a74cfc8..cb35f8d 100644 (file)
@@ -95,7 +95,7 @@ bool op::v0::Greater::evaluate(const HostTensorVector& outputs,
 
 //-------------------------------------- v1 ------------------------------------
 
-constexpr NodeTypeInfo op::v1::Greater::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::Greater, "Greater", 1);
 
 op::v1::Greater::Greater(const Output<Node>& arg0,
                          const Output<Node>& arg1,
index 0be7702..398f82d 100644 (file)
@@ -95,7 +95,7 @@ bool op::v0::GreaterEq::evaluate(const HostTensorVector& outputs,
 
 //---------------------------------- v1 ----------------------------------------
 
-constexpr NodeTypeInfo op::v1::GreaterEqual::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::GreaterEqual, "GreaterEqual", 1);
 
 op::v1::GreaterEqual::GreaterEqual(const Output<Node>& arg0,
                                    const Output<Node>& arg1,
index 2c217e7..01bb4c4 100644 (file)
@@ -94,7 +94,7 @@ bool op::v0::Less::evaluate(const HostTensorVector& outputs, const HostTensorVec
 
 // ----------------------------- v1 --------------------------------------------
 
-constexpr NodeTypeInfo op::v1::Less::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::Less, "Less", 1);
 
 op::v1::Less::Less(const Output<Node>& arg0,
                    const Output<Node>& arg1,
index e580d4c..a03d637 100644 (file)
@@ -24,7 +24,7 @@ using namespace ngraph;
 
 // ---------------------------------- v1 ---------------------------------------
 
-constexpr NodeTypeInfo op::v1::LessEqual::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::LessEqual, "LessEqual", 1);
 
 op::v1::LessEqual::LessEqual(const Output<Node>& arg0,
                              const Output<Node>& arg1,
index f720fc9..e43eedc 100644 (file)
@@ -26,7 +26,7 @@
 using namespace ngraph;
 using namespace std;
 
-constexpr NodeTypeInfo op::v1::LogicalNot::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::LogicalNot, "LogicalNot", 1);
 
 op::v1::LogicalNot::LogicalNot(const Output<Node>& arg)
     : Op({arg})
index 52928db..21da5c5 100644 (file)
@@ -95,7 +95,7 @@ bool op::v0::NotEqual::evaluate(const HostTensorVector& outputs,
 
 // ----------------------------------- v1 --------------------------------------
 
-constexpr NodeTypeInfo op::v1::NotEqual::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::NotEqual, "NotEqual", 1);
 
 op::v1::NotEqual::NotEqual(const Output<Node>& arg0,
                            const Output<Node>& arg1,
index f535c14..8bc3d97 100644 (file)
@@ -22,7 +22,7 @@
 using namespace std;
 using namespace ngraph;
 
-constexpr NodeTypeInfo op::v1::LogicalOr::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::LogicalOr, "LogicalOr", 1, util::BinaryElementwiseLogical);
 
 op::v1::LogicalOr::LogicalOr(const Output<Node>& arg0,
                              const Output<Node>& arg1,
index 47bfae3..6b8dd53 100644 (file)
@@ -24,7 +24,7 @@
 using namespace ngraph;
 using namespace std;
 
-constexpr NodeTypeInfo op::v1::ReduceLogicalAnd::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalAnd, "ReduceLogicalAnd", 1);
 
 op::v1::ReduceLogicalAnd::ReduceLogicalAnd(const Output<Node>& data,
                                            const Output<Node>& reduction_axes,
index cb054b9..33d85c0 100644 (file)
@@ -24,7 +24,7 @@
 using namespace ngraph;
 using namespace std;
 
-constexpr NodeTypeInfo op::v1::ReduceLogicalOr::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalOr, "ReduceLogicalOr", 1);
 
 op::v1::ReduceLogicalOr::ReduceLogicalOr(const Output<Node>& data,
                                          const Output<Node>& reduction_axes,
index 2cdd22d..45cd528 100644 (file)
@@ -26,7 +26,7 @@
 using namespace std;
 using namespace ngraph;
 
-constexpr NodeTypeInfo op::v1::Select::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::Select, "Select", 1);
 
 op::v1::Select::Select(const Output<Node>& arg0,
                        const Output<Node>& arg1,
index 4fb92d5..92e3401 100644 (file)
@@ -170,7 +170,7 @@ bool op::v3::ShapeOf::constant_fold(OutputVector& output_values, const OutputVec
 }
 
 // op::v0::ShapeOf
-constexpr NodeTypeInfo op::v0::ShapeOf::type_info;
+NGRAPH_RTTI_DEFINITION(op::v0::ShapeOf, "ShapeOf", 0);
 
 op::v0::ShapeOf::ShapeOf(const Output<Node>& arg)
     : Op({arg})
index 881e2ef..a34cd35 100644 (file)
@@ -22,7 +22,7 @@
 using namespace std;
 using namespace ngraph;
 
-constexpr NodeTypeInfo op::v1::LogicalXor::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::LogicalXor, "LogicalXor", 1, util::BinaryElementwiseLogical);
 
 op::v1::LogicalXor::LogicalXor(const Output<Node>& arg0,
                                const Output<Node>& arg1,