return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
}
- shared_ptr<Node> op_cast(shared_ptr<op::Gather> node)
- {
- int64_t axis = node->get_axis();
-
- auto axis_node = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{axis});
- auto replacement_node =
- make_shared<op::v1::Gather>(node->input_value(0), node->input_value(1), axis_node);
- replace_node(node, replacement_node);
- return replacement_node;
- }
-
shared_ptr<Node> op_cast(shared_ptr<op::Greater> node)
{
return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
}
- shared_ptr<Node> op_cast(shared_ptr<op::Max> node)
- {
- bool keep_dims = false;
- auto replacement_node =
- make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims);
- replace_node(node, replacement_node);
- return replacement_node;
- }
-
shared_ptr<Node> op_cast(shared_ptr<op::Maximum> node)
{
return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
}
- shared_ptr<Node> op_cast(shared_ptr<op::Min> node)
- {
- bool keep_dims = false;
- auto replacement_node =
- make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims);
- replace_node(node, replacement_node);
- return replacement_node;
- }
-
shared_ptr<Node> op_cast(shared_ptr<op::Minimum> node)
{
return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
}
- shared_ptr<Node> op_cast(shared_ptr<op::Not> node)
- {
- auto replacement_node = make_shared<op::v1::LogicalNot>(node->input_value(0));
- replace_node(node, replacement_node);
- return replacement_node;
- }
-
shared_ptr<Node> op_cast(shared_ptr<op::NotEqual> node)
{
return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
}
- shared_ptr<Node> op_cast(shared_ptr<op::OneHot> node)
- {
- const auto indices = node->input_value(0).get_node_shared_ptr();
- const auto one_hot_axis = node->get_one_hot_axis();
-
- const auto output_pshape = node->get_output_partial_shape(0);
- NGRAPH_CHECK(output_pshape[one_hot_axis].is_static(),
- "OneHot:v0 one hot axis dimension must be static ",
- *node);
- const auto depth = output_pshape[one_hot_axis].get_length();
- const auto depth_node = op::Constant::create(element::i64, Shape{}, {depth});
-
- const auto on_value = op::Constant::create(element::i64, Shape{}, {1});
- const auto off_value = op::Constant::create(element::i64, Shape{}, {0});
-
- auto replacement_node =
- make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis);
- replace_node(node, replacement_node);
- return replacement_node;
- }
-
shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
{
return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
}
- shared_ptr<Node> op_cast(shared_ptr<op::Product> node)
- {
- bool keep_dims = false;
- auto replacement_node =
- make_shared<op::v1::ReduceProd>(node->input_value(0), node->input_value(1), keep_dims);
- replace_node(node, replacement_node);
- return replacement_node;
- }
-
shared_ptr<Node> op_cast(shared_ptr<op::Reverse> node)
{
// creates a Constant node from the v0::Reverse reversed_axes attribute