Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / test / control_dependencies.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <algorithm>
18 #include <cstdio>
19 #include <iostream>
20 #include <list>
21 #include <memory>
22
23 #include "gtest/gtest.h"
24 #include "ngraph/file_util.hpp"
25 #include "ngraph/graph_util.hpp"
26 #include "ngraph/log.hpp"
27 #include "ngraph/ngraph.hpp"
28 #include "ngraph/op/batch_norm.hpp"
29 #include "ngraph/op/parameter.hpp"
30 #include "ngraph/pass/manager.hpp"
31 #include "ngraph/pass/visualize_tree.hpp"
32 #include "ngraph/pattern/matcher.hpp"
33 #include "ngraph/util.hpp"
34 #include "util/all_close.hpp"
35 #include "util/ndarray.hpp"
36 #include "util/random.hpp"
37 #include "util/test_tools.hpp"
38
39 NGRAPH_SUPPRESS_DEPRECATED_START
40
41 using namespace ngraph;
42 using namespace std;
43
44 class ControlDependencyOp : public ngraph::op::Op
45 {
46 public:
47     static constexpr NodeTypeInfo type_info{"ControlDependencyOp", 0};
48     const NodeTypeInfo& get_type_info() const override { return type_info; }
49     virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override
50     {
51         auto clone = make_shared<ControlDependencyOp>(new_args, std::set<std::shared_ptr<Node>>{});
52         return move(clone);
53     }
54
55     ControlDependencyOp(const OutputVector& args, const std::set<std::shared_ptr<Node>>& deps)
56         : Op(args)
57     {
58         if (args.size() == 0 && deps.size() == 0)
59         {
60             throw ngraph_error("Expected some arguments or dependencies");
61         }
62
63         for (auto& node : deps)
64         {
65             add_control_dependency(node);
66         }
67
68         if (args.size() != 0)
69         {
70             set_output_type(0, args.at(0).get_element_type(), args.at(0).get_shape());
71         }
72         else
73         {
74             auto dn = *(deps.begin());
75             set_output_type(0, dn->get_element_type(), dn->get_shape());
76         }
77     }
78 };
79 constexpr NodeTypeInfo ControlDependencyOp::type_info;
80
81 TEST(control_dependencies, cdep_ops)
82 {
83     auto A = make_shared<op::Parameter>(element::f32, Shape{});
84     auto B = make_shared<op::Parameter>(element::f32, Shape{});
85     auto absn = make_shared<op::Abs>(A);
86     auto cdop =
87         make_shared<ControlDependencyOp>(OutputVector{A}, std::set<std::shared_ptr<Node>>{absn});
88
89     auto f = make_shared<Function>(cdop, ParameterVector{A, B});
90     test_ordered_ops(f, NodeVector{absn});
91 }
92
93 TEST(control_dependencies, two_cdep_ops)
94 {
95     auto A = make_shared<op::Parameter>(element::f32, Shape{});
96     auto B = make_shared<op::Parameter>(element::f32, Shape{});
97     auto absn = make_shared<op::Abs>(A);
98     auto C = make_shared<op::Parameter>(element::f32, Shape{});
99     auto absn_c = make_shared<op::Abs>(C);
100     auto cdop = make_shared<ControlDependencyOp>(OutputVector{A},
101                                                  std::set<std::shared_ptr<Node>>{absn, absn_c});
102
103     auto f = make_shared<Function>(cdop, ParameterVector{A, B, C});
104     test_ordered_ops(f, NodeVector{absn, absn_c});
105 }
106
107 TEST(control_dependencies, two_cdep_ops_op_on_top)
108 {
109     auto A = make_shared<op::Parameter>(element::f32, Shape{});
110     auto absn = make_shared<op::Abs>(A);
111     auto B = make_shared<op::Parameter>(element::f32, Shape{});
112     auto absn_b = make_shared<op::Abs>(B);
113     auto cdop = make_shared<ControlDependencyOp>(OutputVector{A},
114                                                  std::set<std::shared_ptr<Node>>{absn, absn_b});
115     auto absn_cdop = make_shared<op::Abs>(cdop);
116
117     auto f = make_shared<Function>(absn_cdop, ParameterVector{A, B});
118     test_ordered_ops(f, NodeVector{absn, absn_b});
119 }
120
121 TEST(control_dependencies, clone_function_cdop)
122 {
123     auto A = make_shared<op::Parameter>(element::f32, Shape{});
124     auto absn = make_shared<op::Abs>(A);
125     auto cdop =
126         make_shared<ControlDependencyOp>(OutputVector{A}, std::set<std::shared_ptr<Node>>{absn});
127
128     auto f = make_shared<Function>(cdop, ParameterVector{A});
129     test_ordered_ops(f, NodeVector{absn});
130     auto clone = ngraph::clone_function(*f.get());
131     auto matcher = std::make_shared<pattern::Matcher>(cdop);
132     auto cdop_clone = clone->get_results().at(0)->input_value(0).get_node_shared_ptr();
133     ASSERT_TRUE(matcher->match(cdop_clone));
134     auto cloned_deps = cdop_clone->get_control_dependencies();
135     ASSERT_EQ(cloned_deps.size(), 1);
136     auto cloned_abs = *begin(cloned_deps);
137     ASSERT_TRUE(is_type<op::Abs>(cloned_abs));
138 }
139
140 TEST(control_dependencies, clone_function_cdop_abs)
141 {
142     auto A = make_shared<op::Parameter>(element::f32, Shape{});
143     auto absn = make_shared<op::Abs>(A);
144     auto B = make_shared<op::Parameter>(element::f32, Shape{});
145     auto absn_b = make_shared<op::Abs>(B);
146     auto cdop = make_shared<ControlDependencyOp>(OutputVector{A},
147                                                  std::set<std::shared_ptr<Node>>{absn, absn_b});
148     auto absn_cdop = make_shared<op::Abs>(cdop);
149
150     auto f = make_shared<Function>(absn_cdop, ParameterVector{A, B});
151     auto clone = ngraph::clone_function(*f.get());
152     auto matcher = std::make_shared<pattern::Matcher>(cdop);
153     auto cdop_clone = clone->get_results()
154                           .at(0)
155                           ->input_value(0)
156                           .get_node_shared_ptr()
157                           ->input_value(0)
158                           .get_node_shared_ptr();
159     ASSERT_TRUE(matcher->match(cdop_clone));
160     auto cloned_deps = cdop_clone->get_control_dependencies();
161     ASSERT_EQ(cloned_deps.size(), 2);
162     for (auto ccdep : cloned_deps)
163     {
164         ASSERT_TRUE(is_type<op::Abs>(ccdep));
165     }
166 }
167
168 static size_t count_control_dependencies(const shared_ptr<Node>& node,
169                                          const shared_ptr<Node>& dependency)
170 {
171     auto& dependencies = node->get_control_dependencies();
172     return count(dependencies.begin(), dependencies.end(), dependency);
173 }
174
175 TEST(control_dependencies, replace_node)
176 {
177     Shape shape{2, 2};
178     auto A = make_shared<op::Parameter>(element::f32, shape);
179     auto B = make_shared<op::Parameter>(element::f32, shape);
180     auto MUL_AB = A * B;
181     auto MUL_BA = B * A;
182     auto ADD = A + B;
183     auto SUM = MUL_AB + ADD;
184     ADD->add_control_dependency(MUL_AB);
185     ASSERT_TRUE(1 == count_control_dependencies(ADD, MUL_AB));
186     ASSERT_TRUE(0 == count_control_dependencies(ADD, MUL_BA));
187     replace_node(MUL_AB, MUL_BA);
188     ASSERT_TRUE(0 == count_control_dependencies(ADD, MUL_AB));
189     ASSERT_TRUE(1 == count_control_dependencies(ADD, MUL_BA));
190 }