Added unit tests and readme for model optimizer (#79)
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / partial_infer / split_test.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20
21 from mo.front.common.partial_infer.split import tf_split_infer, tf_unpack_infer, tf_split_v_infer, split
22 from mo.front.common.partial_infer.utils import int64_array
23 from mo.graph.graph import Node
24 from mo.utils.unittest.graph import build_graph, build_graph_with_edge_attrs
25
26
27 class TestTFSplitInfer(unittest.TestCase):
28     graph = None
29
30     def setUp(self):
31         self.graph = build_graph({'split_dim': {'value': None, 'kind': 'data'},
32                                   'data_to_split': {'value': None, 'shape': None, 'kind': 'data'},
33                                   'split_node': {'kind': 'op', 'op': 'Split', 'num_split': 3, 'axis': None},
34                                   'out_data_1': {'value': None, 'shape': None, 'kind': 'data'},
35                                   'out_data_2': {'value': None, 'shape': None, 'kind': 'data'},
36                                   'out_data_3': {'value': None, 'shape': None, 'kind': 'data'},
37                                   },
38                                  [('split_dim', 'split_node'),
39                                   ('data_to_split', 'split_node'),
40                                   ('split_node', 'out_data_1'),
41                                   ('split_node', 'out_data_2'),
42                                   ('split_node', 'out_data_3'),
43                                   ])
44
45     def test_tf_split_infer(self):
46         split_node = Node(self.graph, 'split_node')
47         self.graph.node['split_dim']['value'] = np.array(1)
48         self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30])
49
50         tf_split_infer(split_node)
51         exp_shape = int64_array([2, 4, 25, 30])
52         for out_node in split_node.out_nodes().values():
53             self.assertTrue(np.all(exp_shape == out_node.shape))
54         self.assertEqual(1, split_node.input_port)
55
56     def test_tf_split_infer_negative_index(self):
57         split_node = Node(self.graph, 'split_node')
58         self.graph.node['split_dim']['value'] = np.array(-3)
59         self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30])
60
61         tf_split_infer(split_node)
62         exp_shape = int64_array([2, 4, 25, 30])
63         for out_node in split_node.out_nodes().values():
64             self.assertTrue(np.all(exp_shape == out_node.shape))
65         self.assertEqual(1, split_node.input_port)
66
67     def test_tf_split_infer_unknown_index(self):
68         split_node = Node(self.graph, 'split_node')
69         self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30])
70
71         tf_split_infer(split_node)
72         for out_node in split_node.out_nodes().values():
73             self.assertIsNone(out_node.shape)
74
75     def test_tf_split_infer_input_shape_is_None(self):
76         split_node = Node(self.graph, 'split_node')
77         self.graph.node['split_dim']['value'] = np.array(1)
78
79         tf_split_infer(split_node)
80         for out_node in split_node.out_nodes().values():
81             self.assertIsNone(out_node.shape)
82
83     def test_tf_split_infer_wrong_num_split(self):
84         split_node = Node(self.graph, 'split_node')
85         self.graph.node['split_dim']['value'] = np.array(0)
86         self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30])
87
88         tf_split_infer(split_node)
89         for out_node in split_node.out_nodes().values():
90             self.assertIsNone(out_node.shape)
91
92
93 class TestTFSplitVInfer(unittest.TestCase):
94     graph = None
95
96     def setUp(self):
97         self.graph = build_graph({'data_to_split': {'value': None, 'shape': None, 'kind': 'data'},
98                                   'size_splits': {'value': [3, 5, 4], 'kind': 'data'},
99                                   'split_dim': {'value': None, 'kind': 'data'},
100                                   'split_node': {'kind': 'op', 'op': 'Split', 'axis': None},
101                                   'out_data_1': {'value': None, 'shape': None, 'kind': 'data'},
102                                   'out_data_2': {'value': None, 'shape': None, 'kind': 'data'},
103                                   'out_data_3': {'value': None, 'shape': None, 'kind': 'data'},
104                                   },
105                                  [('data_to_split', 'split_node'),
106                                   ('size_splits', 'split_node'),
107                                   ('split_dim', 'split_node'),
108                                   ('split_node', 'out_data_1'),
109                                   ('split_node', 'out_data_2'),
110                                   ('split_node', 'out_data_3'),
111                                   ])
112
113     def test_tf_split_infer_three_inputs(self):
114         split_node = Node(self.graph, 'split_node')
115         self.graph.node['split_dim']['value'] = np.array(1)
116         self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30])
117
118         tf_split_v_infer(split_node)
119         exp_shape = [int64_array([2, 3, 25, 30]), int64_array([2, 5, 25, 30]), int64_array([2, 4, 25, 30])]
120         for ind, out_node in split_node.out_nodes().items():
121             self.assertTrue(np.all(exp_shape[ind] == out_node.shape))
122
123
124 class TestTFUnpack(unittest.TestCase):
125     graph = None
126
127     def setUp(self):
128         self.graph = build_graph({'data_to_split': {'value': None, 'shape': None, 'kind': 'data'},
129                                   'unpack': {'kind': 'op', 'op': 'Split', 'num_split': 3, 'axis': None},
130                                   'out_data_1': {'value': None, 'shape': None, 'kind': 'data'},
131                                   'out_data_2': {'value': None, 'shape': None, 'kind': 'data'},
132                                   'out_data_3': {'value': None, 'shape': None, 'kind': 'data'},
133                                   'out_data_4': {'value': None, 'shape': None, 'kind': 'data'},
134                                   },
135                                  [('data_to_split', 'unpack'),
136                                   ('unpack', 'out_data_1'),
137                                   ('unpack', 'out_data_2'),
138                                   ('unpack', 'out_data_3'),
139                                   ])
140
141     def test_tf_unpack_infer(self):
142         unpack_node = Node(self.graph, 'unpack')
143         self.graph.node['unpack']['axis'] = np.array(1)
144         self.graph.node['data_to_split']['shape'] = int64_array([2, 3, 25, 30])
145
146         tf_unpack_infer(unpack_node)
147         exp_shape = int64_array([2, 1, 25, 30])
148         for out_node in unpack_node.out_nodes().values():
149             self.assertTrue(np.all(exp_shape == out_node.shape))
150
151     def test_tf_unpack_infer_default_number_of_pieces(self):
152         unpack_node = Node(self.graph, 'unpack')
153         self.graph.node['unpack']['axis'] = np.array(1)
154         self.graph.node['unpack']['num_split'] = None
155         self.graph.node['data_to_split']['shape'] = int64_array([2, 3, 25, 30])
156
157         tf_unpack_infer(unpack_node)
158         exp_shape = int64_array([2, 1, 25, 30])
159         for out_node in unpack_node.out_nodes().values():
160             self.assertTrue(np.all(exp_shape == out_node.shape))
161
162     def test_tf_unpack_infer_not_supported(self):
163         # the case when the size of the dimension being unpacked is not equal to number of pieces is not supported
164         unpack_node = Node(self.graph, 'unpack')
165         self.graph.node['unpack']['axis'] = np.array(1)
166         self.graph.node['data_to_split']['shape'] = int64_array([2, 6, 25, 30])
167
168         tf_unpack_infer(unpack_node)
169         for out_node in unpack_node.out_nodes().values():
170             self.assertIsNone(out_node.shape)
171
172
173 class TestSplitFunc(unittest.TestCase):
174     graph = None
175
176     def setUp(self):
177         self.graph = build_graph_with_edge_attrs(
178             {'data_to_split': {'value': None, 'shape': int64_array([2, 12, 25, 44]), 'kind': 'data'},
179              'split_node': {'kind': 'op', 'op': 'Split', 'axis': None},
180              'out_data_2': {'value': None, 'shape': None, 'kind': 'data'},
181              'out_data_5': {'value': None, 'shape': None, 'kind': 'data'},
182              'out_data_7': {'value': None, 'shape': None, 'kind': 'data'},
183              },
184             [('data_to_split', 'split_node', {'in': 0}),
185              ('split_node', 'out_data_2', {'out': 2}),
186              ('split_node', 'out_data_5', {'out': 5}),
187              ('split_node', 'out_data_7', {'out': 7}),
188              ])
189
190     def test_split_non_sequential_output_port(self):
191         split(Node(self.graph, 'data_to_split'), Node(self.graph, 'split_node'), -1, [3, 2, 7, 5, 6, 4, 9, 8])
192         self.assertTrue(np.all(Node(self.graph, 'out_data_2').shape == [2, 12, 25, 7]))
193         self.assertTrue(np.all(Node(self.graph, 'out_data_5').shape == [2, 12, 25, 4]))
194         self.assertTrue(np.all(Node(self.graph, 'out_data_7').shape == [2, 12, 25, 8]))
195
196     def test_split_value_infer_non_sequential_output_port(self):
197         data_node = Node(self.graph, 'data_to_split')
198         value = np.array(range(2 * 12 * 25 * 44)).reshape(data_node.shape)
199         data_node.value = value.copy()
200         split(data_node, Node(self.graph, 'split_node'), -1, [3, 2, 7, 5, 6, 4, 9, 8])
201         self.assertTrue(np.all(Node(self.graph, 'out_data_2').shape == [2, 12, 25, 7]))
202         self.assertTrue(np.all(Node(self.graph, 'out_data_5').shape == [2, 12, 25, 4]))
203         self.assertTrue(np.all(Node(self.graph, 'out_data_7').shape == [2, 12, 25, 8]))
204
205         self.assertTrue(np.all(Node(self.graph, 'out_data_2').value == value[:, :, :, 5:12]))
206         self.assertTrue(np.all(Node(self.graph, 'out_data_5').value == value[:, :, :, 23:27]))
207         self.assertTrue(np.all(Node(self.graph, 'out_data_7').value == value[:, :, :, 36:]))