2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
22 from mo.utils.error import Error
23 from mo.utils.utils import refer_to_faq_msg
25 end_of_nnet_tag = '</Nnet>'
26 end_of_component_tag = '<!EndOfComponent>'
28 supported_components = [
32 'convolutional1dcomponent',
33 'convolutionalcomponent',
35 'fixedaffinecomponent',
37 'lstmprojectedstreams',
38 'maxpoolingcomponent',
47 'affinecomponentpreconditionedonline',
48 'rectifiedlinearcomponent'
52 def get_bool(s: bytes) -> bool:
54 Get bool value from bytes
55 :param s: bytes array contains bool value
56 :return: bool value from bytes array
58 return struct.unpack('?', s)[0]
61 def get_uint16(s: bytes) -> int:
63 Get unsigned int16 value from bytes
64 :param s: bytes array contains unsigned int16 value
65 :return: unsigned int16 value from bytes array
67 return struct.unpack('H', s)[0]
70 def get_uint32(s: bytes) -> int:
72 Get unsigned int32 value from bytes
73 :param s: bytes array contains unsigned int32 value
74 :return: unsigned int32 value from bytes array
76 return struct.unpack('I', s)[0]
79 def get_uint64(s: bytes) -> int:
81 Get unsigned int64 value from bytes
82 :param s: bytes array contains unsigned int64 value
83 :return: unsigned int64 value from bytes array
85 return struct.unpack('q', s)[0]
88 def read_binary_bool_token(file_desc: io.BufferedReader) -> bool:
90 Get next bool value from file
91 The carriage moves forward to 1 position.
92 :param file_desc: file descriptor
93 :return: next boolean value in file
95 return get_bool(file_desc.read(1))
98 def read_binary_integer32_token(file_desc: io.BufferedReader) -> int:
100 Get next int32 value from file
101 The carriage moves forward to 5 position.
102 :param file_desc: file descriptor
103 :return: next uint32 value in file
105 buffer_size = file_desc.read(1)
106 return get_uint32(file_desc.read(buffer_size[0]))
109 def read_binary_integer64_token(file_desc: io.BufferedReader) -> int:
111 Get next int64 value from file
112 The carriage moves forward to 9 position.
113 :param file_desc: file descriptor
114 :return: next uint64 value in file
116 buffer_size = file_desc.read(1)
117 return get_uint64(file_desc.read(buffer_size[0]))
120 def find_next_tag(file_desc: io.BufferedReader) -> str:
122 Get next tag in the file
123 :param file_desc:file descriptor
124 :return: string like '<sometag>'
128 symbol = file_desc.read(1)
130 raise Error('Unexpected end of Kaldi model')
131 if tag == b'' and symbol != b'<':
139 return tag.decode('ascii')
140 except UnicodeDecodeError:
141 # Tag in Kaldi model always in ascii encoding
145 def read_placeholder(file_desc: io.BufferedReader, size=3) -> bytes:
147 Read size bytes from file
148 :param file_desc:file descriptor
149 :param size:number of reading bytes
152 return file_desc.read(size)
155 def find_next_component(file_desc: io.BufferedReader) -> str:
157 Read next component in the file.
158 All components are contained in supported_components
159 :param file_desc:file descriptor
160 :return: string like '<component>'
163 tag = find_next_tag(file_desc)
164 # Tag is <NameOfTheLayer>. But we want get without '<' and '>'
165 component_name = tag[1:-1].lower()
166 if component_name in supported_components or tag == end_of_nnet_tag:
167 # There is whitespace after component's name
168 read_placeholder(file_desc, 1)
169 return component_name
172 def get_name_from_path(path: str) -> str:
174 Get name from path to the file
175 :param path: path to the file
176 :return: name of the file
178 return os.path.splitext(os.path.basename(path))[0]
181 def find_end_of_component(file_desc: io.BufferedReader, component: str, end_tags: tuple = ()):
183 Find an index and a tag of the ent of the component
184 :param file_desc: file descriptor
185 :param component: component from supported_components
186 :param end_tags: specific end tags
187 :return: the index and the tag of the end of the component
189 end_tags_of_component = ['</{}>'.format(component),
190 end_of_component_tag.lower(),
191 end_of_nnet_tag.lower(),
193 *['<{}>'.format(component) for component in supported_components]]
194 next_tag = find_next_tag(file_desc)
195 while next_tag.lower() not in end_tags_of_component:
196 next_tag = find_next_tag(file_desc)
197 return next_tag, file_desc.tell()
200 def get_parameters(file_desc: io.BufferedReader, start_index: int, end_index: int):
203 :param file_desc: file descriptor
204 :param start_index: Index of the start reading
205 :param end_index: Index of the end reading
206 :return: part of the file
208 file_desc.seek(start_index)
209 buffer = file_desc.read(end_index - start_index)
210 return io.BytesIO(buffer)
213 def read_token_value(file_desc: io.BufferedReader, token: bytes = b'', value_type: type = np.uint32):
215 Get value of the token.
216 Read next token (until whitespace) and check if next teg equals token
217 :param file_desc: file descriptor
219 :param value_type: type of the reading value
220 :return: value of the token
223 np.uint32: read_binary_integer32_token,
224 np.uint64: read_binary_integer64_token,
225 bool: read_binary_bool_token
227 current_token = collect_until_whitespace(file_desc)
228 if token != b'' and token != current_token:
229 raise Error('Can not load token {} from Kaldi model'.format(token) +
230 refer_to_faq_msg(94))
231 return getters[value_type](file_desc)
234 def collect_until_whitespace(file_desc: io.BufferedReader):
236 Read from file until whitespace
237 :param file_desc: file descriptor
242 new_sym = file_desc.read(1)
243 if new_sym == b' ' or new_sym == b'':
249 def collect_until_token(file_desc: io.BufferedReader, token):
251 Read from file until the token
252 :param file_desc: file descriptor
256 # usually there is the following structure <CellDim> DIM<ClipGradient> VALUEFM
257 res = collect_until_whitespace(file_desc)
258 if res == token or res[-len(token):] == token:
260 if isinstance(file_desc, io.BytesIO):
261 size = len(file_desc.getbuffer())
262 elif isinstance(file_desc, io.BufferedReader):
263 size = os.fstat(file_desc.fileno()).st_size
264 if file_desc.tell() == size:
265 raise Error('End of the file. Token {} not found. {}'.format(token, file_desc.tell()))
268 def create_edge_attrs(prev_layer_id: str, next_layer_id: str) -> dict:
270 Create common edge's attributes
271 :param prev_layer_id: id of previous layer
272 :param next_layer_id: id of next layer
273 :return: dictionary contains common attributes for edge
278 'name': next_layer_id,
279 'fw_tensor_debug_info': [(prev_layer_id, next_layer_id)],
280 'in_attrs': ['in', 'name'],
281 'out_attrs': ['out', 'name'],
282 'data_attrs': ['fw_tensor_debug_info']
286 def read_blob(file_desc: io.BufferedReader, size: int, dtype=np.float32):
288 Read blob from the file
289 :param file_desc: file descriptor
290 :param size: size of the blob
291 :param dtype: type of values of the blob
292 :return: np array contains blob
298 data = file_desc.read(size * dsizes[dtype])
299 return np.fromstring(data, dtype=dtype)