2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from unittest.mock import Mock, MagicMock
19 from accuracy_checker.model_evaluator import ModelEvaluator
22 class TestModelEvaluator:
23 def setup_method(self):
24 self.launcher = Mock()
25 self.launcher.predict.return_value = []
27 self.preprocessor = Mock()
28 self.postprocessor = Mock()
31 annotation_0.identifier = 0
33 annotation_1.identifier = 1
34 annotation_container_0 = Mock()
35 annotation_container_0.values = Mock(return_value=[annotation_0])
36 annotation_container_1 = Mock()
37 annotation_container_1.values = Mock(return_value=([annotation_1]))
39 ([annotation_container_0], [annotation_container_0]),
40 ([annotation_container_1], [annotation_container_1])
43 self.dataset = MagicMock()
44 self.dataset.__iter__.return_value = self.annotations
46 self.postprocessor.process_batch = Mock(side_effect=[
47 ([annotation_container_0], [annotation_container_0]), ([annotation_container_1], [annotation_container_1])
49 self.postprocessor.process_dataset = Mock(return_value=(
50 ([annotation_container_0], [annotation_container_0]), ([annotation_container_1], [annotation_container_1])
52 self.postprocessor.full_process = Mock(return_value=(
53 ([annotation_container_0], [annotation_container_0]), ([annotation_container_1], [annotation_container_1])
57 self.metric.update_metrics_on_batch = Mock()
59 self.evaluator = ModelEvaluator(self.launcher, self.preprocessor, self.postprocessor, self.dataset, self.metric)
60 self.evaluator.store_predictions = Mock()
61 self.evaluator.load = Mock(return_value=(
62 ([annotation_container_0], [annotation_container_0]), ([annotation_container_1], [annotation_container_1])
65 def test_process_dataset_without_storing_predictions_and_dataset_processors(self):
66 self.postprocessor.has_dataset_processors = False
68 self.evaluator.process_dataset(None, None)
70 assert not self.evaluator.store_predictions.called
71 assert not self.evaluator.load.called
72 assert self.launcher.predict.called
73 assert self.postprocessor.process_batch.called
74 assert self.metric.update_metrics_on_batch.call_count == len(self.annotations)
75 assert self.postprocessor.process_dataset.called
76 assert not self.postprocessor.full_process.called
78 def test_process_dataset_without_storing_predictions_and_with_dataset_processors(self):
79 self.postprocessor.has_dataset_processors = True
81 self.evaluator.process_dataset(None, None)
83 assert not self.evaluator.store_predictions.called
84 assert not self.evaluator.load.called
85 assert self.launcher.predict.called
86 assert self.postprocessor.process_batch.called
87 assert self.metric.update_metrics_on_batch.call_count == 1
88 assert self.postprocessor.process_dataset.called
89 assert not self.postprocessor.full_process.called
91 def test_process_dataset_with_storing_predictions_and_without_dataset_processors(self):
92 self.postprocessor.has_dataset_processors = False
94 self.evaluator.process_dataset('path', None)
96 assert self.evaluator.store_predictions.called
97 assert not self.evaluator.load.called
98 assert self.launcher.predict.called
99 assert self.postprocessor.process_batch.called
100 assert self.metric.update_metrics_on_batch.call_count == len(self.annotations)
101 assert self.postprocessor.process_dataset.called
102 assert not self.postprocessor.full_process.called
104 def test_process_dataset_with_storing_predictions_and_with_dataset_processors(self):
105 self.postprocessor.has_dataset_processors = True
107 self.evaluator.process_dataset('path', None)
109 assert self.evaluator.store_predictions.called
110 assert not self.evaluator.load.called
111 assert self.launcher.predict.called
112 assert self.postprocessor.process_batch.called
113 assert self.metric.update_metrics_on_batch.call_count == 1
114 assert self.postprocessor.process_dataset.called
115 assert not self.postprocessor.full_process.called
117 def test_process_dataset_with_loading_predictions_and_without_dataset_processors(self, mocker):
118 mocker.patch('accuracy_checker.model_evaluator.get_path')
119 self.postprocessor.has_dataset_processors = False
121 self.evaluator.process_dataset('path', None)
123 assert not self.evaluator.store_predictions.called
124 assert self.evaluator.load.called
125 assert not self.launcher.predict.called
126 assert not self.postprocessor.process_batch.called
127 assert self.metric.update_metrics_on_batch.call_count == 1
128 assert not self.postprocessor.process_dataset.called
129 assert self.postprocessor.full_process.called
131 def test_process_dataset_with_loading_predictions_and_with_dataset_processors(self, mocker):
132 mocker.patch('accuracy_checker.model_evaluator.get_path')
133 self.postprocessor.has_dataset_processors = True
135 self.evaluator.process_dataset('path', None)
137 assert not self.evaluator.store_predictions.called
138 assert self.evaluator.load.called
139 assert not self.launcher.predict.called
140 assert not self.postprocessor.process_batch.called
141 assert self.metric.update_metrics_on_batch.call_count == 1
142 assert not self.postprocessor.process_dataset.called
143 assert self.postprocessor.full_process.called