--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+#include <vector>
+#include <algorithm>
+#include <string>
+
+#include <transformations_visibility.hpp>
+
+#include "ngraph/op/op.hpp"
+
+namespace ngraph {
+namespace op {
+
+/// A base class for templated TypeRelaxed that maintains overridden input types and output types for an operation.
+class TRANSFORMATIONS_API TypeRelaxedBase {
+public:
+ virtual ~TypeRelaxedBase();
+
+ explicit TypeRelaxedBase(
+ const element::TypeVector& _input_data_types = {},
+ const element::TypeVector& _output_data_types = {}) :
+ m_input_data_types(_input_data_types),
+ m_output_data_types(_output_data_types) {
+ }
+
+ /// \return Data type that will be set for output with a given index outputIndex.
+ /// If output with a specified index outputIndex hasn't been set before, element::undefined will returned.
+ /// Undefined means no type override happens for a given outputIndex and it will deduced as original
+ /// operation defineds in its infer function.
+ ///
+ /// This method may look similar to Node::get_output_element_type, but it is not the same thing, because
+ /// get_output_element_type returns the result of type inference, so it is completely deduced from
+ /// an operation inputs and attributes, and get_overridden_output_type returns value of the attribute that
+ /// is used to deduce output type. In some cases they don't match: get_overridden_output_type may return
+ /// element::undefined for some index i, and get_output_element_type will return some real type for
+ /// the same index i.
+ const element::Type& get_overridden_output_type(size_t outputIndex = 0) const {
+ if (outputIndex >= m_output_data_types.size()) {
+ return element::undefined;
+ }
+ return m_output_data_types[outputIndex];
+ }
+
+ /// Set data type that overrides the original data type for output port with outputIndex index
+ /// In case if outputIndex is out of range of known outputs (and this class cannot detect
+ /// the real number of outputs for original operation), the number of overridden outputs
+ /// is changed according to a given outputIndex value.
+ void set_overridden_output_type(const element::Type& element_type, size_t outputIndex = 0) {
+ if (outputIndex >= m_output_data_types.size()) {
+ m_output_data_types.resize(outputIndex + 1, element::undefined);
+ }
+ m_output_data_types[outputIndex] = element_type;
+ }
+
+ /// \return Data type that will be set for input when original shape/type inference function is called.
+ /// If index inputIndex hasn't been set before, element::undefined will returned. Undefined means that
+ /// the type from input tensor descriptor is used for a given index.
+ const element::Type& get_origin_input_type(size_t inputIndex = 0) const {
+ if (inputIndex >= m_input_data_types.size()) {
+ return element::undefined;
+ }
+ return m_input_data_types[inputIndex];
+ }
+
+ /// Set data type that overrides the original data type for input port with inputIndex index.
+ /// In case if inputIndex is out of range of known inputs (and this class cannot detect
+ /// the real number of inputs for original operation), the number of overridden inputs
+ /// is changed according to a given inputIndex value. All new entries except one added
+ /// at inputIndex position are undefined.
+ void set_origin_input_type(const element::Type& element_type, size_t inputIndex = 0) {
+ if (inputIndex >= m_input_data_types.size()) {
+ m_input_data_types.resize(inputIndex + 1, element::undefined);
+ }
+ m_input_data_types[inputIndex] = element_type;
+ }
+
+protected:
+ // Data types that are used for parent shape/type infer function input ports
+ // to infer output data types
+ element::TypeVector m_input_data_types;
+ element::TypeVector m_output_data_types;
+};
+
+/// Set another type for a specified output for the period of time when an instance of the class exists.
+/// When the execution leaves the scope where an onject of TemporaryReplaceOutputType is defined,
+/// the type of the output is set to its original value. Used when initialized TypeRelaxed<BaseOp> operation
+/// in case when inputs have types that are not compatible with BaseOp infer function. In this case
+/// before TypeRelaxed is constructed the BaseOp contructor requires modified data types.
+/// So it should be
+class TemporaryReplaceOutputType {
+ Output<Node> m_output;
+ element::Type orig_type;
+
+public:
+ /// Replace element type for a given output port by tmp_type
+ TemporaryReplaceOutputType(Output<Node> output, element::Type tmp_type) : m_output(output) {
+ // save original element type in order to restore it in the destructor
+ orig_type = m_output.get_element_type();
+ m_output.get_tensor().set_element_type(tmp_type);
+ }
+
+ /// Return the output port that was used in the constructor
+ Output<Node> get() const {
+ return m_output;
+ }
+
+ /// Restores the original element type for the output
+ ~TemporaryReplaceOutputType() {
+ m_output.get_tensor().set_element_type(orig_type);
+ }
+};
+
+/// Relaxes tensor element type requirements for BaseOp inputs and outputs
+/// This class template should be used with Node descendant class. Defines a new operation by extending the
+/// original BaseOp operation with ability to accept inputs and provide outputs with element type that is
+/// unusual for BaseOp. For example, TypeRelaxed<opset1::Add> can accept mixed-precision inputs and provide
+/// another type of output. New types are provided as inputs attributes for TypeRelaxed template and fixed.
+/// There is no any deduction logic for types are provided as a part of this class and it should be
+/// implemented outside if required.
+template <typename BaseOp>
+class TypeRelaxed : public BaseOp, public TypeRelaxedBase {
+public:
+ NGRAPH_RTTI_DECLARATION;
+
+ using BaseOp::BaseOp;
+
+ TypeRelaxed() = default;
+
+ TypeRelaxed(
+ const BaseOp& base_op,
+ element::Type overridden_type) :
+ TypeRelaxed(base_op,
+ element::TypeVector(base_op.get_input_size(), overridden_type),
+ element::TypeVector(base_op.get_output_size(), overridden_type)) {
+ }
+
+ explicit TypeRelaxed(
+ const BaseOp& base_op,
+ const element::TypeVector& _input_data_types = {},
+ const element::TypeVector& _output_data_types = {}) :
+ BaseOp(base_op), TypeRelaxedBase(_input_data_types, _output_data_types) {
+ init();
+ }
+
+ /// Creating a new TypeRelaxed operation by calling one of the original op ctors forwarding arguments directly.
+ template <typename ... Args>
+ TypeRelaxed(
+ const element::TypeVector& _input_data_types,
+ const element::TypeVector& _output_data_types,
+ Args&& ... args) :
+ BaseOp(std::forward<Args>(args)...), TypeRelaxedBase(_input_data_types, _output_data_types) {
+ init();
+ }
+
+ void validate_and_infer_types() override;
+
+ std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
+
+private:
+ void init() {
+ validate_and_infer_types();
+ }
+};
+
+template <typename BaseOp>
+void TypeRelaxed<BaseOp>::validate_and_infer_types() {
+ // Remember all input data types and reset them to m_output_data_type.
+ element::TypeVector old_input_types;
+ for (size_t i = 0; i < BaseOp::get_input_size(); ++i) {
+ old_input_types.push_back(BaseOp::get_input_element_type(i));
+ auto origin_input_type = get_origin_input_type(i);
+ if (origin_input_type != element::undefined) {
+ BaseOp::get_input_tensor(i).set_tensor_type(origin_input_type, BaseOp::get_input_partial_shape(i));
+ }
+ }
+
+ BaseOp::validate_and_infer_types();
+
+ // Restore original input data types
+ for (size_t i = 0; i < BaseOp::get_input_size(); ++i) {
+ BaseOp::get_input_tensor(i).set_tensor_type(old_input_types[i], BaseOp::get_input_partial_shape(i));
+ }
+
+ // Override (some) output types
+ for (size_t i = 0; i < BaseOp::get_output_size(); ++i) {
+ auto overridden_output_type = get_overridden_output_type(i);
+ if (overridden_output_type != element::undefined) {
+ BaseOp::set_output_type(0, overridden_output_type, BaseOp::get_output_partial_shape(i));
+ }
+ }
+}
+
+
+template <typename BaseOp>
+std::shared_ptr<Node> TypeRelaxed<BaseOp>::clone_with_new_inputs(const OutputVector& new_args) const {
+ // copy then modify inputs
+ std::shared_ptr<Node> new_node = std::make_shared<TypeRelaxed<BaseOp>>((BaseOp&)(*this), m_input_data_types, m_output_data_types);
+ for (size_t i = 0; i < new_node->get_input_size(); ++i) {
+ new_node->input(i).replace_source_output(new_args[i]);
+ }
+ return new_node;
+}
+
+template <typename BaseOp>
+const ::ngraph::Node::type_info_t& TypeRelaxed<BaseOp>::get_type_info() const { return get_type_info_static(); }
+
+template <typename BaseOp>
+const ::ngraph::Node::type_info_t& TypeRelaxed<BaseOp>::get_type_info_static() {
+ auto baseOpTypeInfoPtr = &BaseOp::get_type_info_static();
+
+ // TODO: it should be static const std::string name = std::string("TypeRelaxed_") + baseOpTypeInfoPtr->name;
+ // but currently it will not pass conversion ot Legacy Opset correctly
+ static const std::string name = baseOpTypeInfoPtr->name;
+
+ static const ::ngraph::Node::type_info_t type_info_static{
+ name.c_str(), baseOpTypeInfoPtr->version, baseOpTypeInfoPtr};
+ return type_info_static;
+}
+
+template <typename BaseOp>
+const ::ngraph::Node::type_info_t TypeRelaxed<BaseOp>::type_info = TypeRelaxed<BaseOp>::get_type_info_static();
+
+} // namespace op
+} // namespace ngraph
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+#include "common_test_utils/test_common.hpp"
+#include <ngraph/opsets/opset1.hpp>
+#include <ngraph_ops/type_relaxed.hpp>
+
+
+namespace element = ngraph::element;
+using std::make_shared;
+using TypeVector = element::TypeVector;
+
+using TypeRelaxedTests = CommonTestUtils::TestsCommon;
+
+TEST_F(TypeRelaxedTests, noOverrideCopyCtor) {
+ std::shared_ptr<ngraph::Function> ngraph;
+ {
+ ngraph::PartialShape shape({1, 3, 22, 22});
+ element::Type type(element::Type_t::f32);
+ auto param = make_shared<ngraph::opset1::Parameter>(type, shape);
+ auto op = ngraph::opset1::Relu(param);
+ auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Relu>>(op);
+ auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+ ngraph::ParameterVector params = {param};
+ ngraph::ResultVector results = {result};
+
+ ngraph = make_shared<ngraph::Function>(results, params);
+
+ ASSERT_EQ(element::f32, relaxed_op->get_input_element_type(0));
+ ASSERT_EQ(element::f32, relaxed_op->get_output_element_type(0));
+ }
+
+ ASSERT_EQ(3, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, overrideOutputCopyCtor) {
+ auto input_type = element::f32;
+ auto overriden_type = element::i32;
+ std::shared_ptr<ngraph::Function> ngraph;
+ {
+ ngraph::PartialShape shape({1, 3, 22, 22});
+ auto param = make_shared<ngraph::opset1::Parameter>(input_type, shape);
+ auto op = ngraph::opset1::Relu(param);
+ auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Relu>>(
+ op, TypeVector{}, TypeVector{overriden_type});
+ auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+ ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param});
+
+ ASSERT_EQ(input_type, relaxed_op->get_input_element_type(0));
+ ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0));
+ }
+
+ ASSERT_EQ(3, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, overrideInputCopyCtor) {
+ auto input_type = element::f32;
+ auto overriden_type = element::i32;
+ std::shared_ptr<ngraph::Function> ngraph;
+ {
+ ngraph::PartialShape shape({1, 3, 22, 22});
+ auto param = make_shared<ngraph::opset1::Parameter>(input_type, shape);
+ auto op = ngraph::opset1::Relu(param);
+ auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Relu>>(
+ op, TypeVector{overriden_type}, TypeVector{});
+ auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+ ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param});
+
+ ASSERT_EQ(input_type, relaxed_op->get_input_element_type(0));
+ ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0));
+ }
+
+ ASSERT_EQ(3, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, mixedInputsAutoOutput) {
+ auto input_type1 = element::u8;
+ auto input_type2 = element::i8;
+ auto overriden_type = element::i16;
+ std::shared_ptr<ngraph::Function> ngraph;
+ {
+ ngraph::PartialShape shape({1, 3, 22, 22});
+ auto param1 = make_shared<ngraph::opset1::Parameter>(input_type1, shape);
+ auto param2 = make_shared<ngraph::opset1::Parameter>(input_type2, shape);
+ auto op = ngraph::opset1::Add(
+ ngraph::op::TemporaryReplaceOutputType(param1->output(0), overriden_type).get(),
+ ngraph::op::TemporaryReplaceOutputType(param2->output(0), overriden_type).get());
+ auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Add>>(
+ op, TypeVector{overriden_type, overriden_type}, TypeVector{});
+ auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+ ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2});
+
+ ASSERT_EQ(input_type1, relaxed_op->get_input_element_type(0));
+ ASSERT_EQ(input_type2, relaxed_op->get_input_element_type(1));
+ ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0));
+ }
+
+ ASSERT_EQ(4, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, mixedInputsAutoOutputForwardCtor) {
+ auto input_type1 = element::u8;
+ auto input_type2 = element::i8;
+ auto overriden_type = element::i16;
+ std::shared_ptr<ngraph::Function> ngraph;
+ {
+ ngraph::PartialShape shape({1, 3, 22, 22});
+ auto param1 = make_shared<ngraph::opset1::Parameter>(input_type1, shape);
+ auto param2 = make_shared<ngraph::opset1::Parameter>(input_type2, shape);
+ auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Add>>(
+ TypeVector{overriden_type, overriden_type}, TypeVector{},
+ ngraph::op::TemporaryReplaceOutputType(param1, overriden_type).get(),
+ ngraph::op::TemporaryReplaceOutputType(param2, overriden_type).get());
+ auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+ ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2});
+
+ ASSERT_EQ(input_type1, relaxed_op->get_input_element_type(0));
+ ASSERT_EQ(input_type2, relaxed_op->get_input_element_type(1));
+ ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0));
+ }
+
+ ASSERT_EQ(4, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, notSupportedTypeOverride) {
+ auto overriden_type = element::u8;
+ auto orig_type = element::boolean;
+ std::shared_ptr<ngraph::Function> ngraph;
+ {
+ ngraph::PartialShape shape({1, 3, 22, 22});
+ auto param1 = make_shared<ngraph::opset1::Parameter>(overriden_type, shape);
+ auto param2 = make_shared<ngraph::opset1::Parameter>(overriden_type, shape);
+ auto op = ngraph::opset1::LogicalAnd(
+ ngraph::op::TemporaryReplaceOutputType(param1, orig_type).get(),
+ ngraph::op::TemporaryReplaceOutputType(param2, orig_type).get());
+ auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::LogicalAnd>>(
+ op, TypeVector{orig_type, orig_type}, TypeVector{overriden_type});
+ auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+ ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2});
+
+ ASSERT_EQ(overriden_type, relaxed_op->get_input_element_type(0));
+ ASSERT_EQ(overriden_type, relaxed_op->get_input_element_type(1));
+ ASSERT_EQ(overriden_type, relaxed_op->get_output_element_type(0));
+ }
+
+ ASSERT_EQ(4, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, notSupportedTypeOverridePartially) {
+ auto some_type = element::u8;
+ auto overriden_type = element::f32;
+ auto orig_type = element::i64;
+ std::shared_ptr<ngraph::Function> ngraph;
+ {
+ ngraph::PartialShape shape({1, 3, 22, 22});
+ auto param1 = make_shared<ngraph::opset1::Parameter>(some_type, shape);
+ auto param2 = make_shared<ngraph::opset1::Parameter>(overriden_type, ngraph::PartialShape{1});
+ auto op = ngraph::opset1::Reshape(
+ param1,
+ ngraph::op::TemporaryReplaceOutputType(param2, orig_type).get(),
+ false);
+ auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Reshape>>(
+ op, TypeVector{element::undefined, orig_type}, TypeVector{});
+ auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+ ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2});
+
+ ASSERT_EQ(some_type, relaxed_op->get_input_element_type(0));
+ ASSERT_EQ(overriden_type, relaxed_op->get_input_element_type(1));
+ ASSERT_EQ(some_type, relaxed_op->get_output_element_type(0));
+ }
+
+ ASSERT_EQ(4, ngraph->get_ops().size());
+}
+
+TEST_F(TypeRelaxedTests, setGetTypes) {
+ std::shared_ptr<ngraph::Function> ngraph;
+ {
+ ngraph::PartialShape shape({1, 3, 22, 22});
+ auto param1 = make_shared<ngraph::opset1::Parameter>(element::u8, shape);
+ auto param2 = make_shared<ngraph::opset1::Parameter>(element::u8, shape);
+ // create TypeRelaxed without any type adjustment, the same behaviour as for opset1::Add
+ auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Add>>(param1, param2);
+ auto result = make_shared<ngraph::opset1::Result>(relaxed_op);
+
+ ngraph = make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{param1, param2});
+
+ ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0));
+ ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1));
+ ASSERT_EQ(element::u8, relaxed_op->get_output_element_type(0));
+
+ // internally set types for opset1::Add inference wasn't set when TypeRelaxed created, check it
+ ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(0));
+ ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(1));
+ // if we access elements outside really existing inputs, it should give undefined as well
+ ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(2));
+ // number of inputs for the operation node shouldn't change after that
+ ASSERT_EQ(2, relaxed_op->get_input_size());
+
+ // similar checks for outputs
+ ASSERT_EQ(element::undefined, relaxed_op->get_overridden_output_type(0));
+ ASSERT_EQ(element::undefined, relaxed_op->get_overridden_output_type(1));
+ ASSERT_EQ(1, relaxed_op->get_output_size());
+
+ // previous checks for input/output indices that are out of number of real inputs/outputs
+ // should resize internal vectors that hold orig/overridden types, it may affect
+ // inference for the op, so here we check if the inference is still OK:
+ ngraph->validate_nodes_and_infer_types();
+
+ // recheck basic statements about input/output types; they should be the same as we haven't changed anything
+ ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0));
+ ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1));
+ ASSERT_EQ(element::u8, relaxed_op->get_output_element_type(0));
+
+ // now we are modifying input types and see if the output type reflects this change
+ relaxed_op->set_origin_input_type(element::i8, 0);
+ relaxed_op->set_origin_input_type(element::i8, 1);
+ ngraph->validate_nodes_and_infer_types();
+ ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0));
+ ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1));
+ ASSERT_EQ(element::i8, relaxed_op->get_output_element_type(0));
+
+ // override output type
+ relaxed_op->set_overridden_output_type(element::f32, 0);
+ ngraph->validate_nodes_and_infer_types();
+ ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0));
+ ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1));
+ ASSERT_EQ(element::f32, relaxed_op->get_output_element_type(0));
+
+ // check if get methods reflect recent changes after set methods
+ ASSERT_EQ(element::i8, relaxed_op->get_origin_input_type(0));
+ ASSERT_EQ(element::i8, relaxed_op->get_origin_input_type(1));
+ ASSERT_EQ(element::f32, relaxed_op->get_overridden_output_type(0));
+
+ // Now, a more advanced trick: set real orig/overridden type for a not existing input/output
+ // it shouldn't affect inference as corresponding inputs/outputs don't exist.
+ // This scenario is tested for cases when we want to set new types for operation that will
+ // be further modified in the code by adding new inputs (Concat) or outputs (Split) and this code
+ // is not aware of TypeRelaxed and shouldn't bother about setting types for new items
+ // (a bit hypothetical though).
+ relaxed_op->set_origin_input_type(element::i32, 2);
+ relaxed_op->set_overridden_output_type(element::i32, 1);
+ ngraph->validate_nodes_and_infer_types();
+ ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0));
+ ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1));
+ ASSERT_EQ(element::f32, relaxed_op->get_output_element_type(0));
+ ASSERT_EQ(2, relaxed_op->get_input_size());
+ ASSERT_EQ(1, relaxed_op->get_output_size());
+
+ // lets try to reset types to undefined again and make sure that all original types are restored
+ relaxed_op->set_origin_input_type(element::undefined, 0);
+ relaxed_op->set_origin_input_type(element::undefined, 1);
+ relaxed_op->set_overridden_output_type(element::undefined, 0);
+ ngraph->validate_nodes_and_infer_types();
+ ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(0));
+ ASSERT_EQ(element::u8, relaxed_op->get_input_element_type(1));
+ ASSERT_EQ(element::u8, relaxed_op->get_output_element_type(0));
+
+ ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(0));
+ ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(1));
+ ASSERT_EQ(element::undefined, relaxed_op->get_origin_input_type(0));
+ }
+
+ ASSERT_EQ(4, ngraph->get_ops().size());
+}