Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / partial_infer / expand_dims_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20
21 from mo.front.common.partial_infer.expand_dims import tf_expand_dims_infer
22 from mo.graph.graph import Node
23 from mo.utils.unittest.graph import build_graph
24
25 nodes_attributes = {'input_1': {'kind': 'data', 'value': None},
26                     'input_2': {'kind': 'data', 'value': None},
27                     'expand_dims': {'kind': 'op'},
28                     'out': {'value': None, 'shape': None, 'kind': 'data'}
29                     }
30
31
32 class TestExpandDimsInfer(unittest.TestCase):
33     def test_expand_dims_infer_two_inputs(self):
34         graph = build_graph(nodes_attributes,
35                             [('input_1', 'expand_dims'),
36                              ('input_2', 'expand_dims'),
37                              ('expand_dims', 'out')],
38                             {'input_1': {'shape': np.array([3, 256, 256])},
39                              'input_2': {'shape': np.array([1]), 'value': np.array([1], dtype=np.int32)},
40                              })
41
42         expand_dims_node = Node(graph, 'expand_dims')
43
44         tf_expand_dims_infer(expand_dims_node)
45         exp_shape = np.array([3, 1, 256, 256])
46         res_shape = expand_dims_node.out_node().shape
47         self.assertEqual(len(exp_shape), len(res_shape))
48         for i in range(0, len(exp_shape)):
49             self.assertEqual(exp_shape[i], res_shape[i])
50
51     def test_expand_dims_infer_two_inputs_2(self):
52         graph = build_graph(nodes_attributes,
53                             [('input_1', 'expand_dims'),
54                              ('input_2', 'expand_dims'),
55                              ('expand_dims', 'out')],
56                             {'input_1': {'shape': np.array([3, 256, 256])},
57                              'input_2': {'shape': np.array([1]), 'value': np.array([2], dtype=np.int32)},
58                              })
59
60         expand_dims_node = Node(graph, 'expand_dims')
61
62         tf_expand_dims_infer(expand_dims_node)
63         exp_shape = np.array([3, 256, 1, 256])
64         res_shape = expand_dims_node.out_node().shape
65         self.assertEqual(len(exp_shape), len(res_shape))
66         for i in range(0, len(exp_shape)):
67             self.assertEqual(exp_shape[i], res_shape[i])
68
69     def test_expand_dims_infer_two_inputs_3(self):
70         graph = build_graph(nodes_attributes,
71                             [('input_1', 'expand_dims'),
72                              ('input_2', 'expand_dims'),
73                              ('expand_dims', 'out')],
74                             {'input_1': {'shape': np.array([3, 256, 256])},
75                              'input_2': {'shape': np.array([]), 'value': np.array(3, dtype=np.int32)},
76                              })
77
78         expand_dims_node = Node(graph, 'expand_dims')
79
80         tf_expand_dims_infer(expand_dims_node)
81         exp_shape = np.array([3, 256, 256, 1])
82         res_shape = expand_dims_node.out_node().shape
83         self.assertEqual(len(exp_shape), len(res_shape))
84         for i in range(0, len(exp_shape)):
85             self.assertEqual(exp_shape[i], res_shape[i])
86
87     def test_expand_dims_infer_two_inputs_negative(self):
88         graph = build_graph(nodes_attributes,
89                             [('input_1', 'expand_dims'),
90                              ('input_2', 'expand_dims'),
91                              ('expand_dims', 'out')],
92                             {'input_1': {'shape': np.array([3, 256, 256])},
93                              'input_2': {'shape': np.array([1]), 'value': np.array([2, 3], dtype=np.int32)},
94                              })
95
96         expand_dims_node = Node(graph, 'expand_dims')
97
98         tf_expand_dims_infer(expand_dims_node)
99         self.assertIsNone(expand_dims_node.out_node().shape)
100
101     def test_expand_dims_infer_two_inputs_negative_2(self):
102         graph = build_graph(nodes_attributes,
103                             [('input_1', 'expand_dims'),
104                              ('input_2', 'expand_dims'),
105                              ('expand_dims', 'out')],
106                             {'input_1': {'shape': None},
107                              'input_2': {'shape': np.array([1]), 'value': np.array([2, 3], dtype=np.int32)},
108                              })
109
110         expand_dims_node = Node(graph, 'expand_dims')
111
112         tf_expand_dims_infer(expand_dims_node)
113         self.assertIsNone(expand_dims_node.out_node().shape)
114
115     def test_expand_dims_infer_one_input(self):
116         graph = build_graph(nodes_attributes,
117                             [('input_1', 'expand_dims'),
118                              ('expand_dims', 'out')],
119                             {'input_1': {'shape': np.array([3, 256, 256])},
120                              'expand_dims': {'expand_axis': 1}
121                              })
122
123         expand_dims_node = Node(graph, 'expand_dims')
124
125         tf_expand_dims_infer(expand_dims_node)
126         exp_shape = np.array([3, 1, 256, 256])
127         res_shape = expand_dims_node.out_node().shape
128         self.assertEqual(len(exp_shape), len(res_shape))
129         for i in range(0, len(exp_shape)):
130             self.assertEqual(exp_shape[i], res_shape[i])
131
132     def test_expand_dims_infer_one_input_2(self):
133         graph = build_graph(nodes_attributes,
134                             [('input_1', 'expand_dims'),
135                              ('expand_dims', 'out')],
136                             {'input_1': {'shape': np.array([3, 256, 256])},
137                              'expand_dims': {'expand_axis': 2}
138                              })
139
140         expand_dims_node = Node(graph, 'expand_dims')
141
142         tf_expand_dims_infer(expand_dims_node)
143         exp_shape = np.array([3, 256, 1, 256])
144         res_shape = expand_dims_node.out_node().shape
145         self.assertEqual(len(exp_shape), len(res_shape))
146         for i in range(0, len(exp_shape)):
147             self.assertEqual(exp_shape[i], res_shape[i])
148
149     def test_expand_dims_infer_one_input_3(self):
150         graph = build_graph(nodes_attributes,
151                             [('input_1', 'expand_dims'),
152                              ('expand_dims', 'out')],
153                             {'input_1': {'shape': np.array([3, 256, 256])},
154                              'expand_dims': {'expand_axis': -1}
155                              })
156
157         expand_dims_node = Node(graph, 'expand_dims')
158
159         tf_expand_dims_infer(expand_dims_node)
160         exp_shape = np.array([3, 256, 256, 1])
161         res_shape = expand_dims_node.out_node().shape
162         self.assertEqual(len(exp_shape), len(res_shape))
163         for i in range(0, len(exp_shape)):
164             self.assertEqual(exp_shape[i], res_shape[i])
165
166     def test_expand_dims_infer_one_input_4(self):
167         graph = build_graph(nodes_attributes,
168                             [('input_1', 'expand_dims'),
169                              ('expand_dims', 'out')],
170                             {'input_1': {'shape': np.array([3, 256, 256])},
171                              'expand_dims': {'expand_axis': -2}
172                              })
173
174         expand_dims_node = Node(graph, 'expand_dims')
175
176         tf_expand_dims_infer(expand_dims_node)
177         exp_shape = np.array([3, 256, 1, 256])
178         res_shape = expand_dims_node.out_node().shape
179         self.assertEqual(len(exp_shape), len(res_shape))
180         for i in range(0, len(exp_shape)):
181             self.assertEqual(exp_shape[i], res_shape[i])
182
183     def test_expand_dims_infer_one_input_negative(self):
184         graph = build_graph(nodes_attributes,
185                             [('input_1', 'expand_dims'),
186                              ('expand_dims', 'out')],
187                             {'input_1': {'shape': np.array([3, 256, 256])},
188                              'expand_dims': {'expand_axis': None}
189                              })
190
191         expand_dims_node = Node(graph, 'expand_dims')
192
193         tf_expand_dims_infer(expand_dims_node)
194         self.assertIsNone(expand_dims_node.out_node().shape)