{ngraph::element::u16, ngraph::element::i32},
{ngraph::element::u32, ngraph::element::i32},
{ngraph::element::f16, ngraph::element::f32},
+ {ngraph::element::boolean, ngraph::element::u8},
};
for (auto & precision : convert_precision_list) {
clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
- // WA: ngraph::pass:ConvertPrecision doesn't support BOOL to U8 conversion
- // so we temporary have to call CNNNetwork ConvertPrecision transformation
- NetPass::ConvertPrecision(*clonedNetwork, Precision::BOOL, Precision::U8);
-
// WA: after conversion to CNNNetwork user precision can redefine input/output precisions
// so we need to apply additional precision conversion but only for inputs and outputs
for (auto & precision : convert_precision_list) {
template <typename BaseOp>
void TypeRelaxed<BaseOp>::validate_and_infer_types() {
- // Remember all input data types and reset them to m_output_data_type.
+ // Remember all input data types
element::TypeVector old_input_types;
for (size_t i = 0; i < BaseOp::get_input_size(); ++i) {
old_input_types.push_back(BaseOp::get_input_element_type(i));
+ }
+
+ // Reset input data types to m_output_data_type.
+ for (size_t i = 0; i < BaseOp::get_input_size(); ++i) {
auto origin_input_type = get_origin_input_type(i);
if (origin_input_type != element::undefined) {
BaseOp::get_input_tensor(i).set_tensor_type(origin_input_type, BaseOp::get_input_partial_shape(i));
* u64 -> i32
* i64 -> i32
* f16 -> f32
+ * bool -> u8
+ * bool -> i32
+ *
* For all operations from opset1-opset4 this conversions can be applied without adding Conversion operations.
* That is possible because all operations that produces "FROM" type can produce "TO" type. And for this operations
* we have created special fuse_type_into_<type> functoin (can be found in cpp file) that performs type fusion
* into operation.
- * List of operations that are supported by this transformations:
+ *
+ * List of operations that are supported by this transformations for i64 -> i32 conversion:
* opset4::Parameter
* opset4::Convert
* opset4::ShapeOf
* opset4::TopK
* opset4::NonZero
* opset4::Bucketize
+ *
+ * List of operations that are supported by this transformations for bool -> u8 conversion:
+ * LogicalAnd
+ * LogicalNot
+ * LogicalOr
+ * LogicalXor
+ * ReduceLogicalAnd
+ * ReduceLogicalOr
+ * Equal
+ * NotEqual
+ * Greater
+ * GreaterEqual
+ * Less
+ * LessEqual
*/
class ngraph::pass::ConvertPrecision : public ngraph::pass::FunctionPass {
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset1.hpp>
+#include <ngraph_ops/type_relaxed.hpp>
using namespace ngraph;
bool fuse_type_to_constant(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, const std::vector<ngraph::Input<ngraph::Node>> & consumers);
bool fuse_type_to_shapeof(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
+bool fuse_type_to_shapeof_v0(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_parameter(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_convert(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_nms3(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_bucketize(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_generic_ie(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
-static std::map<ngraph::NodeTypeInfo, std::function<bool(std::shared_ptr<Node>&, element::Type, size_t idx)>> type_to_fuse {
+bool extend_select_type(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
+
+template <typename T>
+bool fuse_type_to_binary_comparision(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
+ if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
+ type_relaxed->set_overridden_output_type(to);
+ return true;
+ } else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
+ auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<T>>(*casted, element::TypeVector{}, element::TypeVector{to});
+ replace_node(node, relaxed_op);
+ return true;
+ }
+ return false;
+}
+
+template <typename T>
+bool fuse_type_to_logical(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
+ if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
+ type_relaxed->set_overridden_output_type(to);
+ type_relaxed->set_origin_input_type(element::boolean, 0);
+ type_relaxed->set_origin_input_type(element::boolean, 1);
+ return true;
+ } else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
+ auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<T>>(*casted,
+ element::TypeVector{element::boolean, element::boolean}, element::TypeVector{to});
+ replace_node(node, relaxed_op);
+ return true;
+ }
+ return false;
+}
+
+template <class T>
+bool fuse_type_to_reduce_logical(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
+ if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
+ type_relaxed->set_overridden_output_type(to);
+ type_relaxed->set_origin_input_type(element::boolean, 0);
+ return true;
+ } else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
+ auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<T>>(*casted,
+ element::TypeVector{element::boolean}, element::TypeVector{to});
+ replace_node(node, relaxed_op);
+ return true;
+ }
+ return false;
+}
+
+bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Function> f) {
+ static std::map<ngraph::NodeTypeInfo, std::function<bool(std::shared_ptr<Node>&, element::Type, size_t idx)>> type_to_fuse {
{opset4::Parameter::type_info, fuse_type_to_parameter},
{opset4::Convert::type_info, fuse_type_to_convert},
{opset4::ShapeOf::type_info, fuse_type_to_shapeof},
{opset4::NonZero::type_info, fuse_type_to_nonzero},
{opset4::Bucketize::type_info, fuse_type_to_bucketize},
{NodeTypeInfo("GenericIE", 1), fuse_type_to_generic_ie},
-};
+ {opset4::Equal::type_info, fuse_type_to_binary_comparision<opset4::Equal>},
+ {opset4::NotEqual::type_info, fuse_type_to_binary_comparision<opset4::NotEqual>},
+ {opset4::Greater::type_info, fuse_type_to_binary_comparision<opset4::Greater>},
+ {opset4::GreaterEqual::type_info, fuse_type_to_binary_comparision<opset4::GreaterEqual>},
+ {opset4::Less::type_info, fuse_type_to_binary_comparision<opset4::Less>},
+ {opset4::LessEqual::type_info, fuse_type_to_binary_comparision<opset4::LessEqual>},
+ {opset4::LogicalAnd::type_info, fuse_type_to_logical<opset4::LogicalAnd>},
+ {opset4::LogicalOr::type_info, fuse_type_to_logical<opset4::LogicalOr>},
+ {opset4::LogicalXor::type_info, fuse_type_to_logical<opset4::LogicalXor>},
+ {opset4::LogicalNot::type_info, fuse_type_to_logical<opset4::LogicalNot>},
+ {opset4::ReduceLogicalAnd::type_info, fuse_type_to_reduce_logical<opset4::ReduceLogicalAnd>},
+ {opset4::ReduceLogicalOr::type_info, fuse_type_to_reduce_logical<opset4::ReduceLogicalOr>},
+ {opset1::ShapeOf::type_info, fuse_type_to_shapeof_v0}
+ };
+
+ static std::map<ngraph::NodeTypeInfo, std::function<bool(std::shared_ptr<Node>&, element::Type, size_t idx)>> type_to_extend {
+ {opset4::Select::type_info, extend_select_type},
+ };
-bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Function> f) {
// As Constant operations can be shared between multiple nGraph Functions so before
// changing precision we need to understand which Constant consumers belongs
// to the current nGraph Function
std::function<void(const std::shared_ptr<Function> &)> register_constants =
[&const_to_internal_output, ®ister_constants](const std::shared_ptr<Function> & f) {
for (auto & node : f->get_ordered_ops()) {
- // Recursively run for TensorIterator body function
- if (auto ti = std::dynamic_pointer_cast<opset4::TensorIterator>(node)) {
- register_constants(ti->get_body()->to_function());
- }
for (auto & input : node->inputs()) {
if (auto const_node = std::dynamic_pointer_cast<opset4::Constant>(input.get_source_output().get_node_shared_ptr())) {
const_to_internal_output[const_node].emplace_back(input);
}
};
- register_constants(f);
-
- auto convert_node_precision = [this, &const_to_internal_output](std::shared_ptr<Node> & node) {
- // As input type could changed we need to propagate output type calculation manually
- node->validate_and_infer_types();
-
+ auto convert_node_output_precision = [this, &const_to_internal_output](std::shared_ptr<Node> & node) {
for (auto output : node->outputs()) {
if (output.get_element_type() == m_from) {
// Handle case with Constants as they can have consumers from other nGraph Function object
if (ngraph::op::is_constant(node) && const_to_internal_output.count(node)) {
fuse_type_to_constant(node, m_to, const_to_internal_output.at(node));
- continue;
+ break;
}
// If node type in map and convert can be fused into node we skip Convert creation
if (type_to_fuse.count(node->get_type_info()) &&
type_to_fuse.at(node->get_type_info())(node, m_to, output.get_index())) {
- node->validate_and_infer_types();
- continue;
+ break;
}
+ }
+ }
+ };
- // Create Convert operation and reconnect consumers
- auto consumers = output.get_target_inputs();
- auto convert = std::make_shared<opset4::Convert>(output, m_to);
- for (auto & input : consumers) {
- input.replace_source_output(convert);
+ auto convert_node_input_precision = [this](std::shared_ptr<Node> & node) {
+ for (auto input : node->inputs()) {
+ if (input.get_element_type() == m_from) {
+ // For some operations we need to extend their input types to support new type
+ if (type_to_extend.count(node->get_type_info()) &&
+ type_to_extend.at(node->get_type_info())(node, m_to, input.get_index())) {
+ break;
}
}
}
};
std::function<void(const std::shared_ptr<Function> &)> convert_function_precision =
- [this, &const_to_internal_output, &convert_node_precision, &convert_function_precision](const std::shared_ptr<Function> & f) {
+ [this, &const_to_internal_output,
+ ®ister_constants,
+ &convert_node_output_precision,
+ &convert_node_input_precision,
+ &convert_function_precision] (const std::shared_ptr<Function> & f) {
// Iterate over all nodes in topological order and then iterate over node outputs.
// If output type mismatch given type we try to fuse type into this operation
// otherwise we insert Convert operation.
if (auto ti = std::dynamic_pointer_cast<opset4::TensorIterator>(node)) {
convert_function_precision(ti->get_body()->to_function());
}
- convert_node_precision(node);
+ convert_node_input_precision(node);
+ }
+ // Register internal constants only after fixing input type that could lead to nodes replacement
+ register_constants(f);
+
+ for (auto &node : f->get_ordered_ops()) {
+ convert_node_output_precision(node);
}
};
convert_function_precision(f);
+ f->validate_nodes_and_infer_types();
// TODO: we need to split NopElimination pass to separate MatcherPasses and call Convert elimination here
for (auto &node : f->get_ordered_ops()) {
bool fuse_type_to_parameter(std::shared_ptr<Node> & node, element::Type to, size_t idx) {
if (auto param = as_type_ptr<opset4::Parameter>(node)) {
param->set_element_type(to);
+ param->validate_and_infer_types();
return true;
}
return false;
return true;
}
+bool fuse_type_to_shapeof_v0(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
+ if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
+ type_relaxed->set_overridden_output_type(to);
+ return true;
+ } else if (auto casted = std::dynamic_pointer_cast<opset1::ShapeOf>(node)) {
+ auto relaxed_op = std::make_shared<ngraph::op::TypeRelaxed<opset1::ShapeOf>>(*casted,
+ element::TypeVector{}, element::TypeVector{to});
+ replace_node(node, relaxed_op);
+ return true;
+ }
+ return false;
+}
+
+bool extend_select_type(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
+ if (auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(node)) {
+ type_relaxed->set_origin_input_type(element::boolean, 0);
+ return true;
+ } else if (auto casted = std::dynamic_pointer_cast<opset4::Select>(node)) {
+ auto relaxed_op = std::make_shared<op::TypeRelaxed<opset4::Select>>(*casted,
+ element::TypeVector{element::boolean},
+ element::TypeVector{});
+ replace_node(node, relaxed_op);
+ return true;
+ }
+ return false;
+}
+
template <element::Type_t PREC_FROM, element::Type_t PREC_TO>
std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::Constant> & constant) {
using src_type = typename element_type_traits<PREC_FROM>::value_type;
#include <transformations/convert_precision.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
+#include <ngraph_ops/type_relaxed.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
}
}
+TEST(TransformationTests, ConvertPrecision_Equal) {
+ std::shared_ptr<Function> f(nullptr);
+ {
+ auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+ auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+ auto node = std::make_shared<ngraph::opset4::Equal>(input1, input2);
+
+ f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
+ manager.run_passes(f);
+ }
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+ ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_NotEqual) {
+ std::shared_ptr<Function> f(nullptr);
+ {
+ auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+ auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+ auto node = std::make_shared<ngraph::opset4::NotEqual>(input1, input2);
+
+ f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
+ manager.run_passes(f);
+ }
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+ ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_Greater) {
+ std::shared_ptr<Function> f(nullptr);
+ {
+ auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+ auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+ auto node = std::make_shared<ngraph::opset4::Greater>(input1, input2);
+
+ f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
+ manager.run_passes(f);
+ }
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+ ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_GreaterEqual) {
+ std::shared_ptr<Function> f(nullptr);
+ {
+ auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+ auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+ auto node = std::make_shared<ngraph::opset4::GreaterEqual>(input1, input2);
+
+ f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
+ manager.run_passes(f);
+ }
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+ ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_Less) {
+ std::shared_ptr<Function> f(nullptr);
+ {
+ auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+ auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+ auto node = std::make_shared<ngraph::opset4::Less>(input1, input2);
+
+ f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
+ manager.run_passes(f);
+ }
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+ ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_LessEqual) {
+ std::shared_ptr<Function> f(nullptr);
+ {
+ auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+ auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{15, 20, 3});
+ auto node = std::make_shared<ngraph::opset4::LessEqual>(input1, input2);
+
+ f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
+ manager.run_passes(f);
+ }
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+ ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_LogicalAnd) {
+ std::shared_ptr<Function> f(nullptr);
+ {
+ auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+ auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+ auto node = std::make_shared<ngraph::opset4::LogicalAnd>(input1, input2);
+
+ f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+ manager.run_passes(f);
+ }
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+ ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_LogicalOr) {
+ std::shared_ptr<Function> f(nullptr);
+ {
+ auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+ auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+ auto node = std::make_shared<ngraph::opset4::LogicalOr>(input1, input2);
+
+ f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+ manager.run_passes(f);
+ }
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+ ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_LogicalXor) {
+ std::shared_ptr<Function> f(nullptr);
+ {
+ auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+ auto input2 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+ auto node = std::make_shared<ngraph::opset4::LogicalXor>(input1, input2);
+
+ f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1, input2});
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+ manager.run_passes(f);
+ }
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+ ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_LogicalNot) {
+ std::shared_ptr<Function> f(nullptr);
+ {
+ auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+ auto node = std::make_shared<ngraph::opset4::LogicalNot>(input1);
+
+ f = std::make_shared<Function>(OutputVector{node}, ParameterVector{input1});
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+ manager.run_passes(f);
+ }
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+ ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_Select) {
+ std::shared_ptr<Function> f(nullptr);
+ {
+ auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+ auto node = std::make_shared<ngraph::opset4::LogicalNot>(input1);
+ auto select = std::make_shared<ngraph::opset4::Select>(node, input1, input1);
+
+ f = std::make_shared<Function>(OutputVector{select}, ParameterVector{input1});
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
+ manager.run_passes(f);
+ }
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+ ASSERT_TRUE(has_type<ngraph::element::Type_t::u8>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_TypeRelaxedWithSelect) {
+ std::shared_ptr<Function> f(nullptr);
+ {
+ auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+ auto node = std::make_shared<ngraph::opset4::LogicalNot>(input1);
+ auto select = std::make_shared<ngraph::opset4::Select>(node, input1, input1);
+
+ f = std::make_shared<Function>(OutputVector{select}, ParameterVector{input1});
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::i32);
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i32, ngraph::element::i64);
+ manager.run_passes(f);
+ }
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::i32>(f));
+ ASSERT_TRUE(has_type<ngraph::element::Type_t::i64>(f));
+}
+
+TEST(TransformationTests, ConvertPrecision_TypeRelaxed) {
+ std::shared_ptr<Function> f(nullptr);
+ {
+ auto input1 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::boolean, ngraph::Shape{15, 20, 3});
+ auto select = std::make_shared<ngraph::opset4::Select>(input1, input1, input1);
+ auto type_relaxed = std::make_shared<op::TypeRelaxed<opset4::Select>>(*select, element::TypeVector{}, element::TypeVector{element::i64});
+
+ f = std::make_shared<Function>(OutputVector{type_relaxed}, ParameterVector{input1});
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::i32);
+ manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i32, ngraph::element::i64);
+ manager.run_passes(f);
+
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::boolean>(f));
+ ASSERT_FALSE(has_type<ngraph::element::Type_t::i32>(f));
+ ASSERT_TRUE(has_type<ngraph::element::Type_t::i64>(f));
+ }
+}
+
TEST(TransformationTests, ConvertPrecision_Variables) {
std::shared_ptr<ngraph::Function> f(nullptr);
{
ASSERT_EQ(4, ngraph->get_ops().size());
}
+
+TEST_F(TypeRelaxedTests, OneOutputMultipleInputPorts) {
+ std::shared_ptr<ngraph::Function> f;
+ {
+ auto param1 = make_shared<ngraph::opset1::Parameter>(element::boolean, ngraph::Shape{1, 3, 22, 22});
+ auto op = ngraph::opset1::Select(param1, param1, param1);
+ auto relaxed_op = make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Select>>(
+ op, TypeVector{}, TypeVector{element::i64});
+
+ f = make_shared<ngraph::Function>(ngraph::OutputVector{relaxed_op}, ngraph::ParameterVector{param1});
+
+ // Prepare relaxed op for input change
+ relaxed_op->set_origin_input_type(element::boolean, 0);
+
+ // Change Parameter element type
+ param1->set_element_type(element::i64);
+ param1->validate_and_infer_types();
+ ASSERT_EQ(param1->output(0).get_element_type(), element::i64);
+
+ // Check that after restoring original precisions inside validate_and_infer_types
+ // function we do not corrupt original types
+ relaxed_op->validate_and_infer_types();
+ ASSERT_EQ(param1->output(0).get_element_type(), element::i64);
+ }
+}
\ No newline at end of file
class NGRAPH_API Equal : public util::BinaryElementwiseComparison
{
public:
- static constexpr NodeTypeInfo type_info{"Equal", 1};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
/// \brief Constructs an equal operation.
Equal()
: util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
class NGRAPH_API Greater : public util::BinaryElementwiseComparison
{
public:
- static constexpr NodeTypeInfo type_info{"Greater", 1};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a greater-than operation.
Greater()
: util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
class NGRAPH_API GreaterEqual : public util::BinaryElementwiseComparison
{
public:
- static constexpr NodeTypeInfo type_info{"GreaterEqual", 1};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a greater-than-or-equal operation.
GreaterEqual()
: util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
class NGRAPH_API Less : public util::BinaryElementwiseComparison
{
public:
- static constexpr NodeTypeInfo type_info{"Less", 1};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a less-than operation.
Less()
: util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
class NGRAPH_API LessEqual : public util::BinaryElementwiseComparison
{
public:
- static constexpr NodeTypeInfo type_info{"LessEqual", 1};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a less-than-or-equal operation.
LessEqual()
: util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
class NGRAPH_API LogicalNot : public Op
{
public:
- static constexpr NodeTypeInfo type_info{"LogicalNot", 1};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a logical negation operation.
LogicalNot() = default;
/// \brief Constructs a logical negation operation.
class NGRAPH_API NotEqual : public util::BinaryElementwiseComparison
{
public:
- static constexpr NodeTypeInfo type_info{"NotEqual", 1};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a not-equal operation.
NotEqual()
: util::BinaryElementwiseComparison(AutoBroadcastSpec::NUMPY)
class NGRAPH_API LogicalOr : public util::BinaryElementwiseLogical
{
public:
- static constexpr NodeTypeInfo type_info{"LogicalOr", 1};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
LogicalOr() = default;
/// \brief Constructs a logical-or operation.
///
class NGRAPH_API ReduceLogicalAnd : public util::LogicalReductionKeepDims
{
public:
- static constexpr NodeTypeInfo type_info{"ReduceLogicalAnd", 1};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
ReduceLogicalAnd() = default;
/// \brief Constructs a ReduceLogicalAnd node.
///
class NGRAPH_API ReduceLogicalOr : public util::LogicalReductionKeepDims
{
public:
- static constexpr NodeTypeInfo type_info{"ReduceLogicalOr", 1};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
ReduceLogicalOr() = default;
/// \brief Constructs a ReduceLogicalOr node.
///
class NGRAPH_API Select : public Op
{
public:
- static constexpr NodeTypeInfo type_info{"Select", 1};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a selection operation.
Select()
: m_auto_broadcast(AutoBroadcastSpec(AutoBroadcastType::NUMPY))
class NGRAPH_API ShapeOf : public Op
{
public:
- static constexpr NodeTypeInfo type_info{"ShapeOf", 0};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
ShapeOf() = default;
/// \brief Constructs a shape-of operation.
ShapeOf(const Output<Node>& arg);
class NGRAPH_API LogicalXor : public util::BinaryElementwiseLogical
{
public:
- static constexpr NodeTypeInfo type_info{"LogicalXor", 1};
- const NodeTypeInfo& get_type_info() const override { return type_info; }
+ NGRAPH_RTTI_DECLARATION;
LogicalXor() = default;
/// \brief Constructs a logical-xor operation.
///
//------------------------------- v1 -------------------------------------------
-constexpr NodeTypeInfo op::v1::Equal::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::Equal, "Equal", 1);
op::v1::Equal::Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
//-------------------------------------- v1 ------------------------------------
-constexpr NodeTypeInfo op::v1::Greater::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::Greater, "Greater", 1);
op::v1::Greater::Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
//---------------------------------- v1 ----------------------------------------
-constexpr NodeTypeInfo op::v1::GreaterEqual::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::GreaterEqual, "GreaterEqual", 1);
op::v1::GreaterEqual::GreaterEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
// ----------------------------- v1 --------------------------------------------
-constexpr NodeTypeInfo op::v1::Less::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::Less, "Less", 1);
op::v1::Less::Less(const Output<Node>& arg0,
const Output<Node>& arg1,
// ---------------------------------- v1 ---------------------------------------
-constexpr NodeTypeInfo op::v1::LessEqual::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::LessEqual, "LessEqual", 1);
op::v1::LessEqual::LessEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
using namespace ngraph;
using namespace std;
-constexpr NodeTypeInfo op::v1::LogicalNot::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::LogicalNot, "LogicalNot", 1);
op::v1::LogicalNot::LogicalNot(const Output<Node>& arg)
: Op({arg})
// ----------------------------------- v1 --------------------------------------
-constexpr NodeTypeInfo op::v1::NotEqual::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::NotEqual, "NotEqual", 1);
op::v1::NotEqual::NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
using namespace std;
using namespace ngraph;
-constexpr NodeTypeInfo op::v1::LogicalOr::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::LogicalOr, "LogicalOr", 1, util::BinaryElementwiseLogical);
op::v1::LogicalOr::LogicalOr(const Output<Node>& arg0,
const Output<Node>& arg1,
using namespace ngraph;
using namespace std;
-constexpr NodeTypeInfo op::v1::ReduceLogicalAnd::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalAnd, "ReduceLogicalAnd", 1);
op::v1::ReduceLogicalAnd::ReduceLogicalAnd(const Output<Node>& data,
const Output<Node>& reduction_axes,
using namespace ngraph;
using namespace std;
-constexpr NodeTypeInfo op::v1::ReduceLogicalOr::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalOr, "ReduceLogicalOr", 1);
op::v1::ReduceLogicalOr::ReduceLogicalOr(const Output<Node>& data,
const Output<Node>& reduction_axes,
using namespace std;
using namespace ngraph;
-constexpr NodeTypeInfo op::v1::Select::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::Select, "Select", 1);
op::v1::Select::Select(const Output<Node>& arg0,
const Output<Node>& arg1,
}
// op::v0::ShapeOf
-constexpr NodeTypeInfo op::v0::ShapeOf::type_info;
+NGRAPH_RTTI_DEFINITION(op::v0::ShapeOf, "ShapeOf", 0);
op::v0::ShapeOf::ShapeOf(const Output<Node>& arg)
: Op({arg})
using namespace std;
using namespace ngraph;
-constexpr NodeTypeInfo op::v1::LogicalXor::type_info;
+NGRAPH_RTTI_DEFINITION(op::v1::LogicalXor, "LogicalXor", 1, util::BinaryElementwiseLogical);
op::v1::LogicalXor::LogicalXor(const Output<Node>& arg0,
const Output<Node>& arg1,