TypeRelaxed implementation (#1561)
authorSergey Lyalin <sergey.lyalin@intel.com>
Thu, 13 Aug 2020 15:45:37 +0000 (18:45 +0300)
committerGitHub <noreply@github.com>
Thu, 13 Aug 2020 15:45:37 +0000 (18:45 +0300)
* RTTI base for ngraph::Node; cherry-pick from another branch, draft

* Added comments, moved code, switched to custom RTTI-based version of is_type

* Move rtti definitions in ngraph op class to the beginning of each class definition as a preparation for the next replacement

* Migrate part of operations to new RTTI

* Migrate GroupConvolution and Concat to new RTTI

* Apply code style for ngraph part

* Rename RTTI_DECLARATION/DEFINITION to NGRAPH_RTTI_DECLARATION/DEFINITION

* Reverted accidentally updated version of mkldnn

* TMP: rewrite RTTI back to constexprions as an attempt to fix static objects initialization order issue

* Apply ngraph code style

* Finalize move back to constexpr for RTTI

* Applied code-style

* TypeRelaxed template class implementation and necessary changes in ngraph + tests.

* Applied code-style

* Fix in fast algorithm in GraphRewrite, add new tests for this and other cases

* Make parent optional parameter for NGRAPH_RTTI_DECLARATION and remove Node::type_info; remove ability to have Node as a parent for type_info

* Try to resolve compilation error on Windows

* The next attempt to fix Windows build: re-introduce get_type_info_static

* Removed file that was removed in master and kept in this branch by mistake

* Next attempt to fix Windows build: externConstexpr

* Attempt to fix win build: extra public (suspect icc bug), remove get_type_info_static as useless.

* Next attempt to fix Windows: proxy const and constexpr

* Fixed constexpr

* Next attmpts: move get_type_info to cpp file

* Code stype fix

* Re-implemented RTTI without use of constexpr; run-time initialization is used; removed global definitions to avoid issues with order of static objects initialization

* Removed externConstexpr flag and removed TRANSFOMRATIONS_API for TypeRelaxed

* get_type_info_static initializes static local constant with type_info that is used for CLASS::type_info and CLASS::get_type_info

* Removed not needed debug output and useless comments

* Implemented better copy ctor for Node

* Fixed VisualizeTree issue for TypeRelaxed: stopped using < and > in type_info::name

* Better comments and names for methods

* Remove unused include

* Remove commented line

* Workaround for legacy conversion that uses Node::get_type_info().name as a type for the resulting CNNLayer leading to incorrect types for TypeRelaxed-based operations and then to fail in plugins

* Fixed typos, explicit ctor for TypeRelaxedBase, explanation for the need of get_overridden_output_type

* Fix typo

* Fixed issue with non-static name in type definition for TypeRelaxed and fixed WrapType to make it compatible with hierarchical relations between types

* Reverted default ctor for Output and reverted ability to reduce number of outputs for a Node; syntactically better debug message for a Node

* Cover methods of TypeRelaxedBase by tests

* Apply code-style

inference-engine/src/transformations/include/ngraph_ops/type_relaxed.hpp [new file with mode: 0644]
inference-engine/src/transformations/src/ngraph_ops/type_relaxed.cpp [new file with mode: 0644]
inference-engine/tests/functional/inference_engine/transformations/type_relaxed_tests.cpp [new file with mode: 0644]
ngraph/core/include/ngraph/node.hpp
ngraph/core/include/ngraph/type/element_type.hpp
ngraph/core/src/node.cpp
ngraph/core/src/pattern/op/wrap_type.cpp
ngraph/core/src/type/element_type.cpp

diff --git a/inference-engine/src/transformations/include/ngraph_ops/type_relaxed.hpp b/inference-engine/src/transformations/include/ngraph_ops/type_relaxed.hpp
new file mode 100644 (file)
index 0000000..c8dff13
--- /dev/null
@@ -0,0 +1,229 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+#include <vector>
+#include <algorithm>
+#include <string>
+
+#include <transformations_visibility.hpp>
+
+#include "ngraph/op/op.hpp"
+
+namespace ngraph {
+namespace op {
+
+/// A base class for templated TypeRelaxed that maintains overridden input types and output types for an operation.
+class TRANSFORMATIONS_API TypeRelaxedBase {
+public:
+    virtual ~TypeRelaxedBase();
+
+    explicit TypeRelaxedBase(
+            const element::TypeVector& _input_data_types = {},
+            const element::TypeVector& _output_data_types = {}) :
+            m_input_data_types(_input_data_types),
+            m_output_data_types(_output_data_types) {
+    }
+
+    /// \return Data type that will be set for output with a given index outputIndex.
+    /// If output with a specified index outputIndex hasn't been set before, element::undefined will returned.
+    /// Undefined means no type override happens for a given outputIndex and it will deduced as original
+    /// operation defineds in its infer function.
+    ///
+    /// This method may look similar to Node::get_output_element_type, but it is not the same thing, because
+    /// get_output_element_type returns the result of type inference, so it is completely deduced from
+    /// an operation inputs and attributes, and get_overridden_output_type returns value of the attribute that
+    /// is used to deduce output type. In some cases they don't match: get_overridden_output_type may return
+    /// element::undefined for some index i, and get_output_element_type will return some real type for
+    /// the same index i.
+    const element::Type& get_overridden_output_type(size_t outputIndex = 0) const {
+        if (outputIndex >= m_output_data_types.size()) {
+            return element::undefined;
+        }
+        return m_output_data_types[outputIndex];
+    }
+
+    /// Set data type that overrides the original data type for output port with outputIndex index
+    /// In case if outputIndex is out of range of known outputs (and this class cannot detect
+    /// the real number of outputs for original operation), the number of overridden outputs
+    /// is changed according to a given outputIndex value.
+    void set_overridden_output_type(const element::Type& element_type, size_t outputIndex = 0) {
+        if (outputIndex >= m_output_data_types.size()) {
+            m_output_data_types.resize(outputIndex + 1, element::undefined);
+        }
+        m_output_data_types[outputIndex] = element_type;
+    }
+
+    /// \return Data type that will be set for input when original shape/type inference function is called.
+    /// If index inputIndex hasn't been set before, element::undefined will returned. Undefined means that
+    /// the type from input tensor descriptor is used for a given index.
+    const element::Type& get_origin_input_type(size_t inputIndex = 0) const {
+        if (inputIndex >= m_input_data_types.size()) {
+            return element::undefined;
+        }
+        return m_input_data_types[inputIndex];
+    }
+
+    /// Set data type that overrides the original data type for input port with inputIndex index.
+    /// In case if inputIndex is out of range of known inputs (and this class cannot detect
+    /// the real number of inputs for original operation), the number of overridden inputs
+    /// is changed according to a given inputIndex value. All new entries except one added
+    /// at inputIndex position are undefined.
+    void set_origin_input_type(const element::Type& element_type, size_t inputIndex = 0) {
+        if (inputIndex >= m_input_data_types.size()) {
+            m_input_data_types.resize(inputIndex + 1, element::undefined);
+        }
+        m_input_data_types[inputIndex] = element_type;
+    }
+
+protected:
+    // Data types that are used for parent shape/type infer function input ports
+    // to infer output data types
+    element::TypeVector m_input_data_types;
+    element::TypeVector m_output_data_types;
+};
+
+/// Set another type for a specified output for the period of time when an instance of the class exists.
+/// When the execution leaves the scope where an onject of TemporaryReplaceOutputType is defined,
+/// the type of the output is set to its original value. Used when initialized TypeRelaxed<BaseOp> operation
+/// in case when inputs have types that are not compatible with BaseOp infer function. In this case
+/// before TypeRelaxed is constructed the BaseOp contructor requires modified data types.
+/// So it should be
+class TemporaryReplaceOutputType {
+    Output<Node> m_output;
+    element::Type orig_type;
+
+public:
+    /// Replace element type for a given output port by tmp_type
+    TemporaryReplaceOutputType(Output<Node> output, element::Type tmp_type) : m_output(output) {
+        // save original element type in order to restore it in the destructor
+        orig_type = m_output.get_element_type();
+        m_output.get_tensor().set_element_type(tmp_type);
+    }
+
+    /// Return the output port that was used in the constructor
+    Output<Node> get() const {
+        return m_output;
+    }
+
+    /// Restores the original element type for the output
+    ~TemporaryReplaceOutputType() {
+        m_output.get_tensor().set_element_type(orig_type);
+    }
+};
+
+/// Relaxes tensor element type requirements for BaseOp inputs and outputs
+/// This class template should be used with Node descendant class. Defines a new operation by extending the
+/// original BaseOp operation with ability to accept inputs and provide outputs with element type that is
+/// unusual for BaseOp. For example, TypeRelaxed<opset1::Add> can accept mixed-precision inputs and provide
+/// another type of output. New types are provided as inputs attributes for TypeRelaxed template and fixed.
+/// There is no any deduction logic for types are provided as a part of this class and it should be
+/// implemented outside if required.
+template <typename BaseOp>
+class TypeRelaxed : public BaseOp, public TypeRelaxedBase {
+public:
+    NGRAPH_RTTI_DECLARATION;
+
+    using BaseOp::BaseOp;
+
+    TypeRelaxed() = default;
+
+    TypeRelaxed(
+            const BaseOp& base_op,
+            element::Type overridden_type) :
+            TypeRelaxed(base_op,
+                    element::TypeVector(base_op.get_input_size(), overridden_type),
+                    element::TypeVector(base_op.get_output_size(), overridden_type)) {
+    }
+
+    explicit TypeRelaxed(
+            const BaseOp& base_op,
+            const element::TypeVector& _input_data_types = {},
+            const element::TypeVector& _output_data_types = {}) :
+            BaseOp(base_op), TypeRelaxedBase(_input_data_types, _output_data_types) {
+        init();
+    }
+
+    /// Creating a new TypeRelaxed operation by calling one of the original op ctors forwarding arguments directly.
+    template <typename ... Args>
+    TypeRelaxed(
+            const element::TypeVector& _input_data_types,
+            const element::TypeVector& _output_data_types,
+            Args&& ... args) :
+            BaseOp(std::forward<Args>(args)...), TypeRelaxedBase(_input_data_types, _output_data_types) {
+        init();
+    }
+
+    void validate_and_infer_types() override;
+
+    std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
+
+private:
+    void init() {
+        validate_and_infer_types();
+    }
+};
+
+template <typename BaseOp>
+void TypeRelaxed<BaseOp>::validate_and_infer_types() {
+    // Remember all input data types and reset them to m_output_data_type.
+    element::TypeVector old_input_types;
+    for (size_t i = 0; i < BaseOp::get_input_size(); ++i) {
+        old_input_types.push_back(BaseOp::get_input_element_type(i));
+        auto origin_input_type = get_origin_input_type(i);
+        if (origin_input_type != element::undefined) {
+            BaseOp::get_input_tensor(i).set_tensor_type(origin_input_type, BaseOp::get_input_partial_shape(i));
+        }
+    }
+
+    BaseOp::validate_and_infer_types();
+
+    // Restore original input data types
+    for (size_t i = 0; i < BaseOp::get_input_size(); ++i) {
+        BaseOp::get_input_tensor(i).set_tensor_type(old_input_types[i], BaseOp::get_input_partial_shape(i));
+    }
+
+    // Override (some) output types
+    for (size_t i = 0; i < BaseOp::get_output_size(); ++i) {
+        auto overridden_output_type = get_overridden_output_type(i);
+        if (overridden_output_type != element::undefined) {
+            BaseOp::set_output_type(0, overridden_output_type, BaseOp::get_output_partial_shape(i));
+        }
+    }
+}
+
+
+template <typename BaseOp>
+std::shared_ptr<Node> TypeRelaxed<BaseOp>::clone_with_new_inputs(const OutputVector& new_args) const {
+    // copy then modify inputs
+    std::shared_ptr<Node> new_node = std::make_shared<TypeRelaxed<BaseOp>>((BaseOp&)(*this), m_input_data_types, m_output_data_types);
+    for (size_t i = 0; i < new_node->get_input_size(); ++i) {
+        new_node->input(i).replace_source_output(new_args[i]);
+    }
+    return new_node;
+}
+
+template <typename BaseOp>
+const ::ngraph::Node::type_info_t& TypeRelaxed<BaseOp>::get_type_info() const { return get_type_info_static(); }
+
+template <typename BaseOp>
+const ::ngraph::Node::type_info_t& TypeRelaxed<BaseOp>::get_type_info_static() {
+    auto baseOpTypeInfoPtr = &BaseOp::get_type_info_static();
+
+    // TODO: it should be static const std::string name = std::string("TypeRelaxed_") + baseOpTypeInfoPtr->name;
+    //       but currently it will not pass conversion ot Legacy Opset correctly
+    static const std::string name = baseOpTypeInfoPtr->name;
+
+    static const ::ngraph::Node::type_info_t type_info_static{
+        name.c_str(), baseOpTypeInfoPtr->version, baseOpTypeInfoPtr};
+    return type_info_static;
+}
+
+template <typename BaseOp>
+const ::ngraph::Node::type_info_t TypeRelaxed<BaseOp>::type_info = TypeRelaxed<BaseOp>::get_type_info_static();
+
+}  // namespace op
+}  // namespace ngraph
diff --git a/inference-engine/src/transformations/src/ngraph_ops/type_relaxed.cpp b/inference-engine/src/transformations/src/ngraph_ops/type_relaxed.cpp
new file mode 100644 (file)
index 0000000..03dcfc6
--- /dev/null
@@ -0,0 +1,17 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <memory>
+#include <vector>
+#include <algorithm>
+
+#include "ngraph_ops/type_relaxed.hpp"
+
+namespace ngraph {
+namespace op {
+
+TypeRelaxedBase::~TypeRelaxedBase() {}
+
+}  // namespace op
+}  // namespace ngraph
diff --git a/inference-engine/tests/functional/inference_engine/transformations/type_relaxed_tests.cpp b/inference-engine/tests/functional/inference_engine/transformations/type_relaxed_tests.cpp
new file mode 100644 (file)
index 0000000..c93a2cc
--- /dev/null
@@ -0,0 +1,273 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+#include "common_test_utils/test_common.hpp"
+#include <ngraph/opsets/opset1.hpp>
+#include <ngraph_ops/type_relaxed.hpp>
+
+
+namespace element = ngraph::element;
+using std::make_shared;
+using TypeVector = element::TypeVector;
+
+using TypeRelaxedTests = CommonTestUtils::TestsCommon;
+
+TEST_F(TypeRelaxedTests, noOverrideCopyCtor) {
+    std::shared_ptr<ngraph::Function> ngraph;
+    {
+        ngraph::PartialShape shape({1, 3, 22, 22});
+        element::Type type(element::Type_t::f32);
+        auto param = make_shared<ngraph::opset1::Parameter>(type, shape);
+        auto op = ngraph::opset1::Relu(param);
+        auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Relu>>(op);
+        auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+        ngraph::ParameterVector params = {param};
+        ngraph::ResultVector results = {result};
+
+        ngraph = make_shared<ngraph::Function>(results, params);
+
+        ASSERT_EQ(element::f32, relaxed_op->get_input_element_type(0));
+        ASSERT_EQ(element::f32, relaxed_op->get_output_element_type(0));
+    }
+
+    ASSERT_EQ(3, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, overrideOutputCopyCtor) {
+    auto input_type = element::f32;
+    auto overriden_type = element::i32;
+    std::shared_ptr<ngraph::Function> ngraph;
+    {
+        ngraph::PartialShape shape({1, 3, 22, 22});
+        auto param = make_shared<ngraph::opset1::Parameter>(input_type, shape);
+        auto op = ngraph::opset1::Relu(param);
+        auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Relu>>(
+                op, TypeVector{}, TypeVector{overriden_type});
+        auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+        ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param});
+
+        ASSERT_EQ(input_type, relaxed_op->get_input_element_type(0));
+        ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0));
+    }
+
+    ASSERT_EQ(3, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, overrideInputCopyCtor) {
+    auto input_type = element::f32;
+    auto overriden_type = element::i32;
+    std::shared_ptr<ngraph::Function> ngraph;
+    {
+        ngraph::PartialShape shape({1, 3, 22, 22});
+        auto param = make_shared<ngraph::opset1::Parameter>(input_type, shape);
+        auto op = ngraph::opset1::Relu(param);
+        auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Relu>>(
+                op, TypeVector{overriden_type}, TypeVector{});
+        auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+        ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param});
+
+        ASSERT_EQ(input_type, relaxed_op->get_input_element_type(0));
+        ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0));
+    }
+
+    ASSERT_EQ(3, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, mixedInputsAutoOutput) {
+    auto input_type1 = element::u8;
+    auto input_type2 = element::i8;
+    auto overriden_type = element::i16;
+    std::shared_ptr<ngraph::Function> ngraph;
+    {
+        ngraph::PartialShape shape({1, 3, 22, 22});
+        auto param1 = make_shared<ngraph::opset1::Parameter>(input_type1, shape);
+        auto param2 = make_shared<ngraph::opset1::Parameter>(input_type2, shape);
+        auto op = ngraph::opset1::Add(
+                ngraph::op::TemporaryReplaceOutputType(param1->output(0), overriden_type).get(),
+                ngraph::op::TemporaryReplaceOutputType(param2->output(0), overriden_type).get());
+        auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Add>>(
+                op, TypeVector{overriden_type, overriden_type}, TypeVector{});
+        auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+        ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2});
+
+        ASSERT_EQ(input_type1, relaxed_op->get_input_element_type(0));
+        ASSERT_EQ(input_type2, relaxed_op->get_input_element_type(1));
+        ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0));
+    }
+
+    ASSERT_EQ(4, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, mixedInputsAutoOutputForwardCtor) {
+    auto input_type1 = element::u8;
+    auto input_type2 = element::i8;
+    auto overriden_type = element::i16;
+    std::shared_ptr<ngraph::Function> ngraph;
+    {
+        ngraph::PartialShape shape({1, 3, 22, 22});
+        auto param1 = make_shared<ngraph::opset1::Parameter>(input_type1, shape);
+        auto param2 = make_shared<ngraph::opset1::Parameter>(input_type2, shape);
+        auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Add>>(
+                TypeVector{overriden_type, overriden_type}, TypeVector{},
+                ngraph::op::TemporaryReplaceOutputType(param1, overriden_type).get(),
+                ngraph::op::TemporaryReplaceOutputType(param2, overriden_type).get());
+        auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+        ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2});
+
+        ASSERT_EQ(input_type1, relaxed_op->get_input_element_type(0));
+        ASSERT_EQ(input_type2, relaxed_op->get_input_element_type(1));
+        ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0));
+    }
+
+    ASSERT_EQ(4, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, notSupportedTypeOverride) {
+    auto overriden_type = element::u8;
+    auto orig_type = element::boolean;
+    std::shared_ptr<ngraph::Function> ngraph;
+    {
+        ngraph::PartialShape shape({1, 3, 22, 22});
+        auto param1 = make_shared<ngraph::opset1::Parameter>(overriden_type, shape);
+        auto param2 = make_shared<ngraph::opset1::Parameter>(overriden_type, shape);
+        auto op = ngraph::opset1::LogicalAnd(
+                ngraph::op::TemporaryReplaceOutputType(param1, orig_type).get(),
+                ngraph::op::TemporaryReplaceOutputType(param2, orig_type).get());
+        auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::LogicalAnd>>(
+                op, TypeVector{orig_type, orig_type}, TypeVector{overriden_type});
+        auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+        ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2});
+
+        ASSERT_EQ(overriden_type, relaxed_op->get_input_element_type(0));
+        ASSERT_EQ(overriden_type, relaxed_op->get_input_element_type(1));
+        ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0));
+    }
+
+    ASSERT_EQ(4, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, notSupportedTypeOverridePartially) {
+    auto some_type = element::u8;
+    auto overriden_type = element::f32;
+    auto orig_type = element::i64;
+    std::shared_ptr<ngraph::Function> ngraph;
+    {
+        ngraph::PartialShape shape({1, 3, 22, 22});
+        auto param1 = make_shared<ngraph::opset1::Parameter>(some_type, shape);
+        auto param2 = make_shared<ngraph::opset1::Parameter>(overriden_type, ngraph::PartialShape{1});
+        auto op = ngraph::opset1::Reshape(
+                param1,
+                ngraph::op::TemporaryReplaceOutputType(param2, orig_type).get(),
+                false);
+        auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Reshape>>(
+                op, TypeVector{element::undefined, orig_type}, TypeVector{});
+        auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+        ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2});
+
+        ASSERT_EQ(some_type, relaxed_op->get_input_element_type(0));
+        ASSERT_EQ(overriden_type, relaxed_op->get_input_element_type(1));
+        ASSERT_EQ(some_type, relaxed_op->get_output_element_type(0));
+    }
+
+    ASSERT_EQ(4, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, setGetTypes) {
+    std::shared_ptr<ngraph::Function> ngraph;
+    {
+        ngraph::PartialShape shape({1, 3, 22, 22});
+        auto param1 = make_shared<ngraph::opset1::Parameter>(element::u8, shape);
+        auto param2 = make_shared<ngraph::opset1::Parameter>(element::u8, shape);
+        // create TypeRelaxed without any type adjustment, the same behaviour as for opset1::Add
+        auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Add>>(param1, param2);
+        auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+        ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2});
+
+        ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0));
+        ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1));
+        ASSERT_EQ(element::u8, relaxed_op->get_output_element_type(0));
+
+        // internally set types for opset1::Add inference wasn't set when TypeRelaxed created, check it
+        ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(0));
+        ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(1));
+        // if we access elements outside really existing inputs, it should give undefined as well
+        ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(2));
+        // number of inputs for the operation node shouldn't change after that
+        ASSERT_EQ(2, relaxed_op->get_input_size());
+
+        // similar checks for outputs
+        ASSERT_EQ(element::undefined, relaxed_op->get_overridden_output_type(0));
+        ASSERT_EQ(element::undefined, relaxed_op->get_overridden_output_type(1));
+        ASSERT_EQ(1, relaxed_op->get_output_size());
+
+        // previous checks for input/output indices that are out of number of real inputs/outputs
+        // should resize internal vectors that hold orig/overridden types, it may affect
+        // inference for the op, so here we check if the inference is still OK:
+        ngraph->validate_nodes_and_infer_types();
+
+        // recheck basic statements about input/output types; they should be the same as we haven't changed anything
+        ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0));
+        ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1));
+        ASSERT_EQ(element::u8, relaxed_op->get_output_element_type(0));
+
+        // now we are modifying input types and see if the output type reflects this change
+        relaxed_op->set_origin_input_type(element::i8, 0);
+        relaxed_op->set_origin_input_type(element::i8, 1);
+        ngraph->validate_nodes_and_infer_types();
+        ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0));
+        ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1));
+        ASSERT_EQ(element::i8, relaxed_op->get_output_element_type(0));
+
+        // override output type
+        relaxed_op->set_overridden_output_type(element::f32, 0);
+        ngraph->validate_nodes_and_infer_types();
+        ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0));
+        ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1));
+        ASSERT_EQ(element::f32, relaxed_op->get_output_element_type(0));
+
+        // check if get methods reflect recent changes after set methods
+        ASSERT_EQ(element::i8, relaxed_op->get_origin_input_type(0));
+        ASSERT_EQ(element::i8, relaxed_op->get_origin_input_type(1));
+        ASSERT_EQ(element::f32, relaxed_op->get_overridden_output_type(0));
+
+        // Now, a more advanced trick: set real orig/overridden type for a not existing input/output
+        // it shouldn't affect inference as corresponding inputs/outputs don't exist.
+        // This scenario is tested for cases when we want to set new types for operation that will
+        // be further modified in the code by adding new inputs (Concat) or outputs (Split) and this code
+        // is not aware of TypeRelaxed and shouldn't bother about setting types for new items
+        // (a bit hypothetical though).
+        relaxed_op->set_origin_input_type(element::i32, 2);
+        relaxed_op->set_overridden_output_type(element::i32, 1);
+        ngraph->validate_nodes_and_infer_types();
+        ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0));
+        ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1));
+        ASSERT_EQ(element::f32, relaxed_op->get_output_element_type(0));
+        ASSERT_EQ(2, relaxed_op->get_input_size());
+        ASSERT_EQ(1, relaxed_op->get_output_size());
+
+        // lets try to reset types to undefined again and make sure that all original types are restored
+        relaxed_op->set_origin_input_type(element::undefined, 0);
+        relaxed_op->set_origin_input_type(element::undefined, 1);
+        relaxed_op->set_overridden_output_type(element::undefined, 0);
+        ngraph->validate_nodes_and_infer_types();
+        ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0));
+        ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1));
+        ASSERT_EQ(element::u8, relaxed_op->get_output_element_type(0));
+
+        ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(0));
+        ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(1));
+        ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(0));
+    }
+
+    ASSERT_EQ(4, ngraph->get_ops().size());
+}
index 7fb73ae..c55de14 100644 (file)
@@ -155,6 +155,9 @@ namespace ngraph
     protected:
         /// \brief Construct an unitialized Node
         Node() {}
+        /// \brief Copying a node
+        Node(const Node&);
+
         /// \brief Construct an unitialized Node
         /// \param output_size Number of outputs for this node
         Node(size_t output_size);
@@ -233,7 +236,7 @@ namespace ngraph
         /// \brief Get the string name for the type of the node, such as `Add` or `Multiply`.
         ///        The class name, must not contain spaces as it is used for codegen.
         /// \returns A const reference to the node's type name
-        virtual const std::string& description() const;
+        virtual std::string description() const;
         /// \brief Get the unique name of the node.
         /// \returns A const reference to the node's unique name.
         const std::string& get_name() const;
index 8b5783f..9bd5edb 100644 (file)
@@ -128,6 +128,9 @@ namespace ngraph
             Type_t m_type{Type_t::undefined};
         };
 
+        typedef std::vector<Type> TypeVector;
+
+        extern NGRAPH_API const Type undefined;
         extern NGRAPH_API const Type dynamic;
         extern NGRAPH_API const Type boolean;
         extern NGRAPH_API const Type bf16;
index 1d6eb09..6a40c66 100644 (file)
@@ -34,6 +34,28 @@ using namespace ngraph;
 
 atomic<size_t> Node::m_next_instance_id(0);
 
+Node::Node(const Node& node)
+    : m_control_dependents(node.m_control_dependents)
+    , m_control_dependencies(node.m_control_dependencies)
+    // skip m_node_type -- will be generated automatically
+    , m_instance_id(m_next_instance_id.fetch_add(1))
+    , m_friendly_name(node.m_friendly_name)
+    // skip m_unique_name -- will be generated automatically
+    , m_provenance_tags(node.m_provenance_tags)
+    , m_provenance_group(node.m_provenance_group)
+    , m_inputs(node.m_inputs) // will be modified in the body
+    // skip m_outputs -- should be initialized outside
+    , m_op_annotations(node.m_op_annotations)
+    , m_rt_info(node.m_rt_info)
+{
+    // cannot do it without copying node.m_inputs first due to too limiting const qualifiers
+    for (auto& input : m_inputs)
+    {
+        input = descriptor::Input(this, input.get_index(), input.get_output());
+        input.get_output().add_input(&input);
+    }
+}
+
 Node::Node(size_t output_size)
     : Node()
 {
@@ -243,12 +265,9 @@ void Node::set_output_type(size_t i, const element::Type& element_type, const Pa
     get_output_descriptor(i).get_tensor_ptr()->set_tensor_type(element_type, pshape);
 }
 
-const std::string& Node::description() const
+std::string Node::description() const
 {
-    // Terrible transitional kludge to keep description working while we change
-    // type_name to const_char and virtual description() to virtual get_type_name()
-    const_cast<Node*>(this)->m_node_type = get_type_name();
-    return m_node_type;
+    return get_type_name();
 }
 
 const std::string& Node::get_friendly_name() const
@@ -711,7 +730,8 @@ NodeVector Node::get_users(bool check_is_used) const
 std::string ngraph::node_validation_failure_loc_string(const Node* node)
 {
     std::stringstream ss;
-    ss << "While validating node '" << *node << "'";
+    ss << "While validating node '" << *node << "' with friendly_name '"
+       << node->get_friendly_name() << '\'';
     return ss.str();
 }
 
index bb639d0..6204dce 100644 (file)
@@ -31,7 +31,7 @@ bool pattern::op::WrapType::match_value(Matcher* matcher,
                                         const Output<Node>& pattern_value,
                                         const Output<Node>& graph_value)
 {
-    if (graph_value.get_node_shared_ptr()->get_type_info() == get_wrapped_type() &&
+    if (graph_value.get_node_shared_ptr()->get_type_info().is_castable(get_wrapped_type()) &&
         m_predicate(graph_value))
     {
         auto& pattern_map = matcher->get_pattern_value_map();
index 1ffe2d3..588a140 100644 (file)
@@ -25,6 +25,7 @@
 using namespace ngraph;
 using namespace std;
 
+const element::Type element::undefined(element::Type_t::undefined);
 const element::Type element::dynamic(element::Type_t::dynamic);
 const element::Type element::boolean(element::Type_t::boolean);
 const element::Type element::bf16(element::Type_t::bf16);