Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / back / PermuteForReshape_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import unittest
17
18 import numpy as np
19
20 from extensions.back.PermuteForReshape import PermuteForReshape
21 from mo.graph.graph import Node
22 from mo.ops.op import PermuteAttrs
23 from mo.utils.unittest.graph import build_graph_with_attrs, compare_graphs
24
25
26 class ReshapeToPermuteTest(unittest.TestCase):
27     nodes = [
28         ('input_data', {'kind': 'data', 'shape': None}),
29         ('reshape', {'kind': 'op', 'op': 'Squeeze', 'type': 'Reshape', 'dim': None}),
30         ('reshape_data', {'kind': 'data'}),
31     ]
32     edges = [
33         ('input_data', 'reshape'),
34         ('reshape', 'reshape_data'),
35     ]
36
37     permute_nodes = [
38         ('permute', {'kind': 'op', 'op': 'Permute'}),
39         ('permute_data', {'kind': 'data', 'shape': None})
40     ]
41     permute_edges = [
42         ('input_data', 'permute'),
43         ('permute', 'permute_data'),
44         ('permute_data', 'reshape'),
45     ]
46
47     def test_from3D_to3D(self):
48         input_shape = np.array([2, 3, 4])
49         new_shape = np.array([2, 3, 4])
50         graph = build_graph_with_attrs(
51             nodes_with_attrs=self.nodes,
52             edges_with_attrs=self.edges,
53             update_nodes_attributes=[('input_data', {'shape': input_shape}),
54                                      ('reshape', {'dim': new_shape}),
55                                      ('reshape_data', {'shape': new_shape})]
56         )
57         graph.graph['layout'] = 'NHWC'
58         # add permute attrs to reshape
59         reshape = Node(graph, 'reshape')
60         PermuteAttrs.create_permute_attrs(reshape, attrs=[('dim', 'output:0')])
61
62         tested_pattern = PermuteForReshape()
63         tested_pattern.find_and_replace_pattern(graph)
64         (flag, resp) = compare_graphs(graph, graph, last_node='reshape_data')
65         self.assertTrue(flag, resp)
66
67     def test_from4D_to3D(self):
68         input_shape = np.array([1, 2, 3, 4])
69         new_shape = np.array([3, 4, 2])
70         nhwc_shape = np.array([1, 3, 4, 2])
71         graph = build_graph_with_attrs(
72             nodes_with_attrs=self.nodes,
73             edges_with_attrs=self.edges,
74             update_nodes_attributes=[('input_data', {'shape': input_shape}),
75                                      ('reshape', {'dim': new_shape}),
76                                      ('reshape_data', {'shape': new_shape})]
77         )
78         graph.graph['layout'] = 'NHWC'
79         # add permute attrs to reshape
80         reshape = Node(graph, 'reshape')
81         PermuteAttrs.create_permute_attrs(reshape, attrs=[('dim', 'output:0')])
82
83         tested_pattern = PermuteForReshape()
84         tested_pattern.find_and_replace_pattern(graph)
85         graph_ref = build_graph_with_attrs(
86             nodes_with_attrs=self.nodes + self.permute_nodes,
87             edges_with_attrs=self.edges[1:] + self.permute_edges,
88             update_nodes_attributes=[('input_data', {'shape': input_shape}),
89                                      ('reshape', {'dim': new_shape}),
90                                      ('reshape_data', {'shape': new_shape}),
91                                      ('permute_data', {'shape': nhwc_shape})]
92         )
93         # check graphs equality
94         (flag, resp) = compare_graphs(graph, graph_ref, last_node='reshape_data')
95         self.assertTrue(flag, resp)
96
97         # check righ order in new permutation node
98         permute_order = graph.node['reshape/Permute_']['order']
99         self.assertTrue(np.all(permute_order == np.array([0, 2, 3, 1]))) # from NCHW to NHWC
100
101     def test_from_5D_to_3D(self):
102         input_shape = np.array([1, 2, 1, 3, 4]) #  NCDHW 1 1 3 4 2
103         new_shape = np.array([3, 4, 2])
104         nhwc_shape = np.array([1, 1, 3, 4, 2])
105         graph = build_graph_with_attrs(
106             nodes_with_attrs=self.nodes,
107             edges_with_attrs=self.edges,
108             update_nodes_attributes=[('input_data', {'shape': input_shape}),
109                                      ('reshape', {'dim': new_shape}),
110                                      ('reshape_data', {'shape': new_shape})]
111         )
112         graph.graph['layout'] = 'NHWC'
113         # add permute attrs to reshape
114         reshape = Node(graph, 'reshape')
115         PermuteAttrs.create_permute_attrs(reshape, attrs=[('dim', 'output:0')])
116
117         tested_pattern = PermuteForReshape()
118         tested_pattern.find_and_replace_pattern(graph)
119         graph_ref = build_graph_with_attrs(
120             nodes_with_attrs=self.nodes + self.permute_nodes,
121             edges_with_attrs=self.edges[1:] + self.permute_edges,
122             update_nodes_attributes=[('input_data', {'shape': input_shape}),
123                                      ('reshape', {'dim': new_shape}),
124                                      ('reshape_data', {'shape': new_shape}),
125                                      ('permute_data', {'shape': nhwc_shape})]
126         )
127         # check graphs equality
128         (flag, resp) = compare_graphs(graph, graph_ref, last_node='reshape_data')
129         self.assertTrue(flag, resp)
130
131         # check righ order in new permutation node
132         permute_order = graph.node['reshape/Permute_']['order']
133         self.assertTrue(np.all(permute_order == np.array([0, 2, 3, 4, 1])))  # from NCDHW to NDHWC