2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from functools import singledispatch
18 from collections import OrderedDict
26 from ..utils import get_path, read_json
27 from ..dependency import ClassProvider
28 from ..config import BaseField, StringField, ConfigValidator, ConfigError, DictField
31 class DataReaderField(BaseField):
32 def validate(self, entry_, field_uri=None):
33 super().validate(entry_, field_uri)
38 field_uri = field_uri or self.field_uri
39 if isinstance(entry_, str):
40 StringField(choices=BaseReader.providers).validate(entry_, 'reader')
41 elif isinstance(entry_, dict):
42 class DictReaderValidator(ConfigValidator):
43 type = StringField(choices=BaseReader.providers)
44 dict_reader_validator = DictReaderValidator(
45 'reader', on_extra_argument=DictReaderValidator.IGNORE_ON_EXTRA_ARGUMENT
47 dict_reader_validator.validate(entry_)
49 self.raise_error(entry_, field_uri, 'reader must be either string or dictionary')
52 class BaseReader(ClassProvider):
53 __provider_type__ = 'reader'
55 def __init__(self, config=None):
57 self.data_source_is_dir = True
58 self.data_source_optional = False
59 self.read_dispatcher = singledispatch(self.read)
60 self.read_dispatcher.register(list, self._read_list)
62 self.validate_config()
65 def __call__(self, *args, **kwargs):
66 return self.read_dispatcher(*args, **kwargs)
71 def validate_config(self):
74 def read(self, data_id, data_dir):
75 raise NotImplementedError
77 def _read_list(self, data_id, data_dir):
78 return [self.read(identifier, data_dir) for identifier in data_id]
81 class ReaderCombinerConfig(ConfigValidator):
84 value_type=DataReaderField(), key_type=StringField(), allow_empty=False
88 class ReaderCombiner(BaseReader):
89 __provider__ = 'combine_reader'
91 def validate_config(self):
92 config_validator = ReaderCombinerConfig('reader_combiner_config')
93 config_validator.validate(self.config)
96 scheme = self.config['scheme']
97 reading_scheme = OrderedDict()
98 for pattern, reader_config in scheme.items():
99 reader = BaseReader.provide(
100 reader_config['type'] if isinstance(reader_config, dict) else reader_config, reader_config
102 pattern = re.compile(pattern)
103 reading_scheme[pattern] = reader
105 self.reading_scheme = reading_scheme
107 def read(self, data_id, data_dir):
108 for pattern, reader in self.reading_scheme.items():
109 if pattern.match(str(data_id)):
110 return reader.read(data_id, data_dir)
112 raise ConfigError('suitable data reader for {} not found'.format(data_id))
115 class OpenCVImageReader(BaseReader):
116 __provider__ = 'opencv_imread'
118 def read(self, data_id, data_dir):
119 return cv2.imread(str(get_path(data_dir / data_id)))
122 class PillowImageReader(BaseReader):
123 __provider__ = 'pillow_imread'
125 def read(self, data_id, data_dir):
126 return np.array(Image.open(str(get_path(data_dir / data_id))))
129 class ScipyImageReader(BaseReader):
130 __provider__ = 'scipy_imread'
132 def read(self, data_id, data_dir):
133 return np.array(scipy.misc.imread(str(get_path(data_dir / data_id))))
135 class OpenCVFrameReader(BaseReader):
136 __provider__ = 'opencv_capture'
138 def __init__(self, config=None):
139 super().__init__(config)
140 self.data_source_is_dir = False
144 def read(self, data_id, data_dir):
145 # source video changed, capture initialization
146 if data_dir != self.source:
147 self.source = data_dir
148 self.videocap = cv2.VideoCapture(str(self.source))
152 raise IndexError('frame with {} index can not be grabbed, non-negative index is expected')
153 if data_id < self.current:
154 self.videocap.set(cv2.CAP_PROP_POS_FRAMES, data_id)
155 self.current = data_id - 1
157 return self._read_sequence(data_id)
159 def _read_sequence(self, data_id):
161 while self.current != data_id:
162 success, frame = self.videocap.read()
165 raise EOFError('frame with {} index does not exists in {}'.format(self.current, self.source))
169 class JSONReaderConfig(ConfigValidator):
171 key = StringField(optional=True, case_sensitive=True)
174 class JSONReader(BaseReader):
175 __provider__ = 'json_reader'
177 def validate_config(self):
178 config_validator = JSONReaderConfig('json_reader_config')
179 config_validator.validate(self.config)
182 self.key = self.config.get('key')
184 def read(self, data_id, data_dir):
185 data = read_json(str(data_dir / data_id))
187 data = data.get(self.key)
190 raise ConfigError('{} does not contain {}'.format(data_id, self.key))
192 return np.array(data).astype(np.float32)
194 class NCF_DataReader(BaseReader):
195 __provider__ = 'ncf_data_reader'
197 def __init__(self, config=None):
198 super().__init__(config)
199 self.data_source_optional = True
201 def read(self, data_id, data_dir):
202 if not isinstance(data_id, str):
203 raise IndexError('Data identifier must be a string')
205 return float(data_id.split(":")[1])
207 class NiftiImageReader(BaseReader):
208 __provider__ = 'nifti_reader'
210 def read(self, data_id, data_dir):
211 nib_image = nib.load(str(get_path(data_dir / data_id)))
212 image = np.array(nib_image.dataobj)
213 if len(image.shape) != 4: # Make sure 4D
214 image = np.expand_dims(image, -1)
215 image = np.swapaxes(np.array(image), 0, -2)