"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import unittest
import numpy as np
+from argparse import Namespace
from extensions.middle.FusePermutesSequence import FusePermutesSequence
from mo.middle.passes.eliminate_test import build_graph
'permute_3': {'type': 'Permute', 'value': None, 'kind': 'op', 'op': 'Permute'},
'permute_3_data': {'value': None, 'shape': None, 'kind': 'data'},
+ 'op_output': { 'op': 'OpOutput', 'kind': 'op'}
}
('placeholder_1_data', 'permute_1'),
('permute_1', 'permute_1_data'),
('permute_1_data', 'permute_2'),
- ('permute_2', 'permute_2_data')
+ ('permute_2', 'permute_2_data'),
+ ('permute_2_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
'permute_1_data': {'shape': np.array([1, 3, 227, 227])},
'permute_2': {'order': np.array([0, 2, 3, 1])},
- 'permute_2_data': {'shape': np.array([1, 227, 227, 3]), 'is_output': True},
+ 'permute_2_data': {'shape': np.array([1, 227, 227, 3])},
}, nodes_with_edges_only=True)
graph.graph['layout'] = 'NHWC'
+ graph.graph['cmd_params'] = Namespace(keep_shape_ops=False)
graph_ref = build_graph(nodes_attributes,
- [('placeholder_1', 'placeholder_1_data')],
- {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])}}, nodes_with_edges_only=True)
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])}},
+ nodes_with_edges_only=True)
pattern = FusePermutesSequence()
pattern.find_and_replace_pattern(graph)
('placeholder_1_data', 'permute_1'),
('permute_1', 'permute_1_data'),
('permute_1_data', 'permute_2'),
- ('permute_2', 'permute_2_data')
+ ('permute_2', 'permute_2_data'),
+ ('permute_2_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
'permute_1_data': {'shape': np.array([1, 3, 227, 227])},
'permute_2': {'order': np.array([0, 1, 2, 3])},
- 'permute_2_data': {'shape': np.array([1, 3, 227, 227]), 'is_output': True},
+ 'permute_2_data': {'shape': np.array([1, 3, 227, 227])},
}, nodes_with_edges_only=True)
graph.graph['layout'] = 'NHWC'
+ graph.graph['cmd_params'] = Namespace(keep_shape_ops=False)
graph_ref = build_graph(nodes_attributes,
- [('placeholder_1', 'placeholder_1_data'),
- ('placeholder_1_data', 'permute_1'),
- ('permute_1', 'permute_1_data'),
- ],
- {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
- 'permute_1': {'order': np.array([0, 3, 1, 2])},
- 'permute_1_data': {'shape': np.array([1, 3, 227, 227])},
- }, nodes_with_edges_only=True)
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'permute_1'),
+ ('permute_1', 'permute_1_data'),
+ ('permute_1_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
+ 'permute_1': {'order': np.array([0, 3, 1, 2])},
+ 'permute_1_data': {'shape': np.array([1, 3, 227, 227])},
+ }, nodes_with_edges_only=True)
pattern = FusePermutesSequence()
pattern.find_and_replace_pattern(graph)