Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / representation / detection_representation.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 import numpy as np
18
19 from ..utils import remove_difficult
20 from .base_representation import BaseRepresentation
21
22
23 class Detection(BaseRepresentation):
24     def __init__(self, identifier='', labels=None, x_mins=None, y_mins=None, x_maxs=None, y_maxs=None, metadata=None):
25         super().__init__(identifier, metadata)
26
27         self.labels = np.array(labels) if labels is not None else np.array([])
28         self.x_mins = np.array(x_mins) if x_mins is not None else np.array([])
29         self.y_mins = np.array(y_mins) if y_mins is not None else np.array([])
30         self.x_maxs = np.array(x_maxs) if x_maxs is not None else np.array([])
31         self.y_maxs = np.array(y_maxs) if y_maxs is not None else np.array([])
32
33     def remove(self, indexes):
34         self.labels = np.delete(self.labels, indexes)
35         self.x_mins = np.delete(self.x_mins, indexes)
36         self.y_mins = np.delete(self.y_mins, indexes)
37         self.x_maxs = np.delete(self.x_maxs, indexes)
38         self.y_maxs = np.delete(self.y_maxs, indexes)
39
40         difficult_boxes = self.metadata.get('difficult_boxes')
41         if not difficult_boxes:
42             return
43
44         new_difficult_boxes = remove_difficult(difficult_boxes, indexes)
45
46         self.metadata['difficult_boxes'] = new_difficult_boxes
47
48     @property
49     def size(self):
50         return len(self.x_mins)
51
52     def __eq__(self, other):
53         if not isinstance(other, type(self)):
54             return False
55
56         def are_bounding_boxes_equal():
57             if not np.array_equal(self.labels, other.labels):
58                 return False
59             if not np.array_equal(self.x_mins, other.x_mins):
60                 return False
61             if not np.array_equal(self.y_mins, other.y_mins):
62                 return False
63             if not np.array_equal(self.x_maxs, other.x_maxs):
64                 return False
65             if not np.array_equal(self.y_maxs, other.y_maxs):
66                 return False
67             return True
68
69         return self.identifier == other.identifier and are_bounding_boxes_equal() and self.metadata == other.metadata
70
71
72 class DetectionAnnotation(Detection):
73     pass
74
75
76 class DetectionPrediction(Detection):
77     def __init__(self, identifier='', labels=None, scores=None, x_mins=None, y_mins=None, x_maxs=None, y_maxs=None,
78                  metadata=None):
79         super().__init__(identifier, labels, x_mins, y_mins, x_maxs, y_maxs, metadata)
80         self.scores = np.array(scores) if scores is not None else np.array([])
81
82     def remove(self, indexes):
83         super().remove(indexes)
84         self.scores = np.delete(self.scores, indexes)
85
86     def __eq__(self, other):
87         return np.array_equal(self.scores, other.scores) if super().__eq__(other) else False