Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / tests / test_model_evaluator.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 from unittest.mock import Mock, MagicMock
18
19 from accuracy_checker.model_evaluator import ModelEvaluator
20
21
22 class TestModelEvaluator:
23     def setup_method(self):
24         self.launcher = Mock()
25         self.launcher.predict.return_value = []
26
27         self.preprocessor = Mock()
28         self.postprocessor = Mock()
29
30         annotation_0 = Mock()
31         annotation_0.identifier = 0
32         annotation_1 = Mock()
33         annotation_1.identifier = 1
34         annotation_container_0 = Mock()
35         annotation_container_0.values = Mock(return_value=[annotation_0])
36         annotation_container_1 = Mock()
37         annotation_container_1.values = Mock(return_value=([annotation_1]))
38         self.annotations = [
39             ([annotation_container_0], [annotation_container_0]),
40             ([annotation_container_1], [annotation_container_1])
41         ]
42
43         self.dataset = MagicMock()
44         self.dataset.__iter__.return_value = self.annotations
45
46         self.postprocessor.process_batch = Mock(side_effect=[
47             ([annotation_container_0], [annotation_container_0]), ([annotation_container_1], [annotation_container_1])
48         ])
49         self.postprocessor.process_dataset = Mock(return_value=(
50             ([annotation_container_0], [annotation_container_0]), ([annotation_container_1], [annotation_container_1])
51         ))
52         self.postprocessor.full_process = Mock(return_value=(
53             ([annotation_container_0], [annotation_container_0]), ([annotation_container_1], [annotation_container_1])
54         ))
55
56         self.metric = Mock()
57         self.metric.update_metrics_on_batch = Mock()
58
59         self.evaluator = ModelEvaluator(self.launcher, self.preprocessor, self.postprocessor, self.dataset, self.metric)
60         self.evaluator.store_predictions = Mock()
61         self.evaluator.load = Mock(return_value=(
62             ([annotation_container_0], [annotation_container_0]), ([annotation_container_1], [annotation_container_1])
63         ))
64
65     def test_process_dataset_without_storing_predictions_and_dataset_processors(self):
66         self.postprocessor.has_dataset_processors = False
67
68         self.evaluator.process_dataset(None, None)
69
70         assert not self.evaluator.store_predictions.called
71         assert not self.evaluator.load.called
72         assert self.launcher.predict.called
73         assert self.postprocessor.process_batch.called
74         assert self.metric.update_metrics_on_batch.call_count == len(self.annotations)
75         assert self.postprocessor.process_dataset.called
76         assert not self.postprocessor.full_process.called
77
78     def test_process_dataset_without_storing_predictions_and_with_dataset_processors(self):
79         self.postprocessor.has_dataset_processors = True
80
81         self.evaluator.process_dataset(None, None)
82
83         assert not self.evaluator.store_predictions.called
84         assert not self.evaluator.load.called
85         assert self.launcher.predict.called
86         assert self.postprocessor.process_batch.called
87         assert self.metric.update_metrics_on_batch.call_count == 1
88         assert self.postprocessor.process_dataset.called
89         assert not self.postprocessor.full_process.called
90
91     def test_process_dataset_with_storing_predictions_and_without_dataset_processors(self):
92         self.postprocessor.has_dataset_processors = False
93
94         self.evaluator.process_dataset('path', None)
95
96         assert self.evaluator.store_predictions.called
97         assert not self.evaluator.load.called
98         assert self.launcher.predict.called
99         assert self.postprocessor.process_batch.called
100         assert self.metric.update_metrics_on_batch.call_count == len(self.annotations)
101         assert self.postprocessor.process_dataset.called
102         assert not self.postprocessor.full_process.called
103
104     def test_process_dataset_with_storing_predictions_and_with_dataset_processors(self):
105         self.postprocessor.has_dataset_processors = True
106
107         self.evaluator.process_dataset('path', None)
108
109         assert self.evaluator.store_predictions.called
110         assert not self.evaluator.load.called
111         assert self.launcher.predict.called
112         assert self.postprocessor.process_batch.called
113         assert self.metric.update_metrics_on_batch.call_count == 1
114         assert self.postprocessor.process_dataset.called
115         assert not self.postprocessor.full_process.called
116
117     def test_process_dataset_with_loading_predictions_and_without_dataset_processors(self, mocker):
118         mocker.patch('accuracy_checker.model_evaluator.get_path')
119         self.postprocessor.has_dataset_processors = False
120
121         self.evaluator.process_dataset('path', None)
122
123         assert not self.evaluator.store_predictions.called
124         assert self.evaluator.load.called
125         assert not self.launcher.predict.called
126         assert not self.postprocessor.process_batch.called
127         assert self.metric.update_metrics_on_batch.call_count == 1
128         assert not self.postprocessor.process_dataset.called
129         assert self.postprocessor.full_process.called
130
131     def test_process_dataset_with_loading_predictions_and_with_dataset_processors(self, mocker):
132         mocker.patch('accuracy_checker.model_evaluator.get_path')
133         self.postprocessor.has_dataset_processors = True
134
135         self.evaluator.process_dataset('path', None)
136
137         assert not self.evaluator.store_predictions.called
138         assert self.evaluator.load.called
139         assert not self.launcher.predict.called
140         assert not self.postprocessor.process_batch.called
141         assert self.metric.update_metrics_on_batch.call_count == 1
142         assert not self.postprocessor.process_dataset.called
143         assert self.postprocessor.full_process.called