2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from .utils import get_path
21 from .dataset import Dataset
22 from .launcher import create_launcher, DummyLauncher
23 from .launcher.loaders import PickleLoader
24 from .logging import print_info
25 from .metrics import MetricsExecutor
26 from .postprocessor import PostprocessingExecutor
27 from .preprocessor import PreprocessingExecutor
31 def __init__(self, launcher, preprocessor, postprocessor, dataset, metric):
32 self.launcher = launcher
33 self.preprocessor = preprocessor
34 self.postprocessor = postprocessor
35 self.dataset = dataset
36 self.metric_executor = metric
38 self._annotations = []
39 self._predictions = []
42 def from_configs(cls, launcher_config, dataset_config):
43 dataset_name = dataset_config['name']
44 preprocessor = PreprocessingExecutor(dataset_config.get('preprocessing'), dataset_name)
45 dataset = Dataset(dataset_config, preprocessor)
47 launcher = create_launcher(launcher_config, dataset.metadata)
48 postprocessor = PostprocessingExecutor(dataset_config.get('postprocessing'), dataset_name, dataset.metadata)
49 metric_dispatcher = MetricsExecutor(dataset_config, dataset)
51 return cls(launcher, preprocessor, postprocessor, dataset, metric_dispatcher)
53 def process_dataset(self, stored_predictions, progress_reporter, *args, **kwargs):
54 if self._is_stored(stored_predictions) or isinstance(self.launcher, DummyLauncher):
55 self._annotations, self._predictions = self.load(stored_predictions, progress_reporter)
56 self._annotations, self._predictions = self.postprocessor.full_process(self._annotations, self._predictions)
58 self.metric_executor.update_metrics_on_batch(self._annotations, self._predictions)
59 return self._annotations, self._predictions
61 self.dataset.batch = self.launcher.batch
62 predictions_to_store = []
63 for batch_id, (batch_annotation, batch_input) in enumerate(self.dataset):
64 batch_identifiers = [annotation.identifier for annotation in batch_annotation]
65 batch_predictions = self.launcher.predict(batch_identifiers, batch_input, *args, **kwargs)
67 if stored_predictions:
68 predictions_to_store.extend(copy.deepcopy(batch_predictions))
70 annotations, predictions = self.postprocessor.process_batch(batch_annotation, batch_predictions)
71 if not self.postprocessor.has_dataset_processors:
72 self.metric_executor.update_metrics_on_batch(annotations, predictions)
74 self._annotations.extend(annotations)
75 self._predictions.extend(predictions)
78 progress_reporter.update(batch_id, len(batch_predictions))
81 progress_reporter.finish()
83 if stored_predictions:
84 self.store_predictions(stored_predictions, predictions_to_store)
86 if self.postprocessor.has_dataset_processors:
87 self.metric_executor.update_metrics_on_batch(self._annotations, self._predictions)
89 return self.postprocessor.process_dataset(self._annotations, self._predictions)
92 def _is_stored(stored_predictions=None):
93 if not stored_predictions:
97 get_path(stored_predictions)
102 def compute_metrics(self, output_callback=None, ignore_results_formatting=False):
103 for result_presenter, evaluated_metric in self.metric_executor.iterate_metrics(
104 self._annotations, self._predictions):
105 result_presenter.write_result(evaluated_metric, output_callback, ignore_results_formatting)
107 def load(self, stored_predictions, progress_reporter):
108 self._annotations = self.dataset.annotation
109 launcher = self.launcher
110 if not isinstance(launcher, DummyLauncher):
111 launcher = DummyLauncher({
112 'framework': 'dummy',
113 'loader': PickleLoader.__provider__,
114 'data_path': stored_predictions
117 predictions = launcher.predict([annotation.identifier for annotation in self._annotations])
119 if progress_reporter:
120 progress_reporter.finish(False)
122 return self._annotations, predictions
125 def store_predictions(stored_predictions, predictions):
126 # since at the first time file does not exist and then created we can not use it as a pathlib.Path object
127 with open(stored_predictions, "wb") as content:
128 pickle.dump(predictions, content)
129 print_info("prediction objects are save to {}".format(stored_predictions))
132 self.launcher.release()