Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / ops / tile_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20
21 from mo.graph.graph import Node
22 from mo.ops.tile import Tile
23 from mo.utils.unittest.graph import build_graph
24
25 nodes_attributes = {'data': {'value': None, 'shape': np.array([10, 20, 30, 40]), 'kind': 'data'},
26                     'tile_values': {'value': None, 'shape': np.array([4]), 'kind': 'data'},
27                     'tile': {'type': 'Tile', 'kind': 'op'},
28                     'tile_out': {'value': None, 'shape': None, 'kind': 'data'},
29                     }
30
31
32 class TestTileInfer(unittest.TestCase):
33     def test_tile_infer_correct(self):
34         graph = build_graph(nodes_attributes,
35                             [('data', 'tile'),
36                              ('tile_values', 'tile'),
37                              ('tile', 'tile_out')],
38                             {'tile_values': {'value': np.array([7, 1, 1, 1])}})
39         tile_node = Node(graph, 'tile')
40         Tile.infer(tile_node)
41         self.assertTrue(np.all(np.array([70, 20, 30, 40]) == graph.node['tile_out']['shape']))
42         self.assertEqual(tile_node.axis, 0)
43         self.assertEqual(tile_node.tiles, 7)
44
45     def test_tile_infer_correct_2(self):
46         graph = build_graph(nodes_attributes,
47                             [('data', 'tile'),
48                              ('tile_values', 'tile'),
49                              ('tile', 'tile_out')],
50                             {'tile_values': {'value': np.array([1, 7, 1, 1])}})
51         tile_node = Node(graph, 'tile')
52         Tile.infer(tile_node)
53         self.assertTrue(np.all(np.array([10, 140, 30, 40]) == graph.node['tile_out']['shape']))
54         self.assertEqual(tile_node.axis, 1)
55         self.assertEqual(tile_node.tiles, 7)
56
57     def test_tile_infer_correct_2d_tensor(self):
58         graph = build_graph(nodes_attributes,
59                             [('data', 'tile'),
60                              ('tile_values', 'tile'),
61                              ('tile', 'tile_out')],
62                             {'data': {'shape': np.array([3, 7])},
63                              'tile_values': {'value': np.array([5, 1])}})
64         tile_node = Node(graph, 'tile')
65         Tile.infer(tile_node)
66         self.assertTrue(np.all(np.array([15, 7]) == graph.node['tile_out']['shape']))
67         self.assertEqual(tile_node.axis, 0)
68         self.assertEqual(tile_node.tiles, 5)
69
70     def test_tile_infer_all_ones(self):
71         graph = build_graph(nodes_attributes,
72                             [('data', 'tile'),
73                              ('tile_values', 'tile'),
74                              ('tile', 'tile_out')],
75                             {'tile_values': {'value': np.array([1, 1, 1, 1])}})
76         tile_node = Node(graph, 'tile')
77         Tile.infer(tile_node)
78         self.assertTrue(np.all(np.array([10, 20, 30, 40]) == graph.node['tile_out']['shape']))
79         self.assertEqual(tile_node.axis, 0)
80         self.assertEqual(tile_node.tiles, 1)
81
82     def test_tile_infer_two_non_one(self):
83         graph = build_graph(nodes_attributes,
84                             [('data', 'tile'),
85                              ('tile_values', 'tile'),
86                              ('tile', 'tile_out')],
87                             {'tile_values': {'value': np.array([2, 1, 1, 2])}})
88         tile_node = Node(graph, 'tile')
89         Tile.infer(tile_node)
90         self.assertIsNone(graph.node['tile']['type'])
91         self.assertTrue(np.all(np.array([20, 20, 30, 80]) == graph.node['tile_out']['shape']))
92         self.assertFalse(tile_node.has_and_set('axis'))
93         self.assertFalse(tile_node.has_and_set('tiles'))
94
95     def test_tile_infer_three_non_one(self):
96         graph = build_graph(nodes_attributes,
97                             [('data', 'tile'),
98                              ('tile_values', 'tile'),
99                              ('tile', 'tile_out')],
100                             {'tile_values': {'value': np.array([2, 1, 5, 2])}})
101         tile_node = Node(graph, 'tile')
102         Tile.infer(tile_node)
103         self.assertIsNone(graph.node['tile']['type'])
104         self.assertTrue(np.all(np.array([20, 20, 150, 80]) == graph.node['tile_out']['shape']))
105
106     def test_tile_infer_none_input_shape(self):
107         graph = build_graph(nodes_attributes,
108                             [('data', 'tile'),
109                              ('tile_values', 'tile'),
110                              ('tile', 'tile_out')],
111                             {'data': {'shape': None},
112                              'tile_values': {'value': np.array([1, 7, 1, 1])}})
113         tile_node = Node(graph, 'tile')
114         Tile.infer(tile_node)
115         self.assertIsNone(graph.node['tile_out']['shape'])
116
117     def test_tile_infer_values_test(self):
118         input_data = np.arange(-30, 60, 0.25).reshape([2, 4, 3, -1])
119         tile_values = np.array([3, 1, 1, 1])
120         graph = build_graph(nodes_attributes,
121                             [('data', 'tile'),
122                              ('tile_values', 'tile'),
123                              ('tile', 'tile_out')],
124                             {'data': {'shape': input_data.shape, 'value': input_data},
125                              'tile_values': {'value': tile_values}})
126         tile_node = Node(graph, 'tile')
127         Tile.infer(tile_node)
128         self.assertTrue(np.all(np.tile(input_data, tile_values) == graph.node['tile_out']['value']))
129         self.assertEqual(tile_node.axis, 0)
130         self.assertEqual(tile_node.tiles, 3)
131
132     def test_tile_infer_values_const_propagation(self):
133         """
134         Test for constant propagation even if tile with multiple tile indices is not supported
135         """
136         input_data = np.arange(-30, 60, 0.25).reshape([2, 4, 3, -1])
137         tile_values = np.array([4, 3, 2, 5])
138         graph = build_graph(nodes_attributes,
139                             [('data', 'tile'),
140                              ('tile_values', 'tile'),
141                              ('tile', 'tile_out')],
142                             {'data': {'shape': input_data.shape, 'value': input_data},
143                              'tile_values': {'value': tile_values}})
144         tile_node = Node(graph, 'tile')
145         Tile.infer(tile_node)
146         self.assertTrue(np.all(np.tile(input_data, tile_values) == graph.node['tile_out']['value']))
147         self.assertIsNone(tile_node.type)
148
149     def test_tile_infer_undefined_tile_values(self):
150         graph = build_graph(nodes_attributes,
151                             [('data', 'tile'),
152                              ('tile_values', 'tile'),
153                              ('tile', 'tile_out')],
154                             {'tile_values': {'value': None}})
155         tile_node = Node(graph, 'tile')
156         Tile.infer(tile_node)
157         self.assertIsNone(graph.node['tile_out']['shape'])
158
159     def test_tile_infer_shapes_mismatch(self):
160         graph = build_graph(nodes_attributes,
161                             [('data', 'tile'),
162                              ('tile_values', 'tile'),
163                              ('tile', 'tile_out')],
164                             {'tile_values': {'value': np.array([1, 2, 1]), 'shape': np.array([3])}})
165         tile_node = Node(graph, 'tile')
166         Tile.infer(tile_node)
167         self.assertIsNone(graph.node['tile_out']['shape'])
168
169     def test_tile_infer_one_input_correct(self):
170         graph = build_graph(nodes_attributes,
171                             [('data', 'tile'),
172                              ('tile', 'tile_out')],
173                             {'tile': {'axis': 1, 'tiles': 7}})
174         tile_node = Node(graph, 'tile')
175         Tile.infer(tile_node)
176         self.assertTrue(np.all(np.array([10, 140, 30, 40]) == graph.node['tile_out']['shape']))
177         self.assertEqual(tile_node.axis, 1)
178         self.assertEqual(tile_node.tiles, 7)
179
180     def test_tile_infer_one_input_correct_missing_axis(self):
181         graph = build_graph(nodes_attributes,
182                             [('data', 'tile'),
183                              ('tile', 'tile_out')],
184                             {'tile': {'tiles': 7}})
185         tile_node = Node(graph, 'tile')
186         Tile.infer(tile_node)
187         self.assertIsNone(graph.node['tile_out']['shape'])
188
189     def test_tile_infer_one_input_correct_missing_tiles(self):
190         graph = build_graph(nodes_attributes,
191                             [('data', 'tile'),
192                              ('tile', 'tile_out')],
193                             {'tile': {'axis': 1}})
194         tile_node = Node(graph, 'tile')
195         Tile.infer(tile_node)
196         self.assertIsNone(graph.node['tile_out']['shape'])