2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
16 from pathlib import Path
17 from functools import singledispatch
18 from collections import OrderedDict, namedtuple
26 from ..utils import get_path, read_json, zipped_transform, set_image_metadata
27 from ..dependency import ClassProvider
28 from ..config import BaseField, StringField, ConfigValidator, ConfigError, DictField
31 class DataRepresentation:
32 def __init__(self, data, meta=None, identifier=''):
33 self.identifier = identifier
35 self.metadata = meta or {}
37 self.metadata['image_size'] = 1
38 elif isinstance(data, list) and np.isscalar(data[0]):
39 self.metadata['image_size'] = len(data)
41 self.metadata['image_size'] = data.shape if not isinstance(data, list) else data[0].shape
44 ClipIdentifier = namedtuple('ClipIdentifier', ['video', 'clip_id', 'frames'])
47 def create_reader(config):
48 return BaseReader.provide(config.get('type', 'opencv_imread'), config.get('data_source'), config=config)
51 class DataReaderField(BaseField):
52 def validate(self, entry_, field_uri=None):
53 super().validate(entry_, field_uri)
58 field_uri = field_uri or self.field_uri
59 if isinstance(entry_, str):
60 StringField(choices=BaseReader.providers).validate(entry_, 'reader')
61 elif isinstance(entry_, dict):
62 class DictReaderValidator(ConfigValidator):
63 type = StringField(choices=BaseReader.providers)
64 dict_reader_validator = DictReaderValidator(
65 'reader', on_extra_argument=DictReaderValidator.IGNORE_ON_EXTRA_ARGUMENT
67 dict_reader_validator.validate(entry_)
69 self.raise_error(entry_, field_uri, 'reader must be either string or dictionary')
72 class BaseReader(ClassProvider):
73 __provider_type__ = 'reader'
75 def __init__(self, data_source, config=None):
77 self.data_source = data_source
78 self.read_dispatcher = singledispatch(self.read)
79 self.read_dispatcher.register(list, self._read_list)
80 self.read_dispatcher.register(ClipIdentifier, self._read_clip)
82 self.validate_config()
85 def __call__(self, context=None, identifier=None, **kwargs):
86 if identifier is not None:
87 return self.read_item(identifier)
90 raise ValueError('identifier or context should be specified')
92 read_data = [self.read_item(identifier) for identifier in context.identifiers_batch]
93 context.data_batch = read_data
94 context.annotation_batch, context.data_batch = zipped_transform(
96 context.annotation_batch,
102 self.data_source = get_path(self.data_source, is_directory=True)
104 def validate_config(self):
107 def read(self, data_id):
108 raise NotImplementedError
110 def _read_list(self, data_id):
111 return [self.read(identifier) for identifier in data_id]
113 def _read_clip(self, data_id):
114 video = Path(data_id.video)
115 frames_identifiers = [video / frame for frame in data_id.frames]
116 return self.read_dispatcher(frames_identifiers)
118 def read_item(self, data_id):
119 return DataRepresentation(self.read_dispatcher(data_id), identifier=data_id)
123 class ReaderCombinerConfig(ConfigValidator):
126 value_type=DataReaderField(), key_type=StringField(), allow_empty=False
130 class ReaderCombiner(BaseReader):
131 __provider__ = 'combine_reader'
133 def validate_config(self):
134 config_validator = ReaderCombinerConfig('reader_combiner_config')
135 config_validator.validate(self.config)
138 scheme = self.config['scheme']
139 reading_scheme = OrderedDict()
140 for pattern, reader_config in scheme.items():
141 reader = BaseReader.provide(
142 reader_config['type'] if isinstance(reader_config, dict) else reader_config,
143 self.data_source, reader_config
145 pattern = re.compile(pattern)
146 reading_scheme[pattern] = reader
148 self.reading_scheme = reading_scheme
150 def read(self, data_id):
151 for pattern, reader in self.reading_scheme.items():
152 if pattern.match(str(data_id)):
153 return reader.read(data_id)
155 raise ConfigError('suitable data reader for {} not found'.format(data_id))
158 class OpenCVImageReader(BaseReader):
159 __provider__ = 'opencv_imread'
161 def read(self, data_id):
162 return cv2.imread(str(get_path(self.data_source / data_id)))
165 class PillowImageReader(BaseReader):
166 __provider__ = 'pillow_imread'
168 def __init__(self, config=None):
169 super().__init__(config)
170 self.convert_to_rgb = True
172 def read(self, data_id):
173 with open(str(self.data_source / data_id), 'rb') as f:
176 return np.array(img.convert('RGB') if self.convert_to_rgb else img)
179 class ScipyImageReader(BaseReader):
180 __provider__ = 'scipy_imread'
182 def read(self, data_id):
183 return np.array(scipy.misc.imread(str(get_path(self.data_source / data_id))))
186 class OpenCVFrameReader(BaseReader):
187 __provider__ = 'opencv_capture'
189 def __init__(self, data_source, config=None):
190 super().__init__(data_source, config)
193 def read(self, data_id):
195 raise IndexError('frame with {} index can not be grabbed, non-negative index is expected')
196 if data_id < self.current:
197 self.videocap.set(cv2.CAP_PROP_POS_FRAMES, data_id)
198 self.current = data_id - 1
200 return self._read_sequence(data_id)
202 def _read_sequence(self, data_id):
204 while self.current != data_id:
205 success, frame = self.videocap.read()
208 raise EOFError('frame with {} index does not exists in {}'.format(self.current, self.data_source))
213 self.data_source = get_path(self.data_source)
214 self.videocap = cv2.VideoCapture(str(self.data_source))
217 class JSONReaderConfig(ConfigValidator):
219 key = StringField(optional=True, case_sensitive=True)
222 class JSONReader(BaseReader):
223 __provider__ = 'json_reader'
225 def validate_config(self):
226 config_validator = JSONReaderConfig('json_reader_config')
227 config_validator.validate(self.config)
230 self.key = self.config.get('key')
232 def read(self, data_id):
233 data = read_json(str(self.data_source / data_id))
235 data = data.get(self.key)
238 raise ConfigError('{} does not contain {}'.format(data_id, self.key))
240 return np.array(data).astype(np.float32)
244 class NCF_DataReader(BaseReader):
245 __provider__ = 'ncf_data_reader'
250 def read(self, data_id):
251 if not isinstance(data_id, str):
252 raise IndexError('Data identifier must be a string')
254 return float(data_id.split(":")[1])
257 class NiftiImageReader(BaseReader):
258 __provider__ = 'nifti_reader'
260 def read(self, data_id):
261 nib_image = nib.load(str(get_path(self.data_source / data_id)))
262 image = np.array(nib_image.dataobj)
263 if len(image.shape) != 4: # Make sure 4D
264 image = np.expand_dims(image, -1)
265 image = np.swapaxes(np.array(image), 0, -2)