1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
19 #include "ngraph/node.hpp"
20 #include "ngraph/ops.hpp"
21 #include "ngraph/provenance.hpp"
22 #include "ngraph/validation_util.hpp"
23 #include "opset1_downgrade.hpp"
26 using namespace ngraph;
30 shared_ptr<Node> op_cast(shared_ptr<op::v3::Broadcast> node)
32 const auto data = node->input_value(0).get_node_shared_ptr();
33 const auto target_shape = node->input_value(1).get_node_shared_ptr();
35 shared_ptr<Node> replacement_node;
36 switch (node->get_broadcast_spec().m_type)
38 case op::BroadcastType::BIDIRECTIONAL:
40 const auto const_filled_with_ones = make_shared<op::v1::Broadcast>(
41 op::Constant::create(data->get_element_type(), {}, {1}), target_shape);
42 replacement_node = make_shared<op::v1::Multiply>(data, const_filled_with_ones);
45 case op::BroadcastType::EXPLICIT:
47 const auto axes_mapping = node->input_value(2).get_node_shared_ptr();
48 replacement_node = make_shared<op::v1::Broadcast>(
49 data, target_shape, axes_mapping, op::AutoBroadcastType::EXPLICIT);
52 case op::BroadcastType::NUMPY:
55 make_shared<op::v1::Broadcast>(data, target_shape, op::AutoBroadcastType::NUMPY);
58 case op::BroadcastType::PDPD:
60 op::AutoBroadcastSpec broadcast_spec;
61 broadcast_spec.m_type = op::AutoBroadcastType::PDPD;
62 broadcast_spec.m_axis = node->get_broadcast_spec().m_axis;
63 replacement_node = make_shared<op::v1::Broadcast>(data, target_shape, broadcast_spec);
70 "Not supported broadcast type during Broadcast:v3 to Broadcast:v1 conversion. ",
75 replace_node(node, replacement_node);
76 return replacement_node;
79 shared_ptr<Node> op_cast(shared_ptr<op::v3::TopK> node)
81 const auto data = node->input_value(0);
82 const auto k = node->input_value(1);
83 const auto replacement_node = make_shared<op::v1::TopK>(data,
87 node->get_sort_type(),
88 node->get_index_element_type());
89 replace_node(node, replacement_node);
90 return replacement_node;
93 using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
96 bool op_cast_thunk(shared_ptr<Node> node)
98 auto downgraded_node = op_cast(as_type_ptr<T>(node));
101 if (ngraph::get_provenance_enabled())
103 const std::string provenance_tag =
104 "<Opset1_Downgrade (v3 " + std::string(node->get_type_name()) + ")>";
105 downgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
112 DispatchMap& get_dispatch_map()
114 static DispatchMap dispatch_map{
115 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
116 NGRAPH_OP(Broadcast, op::v3) NGRAPH_OP(TopK, op::v3)
123 bool pass::Opset1Downgrade::run_on_node(shared_ptr<Node> node)
125 bool modified = false;
126 auto& dispatch_map = get_dispatch_map();
127 auto it = dispatch_map.find(node->get_type_info());
128 if (it != dispatch_map.end())
130 modified = it->second(node);