Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / metrics / metric_executor.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 from collections import namedtuple
18
19 from ..presenters import BasePresenter, EvaluationResult
20 from ..config import StringField
21 from ..utils import zipped_transform
22 from .metric import BaseMetricConfig, Metric
23 from ..config import ConfigError
24
25 MetricInstance = namedtuple('MetricInstance', ['name', 'metric_fn', 'reference', 'threshold', 'presenter'])
26
27
28 class MetricConfig(BaseMetricConfig):
29     type = StringField(choices=Metric.providers)
30
31
32 class MetricsExecutor:
33     """
34     Class for evaluating metrics according to dataset configuration entry.
35     """
36
37     def __init__(self, dataset_config, dataset, state=None):
38         dataset_name = dataset_config.get('name', '')
39         message_prefix = '{}'.format(dataset_name)
40
41         self.state = state or {}
42         self._token = 'metrics'
43
44         dataset_metrics = dataset_config.get(self._token)
45         if not dataset_metrics:
46             raise ConfigError('{} dataset config must specify "{}"'.format(message_prefix, self._token))
47
48         self.dataset = dataset
49
50         self.metrics = []
51         type_ = 'type'
52         identifier = 'name'
53         reference = 'reference'
54         threshold = 'threshold'
55         presenter = 'presenter'
56
57         for metric_config_entry in dataset_metrics:
58             metric_config = MetricConfig(
59                 "{}.metrics".format(dataset_name), on_extra_argument=MetricConfig.IGNORE_ON_EXTRA_ARGUMENT
60             )
61             metric_type = metric_config_entry.get(type_)
62             metric_config.validate(metric_config_entry, type_)
63
64             metric_identifier = metric_config_entry.get(identifier, metric_type)
65
66             metric_fn = Metric.provide(
67                 metric_type, metric_config_entry, self.dataset, metric_identifier, state=self.state
68             )
69             metric_presenter = BasePresenter.provide(metric_config_entry.get(presenter, 'print_scalar'))
70
71             self.metrics.append(MetricInstance(
72                 metric_identifier,
73                 metric_fn,
74                 metric_config_entry.get(reference),
75                 metric_config_entry.get(threshold),
76                 metric_presenter
77             ))
78
79     def update_metrics_on_object(self, annotation, prediction):
80         """
81         Updates metric value corresponding given annotation and prediction objects.
82         """
83
84         for metric in self.metrics:
85             metric.metric_fn.submit(annotation, prediction)
86
87     def update_metrics_on_batch(self, annotation, prediction):
88         """
89         Updates metric value corresponding given batch.
90
91         Args:
92             annotation: list of batch number of annotation objects.
93             prediction: list of batch number of prediction objects.
94         """
95
96         zipped_transform(self.update_metrics_on_object, annotation, prediction)
97
98     def iterate_metrics(self, annotations, predictions):
99         for name, functor, reference, threshold, presenter in self.metrics:
100             yield presenter, EvaluationResult(
101                 name=name,
102                 evaluated_value=functor(annotations, predictions),
103                 reference_value=reference,
104                 threshold=threshold,
105                 meta=functor.meta,
106             )