Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / onnx / detection_output_test.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import onnx
18 import unittest
19
20 import numpy as np
21
22 from extensions.front.onnx.detection_output import DetectionOutputFrontExtractor
23 from extensions.ops.DetectionOutput import DetectionOutput
24 from mo.ops.op import Op
25 from mo.utils.unittest.extractors import PB
26
27
28 class TestDetectionOutputExt(unittest.TestCase):
29     @staticmethod
30     def _create_do_node(num_classes=0, share_location=0, background_label_id=0,
31                         code_type="", variance_encoded_in_target=0, keep_top_k=0,
32                         confidence_threshold=0, nms_threshold=0, top_k=0, eta=0):
33         pb = onnx.helper.make_node(
34             'DetectionOutput',
35             inputs=['x'],
36             outputs=['y'],
37             num_classes=num_classes,
38             share_location=share_location,
39             background_label_id=background_label_id,
40             code_type=code_type,
41             variance_encoded_in_target=variance_encoded_in_target,
42             keep_top_k=keep_top_k,
43             confidence_threshold=confidence_threshold,
44             # nms_param
45             nms_threshold=nms_threshold,
46             top_k=top_k,
47             eta=eta,
48         )
49         
50         node = PB({'pb': pb})
51         return node
52
53     @classmethod
54     def setUpClass(cls):
55         Op.registered_ops['DetectionOutput'] = DetectionOutput
56
57     def test_do_no_pb_no_ml(self):
58         self.assertRaises(AttributeError, DetectionOutputFrontExtractor.extract, None)
59
60     def test_do_ext_ideal_numbers(self):
61         node = self._create_do_node(num_classes=21, share_location=1,
62                                     code_type="CENTER_SIZE", keep_top_k=200,
63                                     confidence_threshold=0.01, nms_threshold=0.45, top_k=400, eta=1.0)
64         
65         DetectionOutputFrontExtractor.extract(node)
66         
67         exp_res = {
68             'op': 'DetectionOutput',
69             'type': 'DetectionOutput',
70             'num_classes': 21,
71             'share_location': 1,
72             'background_label_id': 0,
73             'code_type': "caffe.PriorBoxParameter.CENTER_SIZE",
74             'variance_encoded_in_target': 0,
75             'keep_top_k': 200,
76             'confidence_threshold': 0.01,
77             'visualize_threshold': 0.6,
78             # nms_param
79             'nms_threshold': 0.45,
80             'top_k': 400,
81             'eta': 1.0,
82             # ONNX have not such parameters
83             # save_output_param.resize_param
84             'prob': 0,
85             'resize_mode': "",
86             'height': 0,
87             'width': 0,
88             'height_scale': 0,
89             'width_scale': 0,
90             'pad_mode': "",
91             'pad_value': "",
92             'interp_mode': "",
93             'input_width': 1,
94             'input_height': 1,
95             'normalized': 1,            
96         }
97
98         for key in exp_res.keys():
99             if key in ['confidence_threshold', 'visualise_threshold', 'nms_threshold', 'eta']:
100                 np.testing.assert_almost_equal(node[key], exp_res[key])
101             else:
102                 self.assertEqual(node[key], exp_res[key])