Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / eliminate_test.py
index 79b892c..f253dde 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -18,9 +18,8 @@ import unittest
 
 import numpy as np
 
-from mo.graph.graph import Node, erase_node
-from mo.middle.passes.eliminate import mark_output_reachable_nodes, graph_clean_up, \
-    get_nodes_with_attributes, mark_const_producer_nodes
+from mo.graph.graph import Node, Graph
+from mo.middle.passes.eliminate import mark_output_reachable_nodes, graph_clean_up, mark_const_producer_nodes
 from mo.utils.unittest.graph import build_graph
 
 nodes_attributes = {'placeholder_1': {'type': 'Placeholder', 'kind': 'op'},
@@ -38,11 +37,14 @@ nodes_attributes = {'placeholder_1': {'type': 'Placeholder', 'kind': 'op'},
                     'data_node_3': {'value': None, 'kind': 'data'},
                     'data_node_3_2': {'value': None, 'kind': 'data'},
                     'data_node_4': {'value': None, 'kind': 'data'},
-                    'data_node_5': {'value': None, 'kind': 'data'},
-                    'data_node_6': {'value': None, 'kind': 'data'},
+                    'data_node_5': {'value': None, 'shape': None, 'kind': 'data'},
+                    'data_node_6': {'value': None, 'shape': None, 'kind': 'data'},
                     'tf_call_1': {'type': 'TFCustomSubgraphCall', 'kind': 'op'},
                     'tf_call_2': {'type': 'TFCustomSubgraphCall', 'kind': 'op'},
                     'tf_call_3': {'type': 'TFCustomSubgraphCall', 'kind': 'op'},
+                    'op_output': {'kind': 'op', 'op': 'OpOutput'},
+                    'op_output_1': {'kind': 'op', 'op': 'OpOutput'},
+                    'op_output_2': {'kind': 'op', 'op': 'OpOutput'}
                     }
 
 
@@ -63,15 +65,17 @@ class TestEliminatePass(unittest.TestCase):
                             [('placeholder_1', 'node_1'),
                              ('node_1', 'node_2'),
                              ('placeholder_1', 'node_3'),
-                             ('node_3', 'node_4')],
-                            {'node_4': {'is_output': True}},
+                             ('node_3', 'node_4'),
+                             ('node_4', 'op_output')
+                             ],
+                            {'node_4': {}},
                             nodes_with_edges_only=True)
         mark_output_reachable_nodes(graph)
 
-        self.assertListEqual(sorted(['placeholder_1', 'node_3', 'node_4']),
-                             sorted(get_nodes_with_attributes(graph, is_output_reachable=True)))
+        self.assertListEqual(sorted(['placeholder_1', 'node_3', 'op_output', 'node_4']),
+                             sorted(graph.get_nodes_with_attributes(is_output_reachable=True)))
         self.assertListEqual(sorted(['node_1', 'node_2']),
-                             sorted(get_nodes_with_attributes(graph, is_output_reachable=False)))
+                             sorted(graph.get_nodes_with_attributes(is_output_reachable=False)))
 
     def test_mark_output_unreachable_nodes_behind_output(self):
         """
@@ -86,13 +90,15 @@ class TestEliminatePass(unittest.TestCase):
         graph = build_graph(nodes_attributes,
                             [('placeholder_1', 'node_1'),
                              ('node_1', 'node_2'),
-                             ('node_2', 'node_3')],
-                            {'node_2': {'is_output': True}},
+                             ('node_2', 'node_3'),
+                             ('node_2', 'op_output')
+                             ],
+                            {'node_2': {}},
                             nodes_with_edges_only=True)
         mark_output_reachable_nodes(graph)
 
-        self.assertListEqual(sorted(['placeholder_1', 'node_1', 'node_2']),
-                             sorted(get_nodes_with_attributes(graph, is_output_reachable=True)))
+        self.assertListEqual(sorted(['node_1', 'node_2', 'op_output', 'placeholder_1']),
+                             sorted(graph.get_nodes_with_attributes(is_output_reachable=True)))
         self.assertFalse(graph.node['node_3']['is_output_reachable'])
 
     def test_mark_ops_producing_constant_values(self):
@@ -128,16 +134,19 @@ class TestEliminatePass(unittest.TestCase):
                              ('data_node_3_2', 'node_5'),
                              ('node_5', 'data_node_5'),
                              ('data_node_3', 'node_4'),
-                             ('data_node_4', 'node_1')],
-                            {'data_node_2': {'is_output': True},
-                             'data_node_5': {'is_output': True},
+                             ('data_node_4', 'node_1'),
+                             ('data_node_2', 'op_output'),
+                             ('data_node_5', 'op_output_1')
+                             ],
+                            {'data_node_2': {},
+                             'data_node_5': {},
                              'data_node_3': {'value': np.array(1)},
                              'data_node_6': {'value': np.array(1)}},
                             nodes_with_edges_only=True)
         mark_const_producer_nodes(graph)
         self.assertTrue((graph.node['node_6']['is_const_producer']))
         self.assertListEqual(sorted(['node_1', 'node_2', 'node_3', 'node_5', 'placeholder_1']),
-                             sorted(get_nodes_with_attributes(graph, is_const_producer=False, kind='op')))
+                             sorted(graph.get_nodes_with_attributes(is_const_producer=False, kind='op')))
 
         graph_clean_up(graph)
         self.assertTrue('node_3' in graph.nodes())
@@ -166,6 +175,6 @@ class TestEliminatePass(unittest.TestCase):
                              ('node_1', 'node_2'),
                              ('node_2', 'node_3')],
                             nodes_with_edges_only=True)
-        erase_node(Node(graph, 'node_2'))
+        graph.erase_node(Node(graph, 'node_2'))
 
         self.assertListEqual(sorted(['placeholder_1', 'node_1', 'node_3']), sorted(graph.nodes()))