Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / back / kaldi_remove_memory_output_test.py
index c72351c..12269c6 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -31,21 +31,28 @@ class KaldiRemoveMemoryOutputTest(unittest.TestCase):
         },
         'output_node': {
             'kind': 'data'
+        },
+        'op_output': {
+            'kind': 'data',
+            'op': 'OpOutput',
         }
     }
 
     def test_remove_out_data_for_memory(self):
-        graph = build_graph(self.nodes, [('input_node', 'memory_node')])
-        # Need for matching in pattern. The edge memory_node->out_node must contain only the attribute 'out' = 0
-        # build_graph creates edge  memory_node->out_node with attributes 'in' and 'out'
-        graph.add_node('output_node', is_output=True, **self.nodes['output_node'])
-        graph.add_edge('memory_node', 'output_node', out=0)
+        graph = build_graph(self.nodes,
+                            [
+                                ('input_node', 'memory_node'),
+                                ('memory_node', 'output_node'),
+                                ('output_node', 'op_output')
+                            ])
         KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)
         self.assertNotIn('output_node', graph.node)
 
     def test_do_not_remove_out_data_for_memory(self):
-        graph = build_graph(self.nodes, [('input_node', 'memory_node')])
-        graph.add_node('output_node', **self.nodes['output_node'])
-        graph.add_edge('memory_node', 'output_node', out=0)
+        graph = build_graph(self.nodes,
+                            [
+                                ('input_node', 'memory_node'),
+                                ('memory_node', 'output_node'),
+                            ])
         KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)
         self.assertIn('output_node', graph.node)