2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from collections import namedtuple
19 from ..presenters import BasePresenter, EvaluationResult
20 from ..config import StringField
21 from ..utils import zipped_transform
22 from .metric import BaseMetricConfig, Metric
23 from ..config import ConfigError
25 MetricInstance = namedtuple('MetricInstance', ['name', 'metric_fn', 'reference', 'threshold', 'presenter'])
28 class MetricConfig(BaseMetricConfig):
29 type = StringField(choices=Metric.providers)
32 class MetricsExecutor:
34 Class for evaluating metrics according to dataset configuration entry.
37 def __init__(self, dataset_config, dataset, state=None):
38 dataset_name = dataset_config.get('name', '')
39 message_prefix = '{}'.format(dataset_name)
41 self.state = state or {}
42 self._token = 'metrics'
44 dataset_metrics = dataset_config.get(self._token)
45 if not dataset_metrics:
46 raise ConfigError('{} dataset config must specify "{}"'.format(message_prefix, self._token))
48 self.dataset = dataset
53 reference = 'reference'
54 threshold = 'threshold'
55 presenter = 'presenter'
57 for metric_config_entry in dataset_metrics:
58 metric_config = MetricConfig(
59 "{}.metrics".format(dataset_name), on_extra_argument=MetricConfig.IGNORE_ON_EXTRA_ARGUMENT
61 metric_type = metric_config_entry.get(type_)
62 metric_config.validate(metric_config_entry, type_)
64 metric_identifier = metric_config_entry.get(identifier, metric_type)
66 metric_fn = Metric.provide(
67 metric_type, metric_config_entry, self.dataset, metric_identifier, state=self.state
69 metric_presenter = BasePresenter.provide(metric_config_entry.get(presenter, 'print_scalar'))
71 self.metrics.append(MetricInstance(
74 metric_config_entry.get(reference),
75 metric_config_entry.get(threshold),
79 def update_metrics_on_object(self, annotation, prediction):
81 Updates metric value corresponding given annotation and prediction objects.
84 for metric in self.metrics:
85 metric.metric_fn.submit(annotation, prediction)
87 def update_metrics_on_batch(self, annotation, prediction):
89 Updates metric value corresponding given batch.
92 annotation: list of batch number of annotation objects.
93 prediction: list of batch number of prediction objects.
96 zipped_transform(self.update_metrics_on_object, annotation, prediction)
98 def iterate_metrics(self, annotations, predictions):
99 for name, functor, reference, threshold, presenter in self.metrics:
100 yield presenter, EvaluationResult(
102 evaluated_value=functor(annotations, predictions),
103 reference_value=reference,