2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from extensions.front.tf.ObjectDetectionAPI import calculate_shape_keeping_aspect_ratio, \
20 calculate_placeholder_spatial_shape
21 from mo.front.subgraph_matcher import SubgraphMatch
22 from mo.graph.graph import Graph
23 from mo.utils.custom_replacement_config import CustomReplacementDescriptor
24 from mo.utils.error import Error
27 class FakePipelineConfig:
28 def __init__(self, model_params: dict):
29 self._model_params = model_params
31 def get_param(self, param: str):
32 if param not in self._model_params:
34 return self._model_params[param]
37 class TestCalculateShape(unittest.TestCase):
41 def test_calculate_shape_1(self):
42 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(100, 300, self.min_size, self.max_size), (341, 1024))
44 def test_calculate_shape_2(self):
45 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(100, 600, self.min_size, self.max_size), (171, 1024))
47 def test_calculate_shape_3(self):
48 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(100, 3000, self.min_size, self.max_size), (34, 1024))
50 def test_calculate_shape_4(self):
51 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 300, self.min_size, self.max_size), (600, 600))
53 def test_calculate_shape_5(self):
54 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 400, self.min_size, self.max_size), (600, 800))
56 def test_calculate_shape_6(self):
57 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 600, self.min_size, self.max_size), (512, 1024))
59 def test_calculate_shape_7(self):
60 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(1000, 2500, self.min_size, self.max_size),
63 def test_calculate_shape_8(self):
64 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(1800, 2000, self.min_size, self.max_size),
67 def test_calculate_shape_11(self):
68 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(300, 100, self.min_size, self.max_size), (1024, 341))
70 def test_calculate_shape_12(self):
71 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(600, 100, self.min_size, self.max_size), (1024, 171))
73 def test_calculate_shape_13(self):
74 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(3000, 100, self.min_size, self.max_size), (1024, 34))
76 def test_calculate_shape_15(self):
77 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(400, 300, self.min_size, self.max_size), (800, 600))
79 def test_calculate_shape_16(self):
80 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(600, 300, self.min_size, self.max_size), (1024, 512))
82 def test_calculate_shape_17(self):
83 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(2500, 1000, self.min_size, self.max_size),
86 def test_calculate_shape_18(self):
87 self.assertTupleEqual(calculate_shape_keeping_aspect_ratio(2000, 1800, self.min_size, self.max_size),
91 class TestCalculatePlaceholderSpatialShape(unittest.TestCase):
94 self.graph.graph['user_shapes'] = None
95 self.replacement_desc = CustomReplacementDescriptor('dummy_id', {})
96 self.match = SubgraphMatch(self.graph, self.replacement_desc, [], [], [], '')
97 self.pipeline_config = FakePipelineConfig({})
99 def test_default_fixed_shape_resizer(self):
100 self.pipeline_config._model_params['resizer_image_height'] = 300
101 self.pipeline_config._model_params['resizer_image_width'] = 600
102 self.assertTupleEqual((300, 600),
103 calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
105 def test_fixed_shape_resizer_overrided_by_user(self):
106 self.pipeline_config._model_params['resizer_image_height'] = 300
107 self.pipeline_config._model_params['resizer_image_width'] = 600
108 self.graph.graph['user_shapes'] = {'image_tensor': [{'shape': [1, 400, 500, 3]}]}
109 self.assertTupleEqual((400, 500),
110 calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
112 def test_default_keep_aspect_ratio_resizer(self):
113 self.pipeline_config._model_params['resizer_min_dimension'] = 600
114 self.pipeline_config._model_params['resizer_max_dimension'] = 1024
115 self.assertTupleEqual((600, 600),
116 calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
118 def test_keep_aspect_ratio_resizer_overrided_by_user(self):
119 self.pipeline_config._model_params['resizer_min_dimension'] = 600
120 self.pipeline_config._model_params['resizer_max_dimension'] = 1024
121 self.graph.graph['user_shapes'] = {'image_tensor': [{'shape': [1, 400, 300, 3]}]}
122 self.assertTupleEqual((800, 600),
123 calculate_placeholder_spatial_shape(self.graph, self.match, self.pipeline_config))
125 def test_missing_input_shape_information(self):
126 self.assertRaises(Error, calculate_placeholder_spatial_shape, self.graph, self.match, self.pipeline_config)