"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import numpy as np
-from mo.middle.passes.shared_weights_duplication import duplicate_shared_weights
+from extensions.middle.SharedWeightsDuplication import SharedWeightsDuplication
+from mo.middle.passes.eliminate import graph_clean_up
from mo.utils.unittest.graph import build_graph, compare_graphs
nodes_attributes = {
+ 'const': {'shape': None, 'type': 'Const', 'kind': 'op', 'op': 'Const'},
# Mul and Add operations
'mul_1': {'type': None, 'kind': 'op', 'op': 'Mul'},
'mul_1_w': {'value': None, 'shape': None, 'kind': 'data'},
# Concat1 operation
'concat_1': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
'concat_1_data': {'value': None, 'shape': None, 'kind': 'data'},
+ 'op_output': {'op': 'OpOutput', 'kind': 'op'}
}
class DuplicateSharedWeightsTests(unittest.TestCase):
def test_duplicate_shared_weights_1(self):
graph = build_graph(nodes_attributes,
- [('mul_1_w', 'mul_1'),
+ [('const', 'mul_1_w'),
+ ('mul_1_w', 'mul_1'),
('mul_1', 'mul_1_data'),
('mul_1_w', 'mul_2'),
('mul_2', 'mul_2_data'),
('mul_1_data', 'concat_1'),
('mul_2_data', 'concat_1'),
('mul_3_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
],
- {'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])}})
+ {'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])}},
+ nodes_with_edges_only=True
+ )
graph_ref = build_graph(nodes_attributes,
- [('mul_1_w', 'mul_1'),
+ [
+ ('mul_1_w', 'mul_1'),
('mul_1', 'mul_1_data'),
('mul_2_w', 'mul_2'),
('mul_2', 'mul_2_data'),
('mul_1_data', 'concat_1'),
('mul_2_data', 'concat_1'),
('mul_3_data', 'concat_1'),
- ('concat_1', 'concat_1_data')
- ],
+ ('concat_1', 'concat_1_data'),
+ ('concat_1_data', 'op_output')
+ ],
{'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
'mul_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
'mul_3_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
- })
-
- duplicate_shared_weights(graph)
+ }, nodes_with_edges_only=True)
+ SharedWeightsDuplication().find_and_replace_pattern(graph)
+ graph_clean_up(graph)
+ graph_clean_up(graph_ref)
(flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
- self.assertTrue(flag, resp)
+ self.assertTrue(flag, resp)
\ No newline at end of file