2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from ..utils import remove_difficult
20 from .base_representation import BaseRepresentation
23 class Detection(BaseRepresentation):
24 def __init__(self, identifier='', labels=None, x_mins=None, y_mins=None, x_maxs=None, y_maxs=None, metadata=None):
25 super().__init__(identifier, metadata)
27 self.labels = np.array(labels) if labels is not None else np.array([])
28 self.x_mins = np.array(x_mins) if x_mins is not None else np.array([])
29 self.y_mins = np.array(y_mins) if y_mins is not None else np.array([])
30 self.x_maxs = np.array(x_maxs) if x_maxs is not None else np.array([])
31 self.y_maxs = np.array(y_maxs) if y_maxs is not None else np.array([])
33 def remove(self, indexes):
34 self.labels = np.delete(self.labels, indexes)
35 self.x_mins = np.delete(self.x_mins, indexes)
36 self.y_mins = np.delete(self.y_mins, indexes)
37 self.x_maxs = np.delete(self.x_maxs, indexes)
38 self.y_maxs = np.delete(self.y_maxs, indexes)
40 difficult_boxes = self.metadata.get('difficult_boxes')
41 if not difficult_boxes:
44 new_difficult_boxes = remove_difficult(difficult_boxes, indexes)
46 self.metadata['difficult_boxes'] = new_difficult_boxes
50 return len(self.x_mins)
52 def __eq__(self, other):
53 if not isinstance(other, type(self)):
56 def are_bounding_boxes_equal():
57 if not np.array_equal(self.labels, other.labels):
59 if not np.array_equal(self.x_mins, other.x_mins):
61 if not np.array_equal(self.y_mins, other.y_mins):
63 if not np.array_equal(self.x_maxs, other.x_maxs):
65 if not np.array_equal(self.y_maxs, other.y_maxs):
69 return self.identifier == other.identifier and are_bounding_boxes_equal() and self.metadata == other.metadata
72 class DetectionAnnotation(Detection):
76 class DetectionPrediction(Detection):
77 def __init__(self, identifier='', labels=None, scores=None, x_mins=None, y_mins=None, x_maxs=None, y_maxs=None,
79 super().__init__(identifier, labels, x_mins, y_mins, x_maxs, y_maxs, metadata)
80 self.scores = np.array(scores) if scores is not None else np.array([])
82 def remove(self, indexes):
83 super().remove(indexes)
84 self.scores = np.delete(self.scores, indexes)
86 def __eq__(self, other):
87 return np.array_equal(self.scores, other.scores) if super().__eq__(other) else False