Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / calibration / inference_result.py
1 """
2 Copyright (C) 2018-2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 from .aggregated_statistics import AggregatedStatistics
18 from .calibration_metrics import CalibrationMetrics
19 from .infer_raw_results import InferRawResults
20
21
22 class InferenceResult:
23     def __init__(self,
24                  result: InferRawResults,
25                  metrics: CalibrationMetrics,
26                  aggregated_statistics: AggregatedStatistics,
27                  performance_counters: dict):
28         self._result = result
29         self._metrics = metrics
30         self._aggregated_statistics = aggregated_statistics
31         self._performance_counters = performance_counters
32
33     def __enter__(self):
34         return self
35
36     def __exit__(self, type, value, tb):
37         self.release()
38
39     def release(self):
40         if self._result:
41             self._result.release()
42             self._result = None
43
44     @property
45     def result(self) -> InferRawResults:
46         return self._result
47
48     @property
49     def metrics(self) -> CalibrationMetrics:
50         return self._metrics
51
52     @property
53     def aggregated_statistics(self) -> AggregatedStatistics:
54         return self._aggregated_statistics
55
56     @property
57     def performance_counters(self) -> dict:
58         return self._performance_counters
59
60     def get_class_ids(self, output_layer_name: str) -> list:
61         '''
62         Return class identifier list for classification networks
63         '''
64
65         result_classes_id_list = list()
66         for layers_result in self._result:
67             if output_layer_name not in layers_result:
68                 raise KeyError("layer '{}' is not included int results".format(output_layer_name))
69
70             layer_result = layers_result[output_layer_name]
71             if layer_result.size == 0:
72                 raise ValueError("result array is empty")
73
74             max_value = layer_result.item(0)
75             max_class_id = 0
76
77             for class_id in range(layer_result.size):
78                 value = layer_result.item(class_id)
79                 if value > max_value:
80                     max_value = value
81                     max_class_id = class_id
82
83             result_classes_id_list.append(max_class_id)
84
85         return result_classes_id_list