2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from extensions.front.tf.ObjectDetectionAPI import calculate_shape_keeping_aspect_ratio, \
22 calculate_placeholder_spatial_shape
23 from mo.front.subgraph_matcher import SubgraphMatch
24 from mo.utils.custom_replacement_config import CustomReplacementDescriptor
25 from mo.utils.error import Error
28 class FakePipelineConfig:
29 def __init__(self, model_params: dict):
30 self._model_params = model_params
32 def get_param(self, param: str):
33 if param not in self._model_params:
35 return self._model_params[param]
38 class TestCalculateShape(unittest.TestCase):
42 def test_calculate_shape_1(self):
43 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(100, 300, self.min_size, self.max_size), (341, 1024))
45 def test_calculate_shape_2(self):
46 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(100, 600, self.min_size, self.max_size), (171, 1024))
48 def test_calculate_shape_3(self):
49 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(100, 3000, self.min_size, self.max_size), (34, 1024))
51 def test_calculate_shape_4(self):
52 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 300, self.min_size, self.max_size), (600, 600))
54 def test_calculate_shape_5(self):
55 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 400, self.min_size, self.max_size), (600, 800))
57 def test_calculate_shape_6(self):
58 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 600, self.min_size, self.max_size), (512, 1024))
60 def test_calculate_shape_7(self):
61 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(1000, 2500, self.min_size, self.max_size),
64 def test_calculate_shape_8(self):
65 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(1800, 2000, self.min_size, self.max_size),
68 def test_calculate_shape_11(self):
69 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 100, self.min_size, self.max_size), (1024, 341))
71 def test_calculate_shape_12(self):
72 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(600, 100, self.min_size, self.max_size), (1024, 171))
74 def test_calculate_shape_13(self):
75 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(3000, 100, self.min_size, self.max_size), (1024, 34))
77 def test_calculate_shape_15(self):
78 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(400, 300, self.min_size, self.max_size), (800, 600))
80 def test_calculate_shape_16(self):
81 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(600, 300, self.min_size, self.max_size), (1024, 512))
83 def test_calculate_shape_17(self):
84 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(2500, 1000, self.min_size, self.max_size),
87 def test_calculate_shape_18(self):
88 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(2000, 1800, self.min_size, self.max_size),
92 class TestCalculatePlaceholderSpatialShape(unittest.TestCase):
94 self.graph = nx.MultiDiGraph()
95 self.graph.graph['user_shapes'] = None
96 self.replacement_desc = CustomReplacementDescriptor('dummy_id', {})
97 self.match = SubgraphMatch(self.graph, self.replacement_desc, [], [], [], '')
98 self.pipeline_config = FakePipelineConfig({})
100 def test_default_fixed_shape_resizer(self):
101 self.pipeline_config._model_params['resizer_image_height'] = 300
102 self.pipeline_config._model_params['resizer_image_width'] = 600
103 self.assertTupleEqual((300, 600),
104 calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
106 def test_fixed_shape_resizer_overrided_by_user(self):
107 self.pipeline_config._model_params['resizer_image_height'] = 300
108 self.pipeline_config._model_params['resizer_image_width'] = 600
109 self.graph.graph['user_shapes'] = {'image_tensor': [{'shape': [1, 400, 500, 3]}]}
110 self.assertTupleEqual((400, 500),
111 calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
113 def test_default_keep_aspect_ratio_resizer(self):
114 self.pipeline_config._model_params['resizer_min_dimension'] = 600
115 self.pipeline_config._model_params['resizer_max_dimension'] = 1024
116 self.assertTupleEqual((600, 600),
117 calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
119 def test_keep_aspect_ratio_resizer_overrided_by_user(self):
120 self.pipeline_config._model_params['resizer_min_dimension'] = 600
121 self.pipeline_config._model_params['resizer_max_dimension'] = 1024
122 self.graph.graph['user_shapes'] = {'image_tensor': [{'shape': [1, 400, 300, 3]}]}
123 self.assertTupleEqual((800, 600),
124 calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
126 def test_missing_input_shape_information(self):
127 self.assertRaises(Error, calculate_placeholder_spatial_shape, self.graph, self.match, self.pipeline_config)