Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / ObjectDetectionAPI_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 from extensions.front.tf.ObjectDetectionAPI import calculate_shape_keeping_aspect_ratio, \
20     calculate_placeholder_spatial_shape
21 from mo.front.subgraph_matcher import SubgraphMatch
22 from mo.graph.graph import Graph
23 from mo.utils.custom_replacement_config import CustomReplacementDescriptor
24 from mo.utils.error import Error
25
26
27 class FakePipelineConfig:
28     def __init__(self, model_params: dict):
29         self._model_params = model_params
30
31     def get_param(self, param: str):
32         if param not in self._model_params:
33             return None
34         return self._model_params[param]
35
36
37 class TestCalculateShape(unittest.TestCase):
38     min_size = 600
39     max_size = 1024
40
41     def test_calculate_shape_1(self):
42         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(100, 300, self.min_size, self.max_size), (341, 1024))
43
44     def test_calculate_shape_2(self):
45         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(100, 600, self.min_size, self.max_size), (171, 1024))
46
47     def test_calculate_shape_3(self):
48         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(100, 3000, self.min_size, self.max_size), (34, 1024))
49
50     def test_calculate_shape_4(self):
51         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 300, self.min_size, self.max_size), (600, 600))
52
53     def test_calculate_shape_5(self):
54         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 400, self.min_size, self.max_size), (600, 800))
55
56     def test_calculate_shape_6(self):
57         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 600, self.min_size, self.max_size), (512, 1024))
58
59     def test_calculate_shape_7(self):
60         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(1000, 2500, self.min_size, self.max_size),
61                               (410, 1024))
62
63     def test_calculate_shape_8(self):
64         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(1800, 2000, self.min_size, self.max_size),
65                               (600, 667))
66
67     def test_calculate_shape_11(self):
68         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 100, self.min_size, self.max_size), (1024, 341))
69
70     def test_calculate_shape_12(self):
71         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(600, 100, self.min_size, self.max_size), (1024, 171))
72
73     def test_calculate_shape_13(self):
74         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(3000, 100, self.min_size, self.max_size), (1024, 34))
75
76     def test_calculate_shape_15(self):
77         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(400, 300, self.min_size, self.max_size), (800, 600))
78
79     def test_calculate_shape_16(self):
80         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(600, 300, self.min_size, self.max_size), (1024, 512))
81
82     def test_calculate_shape_17(self):
83         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(2500, 1000, self.min_size, self.max_size),
84                               (1024, 410))
85
86     def test_calculate_shape_18(self):
87         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(2000, 1800, self.min_size, self.max_size),
88                               (667, 600))
89
90
91 class TestCalculatePlaceholderSpatialShape(unittest.TestCase):
92     def setUp(self):
93         self.graph = Graph()
94         self.graph.graph['user_shapes'] = None
95         self.replacement_desc = CustomReplacementDescriptor('dummy_id', {})
96         self.match = SubgraphMatch(self.graph, self.replacement_desc, [], [], [], '')
97         self.pipeline_config = FakePipelineConfig({})
98
99     def test_default_fixed_shape_resizer(self):
100         self.pipeline_config._model_params['resizer_image_height'] = 300
101         self.pipeline_config._model_params['resizer_image_width'] = 600
102         self.assertTupleEqual((300, 600),
103                               calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
104
105     def test_fixed_shape_resizer_overrided_by_user(self):
106         self.pipeline_config._model_params['resizer_image_height'] = 300
107         self.pipeline_config._model_params['resizer_image_width'] = 600
108         self.graph.graph['user_shapes'] = {'image_tensor': [{'shape': [1, 400, 500, 3]}]}
109         self.assertTupleEqual((400, 500),
110                               calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
111
112     def test_default_keep_aspect_ratio_resizer(self):
113         self.pipeline_config._model_params['resizer_min_dimension'] = 600
114         self.pipeline_config._model_params['resizer_max_dimension'] = 1024
115         self.assertTupleEqual((600, 600),
116                               calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
117
118     def test_keep_aspect_ratio_resizer_overrided_by_user(self):
119         self.pipeline_config._model_params['resizer_min_dimension'] = 600
120         self.pipeline_config._model_params['resizer_max_dimension'] = 1024
121         self.graph.graph['user_shapes'] = {'image_tensor': [{'shape': [1, 400, 300, 3]}]}
122         self.assertTupleEqual((800, 600),
123                               calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
124
125     def test_missing_input_shape_information(self):
126         self.assertRaises(Error, calculate_placeholder_spatial_shape, self.graph, self.match, self.pipeline_config)