Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / SharedWeightsDuplication_test.py
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -18,10 +18,12 @@ import unittest
 
 import numpy as np
 
-from mo.middle.passes.shared_weights_duplication import duplicate_shared_weights
+from extensions.middle.SharedWeightsDuplication import SharedWeightsDuplication
+from mo.middle.passes.eliminate import graph_clean_up
 from mo.utils.unittest.graph import build_graph, compare_graphs
 
 nodes_attributes = {
+    'const': {'shape': None, 'type': 'Const', 'kind': 'op', 'op': 'Const'},
     # Mul and Add operations
     'mul_1': {'type': None, 'kind': 'op', 'op': 'Mul'},
     'mul_1_w': {'value': None, 'shape': None, 'kind': 'data'},
@@ -35,13 +37,15 @@ nodes_attributes = {
     # Concat1 operation
     'concat_1': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
     'concat_1_data': {'value': None, 'shape': None, 'kind': 'data'},
+    'op_output': {'op': 'OpOutput', 'kind': 'op'}
 }
 
 
 class DuplicateSharedWeightsTests(unittest.TestCase):
     def test_duplicate_shared_weights_1(self):
         graph = build_graph(nodes_attributes,
-                            [('mul_1_w', 'mul_1'),
+                            [('const', 'mul_1_w'),
+                             ('mul_1_w', 'mul_1'),
                              ('mul_1', 'mul_1_data'),
                              ('mul_1_w', 'mul_2'),
                              ('mul_2', 'mul_2_data'),
@@ -50,12 +54,16 @@ class DuplicateSharedWeightsTests(unittest.TestCase):
                              ('mul_1_data', 'concat_1'),
                              ('mul_2_data', 'concat_1'),
                              ('mul_3_data', 'concat_1'),
-                             ('concat_1', 'concat_1_data')
+                             ('concat_1', 'concat_1_data'),
+                             ('concat_1_data', 'op_output')
                              ],
-                            {'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])}})
+                            {'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])}},
+                            nodes_with_edges_only=True
+                            )
 
         graph_ref = build_graph(nodes_attributes,
-                                [('mul_1_w', 'mul_1'),
+                                [
+                                 ('mul_1_w', 'mul_1'),
                                  ('mul_1', 'mul_1_data'),
                                  ('mul_2_w', 'mul_2'),
                                  ('mul_2', 'mul_2_data'),
@@ -64,14 +72,16 @@ class DuplicateSharedWeightsTests(unittest.TestCase):
                                  ('mul_1_data', 'concat_1'),
                                  ('mul_2_data', 'concat_1'),
                                  ('mul_3_data', 'concat_1'),
-                                 ('concat_1', 'concat_1_data')
-                                 ],
+                                 ('concat_1', 'concat_1_data'),
+                                 ('concat_1_data', 'op_output')
+                                ],
                                 {'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
                                  'mul_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
                                  'mul_3_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
-                                 })
-
-        duplicate_shared_weights(graph)
+                                 }, nodes_with_edges_only=True)
 
+        SharedWeightsDuplication().find_and_replace_pattern(graph)
+        graph_clean_up(graph)
+        graph_clean_up(graph_ref)
         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
-        self.assertTrue(flag, resp)
+        self.assertTrue(flag, resp)
\ No newline at end of file