Added unit tests and readme for model optimizer (#79)
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / ObjectDetectionAPI_test.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import networkx as nx
20
21 from extensions.front.tf.ObjectDetectionAPI import calculate_shape_keeping_aspect_ratio, \
22     calculate_placeholder_spatial_shape
23 from mo.front.subgraph_matcher import SubgraphMatch
24 from mo.utils.custom_replacement_config import CustomReplacementDescriptor
25 from mo.utils.error import Error
26
27
28 class FakePipelineConfig:
29     def __init__(self, model_params: dict):
30         self._model_params = model_params
31
32     def get_param(self, param: str):
33         if param not in self._model_params:
34             return None
35         return self._model_params[param]
36
37
38 class TestCalculateShape(unittest.TestCase):
39     min_size = 600
40     max_size = 1024
41
42     def test_calculate_shape_1(self):
43         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(100, 300, self.min_size, self.max_size), (341, 1024))
44
45     def test_calculate_shape_2(self):
46         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(100, 600, self.min_size, self.max_size), (171, 1024))
47
48     def test_calculate_shape_3(self):
49         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(100, 3000, self.min_size, self.max_size), (34, 1024))
50
51     def test_calculate_shape_4(self):
52         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 300, self.min_size, self.max_size), (600, 600))
53
54     def test_calculate_shape_5(self):
55         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 400, self.min_size, self.max_size), (600, 800))
56
57     def test_calculate_shape_6(self):
58         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 600, self.min_size, self.max_size), (512, 1024))
59
60     def test_calculate_shape_7(self):
61         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(1000, 2500, self.min_size, self.max_size),
62                               (410, 1024))
63
64     def test_calculate_shape_8(self):
65         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(1800, 2000, self.min_size, self.max_size),
66                               (600, 667))
67
68     def test_calculate_shape_11(self):
69         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 100, self.min_size, self.max_size), (1024, 341))
70
71     def test_calculate_shape_12(self):
72         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(600, 100, self.min_size, self.max_size), (1024, 171))
73
74     def test_calculate_shape_13(self):
75         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(3000, 100, self.min_size, self.max_size), (1024, 34))
76
77     def test_calculate_shape_15(self):
78         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(400, 300, self.min_size, self.max_size), (800, 600))
79
80     def test_calculate_shape_16(self):
81         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(600, 300, self.min_size, self.max_size), (1024, 512))
82
83     def test_calculate_shape_17(self):
84         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(2500, 1000, self.min_size, self.max_size),
85                               (1024, 410))
86
87     def test_calculate_shape_18(self):
88         self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(2000, 1800, self.min_size, self.max_size),
89                               (667, 600))
90
91
92 class TestCalculatePlaceholderSpatialShape(unittest.TestCase):
93     def setUp(self):
94         self.graph = nx.MultiDiGraph()
95         self.graph.graph['user_shapes'] = None
96         self.replacement_desc = CustomReplacementDescriptor('dummy_id', {})
97         self.match = SubgraphMatch(self.graph, self.replacement_desc, [], [], [], '')
98         self.pipeline_config = FakePipelineConfig({})
99
100     def test_default_fixed_shape_resizer(self):
101         self.pipeline_config._model_params['resizer_image_height'] = 300
102         self.pipeline_config._model_params['resizer_image_width'] = 600
103         self.assertTupleEqual((300, 600),
104                               calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
105
106     def test_fixed_shape_resizer_overrided_by_user(self):
107         self.pipeline_config._model_params['resizer_image_height'] = 300
108         self.pipeline_config._model_params['resizer_image_width'] = 600
109         self.graph.graph['user_shapes'] = {'image_tensor': [{'shape': [1, 400, 500, 3]}]}
110         self.assertTupleEqual((400, 500),
111                               calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
112
113     def test_default_keep_aspect_ratio_resizer(self):
114         self.pipeline_config._model_params['resizer_min_dimension'] = 600
115         self.pipeline_config._model_params['resizer_max_dimension'] = 1024
116         self.assertTupleEqual((600, 600),
117                               calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
118
119     def test_keep_aspect_ratio_resizer_overrided_by_user(self):
120         self.pipeline_config._model_params['resizer_min_dimension'] = 600
121         self.pipeline_config._model_params['resizer_max_dimension'] = 1024
122         self.graph.graph['user_shapes'] = {'image_tensor': [{'shape': [1, 400, 300, 3]}]}
123         self.assertTupleEqual((800, 600),
124                               calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
125
126     def test_missing_input_shape_information(self):
127         self.assertRaises(Error, calculate_placeholder_spatial_shape, self.graph, self.match, self.pipeline_config)