2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from collections import namedtuple
21 from .dependency import ClassProvider
22 from .logging import print_info
24 EvaluationResult = namedtuple('EvaluationResult', ['evaluated_value', 'reference_value', 'name', 'threshold', 'meta'])
32 def color_format(s, color=Color.PASSED):
33 if color == Color.PASSED:
34 return "\x1b[0;32m{}\x1b[0m".format(s)
35 return "\x1b[0;31m{}\x1b[0m".format(s)
38 class BasePresenter(ClassProvider):
39 __provider_type__ = "presenter"
41 def write_result(self, evaluation_result, output_callback=None, ignore_results_formatting=False):
42 raise NotImplementedError
45 class ScalarPrintPresenter(BasePresenter):
46 __provider__ = "print_scalar"
48 def write_result(self, evaluation_result: EvaluationResult, output_callback=None, ignore_results_formatting=False):
49 value, reference, name, threshold, meta = evaluation_result
50 value = np.mean(value)
51 postfix, scale, result_format = get_result_format_parameters(meta, ignore_results_formatting)
53 value, name, reference, threshold, postfix=postfix, scale=scale, result_format=result_format
57 class VectorPrintPresenter(BasePresenter):
58 __provider__ = "print_vector"
60 def write_result(self, evaluation_result: EvaluationResult, output_callback=None, ignore_results_formatting=False):
61 value, reference, name, threshold, meta = evaluation_result
63 threshold = float(threshold)
65 value_names = meta.get('names')
66 postfix, scale, result_format = get_result_format_parameters(meta, ignore_results_formatting)
67 if np.isscalar(value) or np.size(value) == 1:
70 for index, res in enumerate(value):
72 res, name, reference, threshold,
73 value_name=value_names[index] if value_names else None,
74 postfix=postfix[index] if not np.isscalar(postfix) else postfix,
75 scale=scale[index] if not np.isscalar(scale) else scale,
76 result_format=result_format
79 if len(value) > 1 and meta.get('calculate_mean', True):
81 np.mean(np.multiply(value, scale)), name, reference, threshold, value_name='mean',
82 postfix=postfix[-1] if not np.isscalar(postfix) else postfix, scale=1,
83 result_format=result_format
87 def write_scalar_result(res_value, name, reference, threshold, value_name=None, postfix='%', scale=100,
88 result_format='{:.2f}'):
89 display_name = "{}@{}".format(name, value_name) if value_name else name
90 display_result = result_format.format(res_value * scale)
91 message = '{}: {}{}'.format(display_name, display_result, postfix)
94 threshold = threshold or 0
96 difference = abs(reference - (res_value * scale))
97 if threshold <= difference:
98 fail_message = "[FAILED: error = {:.4}]".format(difference)
99 message = "{} {}".format(message, color_format(fail_message, Color.FAILED))
101 message = "{} {}".format(message, color_format("[OK]", Color.PASSED))
106 class ReturnValuePresenter(BasePresenter):
107 __provider__ = "return_value"
109 def write_result(self, evaluation_result: EvaluationResult, output_callback=None, ignore_results_formatting=False):
111 output_callback(evaluation_result)
114 def get_result_format_parameters(meta, use_default_formatting):
118 if not use_default_formatting:
119 postfix = meta.get('postfix', '%')
120 scale = meta.get('scale', 100)
121 result_format = meta.get('data_format', '{:.2f}')
123 return postfix, scale, result_format