Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / presenters.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 from collections import namedtuple
18 from enum import Enum
19 import numpy as np
20
21 from .dependency import ClassProvider
22 from .logging import print_info
23
24 EvaluationResult = namedtuple('EvaluationResult', ['evaluated_value', 'reference_value', 'name', 'threshold', 'meta'])
25
26
27 class Color(Enum):
28     PASSED = 0
29     FAILED = 1
30
31
32 def color_format(s, color=Color.PASSED):
33     if color == Color.PASSED:
34         return "\x1b[0;32m{}\x1b[0m".format(s)
35     return "\x1b[0;31m{}\x1b[0m".format(s)
36
37
38 class BasePresenter(ClassProvider):
39     __provider_type__ = "presenter"
40
41     def write_result(self, evaluation_result, output_callback=None, ignore_results_formatting=False):
42         raise NotImplementedError
43
44
45 class ScalarPrintPresenter(BasePresenter):
46     __provider__ = "print_scalar"
47
48     def write_result(self, evaluation_result: EvaluationResult, output_callback=None, ignore_results_formatting=False):
49         value, reference, name, threshold, meta = evaluation_result
50         value = np.mean(value)
51         postfix, scale, result_format = get_result_format_parameters(meta, ignore_results_formatting)
52         write_scalar_result(
53             value, name, reference, threshold, postfix=postfix, scale=scale, result_format=result_format
54         )
55
56
57 class VectorPrintPresenter(BasePresenter):
58     __provider__ = "print_vector"
59
60     def write_result(self, evaluation_result: EvaluationResult, output_callback=None, ignore_results_formatting=False):
61         value, reference, name, threshold, meta = evaluation_result
62         if threshold:
63             threshold = float(threshold)
64
65         value_names = meta.get('names')
66         postfix, scale, result_format = get_result_format_parameters(meta, ignore_results_formatting)
67         if np.isscalar(value) or np.size(value) == 1:
68             value = [value]
69
70         for index, res in enumerate(value):
71             write_scalar_result(
72                 res, name, reference, threshold,
73                 value_name=value_names[index] if value_names else None,
74                 postfix=postfix[index] if not np.isscalar(postfix) else postfix,
75                 scale=scale[index] if not np.isscalar(scale) else scale,
76                 result_format=result_format
77             )
78
79         if len(value) > 1 and meta.get('calculate_mean', True):
80             write_scalar_result(
81                 np.mean(np.multiply(value, scale)), name, reference, threshold, value_name='mean',
82                 postfix=postfix[-1] if not np.isscalar(postfix) else postfix, scale=1,
83                 result_format=result_format
84             )
85
86
87 def write_scalar_result(res_value, name, reference, threshold, value_name=None, postfix='%', scale=100,
88                         result_format='{:.2f}'):
89     display_name = "{}@{}".format(name, value_name) if value_name else name
90     display_result = result_format.format(res_value * scale)
91     message = '{}: {}{}'.format(display_name, display_result, postfix)
92
93     if reference:
94         threshold = threshold or 0
95
96         difference = abs(reference - (res_value * scale))
97         if threshold <= difference:
98             fail_message = "[FAILED: error = {:.4}]".format(difference)
99             message = "{} {}".format(message, color_format(fail_message, Color.FAILED))
100         else:
101             message = "{} {}".format(message, color_format("[OK]", Color.PASSED))
102
103     print_info(message)
104
105
106 class ReturnValuePresenter(BasePresenter):
107     __provider__ = "return_value"
108
109     def write_result(self, evaluation_result: EvaluationResult, output_callback=None, ignore_results_formatting=False):
110         if output_callback:
111             output_callback(evaluation_result)
112
113
114 def get_result_format_parameters(meta, use_default_formatting):
115     postfix = ' '
116     scale = 1
117     result_format = '{}'
118     if not use_default_formatting:
119         postfix = meta.get('postfix', '%')
120         scale = meta.get('scale', 100)
121         result_format = meta.get('data_format', '{:.2f}')
122
123     return postfix, scale, result_format