2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.graph.graph import Node
22 from mo.ops.tile import Tile
23 from mo.utils.unittest.graph import build_graph
25 nodes_attributes = {'data': {'value': None, 'shape': np.array([10, 20, 30, 40]), 'kind': 'data'},
26 'tile_values': {'value': None, 'shape': np.array([4]), 'kind': 'data'},
27 'tile': {'type': 'Tile', 'kind': 'op'},
28 'tile_out': {'value': None, 'shape': None, 'kind': 'data'},
32 class TestTileInfer(unittest.TestCase):
33 def test_tile_infer_correct(self):
34 graph = build_graph(nodes_attributes,
36 ('tile_values', 'tile'),
37 ('tile', 'tile_out')],
38 {'tile_values': {'value': np.array([7, 1, 1, 1])}})
39 tile_node = Node(graph, 'tile')
41 self.assertTrue(np.all(np.array([70, 20, 30, 40]) == graph.node['tile_out']['shape']))
42 self.assertEqual(tile_node.axis, 0)
43 self.assertEqual(tile_node.tiles, 7)
45 def test_tile_infer_correct_2(self):
46 graph = build_graph(nodes_attributes,
48 ('tile_values', 'tile'),
49 ('tile', 'tile_out')],
50 {'tile_values': {'value': np.array([1, 7, 1, 1])}})
51 tile_node = Node(graph, 'tile')
53 self.assertTrue(np.all(np.array([10, 140, 30, 40]) == graph.node['tile_out']['shape']))
54 self.assertEqual(tile_node.axis, 1)
55 self.assertEqual(tile_node.tiles, 7)
57 def test_tile_infer_correct_2d_tensor(self):
58 graph = build_graph(nodes_attributes,
60 ('tile_values', 'tile'),
61 ('tile', 'tile_out')],
62 {'data': {'shape': np.array([3, 7])},
63 'tile_values': {'value': np.array([5, 1])}})
64 tile_node = Node(graph, 'tile')
66 self.assertTrue(np.all(np.array([15, 7]) == graph.node['tile_out']['shape']))
67 self.assertEqual(tile_node.axis, 0)
68 self.assertEqual(tile_node.tiles, 5)
70 def test_tile_infer_all_ones(self):
71 graph = build_graph(nodes_attributes,
73 ('tile_values', 'tile'),
74 ('tile', 'tile_out')],
75 {'tile_values': {'value': np.array([1, 1, 1, 1])}})
76 tile_node = Node(graph, 'tile')
78 self.assertTrue(np.all(np.array([10, 20, 30, 40]) == graph.node['tile_out']['shape']))
79 self.assertEqual(tile_node.axis, 0)
80 self.assertEqual(tile_node.tiles, 1)
82 def test_tile_infer_two_non_one(self):
83 graph = build_graph(nodes_attributes,
85 ('tile_values', 'tile'),
86 ('tile', 'tile_out')],
87 {'tile_values': {'value': np.array([2, 1, 1, 2])}})
88 tile_node = Node(graph, 'tile')
90 self.assertIsNone(graph.node['tile']['type'])
91 self.assertTrue(np.all(np.array([20, 20, 30, 80]) == graph.node['tile_out']['shape']))
92 self.assertFalse(tile_node.has_and_set('axis'))
93 self.assertFalse(tile_node.has_and_set('tiles'))
95 def test_tile_infer_three_non_one(self):
96 graph = build_graph(nodes_attributes,
98 ('tile_values', 'tile'),
99 ('tile', 'tile_out')],
100 {'tile_values': {'value': np.array([2, 1, 5, 2])}})
101 tile_node = Node(graph, 'tile')
102 Tile.infer(tile_node)
103 self.assertIsNone(graph.node['tile']['type'])
104 self.assertTrue(np.all(np.array([20, 20, 150, 80]) == graph.node['tile_out']['shape']))
106 def test_tile_infer_none_input_shape(self):
107 graph = build_graph(nodes_attributes,
109 ('tile_values', 'tile'),
110 ('tile', 'tile_out')],
111 {'data': {'shape': None},
112 'tile_values': {'value': np.array([1, 7, 1, 1])}})
113 tile_node = Node(graph, 'tile')
114 Tile.infer(tile_node)
115 self.assertIsNone(graph.node['tile_out']['shape'])
117 def test_tile_infer_values_test(self):
118 input_data = np.arange(-30, 60, 0.25).reshape([2, 4, 3, -1])
119 tile_values = np.array([3, 1, 1, 1])
120 graph = build_graph(nodes_attributes,
122 ('tile_values', 'tile'),
123 ('tile', 'tile_out')],
124 {'data': {'shape': input_data.shape, 'value': input_data},
125 'tile_values': {'value': tile_values}})
126 tile_node = Node(graph, 'tile')
127 Tile.infer(tile_node)
128 self.assertTrue(np.all(np.tile(input_data, tile_values) == graph.node['tile_out']['value']))
129 self.assertEqual(tile_node.axis, 0)
130 self.assertEqual(tile_node.tiles, 3)
132 def test_tile_infer_values_const_propagation(self):
134 Test for constant propagation even if tile with multiple tile indices is not supported
136 input_data = np.arange(-30, 60, 0.25).reshape([2, 4, 3, -1])
137 tile_values = np.array([4, 3, 2, 5])
138 graph = build_graph(nodes_attributes,
140 ('tile_values', 'tile'),
141 ('tile', 'tile_out')],
142 {'data': {'shape': input_data.shape, 'value': input_data},
143 'tile_values': {'value': tile_values}})
144 tile_node = Node(graph, 'tile')
145 Tile.infer(tile_node)
146 self.assertTrue(np.all(np.tile(input_data, tile_values) == graph.node['tile_out']['value']))
147 self.assertIsNone(tile_node.type)
149 def test_tile_infer_undefined_tile_values(self):
150 graph = build_graph(nodes_attributes,
152 ('tile_values', 'tile'),
153 ('tile', 'tile_out')],
154 {'tile_values': {'value': None}})
155 tile_node = Node(graph, 'tile')
156 Tile.infer(tile_node)
157 self.assertIsNone(graph.node['tile_out']['shape'])
159 def test_tile_infer_shapes_mismatch(self):
160 graph = build_graph(nodes_attributes,
162 ('tile_values', 'tile'),
163 ('tile', 'tile_out')],
164 {'tile_values': {'value': np.array([1, 2, 1]), 'shape': np.array([3])}})
165 tile_node = Node(graph, 'tile')
166 Tile.infer(tile_node)
167 self.assertIsNone(graph.node['tile_out']['shape'])
169 def test_tile_infer_one_input_correct(self):
170 graph = build_graph(nodes_attributes,
172 ('tile', 'tile_out')],
173 {'tile': {'axis': 1, 'tiles': 7}})
174 tile_node = Node(graph, 'tile')
175 Tile.infer(tile_node)
176 self.assertTrue(np.all(np.array([10, 140, 30, 40]) == graph.node['tile_out']['shape']))
177 self.assertEqual(tile_node.axis, 1)
178 self.assertEqual(tile_node.tiles, 7)
180 def test_tile_infer_one_input_correct_missing_axis(self):
181 graph = build_graph(nodes_attributes,
183 ('tile', 'tile_out')],
184 {'tile': {'tiles': 7}})
185 tile_node = Node(graph, 'tile')
186 Tile.infer(tile_node)
187 self.assertIsNone(graph.node['tile_out']['shape'])
189 def test_tile_infer_one_input_correct_missing_tiles(self):
190 graph = build_graph(nodes_attributes,
192 ('tile', 'tile_out')],
193 {'tile': {'axis': 1}})
194 tile_node = Node(graph, 'tile')
195 Tile.infer(tile_node)
196 self.assertIsNone(graph.node['tile_out']['shape'])