From 9a62e00674ffb1b0722edd7c63d27572844bc558 Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Thu, 13 Aug 2020 18:45:37 +0300 Subject: [PATCH] TypeRelaxed implementation (#1561) * RTTI base for ngraph::Node; cherry-pick from another branch, draft * Added comments, moved code, switched to custom RTTI-based version of is_type * Move rtti definitions in ngraph op class to the beginning of each class definition as a preparation for the next replacement * Migrate part of operations to new RTTI * Migrate GroupConvolution and Concat to new RTTI * Apply code style for ngraph part * Rename RTTI_DECLARATION/DEFINITION to NGRAPH_RTTI_DECLARATION/DEFINITION * Reverted accidentally updated version of mkldnn * TMP: rewrite RTTI back to constexprions as an attempt to fix static objects initialization order issue * Apply ngraph code style * Finalize move back to constexpr for RTTI * Applied code-style * TypeRelaxed template class implementation and necessary changes in ngraph + tests. * Applied code-style * Fix in fast algorithm in GraphRewrite, add new tests for this and other cases * Make parent optional parameter for NGRAPH_RTTI_DECLARATION and remove Node::type_info; remove ability to have Node as a parent for type_info * Try to resolve compilation error on Windows * The next attempt to fix Windows build: re-introduce get_type_info_static * Removed file that was removed in master and kept in this branch by mistake * Next attempt to fix Windows build: externConstexpr * Attempt to fix win build: extra public (suspect icc bug), remove get_type_info_static as useless. * Next attempt to fix Windows: proxy const and constexpr * Fixed constexpr * Next attmpts: move get_type_info to cpp file * Code stype fix * Re-implemented RTTI without use of constexpr; run-time initialization is used; removed global definitions to avoid issues with order of static objects initialization * Removed externConstexpr flag and removed TRANSFOMRATIONS_API for TypeRelaxed * get_type_info_static initializes static local constant with type_info that is used for CLASS::type_info and CLASS::get_type_info * Removed not needed debug output and useless comments * Implemented better copy ctor for Node * Fixed VisualizeTree issue for TypeRelaxed: stopped using < and > in type_info::name * Better comments and names for methods * Remove unused include * Remove commented line * Workaround for legacy conversion that uses Node::get_type_info().name as a type for the resulting CNNLayer leading to incorrect types for TypeRelaxed-based operations and then to fail in plugins * Fixed typos, explicit ctor for TypeRelaxedBase, explanation for the need of get_overridden_output_type * Fix typo * Fixed issue with non-static name in type definition for TypeRelaxed and fixed WrapType to make it compatible with hierarchical relations between types * Reverted default ctor for Output and reverted ability to reduce number of outputs for a Node; syntactically better debug message for a Node * Cover methods of TypeRelaxedBase by tests * Apply code-style --- .../include/ngraph_ops/type_relaxed.hpp | 229 +++++++++++++++++ .../src/ngraph_ops/type_relaxed.cpp | 17 ++ .../transformations/type_relaxed_tests.cpp | 273 +++++++++++++++++++++ ngraph/core/include/ngraph/node.hpp | 5 +- ngraph/core/include/ngraph/type/element_type.hpp | 3 + ngraph/core/src/node.cpp | 32 ++- ngraph/core/src/pattern/op/wrap_type.cpp | 2 +- ngraph/core/src/type/element_type.cpp | 1 + 8 files changed, 554 insertions(+), 8 deletions(-) create mode 100644 inference-engine/src/transformations/include/ngraph_ops/type_relaxed.hpp create mode 100644 inference-engine/src/transformations/src/ngraph_ops/type_relaxed.cpp create mode 100644 inference-engine/tests/functional/inference_engine/transformations/type_relaxed_tests.cpp diff --git a/inference-engine/src/transformations/include/ngraph_ops/type_relaxed.hpp b/inference-engine/src/transformations/include/ngraph_ops/type_relaxed.hpp new file mode 100644 index 0000000..c8dff13 --- /dev/null +++ b/inference-engine/src/transformations/include/ngraph_ops/type_relaxed.hpp @@ -0,0 +1,229 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include + +#include "ngraph/op/op.hpp" + +namespace ngraph { +namespace op { + +/// A base class for templated TypeRelaxed that maintains overridden input types and output types for an operation. +class TRANSFORMATIONS_API TypeRelaxedBase { +public: + virtual ~TypeRelaxedBase(); + + explicit TypeRelaxedBase( + const element::TypeVector& _input_data_types = {}, + const element::TypeVector& _output_data_types = {}) : + m_input_data_types(_input_data_types), + m_output_data_types(_output_data_types) { + } + + /// \return Data type that will be set for output with a given index outputIndex. + /// If output with a specified index outputIndex hasn't been set before, element::undefined will returned. + /// Undefined means no type override happens for a given outputIndex and it will deduced as original + /// operation defineds in its infer function. + /// + /// This method may look similar to Node::get_output_element_type, but it is not the same thing, because + /// get_output_element_type returns the result of type inference, so it is completely deduced from + /// an operation inputs and attributes, and get_overridden_output_type returns value of the attribute that + /// is used to deduce output type. In some cases they don't match: get_overridden_output_type may return + /// element::undefined for some index i, and get_output_element_type will return some real type for + /// the same index i. + const element::Type& get_overridden_output_type(size_t outputIndex = 0) const { + if (outputIndex >= m_output_data_types.size()) { + return element::undefined; + } + return m_output_data_types[outputIndex]; + } + + /// Set data type that overrides the original data type for output port with outputIndex index + /// In case if outputIndex is out of range of known outputs (and this class cannot detect + /// the real number of outputs for original operation), the number of overridden outputs + /// is changed according to a given outputIndex value. + void set_overridden_output_type(const element::Type& element_type, size_t outputIndex = 0) { + if (outputIndex >= m_output_data_types.size()) { + m_output_data_types.resize(outputIndex + 1, element::undefined); + } + m_output_data_types[outputIndex] = element_type; + } + + /// \return Data type that will be set for input when original shape/type inference function is called. + /// If index inputIndex hasn't been set before, element::undefined will returned. Undefined means that + /// the type from input tensor descriptor is used for a given index. + const element::Type& get_origin_input_type(size_t inputIndex = 0) const { + if (inputIndex >= m_input_data_types.size()) { + return element::undefined; + } + return m_input_data_types[inputIndex]; + } + + /// Set data type that overrides the original data type for input port with inputIndex index. + /// In case if inputIndex is out of range of known inputs (and this class cannot detect + /// the real number of inputs for original operation), the number of overridden inputs + /// is changed according to a given inputIndex value. All new entries except one added + /// at inputIndex position are undefined. + void set_origin_input_type(const element::Type& element_type, size_t inputIndex = 0) { + if (inputIndex >= m_input_data_types.size()) { + m_input_data_types.resize(inputIndex + 1, element::undefined); + } + m_input_data_types[inputIndex] = element_type; + } + +protected: + // Data types that are used for parent shape/type infer function input ports + // to infer output data types + element::TypeVector m_input_data_types; + element::TypeVector m_output_data_types; +}; + +/// Set another type for a specified output for the period of time when an instance of the class exists. +/// When the execution leaves the scope where an onject of TemporaryReplaceOutputType is defined, +/// the type of the output is set to its original value. Used when initialized TypeRelaxed operation +/// in case when inputs have types that are not compatible with BaseOp infer function. In this case +/// before TypeRelaxed is constructed the BaseOp contructor requires modified data types. +/// So it should be +class TemporaryReplaceOutputType { + Output m_output; + element::Type orig_type; + +public: + /// Replace element type for a given output port by tmp_type + TemporaryReplaceOutputType(Output output, element::Type tmp_type) : m_output(output) { + // save original element type in order to restore it in the destructor + orig_type = m_output.get_element_type(); + m_output.get_tensor().set_element_type(tmp_type); + } + + /// Return the output port that was used in the constructor + Output get() const { + return m_output; + } + + /// Restores the original element type for the output + ~TemporaryReplaceOutputType() { + m_output.get_tensor().set_element_type(orig_type); + } +}; + +/// Relaxes tensor element type requirements for BaseOp inputs and outputs +/// This class template should be used with Node descendant class. Defines a new operation by extending the +/// original BaseOp operation with ability to accept inputs and provide outputs with element type that is +/// unusual for BaseOp. For example, TypeRelaxed can accept mixed-precision inputs and provide +/// another type of output. New types are provided as inputs attributes for TypeRelaxed template and fixed. +/// There is no any deduction logic for types are provided as a part of this class and it should be +/// implemented outside if required. +template +class TypeRelaxed : public BaseOp, public TypeRelaxedBase { +public: + NGRAPH_RTTI_DECLARATION; + + using BaseOp::BaseOp; + + TypeRelaxed() = default; + + TypeRelaxed( + const BaseOp& base_op, + element::Type overridden_type) : + TypeRelaxed(base_op, + element::TypeVector(base_op.get_input_size(), overridden_type), + element::TypeVector(base_op.get_output_size(), overridden_type)) { + } + + explicit TypeRelaxed( + const BaseOp& base_op, + const element::TypeVector& _input_data_types = {}, + const element::TypeVector& _output_data_types = {}) : + BaseOp(base_op), TypeRelaxedBase(_input_data_types, _output_data_types) { + init(); + } + + /// Creating a new TypeRelaxed operation by calling one of the original op ctors forwarding arguments directly. + template + TypeRelaxed( + const element::TypeVector& _input_data_types, + const element::TypeVector& _output_data_types, + Args&& ... args) : + BaseOp(std::forward(args)...), TypeRelaxedBase(_input_data_types, _output_data_types) { + init(); + } + + void validate_and_infer_types() override; + + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + +private: + void init() { + validate_and_infer_types(); + } +}; + +template +void TypeRelaxed::validate_and_infer_types() { + // Remember all input data types and reset them to m_output_data_type. + element::TypeVector old_input_types; + for (size_t i = 0; i < BaseOp::get_input_size(); ++i) { + old_input_types.push_back(BaseOp::get_input_element_type(i)); + auto origin_input_type = get_origin_input_type(i); + if (origin_input_type != element::undefined) { + BaseOp::get_input_tensor(i).set_tensor_type(origin_input_type, BaseOp::get_input_partial_shape(i)); + } + } + + BaseOp::validate_and_infer_types(); + + // Restore original input data types + for (size_t i = 0; i < BaseOp::get_input_size(); ++i) { + BaseOp::get_input_tensor(i).set_tensor_type(old_input_types[i], BaseOp::get_input_partial_shape(i)); + } + + // Override (some) output types + for (size_t i = 0; i < BaseOp::get_output_size(); ++i) { + auto overridden_output_type = get_overridden_output_type(i); + if (overridden_output_type != element::undefined) { + BaseOp::set_output_type(0, overridden_output_type, BaseOp::get_output_partial_shape(i)); + } + } +} + + +template +std::shared_ptr TypeRelaxed::clone_with_new_inputs(const OutputVector& new_args) const { + // copy then modify inputs + std::shared_ptr new_node = std::make_shared>((BaseOp&)(*this), m_input_data_types, m_output_data_types); + for (size_t i = 0; i < new_node->get_input_size(); ++i) { + new_node->input(i).replace_source_output(new_args[i]); + } + return new_node; +} + +template +const ::ngraph::Node::type_info_t& TypeRelaxed::get_type_info() const { return get_type_info_static(); } + +template +const ::ngraph::Node::type_info_t& TypeRelaxed::get_type_info_static() { + auto baseOpTypeInfoPtr = &BaseOp::get_type_info_static(); + + // TODO: it should be static const std::string name = std::string("TypeRelaxed_") + baseOpTypeInfoPtr->name; + // but currently it will not pass conversion ot Legacy Opset correctly + static const std::string name = baseOpTypeInfoPtr->name; + + static const ::ngraph::Node::type_info_t type_info_static{ + name.c_str(), baseOpTypeInfoPtr->version, baseOpTypeInfoPtr}; + return type_info_static; +} + +template +const ::ngraph::Node::type_info_t TypeRelaxed::type_info = TypeRelaxed::get_type_info_static(); + +} // namespace op +} // namespace ngraph diff --git a/inference-engine/src/transformations/src/ngraph_ops/type_relaxed.cpp b/inference-engine/src/transformations/src/ngraph_ops/type_relaxed.cpp new file mode 100644 index 0000000..03dcfc6 --- /dev/null +++ b/inference-engine/src/transformations/src/ngraph_ops/type_relaxed.cpp @@ -0,0 +1,17 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include + +#include "ngraph_ops/type_relaxed.hpp" + +namespace ngraph { +namespace op { + +TypeRelaxedBase::~TypeRelaxedBase() {} + +} // namespace op +} // namespace ngraph diff --git a/inference-engine/tests/functional/inference_engine/transformations/type_relaxed_tests.cpp b/inference-engine/tests/functional/inference_engine/transformations/type_relaxed_tests.cpp new file mode 100644 index 0000000..c93a2cc --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/type_relaxed_tests.cpp @@ -0,0 +1,273 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include "common_test_utils/test_common.hpp" +#include +#include + + +namespace element = ngraph::element; +using std::make_shared; +using TypeVector = element::TypeVector; + +using TypeRelaxedTests = CommonTestUtils::TestsCommon; + +TEST_F(TypeRelaxedTests, noOverrideCopyCtor) { + std::shared_ptr ngraph; + { + ngraph::PartialShape shape({1, 3, 22, 22}); + element::Type type(element::Type_t::f32); + auto param = make_shared(type, shape); + auto op = ngraph::opset1::Relu(param); + auto relaxed_op = make_shared>(op); + auto result = make_shared(relaxed_op); + + ngraph::ParameterVector params = {param}; + ngraph::ResultVector results = {result}; + + ngraph = make_shared(results, params); + + ASSERT_EQ(element::f32, relaxed_op->get_input_element_type(0)); + ASSERT_EQ(element::f32, relaxed_op->get_output_element_type(0)); + } + + ASSERT_EQ(3, ngraph->get_ops().size()); +} + +TEST_F(TypeRelaxedTests, overrideOutputCopyCtor) { + auto input_type = element::f32; + auto overriden_type = element::i32; + std::shared_ptr ngraph; + { + ngraph::PartialShape shape({1, 3, 22, 22}); + auto param = make_shared(input_type, shape); + auto op = ngraph::opset1::Relu(param); + auto relaxed_op = make_shared>( + op, TypeVector{}, TypeVector{overriden_type}); + auto result = make_shared(relaxed_op); + + ngraph = make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{param}); + + ASSERT_EQ(input_type, relaxed_op->get_input_element_type(0)); + ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0)); + } + + ASSERT_EQ(3, ngraph->get_ops().size()); +} + +TEST_F(TypeRelaxedTests, overrideInputCopyCtor) { + auto input_type = element::f32; + auto overriden_type = element::i32; + std::shared_ptr ngraph; + { + ngraph::PartialShape shape({1, 3, 22, 22}); + auto param = make_shared(input_type, shape); + auto op = ngraph::opset1::Relu(param); + auto relaxed_op = make_shared>( + op, TypeVector{overriden_type}, TypeVector{}); + auto result = make_shared(relaxed_op); + + ngraph = make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{param}); + + ASSERT_EQ(input_type, relaxed_op->get_input_element_type(0)); + ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0)); + } + + ASSERT_EQ(3, ngraph->get_ops().size()); +} + +TEST_F(TypeRelaxedTests, mixedInputsAutoOutput) { + auto input_type1 = element::u8; + auto input_type2 = element::i8; + auto overriden_type = element::i16; + std::shared_ptr ngraph; + { + ngraph::PartialShape shape({1, 3, 22, 22}); + auto param1 = make_shared(input_type1, shape); + auto param2 = make_shared(input_type2, shape); + auto op = ngraph::opset1::Add( + ngraph::op::TemporaryReplaceOutputType(param1->output(0), overriden_type).get(), + ngraph::op::TemporaryReplaceOutputType(param2->output(0), overriden_type).get()); + auto relaxed_op = make_shared>( + op, TypeVector{overriden_type, overriden_type}, TypeVector{}); + auto result = make_shared(relaxed_op); + + ngraph = make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2}); + + ASSERT_EQ(input_type1, relaxed_op->get_input_element_type(0)); + ASSERT_EQ(input_type2, relaxed_op->get_input_element_type(1)); + ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0)); + } + + ASSERT_EQ(4, ngraph->get_ops().size()); +} + +TEST_F(TypeRelaxedTests, mixedInputsAutoOutputForwardCtor) { + auto input_type1 = element::u8; + auto input_type2 = element::i8; + auto overriden_type = element::i16; + std::shared_ptr ngraph; + { + ngraph::PartialShape shape({1, 3, 22, 22}); + auto param1 = make_shared(input_type1, shape); + auto param2 = make_shared(input_type2, shape); + auto relaxed_op = make_shared>( + TypeVector{overriden_type, overriden_type}, TypeVector{}, + ngraph::op::TemporaryReplaceOutputType(param1, overriden_type).get(), + ngraph::op::TemporaryReplaceOutputType(param2, overriden_type).get()); + auto result = make_shared(relaxed_op); + + ngraph = make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2}); + + ASSERT_EQ(input_type1, relaxed_op->get_input_element_type(0)); + ASSERT_EQ(input_type2, relaxed_op->get_input_element_type(1)); + ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0)); + } + + ASSERT_EQ(4, ngraph->get_ops().size()); +} + +TEST_F(TypeRelaxedTests, notSupportedTypeOverride) { + auto overriden_type = element::u8; + auto orig_type = element::boolean; + std::shared_ptr ngraph; + { + ngraph::PartialShape shape({1, 3, 22, 22}); + auto param1 = make_shared(overriden_type, shape); + auto param2 = make_shared(overriden_type, shape); + auto op = ngraph::opset1::LogicalAnd( + ngraph::op::TemporaryReplaceOutputType(param1, orig_type).get(), + ngraph::op::TemporaryReplaceOutputType(param2, orig_type).get()); + auto relaxed_op = make_shared>( + op, TypeVector{orig_type, orig_type}, TypeVector{overriden_type}); + auto result = make_shared(relaxed_op); + + ngraph = make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2}); + + ASSERT_EQ(overriden_type, relaxed_op->get_input_element_type(0)); + ASSERT_EQ(overriden_type, relaxed_op->get_input_element_type(1)); + ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0)); + } + + ASSERT_EQ(4, ngraph->get_ops().size()); +} + +TEST_F(TypeRelaxedTests, notSupportedTypeOverridePartially) { + auto some_type = element::u8; + auto overriden_type = element::f32; + auto orig_type = element::i64; + std::shared_ptr ngraph; + { + ngraph::PartialShape shape({1, 3, 22, 22}); + auto param1 = make_shared(some_type, shape); + auto param2 = make_shared(overriden_type, ngraph::PartialShape{1}); + auto op = ngraph::opset1::Reshape( + param1, + ngraph::op::TemporaryReplaceOutputType(param2, orig_type).get(), + false); + auto relaxed_op = make_shared>( + op, TypeVector{element::undefined, orig_type}, TypeVector{}); + auto result = make_shared(relaxed_op); + + ngraph = make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2}); + + ASSERT_EQ(some_type, relaxed_op->get_input_element_type(0)); + ASSERT_EQ(overriden_type, relaxed_op->get_input_element_type(1)); + ASSERT_EQ(some_type, relaxed_op->get_output_element_type(0)); + } + + ASSERT_EQ(4, ngraph->get_ops().size()); +} + +TEST_F(TypeRelaxedTests, setGetTypes) { + std::shared_ptr ngraph; + { + ngraph::PartialShape shape({1, 3, 22, 22}); + auto param1 = make_shared(element::u8, shape); + auto param2 = make_shared(element::u8, shape); + // create TypeRelaxed without any type adjustment, the same behaviour as for opset1::Add + auto relaxed_op = make_shared>(param1, param2); + auto result = make_shared(relaxed_op); + + ngraph = make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2}); + + ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0)); + ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1)); + ASSERT_EQ(element::u8, relaxed_op->get_output_element_type(0)); + + // internally set types for opset1::Add inference wasn't set when TypeRelaxed created, check it + ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(0)); + ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(1)); + // if we access elements outside really existing inputs, it should give undefined as well + ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(2)); + // number of inputs for the operation node shouldn't change after that + ASSERT_EQ(2, relaxed_op->get_input_size()); + + // similar checks for outputs + ASSERT_EQ(element::undefined, relaxed_op->get_overridden_output_type(0)); + ASSERT_EQ(element::undefined, relaxed_op->get_overridden_output_type(1)); + ASSERT_EQ(1, relaxed_op->get_output_size()); + + // previous checks for input/output indices that are out of number of real inputs/outputs + // should resize internal vectors that hold orig/overridden types, it may affect + // inference for the op, so here we check if the inference is still OK: + ngraph->validate_nodes_and_infer_types(); + + // recheck basic statements about input/output types; they should be the same as we haven't changed anything + ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0)); + ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1)); + ASSERT_EQ(element::u8, relaxed_op->get_output_element_type(0)); + + // now we are modifying input types and see if the output type reflects this change + relaxed_op->set_origin_input_type(element::i8, 0); + relaxed_op->set_origin_input_type(element::i8, 1); + ngraph->validate_nodes_and_infer_types(); + ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0)); + ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1)); + ASSERT_EQ(element::i8, relaxed_op->get_output_element_type(0)); + + // override output type + relaxed_op->set_overridden_output_type(element::f32, 0); + ngraph->validate_nodes_and_infer_types(); + ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0)); + ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1)); + ASSERT_EQ(element::f32, relaxed_op->get_output_element_type(0)); + + // check if get methods reflect recent changes after set methods + ASSERT_EQ(element::i8, relaxed_op->get_origin_input_type(0)); + ASSERT_EQ(element::i8, relaxed_op->get_origin_input_type(1)); + ASSERT_EQ(element::f32, relaxed_op->get_overridden_output_type(0)); + + // Now, a more advanced trick: set real orig/overridden type for a not existing input/output + // it shouldn't affect inference as corresponding inputs/outputs don't exist. + // This scenario is tested for cases when we want to set new types for operation that will + // be further modified in the code by adding new inputs (Concat) or outputs (Split) and this code + // is not aware of TypeRelaxed and shouldn't bother about setting types for new items + // (a bit hypothetical though). + relaxed_op->set_origin_input_type(element::i32, 2); + relaxed_op->set_overridden_output_type(element::i32, 1); + ngraph->validate_nodes_and_infer_types(); + ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0)); + ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1)); + ASSERT_EQ(element::f32, relaxed_op->get_output_element_type(0)); + ASSERT_EQ(2, relaxed_op->get_input_size()); + ASSERT_EQ(1, relaxed_op->get_output_size()); + + // lets try to reset types to undefined again and make sure that all original types are restored + relaxed_op->set_origin_input_type(element::undefined, 0); + relaxed_op->set_origin_input_type(element::undefined, 1); + relaxed_op->set_overridden_output_type(element::undefined, 0); + ngraph->validate_nodes_and_infer_types(); + ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0)); + ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1)); + ASSERT_EQ(element::u8, relaxed_op->get_output_element_type(0)); + + ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(0)); + ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(1)); + ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(0)); + } + + ASSERT_EQ(4, ngraph->get_ops().size()); +} diff --git a/ngraph/core/include/ngraph/node.hpp b/ngraph/core/include/ngraph/node.hpp index 7fb73ae..c55de14 100644 --- a/ngraph/core/include/ngraph/node.hpp +++ b/ngraph/core/include/ngraph/node.hpp @@ -155,6 +155,9 @@ namespace ngraph protected: /// \brief Construct an unitialized Node Node() {} + /// \brief Copying a node + Node(const Node&); + /// \brief Construct an unitialized Node /// \param output_size Number of outputs for this node Node(size_t output_size); @@ -233,7 +236,7 @@ namespace ngraph /// \brief Get the string name for the type of the node, such as `Add` or `Multiply`. /// The class name, must not contain spaces as it is used for codegen. /// \returns A const reference to the node's type name - virtual const std::string& description() const; + virtual std::string description() const; /// \brief Get the unique name of the node. /// \returns A const reference to the node's unique name. const std::string& get_name() const; diff --git a/ngraph/core/include/ngraph/type/element_type.hpp b/ngraph/core/include/ngraph/type/element_type.hpp index 8b5783f..9bd5edb 100644 --- a/ngraph/core/include/ngraph/type/element_type.hpp +++ b/ngraph/core/include/ngraph/type/element_type.hpp @@ -128,6 +128,9 @@ namespace ngraph Type_t m_type{Type_t::undefined}; }; + typedef std::vector TypeVector; + + extern NGRAPH_API const Type undefined; extern NGRAPH_API const Type dynamic; extern NGRAPH_API const Type boolean; extern NGRAPH_API const Type bf16; diff --git a/ngraph/core/src/node.cpp b/ngraph/core/src/node.cpp index 1d6eb09..6a40c66 100644 --- a/ngraph/core/src/node.cpp +++ b/ngraph/core/src/node.cpp @@ -34,6 +34,28 @@ using namespace ngraph; atomic Node::m_next_instance_id(0); +Node::Node(const Node& node) + : m_control_dependents(node.m_control_dependents) + , m_control_dependencies(node.m_control_dependencies) + // skip m_node_type -- will be generated automatically + , m_instance_id(m_next_instance_id.fetch_add(1)) + , m_friendly_name(node.m_friendly_name) + // skip m_unique_name -- will be generated automatically + , m_provenance_tags(node.m_provenance_tags) + , m_provenance_group(node.m_provenance_group) + , m_inputs(node.m_inputs) // will be modified in the body + // skip m_outputs -- should be initialized outside + , m_op_annotations(node.m_op_annotations) + , m_rt_info(node.m_rt_info) +{ + // cannot do it without copying node.m_inputs first due to too limiting const qualifiers + for (auto& input : m_inputs) + { + input = descriptor::Input(this, input.get_index(), input.get_output()); + input.get_output().add_input(&input); + } +} + Node::Node(size_t output_size) : Node() { @@ -243,12 +265,9 @@ void Node::set_output_type(size_t i, const element::Type& element_type, const Pa get_output_descriptor(i).get_tensor_ptr()->set_tensor_type(element_type, pshape); } -const std::string& Node::description() const +std::string Node::description() const { - // Terrible transitional kludge to keep description working while we change - // type_name to const_char and virtual description() to virtual get_type_name() - const_cast(this)->m_node_type = get_type_name(); - return m_node_type; + return get_type_name(); } const std::string& Node::get_friendly_name() const @@ -711,7 +730,8 @@ NodeVector Node::get_users(bool check_is_used) const std::string ngraph::node_validation_failure_loc_string(const Node* node) { std::stringstream ss; - ss << "While validating node '" << *node << "'"; + ss << "While validating node '" << *node << "' with friendly_name '" + << node->get_friendly_name() << '\''; return ss.str(); } diff --git a/ngraph/core/src/pattern/op/wrap_type.cpp b/ngraph/core/src/pattern/op/wrap_type.cpp index bb639d0..6204dce 100644 --- a/ngraph/core/src/pattern/op/wrap_type.cpp +++ b/ngraph/core/src/pattern/op/wrap_type.cpp @@ -31,7 +31,7 @@ bool pattern::op::WrapType::match_value(Matcher* matcher, const Output& pattern_value, const Output& graph_value) { - if (graph_value.get_node_shared_ptr()->get_type_info() == get_wrapped_type() && + if (graph_value.get_node_shared_ptr()->get_type_info().is_castable(get_wrapped_type()) && m_predicate(graph_value)) { auto& pattern_map = matcher->get_pattern_value_map(); diff --git a/ngraph/core/src/type/element_type.cpp b/ngraph/core/src/type/element_type.cpp index 1ffe2d3..588a140 100644 --- a/ngraph/core/src/type/element_type.cpp +++ b/ngraph/core/src/type/element_type.cpp @@ -25,6 +25,7 @@ using namespace ngraph; using namespace std; +const element::Type element::undefined(element::Type_t::undefined); const element::Type element::dynamic(element::Type_t::dynamic); const element::Type element::boolean(element::Type_t::boolean); const element::Type element::bf16(element::Type_t::bf16); -- 2.7.4