2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from collections import defaultdict, namedtuple
18 from sklearn.metrics import auc, precision_recall_curve
19 # noinspection PyProtectedMember
20 from sklearn.metrics.base import _average_binary_score
23 from ..representation import (
24 ReIdentificationClassificationAnnotation,
25 ReIdentificationAnnotation,
26 ReIdentificationPrediction
28 from ..config import BaseField, BoolField, NumberField
29 from .metric import BaseMetricConfig, FullDatasetEvaluationMetric
31 PairDesc = namedtuple('PairDesc', 'image1 image2 same')
34 class CMCScore(FullDatasetEvaluationMetric):
36 Cumulative Matching Characteristics (CMC) score.
39 annotation: reid annotation.
40 prediction: predicted embeddings.
41 top_k: number of k highest ranked samples to consider when matching.
42 separate_camera_set: should identities from the same camera view be filtered out.
43 single_gallery_shot: each identity has only one instance in the gallery.
44 number_single_shot_repeats: number of repeats for single_gallery_shot setting.
45 first_match_break: break on first matched gallery sample.
50 annotation_types = (ReIdentificationAnnotation, )
51 prediction_types = (ReIdentificationPrediction, )
53 def validate_config(self):
54 class _CMCConfigValidator(BaseMetricConfig):
55 top_k = NumberField(floats=False, min_value=1, optional=True)
56 separate_camera_set = BoolField(optional=True)
57 single_gallery_shot = BoolField(optional=True)
58 first_match_break = BoolField(optional=True)
59 number_single_shot_repeats = NumberField(floats=False, optional=True)
61 validator = _CMCConfigValidator('cmc', on_extra_argument=_CMCConfigValidator.ERROR_ON_EXTRA_ARGUMENT)
62 validator.validate(self.config)
65 self.top_k = self.config.get('top_k', 1)
66 self.separate_camera_set = self.config.get('separate_camera_set', False)
67 self.single_gallery_shot = self.config.get('single_gallery_shot', False)
68 self.first_match_break = self.config.get('first_match_break', True)
69 self.number_single_shot_repeats = self.config.get('number_single_shot_repeats', 10)
71 def evaluate(self, annotations, predictions):
72 dist_matrix = distance_matrix(annotations, predictions)
73 gallery_cameras, gallery_pids, query_cameras, query_pids = get_gallery_query_pids(annotations)
75 _cmc_score = eval_cmc(
76 dist_matrix, query_pids, gallery_pids, query_cameras, gallery_cameras, self.separate_camera_set,
77 self.single_gallery_shot, self.first_match_break, self.number_single_shot_repeats
80 return _cmc_score[self.top_k - 1]
83 class ReidMAP(FullDatasetEvaluationMetric):
85 Mean Average Precision score.
88 annotation: reid annotation.
89 prediction: predicted embeddings.
90 interpolated_auc: should area under precision recall curve be computed using trapezoidal rule or directly.
93 __provider__ = 'reid_map'
95 annotation_types = (ReIdentificationAnnotation, )
96 prediction_types = (ReIdentificationPrediction, )
98 def validate_config(self):
99 class _ReidMapConfig(BaseMetricConfig):
100 interpolated_auc = BoolField(optional=True)
102 validator = _ReidMapConfig('reid_map', on_extra_argument=_ReidMapConfig.ERROR_ON_EXTRA_ARGUMENT)
103 validator.validate(self.config)
106 self.interpolated_auc = self.config.get('interpolated_auc', True)
108 def evaluate(self, annotations, predictions):
109 dist_matrix = distance_matrix(annotations, predictions)
110 gallery_cameras, gallery_pids, query_cameras, query_pids = get_gallery_query_pids(annotations)
113 dist_matrix, query_pids, gallery_pids, query_cameras, gallery_cameras, self.interpolated_auc
117 class PairwiseAccuracy(FullDatasetEvaluationMetric):
118 __provider__ = 'pairwise_accuracy'
120 annotation_types = (ReIdentificationClassificationAnnotation, )
121 prediction_types = (ReIdentificationPrediction, )
123 def validate_config(self):
124 class _PWAccConfig(BaseMetricConfig):
125 min_score = BaseField(optional=True)
127 validator = _PWAccConfig('pairwise_accuracy', on_extra_argument=_PWAccConfig.ERROR_ON_EXTRA_ARGUMENT)
128 validator.validate(self.config)
131 self.min_score = self.config.get('min_score', 'train_median')
133 def evaluate(self, annotations, predictions):
134 embed_distances, pairs = get_embedding_distances(annotations, predictions)
136 min_score = self.min_score
137 if min_score == 'train_median':
138 train_distances, _train_pairs = get_embedding_distances(annotations, predictions, train=True)
139 min_score = np.median(train_distances)
141 embed_same_class = embed_distances < min_score
144 for i, pair in enumerate(pairs):
145 same_label = pair.same
146 out_same = embed_same_class[i]
148 correct_prediction = same_label and out_same or (not same_label and not out_same)
150 if correct_prediction:
153 return float(accuracy) / len(pairs)
156 class PairwiseAccuracySubsets(FullDatasetEvaluationMetric):
157 __provider__ = 'pairwise_accuracy_subsets'
159 annotation_types = (ReIdentificationClassificationAnnotation, )
160 prediction_types = (ReIdentificationPrediction, )
162 def validate_config(self):
163 class _PWAccConfig(BaseMetricConfig):
164 subset_number = NumberField(optional=True, min_value=1, floats=False)
166 validator = _PWAccConfig('pairwise_accuracy', on_extra_argument=_PWAccConfig.ERROR_ON_EXTRA_ARGUMENT)
167 validator.validate(self.config)
170 self.meta['scale'] = 1
171 self.meta['postfix'] = ' '
172 self.subset_num = self.config.get('subset_number', 10)
173 self.accuracy_metric = PairwiseAccuracy(self.config, self.dataset)
175 def evaluate(self, annotations, predictions):
177 first_images_annotations = list(filter(
178 lambda annotation: (len(annotation.negative_pairs) > 0 or len(annotation.positive_pairs) > 0), annotations
181 idx_subsets = self.make_subsets(self.subset_num, len(first_images_annotations))
182 for subset in range(self.subset_num):
183 test_subset = self.get_subset(first_images_annotations, idx_subsets[subset]['test'])
184 test_subset = self.mark_subset(test_subset, False)
186 train_subset = self.get_subset(first_images_annotations, idx_subsets[subset]['train'])
187 train_subset = self.mark_subset(train_subset)
189 subset_result = self.accuracy_metric.evaluate(test_subset+train_subset, predictions)
190 subset_results.append(subset_result)
192 return np.mean(subset_results)
195 def make_subsets(subset_num, dataset_size):
197 if subset_num > dataset_size:
198 raise ValueError('It is impossible to divide dataset on more than number of annotations subsets.')
200 for subset in range(subset_num):
201 lower_bnd = subset * dataset_size // subset_num
202 upper_bnd = (subset + 1) * dataset_size // subset_num
203 subset_test = [(lower_bnd, upper_bnd)]
205 subset_train = [(0, lower_bnd), (upper_bnd, dataset_size)]
206 subsets.append({'test': subset_test, 'train': subset_train})
211 def mark_subset(subset_annotations, train=True):
212 for annotation in subset_annotations:
213 annotation.metadata['train'] = train
215 return subset_annotations
218 def get_subset(container, subset_bounds):
220 for bound in subset_bounds:
221 subset += container[bound[0]: bound[1]]
226 def extract_embeddings(annotation, prediction, query):
227 return np.stack([pred.embedding for pred, ann in zip(prediction, annotation) if ann.query == query])
230 def get_gallery_query_pids(annotation):
231 gallery_pids = np.asarray([ann.person_id for ann in annotation if not ann.query])
232 query_pids = np.asarray([ann.person_id for ann in annotation if ann.query])
233 gallery_cameras = np.asarray([ann.camera_id for ann in annotation if not ann.query])
234 query_cameras = np.asarray([ann.camera_id for ann in annotation if ann.query])
236 return gallery_cameras, gallery_pids, query_cameras, query_pids
239 def distance_matrix(annotation, prediction):
240 gallery_embeddings = extract_embeddings(annotation, prediction, query=False)
241 query_embeddings = extract_embeddings(annotation, prediction, query=True)
243 return 1. - np.matmul(gallery_embeddings, np.transpose(query_embeddings)).T
246 def unique_sample(ids_dict, num):
247 mask = np.zeros(num, dtype=np.bool)
248 for indices in ids_dict.values():
249 mask[np.random.choice(indices)] = True
254 def eval_map(distance_mat, query_ids, gallery_ids, query_cams, gallery_cams, interpolated_auc=False):
255 number_queries, _number_gallery = distance_mat.shape
256 # Sort and find correct matches
257 indices = np.argsort(distance_mat, axis=1)
258 matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) # type: np.ndarray
260 # Compute AP for each query
261 average_precisions = []
262 for query in range(number_queries):
263 # Filter out the same id and same camera
264 valid = (gallery_ids[indices[query]] != query_ids[query]) | (gallery_cams[indices[query]] != query_cams[query])
266 y_true = matches[query, valid]
267 y_score = -distance_mat[query][indices[query]][valid]
268 if not np.any(y_true):
271 average_precisions.append(binary_average_precision(y_true, y_score, interpolated_auc=interpolated_auc))
273 if not average_precisions:
274 raise RuntimeError("No valid query")
276 return np.mean(average_precisions)
279 def eval_cmc(distance_mat, query_ids, gallery_ids, query_cams, gallery_cams, separate_camera_set=False,
280 single_gallery_shot=False, first_match_break=False, number_single_shot_repeats=10, top_k=100):
281 number_queries, _number_gallery = distance_mat.shape
283 if not single_gallery_shot:
284 number_single_shot_repeats = 1
286 # Sort and find correct matches
287 indices = np.argsort(distance_mat, axis=1)
288 matches = gallery_ids[indices] == query_ids[:, np.newaxis] # type: np.ndarray
290 # Compute CMC for each query
291 ret = np.zeros(top_k)
292 num_valid_queries = 0
293 for query in range(number_queries):
294 valid = get_valid_subset(
295 gallery_cams, gallery_ids, query, indices, query_cams, query_ids, separate_camera_set
298 if not np.any(matches[query, valid]):
301 ids_dict = defaultdict(list)
302 if single_gallery_shot:
303 gallery_indexes = gallery_ids[indices[query][valid]]
304 for j, x in zip(np.where(valid)[0], gallery_indexes):
305 ids_dict[x].append(j)
307 for _ in range(number_single_shot_repeats):
308 if single_gallery_shot:
309 # Randomly choose one instance for each id
310 # required for correct validation on CUHK datasets
311 # http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html
312 sampled = (valid & unique_sample(ids_dict, len(valid)))
313 index = np.nonzero(matches[query, sampled])[0]
315 index = np.nonzero(matches[query, valid])[0]
317 delta = 1. / (len(index) * number_single_shot_repeats)
318 for j, k in enumerate(index):
321 if first_match_break:
326 num_valid_queries += 1
328 if num_valid_queries == 0:
329 raise RuntimeError("No valid query")
331 return ret.cumsum() / num_valid_queries
334 def get_valid_subset(gallery_cams, gallery_ids, query_index, indices, query_cams, query_ids, separate_camera_set):
335 # Filter out the same id and same camera
337 (gallery_ids[indices[query_index]] != query_ids[query_index]) |
338 (gallery_cams[indices[query_index]] != query_cams[query_index])
340 if separate_camera_set:
341 # Filter out samples from same camera
342 valid &= (gallery_cams[indices[query_index]] != query_cams[query_index])
347 def get_embedding_distances(annotation, prediction, train=False):
349 for i, pred in enumerate(prediction):
350 image_indexes[pred.identifier] = i
353 for image1 in annotation:
354 if train != image1.metadata.get("train", False):
357 for image2 in image1.positive_pairs:
358 pairs.append(PairDesc(image_indexes[image1.identifier], image_indexes[image2], True))
359 for image2 in image1.negative_pairs:
360 pairs.append(PairDesc(image_indexes[image1.identifier], image_indexes[image2], False))
362 embed1 = np.asarray([prediction[idx].embedding for idx, _, _ in pairs])
363 embed2 = np.asarray([prediction[idx].embedding for _, idx, _ in pairs])
365 return 0.5 * (1 - np.sum(embed1 * embed2, axis=1)), pairs
368 def binary_average_precision(y_true, y_score, interpolated_auc=True):
369 def _average_precision(y_true_, y_score_, sample_weight=None):
370 precision, recall, _ = precision_recall_curve(y_true_, y_score_, sample_weight)
371 if not interpolated_auc:
372 # Return the step function integral
373 # The following works because the last entry of precision is
374 # guaranteed to be 1, as returned by precision_recall_curve
375 return -1 * np.sum(np.diff(recall) * np.array(precision)[:-1])
377 return auc(recall, precision)
379 return _average_binary_score(_average_precision, y_true, y_score, average="macro")