[NGraph] Add scatterNDUpdate and scatterUpdate reference implementations (#1494)
[platform/upstream/dldt.git] / ngraph / test / runtime / opset1_downgrade.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <algorithm>
18
19 #include "ngraph/node.hpp"
20 #include "ngraph/ops.hpp"
21 #include "ngraph/provenance.hpp"
22 #include "ngraph/validation_util.hpp"
23 #include "opset1_downgrade.hpp"
24
25 using namespace std;
26 using namespace ngraph;
27
28 namespace
29 {
30     shared_ptr<Node> op_cast(shared_ptr<op::v3::Broadcast> node)
31     {
32         const auto data = node->input_value(0).get_node_shared_ptr();
33         const auto target_shape = node->input_value(1).get_node_shared_ptr();
34
35         shared_ptr<Node> replacement_node;
36         switch (node->get_broadcast_spec().m_type)
37         {
38         case op::BroadcastType::BIDIRECTIONAL:
39         {
40             const auto const_filled_with_ones = make_shared<op::v1::Broadcast>(
41                 op::Constant::create(data->get_element_type(), {}, {1}), target_shape);
42             replacement_node = make_shared<op::v1::Multiply>(data, const_filled_with_ones);
43             break;
44         }
45         case op::BroadcastType::EXPLICIT:
46         {
47             const auto axes_mapping = node->input_value(2).get_node_shared_ptr();
48             replacement_node = make_shared<op::v1::Broadcast>(
49                 data, target_shape, axes_mapping, op::AutoBroadcastType::EXPLICIT);
50             break;
51         }
52         case op::BroadcastType::NUMPY:
53         {
54             replacement_node =
55                 make_shared<op::v1::Broadcast>(data, target_shape, op::AutoBroadcastType::NUMPY);
56             break;
57         }
58         case op::BroadcastType::PDPD:
59         {
60             op::AutoBroadcastSpec broadcast_spec;
61             broadcast_spec.m_type = op::AutoBroadcastType::PDPD;
62             broadcast_spec.m_axis = node->get_broadcast_spec().m_axis;
63             replacement_node = make_shared<op::v1::Broadcast>(data, target_shape, broadcast_spec);
64             break;
65         }
66         default:
67         {
68             NGRAPH_CHECK(
69                 true,
70                 "Not supported broadcast type during Broadcast:v3 to Broadcast:v1 conversion. ",
71                 "Node: ",
72                 *node);
73         }
74         }
75         replace_node(node, replacement_node);
76         return replacement_node;
77     }
78
79     shared_ptr<Node> op_cast(shared_ptr<op::v3::TopK> node)
80     {
81         const auto data = node->input_value(0);
82         const auto k = node->input_value(1);
83         const auto replacement_node = make_shared<op::v1::TopK>(data,
84                                                                 k,
85                                                                 node->get_axis(),
86                                                                 node->get_mode(),
87                                                                 node->get_sort_type(),
88                                                                 node->get_index_element_type());
89         replace_node(node, replacement_node);
90         return replacement_node;
91     }
92
93     using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
94
95     template <typename T>
96     bool op_cast_thunk(shared_ptr<Node> node)
97     {
98         auto downgraded_node = op_cast(as_type_ptr<T>(node));
99         if (downgraded_node)
100         {
101             if (ngraph::get_provenance_enabled())
102             {
103                 const std::string provenance_tag =
104                     "<Opset1_Downgrade (v3 " + std::string(node->get_type_name()) + ")>";
105                 downgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
106             }
107             return true;
108         }
109         return false;
110     }
111
112     DispatchMap& get_dispatch_map()
113     {
114         static DispatchMap dispatch_map{
115 #define NGRAPH_OP(NAME, NAMESPACE) {NAMESPACE::NAME::type_info, op_cast_thunk<NAMESPACE::NAME>},
116             NGRAPH_OP(Broadcast, op::v3) NGRAPH_OP(TopK, op::v3)
117 #undef NGRAPH_OP
118         };
119         return dispatch_map;
120     }
121 } // namespace
122
123 bool pass::Opset1Downgrade::run_on_node(shared_ptr<Node> node)
124 {
125     bool modified = false;
126     auto& dispatch_map = get_dispatch_map();
127     auto it = dispatch_map.find(node->get_type_info());
128     if (it != dispatch_map.end())
129     {
130         modified = it->second(node);
131     }
132     return modified;
133 }