2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from ..config import BoolField
20 from ..representation import (
21 SegmentationAnnotation,
22 SegmentationPrediction,
23 BrainTumorSegmentationAnnotation,
24 BrainTumorSegmentationPrediction
26 from .metric import PerImageEvaluationMetric, BaseMetricConfig
27 from ..utils import finalize_metric_result
30 class SegmentationMetricConfig(BaseMetricConfig):
31 use_argmax = BoolField(optional=True)
34 class SegmentationMetric(PerImageEvaluationMetric):
35 annotation_types = (SegmentationAnnotation, )
36 prediction_types = (SegmentationPrediction, )
38 CONFUSION_MATRIX_KEY = 'segmentation_confusion_matrix'
40 def evaluate(self, annotations, predictions):
41 raise NotImplementedError
43 def validate_config(self):
44 config_validator = SegmentationMetricConfig(
45 'SemanticSegmentation_config', SegmentationMetricConfig.ERROR_ON_EXTRA_ARGUMENT
47 config_validator.validate(self.config)
50 self.use_argmax = self.config.get('use_argmax', True)
52 def update(self, annotation, prediction):
53 n_classes = len(self.dataset.labels)
54 prediction_mask = np.argmax(prediction.mask, axis=0) if self.use_argmax else prediction.mask.astype('int64')
56 def update_confusion_matrix(confusion_matrix):
57 label_true = annotation.mask.flatten()
58 label_pred = prediction_mask.flatten()
60 mask = (label_true >= 0) & (label_true < n_classes)
61 hist = np.bincount(n_classes * label_true[mask].astype(int) + label_pred[mask], minlength=n_classes ** 2)
62 hist = hist.reshape(n_classes, n_classes)
63 confusion_matrix += hist
65 return confusion_matrix
67 self._update_state(update_confusion_matrix, self.CONFUSION_MATRIX_KEY, lambda: np.zeros((n_classes, n_classes)))
70 class SegmentationAccuracy(SegmentationMetric):
71 __provider__ = 'segmentation_accuracy'
73 def evaluate(self, annotations, predictions):
74 confusion_matrix = self.state[self.CONFUSION_MATRIX_KEY]
75 return np.diag(confusion_matrix).sum() / confusion_matrix.sum()
78 class SegmentationIOU(SegmentationMetric):
79 __provider__ = 'mean_iou'
81 def evaluate(self, annotations, predictions):
82 confusion_matrix = self.state[self.CONFUSION_MATRIX_KEY]
83 union = confusion_matrix.sum(axis=1) + confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)
84 diagonal = np.diag(confusion_matrix)
85 iou = np.divide(diagonal, union, out=np.zeros_like(diagonal), where=union != 0)
87 values, names = finalize_metric_result(iou, list(self.dataset.labels.values()))
88 self.meta['names'] = names
93 class SegmentationMeanAccuracy(SegmentationMetric):
94 __provider__ = 'mean_accuracy'
96 def evaluate(self, annotations, predictions):
97 confusion_matrix = self.state[self.CONFUSION_MATRIX_KEY]
98 diagonal = np.diag(confusion_matrix)
99 per_class_count = confusion_matrix.sum(axis=1)
100 acc_cls = np.divide(diagonal, per_class_count, out=np.zeros_like(diagonal), where=per_class_count != 0)
102 values, names = finalize_metric_result(acc_cls, list(self.dataset.labels.values()))
103 self.meta['names'] = names
108 class SegmentationFWAcc(SegmentationMetric):
109 __provider__ = 'frequency_weighted_accuracy'
111 def evaluate(self, annotations, predictions):
112 confusion_matrix = self.state[self.CONFUSION_MATRIX_KEY]
114 union = (confusion_matrix.sum(axis=1) + confusion_matrix.sum(axis=0) - np.diag(confusion_matrix))
115 diagonal = np.diag(confusion_matrix)
116 iou = np.divide(diagonal, union, out=np.zeros_like(diagonal), where=union != 0)
117 freq = confusion_matrix.sum(axis=1) / confusion_matrix.sum()
119 return (freq[freq > 0] * iou[freq > 0]).sum()
122 class SegmentationDSCAcc(PerImageEvaluationMetric):
123 __provider__ = 'dice'
124 annotation_types = (BrainTumorSegmentationAnnotation,)
125 prediction_types = (BrainTumorSegmentationPrediction,)
128 def update(self, annotation, prediction):
130 for prediction_mask, annotation_mask in zip(prediction.mask, annotation.mask):
131 annotation_mask = np.transpose(annotation_mask, (2, 0, 1))
132 annotation_mask = np.expand_dims(annotation_mask, 0)
133 numerator = np.sum(prediction_mask * annotation_mask) * 2.0 + 1.0
134 denominator = np.sum(annotation_mask) + np.sum(prediction_mask) + 1.0
135 self.overall_metric.append(numerator / denominator)
138 def evaluate(self, annotations, predictions):
139 return sum(self.overall_metric) / len(self.overall_metric)