Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / annotation_converters / imagenet.py
1 import numpy as np
2
3 from ..config import PathField, BoolField
4 from ..representation import ClassificationAnnotation
5 from ..utils import read_txt, get_path
6
7 from .format_converter import BaseFormatConverter, BaseFormatConverterConfig
8
9
10 class ImageNetFormatConverterConfig(BaseFormatConverterConfig):
11     annotation_file = PathField()
12     labels_file = PathField(optional=True)
13     has_background = BoolField(optional=True)
14
15
16 class ImageNetFormatConverter(BaseFormatConverter):
17     __provider__ = 'imagenet'
18
19     _config_validator_type = ImageNetFormatConverterConfig
20
21     def configure(self):
22         self.annotation_file = self.config['annotation_file']
23         self.labels_file = self.config.get('labels_file')
24         self.has_background = self.config.get('has_background', False)
25
26     def convert(self):
27         annotation = []
28         for image in read_txt(get_path(self.annotation_file)):
29             image_name, label = image.split()
30             label = np.int64(label) if not self.has_background else np.int64(label) + 1
31             annotation.append(ClassificationAnnotation(image_name, label))
32         meta = self._create_meta(self.labels_file, self.has_background) if self.labels_file else None
33
34         return annotation, meta
35
36     @staticmethod
37     def _create_meta(labels_file, has_background=False):
38         meta = {}
39         labels = {}
40         for i, line in enumerate(read_txt(get_path(labels_file))):
41             index_for_label = i if not has_background else i + 1
42             line = line.strip()
43             label = line[line.find(' ') + 1:]
44             labels[index_for_label] = label
45
46         if has_background:
47             labels[0] = 'background'
48             meta['backgound_label'] = 0
49
50         meta['label_map'] = labels
51
52         return meta