Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / dataset.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 from pathlib import Path
18 import numpy as np
19
20 from .annotation_converters import BaseFormatConverter, save_annotation, make_subset
21 from .data_readers import BaseReader, DataReaderField
22 from .config import ConfigValidator, StringField, PathField, ListField, DictField, BaseField, NumberField, ConfigError
23 from .utils import JSONDecoderWithAutoConversion, read_json, get_path, contains_all
24 from .representation import BaseRepresentation
25
26
27 class DataRepresentation:
28     def __init__(self, data, meta=None, identifier=''):
29         self.identifier = identifier
30         self.data = data
31         self.metadata = meta or {}
32         if np.isscalar(data):
33             self.metadata['image_size'] = 1
34         elif isinstance(data, list) and np.isscalar(data[0]):
35             self.metadata['image_size'] = len(data)
36         else:
37             self.metadata['image_size'] = data.shape if not isinstance(data, list) else data[0].shape
38
39
40 class DatasetConfig(ConfigValidator):
41     """
42     Specifies configuration structure for dataset
43     """
44     name = StringField()
45     annotation = BaseField(optional=True)
46     data_source = PathField()
47     dataset_meta = BaseField(optional=True)
48     metrics = ListField(allow_empty=False)
49     postprocessing = ListField(allow_empty=False, optional=True)
50     preprocessing = ListField(allow_empty=False, optional=True)
51     reader = DataReaderField(optional=True)
52     annotation_conversion = DictField(optional=True)
53     subsample_size = BaseField(optional=True)
54     subsample_seed = NumberField(floats=False, min_value=0, optional=True)
55
56
57 class Dataset:
58     def __init__(self, config_entry, preprocessor):
59         self._config = config_entry
60         self._preprocessor = preprocessor
61
62         self.batch = 1
63
64         dataset_config = DatasetConfig('Dataset')
65         data_reader_config = self._config.get('reader', 'opencv_imread')
66         if isinstance(data_reader_config, str):
67             self.read_image_fn = BaseReader.provide(data_reader_config)
68         elif isinstance(data_reader_config, dict):
69             self.read_image_fn = BaseReader.provide(data_reader_config['type'], data_reader_config)
70         else:
71             raise ConfigError('reader should be dict or string')
72
73         dataset_config.fields['data_source'].is_directory = self.read_image_fn.data_source_is_dir
74         dataset_config.fields['data_source'].optional = self.read_image_fn.data_source_optional
75         dataset_config.validate(self._config)
76         annotation, meta = None, None
77         self._images_dir = Path(self._config.get('data_source', ''))
78         if 'annotation_conversion' in self._config:
79             annotation, meta = self._convert_annotation()
80         else:
81             stored_annotation = self._config.get('annotation')
82             if stored_annotation:
83                 annotation = read_annotation(get_path(stored_annotation))
84                 meta = self._load_meta()
85
86         if not annotation:
87             raise ConfigError('path to converted annotation or data for conversion should be specified')
88
89         subsample_size = self._config.get('subsample_size')
90         if subsample_size:
91             subsample_seed = self._config.get('subsample_seed', 666)
92             if isinstance(subsample_size, str):
93                 if subsample_size.endswith('%'):
94                     subsample_size = float(subsample_size[:-1]) / 100 * len(annotation)
95             subsample_size = int(subsample_size)
96             annotation = make_subset(annotation, subsample_size, subsample_seed)
97
98         if contains_all(self._config, ['annotation', 'annotation_conversion']):
99             annotation_name = self._config['annotation']
100             meta_name = self._config.get('dataset_meta')
101             if meta_name:
102                 meta_name = Path(meta_name)
103             save_annotation(annotation, meta, Path(annotation_name), meta_name)
104
105         self._annotation = annotation
106         self._meta = meta
107         self.size = len(self._annotation)
108         self.name = self._config.get('name')
109
110     @property
111     def annotation(self):
112         return self._annotation
113
114     def __len__(self):
115         return self.size
116
117     @property
118     def metadata(self):
119         return self._meta
120
121     @property
122     def labels(self):
123         return self._meta.get('label_map', {})
124
125     def __getitem__(self, item):
126         if self.size <= item * self.batch:
127             raise IndexError
128
129         batch_start = item * self.batch
130         batch_end = min(self.size, batch_start + self.batch)
131         batch_annotation = self._annotation[batch_start:batch_end]
132
133         identifiers = [annotation.identifier for annotation in batch_annotation]
134         images = self._read_images(identifiers)
135
136         for image, annotation in zip(images, batch_annotation):
137             self.set_annotation_metadata(annotation, image)
138
139         preprocessed = self._preprocessor.process(images, batch_annotation)
140
141         return batch_annotation, preprocessed
142
143     @staticmethod
144     def set_image_metadata(annotation, images):
145         image_sizes = []
146         if not isinstance(images, list):
147             images = [images]
148         for image in images:
149             if np.isscalar(image):
150                 image_sizes.append((1,))
151             else:
152                 image_sizes.append(image.shape)
153         annotation.set_image_size(image_sizes)
154
155     def set_annotation_metadata(self, annotation, image):
156         self.set_image_metadata(annotation, image.data)
157         annotation.set_data_source(self._images_dir)
158
159     def _read_images(self, identifiers):
160         images = []
161         for identifier in identifiers:
162             images.append(DataRepresentation(self.read_image_fn(identifier, self._images_dir), identifier=identifier))
163
164         return images
165
166     def _load_meta(self):
167         meta_data_file = self._config.get('dataset_meta')
168         return read_json(meta_data_file, cls=JSONDecoderWithAutoConversion) if meta_data_file else None
169
170     def _convert_annotation(self):
171         conversion_params = self._config.get('annotation_conversion')
172         converter = conversion_params['converter']
173         annotation_converter = BaseFormatConverter.provide(converter, conversion_params)
174         annotation, meta = annotation_converter.convert()
175
176         return annotation, meta
177
178
179 def read_annotation(annotation_file: Path):
180     annotation_file = get_path(annotation_file)
181
182     result = []
183     with annotation_file.open('rb') as file:
184         while True:
185             try:
186                 result.append(BaseRepresentation.load(file))
187             except EOFError:
188                 break
189
190     return result