using namespace std;
using namespace ngraph;
-namespace
+namespace opset1_upgrade
{
template <typename OpV0, typename OpV1>
shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV0>& node)
return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
}
- shared_ptr<Node> op_cast(shared_ptr<op::Broadcast> node)
- {
- auto replacement_node = ngraph::builder::opset1::make_broadcast(
- node->input_value(0), node->get_broadcast_shape(), node->get_broadcast_axes());
- replace_node(node, replacement_node.get_node_shared_ptr());
- return replacement_node.get_node_shared_ptr();
- }
-
- shared_ptr<Node> op_cast(shared_ptr<op::BroadcastLike> node) { return nullptr; }
shared_ptr<Node> op_cast(shared_ptr<op::v0::Convolution> node)
{
auto strides = node->get_window_movement_strides();
return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
}
- shared_ptr<Node> op_cast(shared_ptr<op::Gather> node)
- {
- int64_t axis = node->get_axis();
-
- auto axis_node = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{axis});
- auto replacement_node =
- make_shared<op::v1::Gather>(node->input_value(0), node->input_value(1), axis_node);
- replace_node(node, replacement_node);
- return replacement_node;
- }
-
shared_ptr<Node> op_cast(shared_ptr<op::Greater> node)
{
return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
}
- shared_ptr<Node> op_cast(shared_ptr<op::Max> node)
- {
- bool keep_dims = false;
- auto replacement_node =
- make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims);
- replace_node(node, replacement_node);
- return replacement_node;
- }
-
shared_ptr<Node> op_cast(shared_ptr<op::Maximum> node)
{
return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
}
- shared_ptr<Node> op_cast(shared_ptr<op::Min> node)
- {
- bool keep_dims = false;
- auto replacement_node =
- make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims);
- replace_node(node, replacement_node);
- return replacement_node;
- }
-
shared_ptr<Node> op_cast(shared_ptr<op::Minimum> node)
{
return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
}
- shared_ptr<Node> op_cast(shared_ptr<op::Not> node)
- {
- auto replacement_node = make_shared<op::v1::LogicalNot>(node->input_value(0));
- replace_node(node, replacement_node);
- return replacement_node;
- }
-
shared_ptr<Node> op_cast(shared_ptr<op::NotEqual> node)
{
return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
}
- shared_ptr<Node> op_cast(shared_ptr<op::OneHot> node)
- {
- const auto indices = node->input_value(0).get_node_shared_ptr();
- const auto one_hot_axis = node->get_one_hot_axis();
-
- const auto output_pshape = node->get_output_partial_shape(0);
- NGRAPH_CHECK(output_pshape[one_hot_axis].is_static(),
- "OneHot:v0 one hot axis dimension must be static ",
- *node);
- const auto depth = output_pshape[one_hot_axis].get_length();
- const auto depth_node = op::Constant::create(element::i64, Shape{}, {depth});
-
- const auto on_value = op::Constant::create(element::i64, Shape{}, {1});
- const auto off_value = op::Constant::create(element::i64, Shape{}, {0});
-
- auto replacement_node =
- make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis);
- replace_node(node, replacement_node);
- return replacement_node;
- }
-
shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
{
return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
}
- shared_ptr<Node> op_cast(shared_ptr<op::Product> node)
- {
- bool keep_dims = false;
- auto replacement_node =
- make_shared<op::v1::ReduceProd>(node->input_value(0), node->input_value(1), keep_dims);
- replace_node(node, replacement_node);
- return replacement_node;
- }
-
shared_ptr<Node> op_cast(shared_ptr<op::Reverse> node)
{
// creates a Constant node from the v0::Reverse reversed_axes attribute
return dispatch_map;
NGRAPH_SUPPRESS_DEPRECATED_END
}
-} // namespace
+} // namespace opset1_upgrade
bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
{
bool modified = false;
- auto& dispatch_map = get_dispatch_map();
+ auto& dispatch_map = opset1_upgrade::get_dispatch_map();
auto it = dispatch_map.find(node->get_type_info());
if (it != dispatch_map.end())
{