66c3c4d6b895ebd3201a17b8a2b48db844551feb
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / data_readers / data_reader.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16 from pathlib import Path
17 from functools import singledispatch
18 from collections import OrderedDict, namedtuple
19 import re
20 import cv2
21 from PIL import Image
22 import scipy.misc
23 import numpy as np
24 import nibabel as nib
25
26 from ..utils import get_path, read_json, zipped_transform, set_image_metadata
27 from ..dependency import ClassProvider
28 from ..config import BaseField, StringField, ConfigValidator, ConfigError, DictField
29
30
31 class DataRepresentation:
32     def __init__(self, data, meta=None, identifier=''):
33         self.identifier = identifier
34         self.data = data
35         self.metadata = meta or {}
36         if np.isscalar(data):
37             self.metadata['image_size'] = 1
38         elif isinstance(data, list) and np.isscalar(data[0]):
39             self.metadata['image_size'] = len(data)
40         else:
41             self.metadata['image_size'] = data.shape if not isinstance(data, list) else data[0].shape
42
43
44 ClipIdentifier = namedtuple('ClipIdentifier', ['video', 'clip_id', 'frames'])
45
46
47 def create_reader(config):
48     return BaseReader.provide(config.get('type', 'opencv_imread'), config.get('data_source'), config=config)
49
50
51 class DataReaderField(BaseField):
52     def validate(self, entry_, field_uri=None):
53         super().validate(entry_, field_uri)
54
55         if entry_ is None:
56             return
57
58         field_uri = field_uri or self.field_uri
59         if isinstance(entry_, str):
60             StringField(choices=BaseReader.providers).validate(entry_, 'reader')
61         elif isinstance(entry_, dict):
62             class DictReaderValidator(ConfigValidator):
63                 type = StringField(choices=BaseReader.providers)
64             dict_reader_validator = DictReaderValidator(
65                 'reader', on_extra_argument=DictReaderValidator.IGNORE_ON_EXTRA_ARGUMENT
66             )
67             dict_reader_validator.validate(entry_)
68         else:
69             self.raise_error(entry_, field_uri, 'reader must be either string or dictionary')
70
71
72 class BaseReader(ClassProvider):
73     __provider_type__ = 'reader'
74
75     def __init__(self, data_source, config=None):
76         self.config = config
77         self.data_source = data_source
78         self.read_dispatcher = singledispatch(self.read)
79         self.read_dispatcher.register(list, self._read_list)
80         self.read_dispatcher.register(ClipIdentifier, self._read_clip)
81
82         self.validate_config()
83         self.configure()
84
85     def __call__(self, context=None, identifier=None, **kwargs):
86         if identifier is not None:
87             return self.read_item(identifier)
88
89         if not context:
90             raise ValueError('identifier or context should be specified')
91
92         read_data = [self.read_item(identifier) for identifier in context.identifiers_batch]
93         context.data_batch = read_data
94         context.annotation_batch, context.data_batch = zipped_transform(
95             set_image_metadata,
96             context.annotation_batch,
97             context.data_batch
98         )
99         return context
100
101     def configure(self):
102         self.data_source = get_path(self.data_source, is_directory=True)
103
104     def validate_config(self):
105         pass
106
107     def read(self, data_id):
108         raise NotImplementedError
109
110     def _read_list(self, data_id):
111         return [self.read(identifier) for identifier in data_id]
112
113     def _read_clip(self, data_id):
114         video = Path(data_id.video)
115         frames_identifiers = [video / frame for frame in data_id.frames]
116         return self.read_dispatcher(frames_identifiers)
117
118     def read_item(self, data_id):
119         return DataRepresentation(self.read_dispatcher(data_id), identifier=data_id)
120
121
122
123 class ReaderCombinerConfig(ConfigValidator):
124     type = StringField()
125     scheme = DictField(
126         value_type=DataReaderField(), key_type=StringField(), allow_empty=False
127     )
128
129
130 class ReaderCombiner(BaseReader):
131     __provider__ = 'combine_reader'
132
133     def validate_config(self):
134         config_validator = ReaderCombinerConfig('reader_combiner_config')
135         config_validator.validate(self.config)
136
137     def configure(self):
138         scheme = self.config['scheme']
139         reading_scheme = OrderedDict()
140         for pattern, reader_config in scheme.items():
141             reader = BaseReader.provide(
142                 reader_config['type'] if isinstance(reader_config, dict) else reader_config,
143                 self.data_source, reader_config
144             )
145             pattern = re.compile(pattern)
146             reading_scheme[pattern] = reader
147
148         self.reading_scheme = reading_scheme
149
150     def read(self, data_id):
151         for pattern, reader in self.reading_scheme.items():
152             if pattern.match(str(data_id)):
153                 return reader.read(data_id)
154
155         raise ConfigError('suitable data reader for {} not found'.format(data_id))
156
157
158 class OpenCVImageReader(BaseReader):
159     __provider__ = 'opencv_imread'
160
161     def read(self, data_id):
162         return cv2.imread(str(get_path(self.data_source / data_id)))
163
164
165 class PillowImageReader(BaseReader):
166     __provider__ = 'pillow_imread'
167
168     def __init__(self, config=None):
169         super().__init__(config)
170         self.convert_to_rgb = True
171
172     def read(self, data_id):
173         with open(str(self.data_source / data_id), 'rb') as f:
174             img = Image.open(f)
175
176             return np.array(img.convert('RGB') if self.convert_to_rgb else img)
177
178
179 class ScipyImageReader(BaseReader):
180     __provider__ = 'scipy_imread'
181
182     def read(self, data_id):
183         return np.array(scipy.misc.imread(str(get_path(self.data_source / data_id))))
184
185
186 class OpenCVFrameReader(BaseReader):
187     __provider__ = 'opencv_capture'
188
189     def __init__(self, data_source, config=None):
190         super().__init__(data_source, config)
191         self.current = -1
192
193     def read(self, data_id):
194         if data_id < 0:
195             raise IndexError('frame with {} index can not be grabbed, non-negative index is expected')
196         if data_id < self.current:
197             self.videocap.set(cv2.CAP_PROP_POS_FRAMES, data_id)
198             self.current = data_id - 1
199
200         return self._read_sequence(data_id)
201
202     def _read_sequence(self, data_id):
203         frame = None
204         while self.current != data_id:
205             success, frame = self.videocap.read()
206             self.current += 1
207             if not success:
208                 raise EOFError('frame with {} index does not exists in {}'.format(self.current, self.data_source))
209
210         return frame
211
212     def configure(self):
213         self.data_source = get_path(self.data_source)
214         self.videocap = cv2.VideoCapture(str(self.data_source))
215
216
217 class JSONReaderConfig(ConfigValidator):
218     type = StringField()
219     key = StringField(optional=True, case_sensitive=True)
220
221
222 class JSONReader(BaseReader):
223     __provider__ = 'json_reader'
224
225     def validate_config(self):
226         config_validator = JSONReaderConfig('json_reader_config')
227         config_validator.validate(self.config)
228
229     def configure(self):
230         self.key = self.config.get('key')
231
232     def read(self, data_id):
233         data = read_json(str(self.data_source / data_id))
234         if self.key:
235             data = data.get(self.key)
236
237             if not data:
238                 raise ConfigError('{} does not contain {}'.format(data_id, self.key))
239
240         return np.array(data).astype(np.float32)
241
242
243
244 class NCF_DataReader(BaseReader):
245     __provider__ = 'ncf_data_reader'
246
247     def configure(self):
248         pass
249
250     def read(self, data_id):
251         if not isinstance(data_id, str):
252             raise IndexError('Data identifier must be a string')
253
254         return float(data_id.split(":")[1])
255
256
257 class NiftiImageReader(BaseReader):
258     __provider__ = 'nifti_reader'
259
260     def read(self, data_id):
261         nib_image = nib.load(str(get_path(self.data_source / data_id)))
262         image = np.array(nib_image.dataobj)
263         if len(image.shape) != 4:  # Make sure 4D
264             image = np.expand_dims(image, -1)
265         image = np.swapaxes(np.array(image), 0, -2)
266
267         return image