Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / model_evaluator.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 import copy
18 import pickle
19
20 from .utils import get_path
21 from .dataset import Dataset
22 from .launcher import create_launcher, DummyLauncher
23 from .launcher.loaders import PickleLoader
24 from .logging import print_info
25 from .metrics import MetricsExecutor
26 from .postprocessor import PostprocessingExecutor
27 from .preprocessor import PreprocessingExecutor
28
29
30 class ModelEvaluator:
31     def __init__(self, launcher, preprocessor, postprocessor, dataset, metric):
32         self.launcher = launcher
33         self.preprocessor = preprocessor
34         self.postprocessor = postprocessor
35         self.dataset = dataset
36         self.metric_executor = metric
37
38         self._annotations = []
39         self._predictions = []
40
41     @classmethod
42     def from_configs(cls, launcher_config, dataset_config):
43         dataset_name = dataset_config['name']
44         preprocessor = PreprocessingExecutor(dataset_config.get('preprocessing'), dataset_name)
45         dataset = Dataset(dataset_config, preprocessor)
46
47         launcher = create_launcher(launcher_config, dataset.metadata)
48         postprocessor = PostprocessingExecutor(dataset_config.get('postprocessing'), dataset_name, dataset.metadata)
49         metric_dispatcher = MetricsExecutor(dataset_config, dataset)
50
51         return cls(launcher, preprocessor, postprocessor, dataset, metric_dispatcher)
52
53     def process_dataset(self, stored_predictions, progress_reporter, *args, **kwargs):
54         if self._is_stored(stored_predictions) or isinstance(self.launcher, DummyLauncher):
55             self._annotations, self._predictions = self.load(stored_predictions, progress_reporter)
56             self._annotations, self._predictions = self.postprocessor.full_process(self._annotations, self._predictions)
57
58             self.metric_executor.update_metrics_on_batch(self._annotations, self._predictions)
59             return self._annotations, self._predictions
60
61         self.dataset.batch = self.launcher.batch
62         predictions_to_store = []
63         for batch_id, (batch_annotation, batch_input) in enumerate(self.dataset):
64             batch_identifiers = [annotation.identifier for annotation in batch_annotation]
65             batch_predictions = self.launcher.predict(batch_identifiers, batch_input, *args, **kwargs)
66
67             if stored_predictions:
68                 predictions_to_store.extend(copy.deepcopy(batch_predictions))
69
70             annotations, predictions = self.postprocessor.process_batch(batch_annotation, batch_predictions)
71             if not self.postprocessor.has_dataset_processors:
72                 self.metric_executor.update_metrics_on_batch(annotations, predictions)
73
74             self._annotations.extend(annotations)
75             self._predictions.extend(predictions)
76
77             if progress_reporter:
78                 progress_reporter.update(batch_id, len(batch_predictions))
79
80         if progress_reporter:
81             progress_reporter.finish()
82
83         if stored_predictions:
84             self.store_predictions(stored_predictions, predictions_to_store)
85
86         if self.postprocessor.has_dataset_processors:
87             self.metric_executor.update_metrics_on_batch(self._annotations, self._predictions)
88
89         return self.postprocessor.process_dataset(self._annotations, self._predictions)
90
91     @staticmethod
92     def _is_stored(stored_predictions=None):
93         if not stored_predictions:
94             return False
95
96         try:
97             get_path(stored_predictions)
98             return True
99         except OSError:
100             return False
101
102     def compute_metrics(self, output_callback=None, ignore_results_formatting=False):
103         for result_presenter, evaluated_metric in self.metric_executor.iterate_metrics(
104                 self._annotations, self._predictions):
105             result_presenter.write_result(evaluated_metric, output_callback, ignore_results_formatting)
106
107     def load(self, stored_predictions, progress_reporter):
108         self._annotations = self.dataset.annotation
109         launcher = self.launcher
110         if not isinstance(launcher, DummyLauncher):
111             launcher = DummyLauncher({
112                 'framework': 'dummy',
113                 'loader': PickleLoader.__provider__,
114                 'data_path': stored_predictions
115             }, adapter=None)
116
117         predictions = launcher.predict([annotation.identifier for annotation in self._annotations])
118
119         if progress_reporter:
120             progress_reporter.finish(False)
121
122         return self._annotations, predictions
123
124     @staticmethod
125     def store_predictions(stored_predictions, predictions):
126         # since at the first time file does not exist and then created we can not use it as a pathlib.Path object
127         with open(stored_predictions, "wb") as content:
128             pickle.dump(predictions, content)
129             print_info("prediction objects are save to {}".format(stored_predictions))
130
131     def release(self):
132         self.launcher.release()