2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.front.common.partial_infer.split import tf_split_infer, tf_unpack_infer, tf_split_v_infer, split
22 from mo.front.common.partial_infer.utils import int64_array
23 from mo.graph.graph import Node
24 from mo.utils.unittest.graph import build_graph, build_graph_with_edge_attrs
27 class TestTFSplitInfer(unittest.TestCase):
31 self.graph = build_graph({'split_dim': {'value': None, 'kind': 'data'},
32 'data_to_split': {'value': None, 'shape': None, 'kind': 'data'},
33 'split_node': {'kind': 'op', 'op': 'Split', 'num_split': 3, 'axis': None},
34 'out_data_1': {'value': None, 'shape': None, 'kind': 'data'},
35 'out_data_2': {'value': None, 'shape': None, 'kind': 'data'},
36 'out_data_3': {'value': None, 'shape': None, 'kind': 'data'},
38 [('split_dim', 'split_node'),
39 ('data_to_split', 'split_node'),
40 ('split_node', 'out_data_1'),
41 ('split_node', 'out_data_2'),
42 ('split_node', 'out_data_3'),
45 def test_tf_split_infer(self):
46 split_node = Node(self.graph, 'split_node')
47 self.graph.node['split_dim']['value'] = np.array(1)
48 self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30])
50 tf_split_infer(split_node)
51 exp_shape = int64_array([2, 4, 25, 30])
52 for out_node in split_node.out_nodes().values():
53 self.assertTrue(np.all(exp_shape == out_node.shape))
54 self.assertEqual(1, split_node.input_port)
56 def test_tf_split_infer_negative_index(self):
57 split_node = Node(self.graph, 'split_node')
58 self.graph.node['split_dim']['value'] = np.array(-3)
59 self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30])
61 tf_split_infer(split_node)
62 exp_shape = int64_array([2, 4, 25, 30])
63 for out_node in split_node.out_nodes().values():
64 self.assertTrue(np.all(exp_shape == out_node.shape))
65 self.assertEqual(1, split_node.input_port)
67 def test_tf_split_infer_unknown_index(self):
68 split_node = Node(self.graph, 'split_node')
69 self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30])
71 tf_split_infer(split_node)
72 for out_node in split_node.out_nodes().values():
73 self.assertIsNone(out_node.shape)
75 def test_tf_split_infer_input_shape_is_None(self):
76 split_node = Node(self.graph, 'split_node')
77 self.graph.node['split_dim']['value'] = np.array(1)
79 tf_split_infer(split_node)
80 for out_node in split_node.out_nodes().values():
81 self.assertIsNone(out_node.shape)
83 def test_tf_split_infer_wrong_num_split(self):
84 split_node = Node(self.graph, 'split_node')
85 self.graph.node['split_dim']['value'] = np.array(0)
86 self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30])
88 tf_split_infer(split_node)
89 for out_node in split_node.out_nodes().values():
90 self.assertIsNone(out_node.shape)
93 class TestTFSplitVInfer(unittest.TestCase):
97 self.graph = build_graph({'data_to_split': {'value': None, 'shape': None, 'kind': 'data'},
98 'size_splits': {'value': [3, 5, 4], 'kind': 'data'},
99 'split_dim': {'value': None, 'kind': 'data'},
100 'split_node': {'kind': 'op', 'op': 'Split', 'axis': None},
101 'out_data_1': {'value': None, 'shape': None, 'kind': 'data'},
102 'out_data_2': {'value': None, 'shape': None, 'kind': 'data'},
103 'out_data_3': {'value': None, 'shape': None, 'kind': 'data'},
105 [('data_to_split', 'split_node'),
106 ('size_splits', 'split_node'),
107 ('split_dim', 'split_node'),
108 ('split_node', 'out_data_1'),
109 ('split_node', 'out_data_2'),
110 ('split_node', 'out_data_3'),
113 def test_tf_split_infer_three_inputs(self):
114 split_node = Node(self.graph, 'split_node')
115 self.graph.node['split_dim']['value'] = np.array(1)
116 self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30])
118 tf_split_v_infer(split_node)
119 exp_shape = [int64_array([2, 3, 25, 30]), int64_array([2, 5, 25, 30]), int64_array([2, 4, 25, 30])]
120 for ind, out_node in split_node.out_nodes().items():
121 self.assertTrue(np.all(exp_shape[ind] == out_node.shape))
124 class TestTFUnpack(unittest.TestCase):
128 self.graph = build_graph({'data_to_split': {'value': None, 'shape': None, 'kind': 'data'},
129 'unpack': {'kind': 'op', 'op': 'Split', 'num_split': 3, 'axis': None},
130 'out_data_1': {'value': None, 'shape': None, 'kind': 'data'},
131 'out_data_2': {'value': None, 'shape': None, 'kind': 'data'},
132 'out_data_3': {'value': None, 'shape': None, 'kind': 'data'},
133 'out_data_4': {'value': None, 'shape': None, 'kind': 'data'},
135 [('data_to_split', 'unpack'),
136 ('unpack', 'out_data_1'),
137 ('unpack', 'out_data_2'),
138 ('unpack', 'out_data_3'),
141 def test_tf_unpack_infer(self):
142 unpack_node = Node(self.graph, 'unpack')
143 self.graph.node['unpack']['axis'] = np.array(1)
144 self.graph.node['data_to_split']['shape'] = int64_array([2, 3, 25, 30])
146 tf_unpack_infer(unpack_node)
147 exp_shape = int64_array([2, 1, 25, 30])
148 for out_node in unpack_node.out_nodes().values():
149 self.assertTrue(np.all(exp_shape == out_node.shape))
151 def test_tf_unpack_infer_default_number_of_pieces(self):
152 unpack_node = Node(self.graph, 'unpack')
153 self.graph.node['unpack']['axis'] = np.array(1)
154 self.graph.node['unpack']['num_split'] = None
155 self.graph.node['data_to_split']['shape'] = int64_array([2, 3, 25, 30])
157 tf_unpack_infer(unpack_node)
158 exp_shape = int64_array([2, 1, 25, 30])
159 for out_node in unpack_node.out_nodes().values():
160 self.assertTrue(np.all(exp_shape == out_node.shape))
162 def test_tf_unpack_infer_not_supported(self):
163 # the case when the size of the dimension being unpacked is not equal to number of pieces is not supported
164 unpack_node = Node(self.graph, 'unpack')
165 self.graph.node['unpack']['axis'] = np.array(1)
166 self.graph.node['data_to_split']['shape'] = int64_array([2, 6, 25, 30])
168 tf_unpack_infer(unpack_node)
169 for out_node in unpack_node.out_nodes().values():
170 self.assertIsNone(out_node.shape)
173 class TestSplitFunc(unittest.TestCase):
177 self.graph = build_graph_with_edge_attrs(
178 {'data_to_split': {'value': None, 'shape': int64_array([2, 12, 25, 44]), 'kind': 'data'},
179 'split_node': {'kind': 'op', 'op': 'Split', 'axis': None},
180 'out_data_2': {'value': None, 'shape': None, 'kind': 'data'},
181 'out_data_5': {'value': None, 'shape': None, 'kind': 'data'},
182 'out_data_7': {'value': None, 'shape': None, 'kind': 'data'},
184 [('data_to_split', 'split_node', {'in': 0}),
185 ('split_node', 'out_data_2', {'out': 2}),
186 ('split_node', 'out_data_5', {'out': 5}),
187 ('split_node', 'out_data_7', {'out': 7}),
190 def test_split_non_sequential_output_port(self):
191 split(Node(self.graph, 'data_to_split'), Node(self.graph, 'split_node'), -1, [3, 2, 7, 5, 6, 4, 9, 8])
192 self.assertTrue(np.all(Node(self.graph, 'out_data_2').shape == [2, 12, 25, 7]))
193 self.assertTrue(np.all(Node(self.graph, 'out_data_5').shape == [2, 12, 25, 4]))
194 self.assertTrue(np.all(Node(self.graph, 'out_data_7').shape == [2, 12, 25, 8]))
196 def test_split_value_infer_non_sequential_output_port(self):
197 data_node = Node(self.graph, 'data_to_split')
198 value = np.array(range(2 * 12 * 25 * 44)).reshape(data_node.shape)
199 data_node.value = value.copy()
200 split(data_node, Node(self.graph, 'split_node'), -1, [3, 2, 7, 5, 6, 4, 9, 8])
201 self.assertTrue(np.all(Node(self.graph, 'out_data_2').shape == [2, 12, 25, 7]))
202 self.assertTrue(np.all(Node(self.graph, 'out_data_5').shape == [2, 12, 25, 4]))
203 self.assertTrue(np.all(Node(self.graph, 'out_data_7').shape == [2, 12, 25, 8]))
205 self.assertTrue(np.all(Node(self.graph, 'out_data_2').value == value[:, :, :, 5:12]))
206 self.assertTrue(np.all(Node(self.graph, 'out_data_5').value == value[:, :, :, 23:27]))
207 self.assertTrue(np.all(Node(self.graph, 'out_data_7').value == value[:, :, :, 36:]))