Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / ops / argmax_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20
21 from extensions.ops.argmax import ArgMaxOp
22 from mo.graph.graph import Node
23 from mo.utils.unittest.graph import build_graph
24
25 nodes_attributes = {'node_1': {'type': 'Identity', 'kind': 'op'},
26                     'argmax': {'type': 'ArgMax', 'kind': 'op'},
27                     'node_3': {'type': 'Identity', 'kind': 'op'},
28                     'op_output': { 'kind': 'op', 'op': 'OpOutput'}
29                     }
30
31
32 class TestArgMaxOp(unittest.TestCase):
33     def test_caffe_argmax_axis(self):
34         graph = build_graph(nodes_attributes,
35                             [('node_1', 'argmax'),
36                              ('argmax', 'node_3'),
37                              ('node_3', 'op_output')
38                              ],
39                             {'node_3': {'shape': None},
40                              'node_1': {'shape': np.array([1, 3, 1025, 2049])},
41                              'argmax': {
42                                  'out_max_val': True,
43                                  'top_k': 100,
44                                  'axis': 2
45                              }
46                              })
47
48         argmax_node = Node(graph, 'argmax')
49         ArgMaxOp.argmax_infer(argmax_node)
50         exp_shape = np.array([1, 3, 100, 2049])
51         res_shape = graph.node['node_3']['shape']
52         for i in range(0, len(exp_shape)):
53             self.assertEqual(exp_shape[i], res_shape[i])
54
55     def test_caffe_argmax_axis_negative(self):
56         graph = build_graph(nodes_attributes,
57                             [('node_1', 'argmax'),
58                              ('argmax', 'node_3'),
59                              ('node_3', 'op_output')
60                              ],
61                             {'node_3': {'shape': None},
62                              'node_1': {'shape': np.array([1, 3, 1025, 2049])},
63                              'argmax': {
64                                  'out_max_val': True,
65                                  'top_k': 100,
66                                  'axis': -1
67                              }
68                              })
69
70         argmax_node = Node(graph, 'argmax')
71         ArgMaxOp.argmax_infer(argmax_node)
72         exp_shape = np.array([1, 3, 1025, 100])
73         res_shape = graph.node['node_3']['shape']
74         self.assertEqual(argmax_node.axis, 3)
75         for i in range(0, len(exp_shape)):
76             self.assertEqual(exp_shape[i], res_shape[i])
77
78     def test_caffe_argmax_no_axis(self):
79         graph = build_graph(nodes_attributes,
80                             [('node_1', 'argmax'),
81                              ('argmax', 'node_3'),
82                              ('node_3', 'op_output')
83                              ],
84                             {'node_3': {'shape': None},
85                              'node_1': {'shape': np.array([1, 3, 1025, 2049])},
86                              'argmax': {
87                                  'out_max_val': True,
88                                  'top_k': 100
89                              }
90                              })
91
92         argmax_node = Node(graph, 'argmax')
93         ArgMaxOp.argmax_infer(argmax_node)
94         exp_shape = np.array([1, 2, 100, 1])
95         res_shape = graph.node['node_3']['shape']
96         for i in range(0, len(exp_shape)):
97             self.assertEqual(exp_shape[i], res_shape[i])
98
99     def test_caffe_argmax_extend_shape(self):
100         graph = build_graph(nodes_attributes,
101                             [('node_1', 'argmax'),
102                              ('argmax', 'node_3'),
103                              ('node_3', 'op_output')
104                              ],
105                             {'node_3': {'shape': None},
106                              'node_1': {'shape': np.array([1, 3])},
107                              'argmax': {
108                                  'out_max_val': True,
109                                  'top_k': 100
110                              }
111                              })
112
113         argmax_node = Node(graph, 'argmax')
114         ArgMaxOp.argmax_infer(argmax_node)
115         exp_shape = np.array([1, 2, 100])
116         res_shape = graph.node['node_3']['shape']
117         for i in range(0, len(exp_shape)):
118             self.assertEqual(exp_shape[i], res_shape[i])
119
120     def test_caffe_argmax_out_max_val_false(self):
121         graph = build_graph(nodes_attributes,
122                             [('node_1', 'argmax'),
123                              ('argmax', 'node_3'),
124                              ('node_3', 'op_output')
125                              ],
126                             {'node_3': {'shape': None},
127                              'node_1': {'shape': np.array([1, 3])},
128                              'argmax': {
129                                  'out_max_val': False,
130                                  'top_k': 100
131                              }
132                              })
133
134         argmax_node = Node(graph, 'argmax')
135         ArgMaxOp.argmax_infer(argmax_node)
136         exp_shape = np.array([1, 1, 100])
137         res_shape = graph.node['node_3']['shape']
138         for i in range(0, len(exp_shape)):
139             self.assertEqual(exp_shape[i], res_shape[i])
140
141     def test_caffe_argmax_no_shape(self):
142         graph = build_graph(nodes_attributes,
143                             [('node_1', 'argmax'),
144                              ('argmax', 'node_3'),
145                              ('node_3', 'op_output')
146                              ],
147                             {'node_3': {'shape': None},
148                              'node_1': {'shape': None},
149                              'argmax': {
150                                  'out_max_val': False,
151                                  'top_k': 100
152                              }
153                              })
154
155         argmax_node = Node(graph, 'argmax')
156         ArgMaxOp.argmax_infer(argmax_node)
157         res_shape = graph.node['node_3']['shape']
158         self.assertIsNone(res_shape)