Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / test / runtime / pass / implicit_broadcast_elimination.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "implicit_broadcast_elimination.hpp"
18
19 #include "ngraph/builder/autobroadcast.hpp"
20 #include "ngraph/graph_util.hpp"
21 #include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
22 #include "ngraph/op/util/binary_elementwise_comparison.hpp"
23 #include "ngraph/op/util/binary_elementwise_logical.hpp"
24 #include "ngraph/op/util/op_types.hpp"
25
26 NGRAPH_SUPPRESS_DEPRECATED_START
27
28 using namespace std;
29 using namespace ngraph;
30
31 bool ngraph::pass::ImplicitBroadcastElimination::run_on_node(std::shared_ptr<Node> node)
32 {
33     if (ngraph::op::supports_auto_broadcast(node))
34     {
35         if (node->get_autob().m_type != op::AutoBroadcastType::NONE)
36         {
37             auto new_args = pass::explicit_broadcast(node);
38             for (size_t i = 0; i < new_args.size(); i++)
39             {
40                 node->input(i).replace_source_output(new_args[i]->output(0));
41             }
42             return true;
43         }
44     }
45     return false;
46 }
47
48 NodeVector ngraph::pass::explicit_broadcast(std::shared_ptr<Node>& node)
49 {
50     NodeVector rc;
51     if (ngraph::op::supports_auto_broadcast(node))
52     {
53         auto autob = node->get_autob();
54         if (autob.m_type == op::AutoBroadcastType::NONE)
55         {
56             for (auto& val : node->input_values())
57                 rc.emplace_back(val.get_node_shared_ptr());
58         }
59         else if (autob.m_type == op::AutoBroadcastType::NUMPY)
60         {
61             rc = as_node_vector(builder::numpy_broadcast_outputs(node->input_values()));
62         }
63         else if (autob.m_type == op::AutoBroadcastType::PDPD)
64         {
65             rc = as_node_vector(builder::pdpd_broadcast(node->input_values(), autob.m_axis));
66         }
67         else
68         {
69             throw ngraph_error("Unsupported implicit broadcast type");
70         }
71     }
72     return rc;
73 }