Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / data_readers / data_reader.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 from functools import singledispatch
18 from collections import OrderedDict
19 import re
20 import cv2
21 from PIL import Image
22 import scipy.misc
23 import numpy as np
24 import nibabel as nib
25
26 from ..utils import get_path, read_json
27 from ..dependency import ClassProvider
28 from ..config import BaseField, StringField, ConfigValidator, ConfigError, DictField
29
30
31 class DataReaderField(BaseField):
32     def validate(self, entry_, field_uri=None):
33         super().validate(entry_, field_uri)
34
35         if entry_ is None:
36             return
37
38         field_uri = field_uri or self.field_uri
39         if isinstance(entry_, str):
40             StringField(choices=BaseReader.providers).validate(entry_, 'reader')
41         elif isinstance(entry_, dict):
42             class DictReaderValidator(ConfigValidator):
43                 type = StringField(choices=BaseReader.providers)
44             dict_reader_validator = DictReaderValidator(
45                 'reader', on_extra_argument=DictReaderValidator.IGNORE_ON_EXTRA_ARGUMENT
46             )
47             dict_reader_validator.validate(entry_)
48         else:
49             self.raise_error(entry_, field_uri, 'reader must be either string or dictionary')
50
51
52 class BaseReader(ClassProvider):
53     __provider_type__ = 'reader'
54
55     def __init__(self, config=None):
56         self.config = config
57         self.data_source_is_dir = True
58         self.data_source_optional = False
59         self.read_dispatcher = singledispatch(self.read)
60         self.read_dispatcher.register(list, self._read_list)
61
62         self.validate_config()
63         self.configure()
64
65     def __call__(self, *args, **kwargs):
66         return self.read_dispatcher(*args, **kwargs)
67
68     def configure(self):
69         pass
70
71     def validate_config(self):
72         pass
73
74     def read(self, data_id, data_dir):
75         raise NotImplementedError
76
77     def _read_list(self, data_id, data_dir):
78         return [self.read(identifier, data_dir) for identifier in data_id]
79
80
81 class ReaderCombinerConfig(ConfigValidator):
82     type = StringField()
83     scheme = DictField(
84         value_type=DataReaderField(), key_type=StringField(), allow_empty=False
85     )
86
87
88 class ReaderCombiner(BaseReader):
89     __provider__ = 'combine_reader'
90
91     def validate_config(self):
92         config_validator = ReaderCombinerConfig('reader_combiner_config')
93         config_validator.validate(self.config)
94
95     def configure(self):
96         scheme = self.config['scheme']
97         reading_scheme = OrderedDict()
98         for pattern, reader_config in scheme.items():
99             reader = BaseReader.provide(
100                 reader_config['type'] if isinstance(reader_config, dict) else reader_config, reader_config
101             )
102             pattern = re.compile(pattern)
103             reading_scheme[pattern] = reader
104
105         self.reading_scheme = reading_scheme
106
107     def read(self, data_id, data_dir):
108         for pattern, reader in self.reading_scheme.items():
109             if pattern.match(str(data_id)):
110                 return reader.read(data_id, data_dir)
111
112         raise ConfigError('suitable data reader for {} not found'.format(data_id))
113
114
115 class OpenCVImageReader(BaseReader):
116     __provider__ = 'opencv_imread'
117
118     def read(self, data_id, data_dir):
119         return cv2.imread(str(get_path(data_dir / data_id)))
120
121
122 class PillowImageReader(BaseReader):
123     __provider__ = 'pillow_imread'
124
125     def read(self, data_id, data_dir):
126         return np.array(Image.open(str(get_path(data_dir / data_id))))
127
128
129 class ScipyImageReader(BaseReader):
130     __provider__ = 'scipy_imread'
131
132     def read(self, data_id, data_dir):
133         return np.array(scipy.misc.imread(str(get_path(data_dir / data_id))))
134
135 class OpenCVFrameReader(BaseReader):
136     __provider__ = 'opencv_capture'
137
138     def __init__(self, config=None):
139         super().__init__(config)
140         self.data_source_is_dir = False
141         self.source = None
142         self.current = -1
143
144     def read(self, data_id, data_dir):
145         # source video changed, capture initialization
146         if data_dir != self.source:
147             self.source = data_dir
148             self.videocap = cv2.VideoCapture(str(self.source))
149             self.current = -1
150
151         if data_id < 0:
152             raise IndexError('frame with {} index can not be grabbed, non-negative index is expected')
153         if data_id < self.current:
154             self.videocap.set(cv2.CAP_PROP_POS_FRAMES, data_id)
155             self.current = data_id - 1
156
157         return self._read_sequence(data_id)
158
159     def _read_sequence(self, data_id):
160         frame = None
161         while self.current != data_id:
162             success, frame = self.videocap.read()
163             self.current += 1
164             if not success:
165                 raise EOFError('frame with {} index does not exists in {}'.format(self.current, self.source))
166         return frame
167
168
169 class JSONReaderConfig(ConfigValidator):
170     type = StringField()
171     key = StringField(optional=True, case_sensitive=True)
172
173
174 class JSONReader(BaseReader):
175     __provider__ = 'json_reader'
176
177     def validate_config(self):
178         config_validator = JSONReaderConfig('json_reader_config')
179         config_validator.validate(self.config)
180
181     def configure(self):
182         self.key = self.config.get('key')
183
184     def read(self, data_id, data_dir):
185         data = read_json(str(data_dir / data_id))
186         if self.key:
187             data = data.get(self.key)
188
189             if not data:
190                 raise ConfigError('{} does not contain {}'.format(data_id, self.key))
191
192         return np.array(data).astype(np.float32)
193
194 class NCF_DataReader(BaseReader):
195     __provider__ = 'ncf_data_reader'
196
197     def __init__(self, config=None):
198         super().__init__(config)
199         self.data_source_optional = True
200
201     def read(self, data_id, data_dir):
202         if not isinstance(data_id, str):
203             raise IndexError('Data identifier must be a string')
204
205         return float(data_id.split(":")[1])
206
207 class NiftiImageReader(BaseReader):
208     __provider__ = 'nifti_reader'
209
210     def read(self, data_id, data_dir):
211         nib_image = nib.load(str(get_path(data_dir / data_id)))
212         image = np.array(nib_image.dataobj)
213         if len(image.shape) != 4:  # Make sure 4D
214             image = np.expand_dims(image, -1)
215         image = np.swapaxes(np.array(image), 0, -2)
216         return image