Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / kaldi / loader / utils.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import io
17
18 import numpy as np
19 import os
20 import struct
21
22 from mo.utils.error import Error
23 from mo.utils.utils import refer_to_faq_msg
24
25 end_of_nnet_tag = '</Nnet>'
26 end_of_component_tag = '<!EndOfComponent>'
27
28 supported_components = [
29     'addshift',
30     'affinecomponent',
31     'affinetransform',
32     'convolutional1dcomponent',
33     'convolutionalcomponent',
34     'copy',
35     'fixedaffinecomponent',
36     'lstmprojected',
37     'lstmprojectedstreams',
38     'maxpoolingcomponent',
39     'parallelcomponent',
40     'rescale',
41     'sigmoid',
42     'softmax',
43     'softmaxcomponent',
44     'splicecomponent',
45     'tanhcomponent',
46     'normalizecomponent',
47     'affinecomponentpreconditionedonline',
48     'rectifiedlinearcomponent'
49 ]
50
51
52 def get_bool(s: bytes) -> bool:
53     """
54     Get bool value from bytes
55     :param s: bytes array contains bool value
56     :return: bool value from bytes array
57     """
58     return struct.unpack('?', s)[0]
59
60
61 def get_uint16(s: bytes) -> int:
62     """
63     Get unsigned int16 value from bytes
64     :param s: bytes array contains unsigned int16 value
65     :return: unsigned int16 value from bytes array
66     """
67     return struct.unpack('H', s)[0]
68
69
70 def get_uint32(s: bytes) -> int:
71     """
72     Get unsigned int32 value from bytes
73     :param s: bytes array contains unsigned int32 value
74     :return: unsigned int32 value from bytes array
75     """
76     return struct.unpack('I', s)[0]
77
78
79 def get_uint64(s: bytes) -> int:
80     """
81     Get unsigned int64 value from bytes
82     :param s: bytes array contains unsigned int64 value
83     :return: unsigned int64 value from bytes array
84     """
85     return struct.unpack('q', s)[0]
86
87
88 def read_binary_bool_token(file_desc: io.BufferedReader) -> bool:
89     """
90     Get next bool value from file
91     The carriage moves forward to 1 position.
92     :param file_desc: file descriptor
93     :return: next boolean value in file
94     """
95     return get_bool(file_desc.read(1))
96
97
98 def read_binary_integer32_token(file_desc: io.BufferedReader) -> int:
99     """
100     Get next int32 value from file
101     The carriage moves forward to 5 position.
102     :param file_desc: file descriptor
103     :return: next uint32 value in file
104     """
105     buffer_size = file_desc.read(1)
106     return get_uint32(file_desc.read(buffer_size[0]))
107
108
109 def read_binary_integer64_token(file_desc: io.BufferedReader) -> int:
110     """
111     Get next int64 value from file
112     The carriage moves forward to 9 position.
113     :param file_desc: file descriptor
114     :return: next uint64 value in file
115     """
116     buffer_size = file_desc.read(1)
117     return get_uint64(file_desc.read(buffer_size[0]))
118
119
120 def find_next_tag(file_desc: io.BufferedReader) -> str:
121     """
122     Get next tag in the file
123     :param file_desc:file descriptor
124     :return: string like '<sometag>'
125     """
126     tag = b''
127     while True:
128         symbol = file_desc.read(1)
129         if symbol == b'':
130             raise Error('Unexpected end of Kaldi model')
131         if tag == b'' and symbol != b'<':
132             continue
133         elif symbol == b'<':
134             tag = b''
135         tag += symbol
136         if symbol != b'>':
137             continue
138         try:
139             return tag.decode('ascii')
140         except UnicodeDecodeError:
141             # Tag in Kaldi model always in ascii encoding
142             tag = b''
143
144
145 def read_placeholder(file_desc: io.BufferedReader, size=3) -> bytes:
146     """
147     Read size bytes from file
148     :param file_desc:file descriptor
149     :param size:number of reading bytes
150     :return: bytes
151     """
152     return file_desc.read(size)
153
154
155 def find_next_component(file_desc: io.BufferedReader) -> str:
156     """
157     Read next component in the file.
158     All components are contained in supported_components
159     :param file_desc:file descriptor
160     :return: string like '<component>'
161     """
162     while True:
163         tag = find_next_tag(file_desc)
164         # Tag is <NameOfTheLayer>. But we want get without '<' and '>'
165         component_name = tag[1:-1].lower()
166         if component_name in supported_components or tag == end_of_nnet_tag:
167             # There is whitespace after component's name
168             read_placeholder(file_desc, 1)
169             return component_name
170
171
172 def get_name_from_path(path: str) -> str:
173     """
174     Get name from path to the file
175     :param path: path to the file
176     :return: name of the file
177     """
178     return os.path.splitext(os.path.basename(path))[0]
179
180
181 def find_end_of_component(file_desc: io.BufferedReader, component: str, end_tags: tuple = ()):
182     """
183     Find an index and a tag of the ent of the component
184     :param file_desc: file descriptor
185     :param component: component from supported_components
186     :param end_tags: specific end tags
187     :return: the index and the tag of the end of the component
188     """
189     end_tags_of_component = ['</{}>'.format(component),
190                              end_of_component_tag.lower(),
191                              end_of_nnet_tag.lower(),
192                              *end_tags,
193                              *['<{}>'.format(component) for component in supported_components]]
194     next_tag = find_next_tag(file_desc)
195     while next_tag.lower() not in end_tags_of_component:
196         next_tag = find_next_tag(file_desc)
197     return next_tag, file_desc.tell()
198
199
200 def get_parameters(file_desc: io.BufferedReader, start_index: int, end_index: int):
201     """
202     Get part of file
203     :param file_desc: file descriptor
204     :param start_index: Index of the start reading
205     :param end_index:  Index of the end reading
206     :return: part of the file
207     """
208     file_desc.seek(start_index)
209     buffer = file_desc.read(end_index - start_index)
210     return io.BytesIO(buffer)
211
212
213 def read_token_value(file_desc: io.BufferedReader, token: bytes = b'', value_type: type = np.uint32):
214     """
215     Get value of the token.
216     Read next token (until whitespace) and check if next teg equals token
217     :param file_desc: file descriptor
218     :param token: token
219     :param value_type:  type of the reading value
220     :return: value of the token
221     """
222     getters = {
223         np.uint32: read_binary_integer32_token,
224         np.uint64: read_binary_integer64_token,
225         bool: read_binary_bool_token
226     }
227     current_token = collect_until_whitespace(file_desc)
228     if token != b'' and token != current_token:
229         raise Error('Can not load token {} from Kaldi model'.format(token) +
230                     refer_to_faq_msg(94))
231     return getters[value_type](file_desc)
232
233
234 def collect_until_whitespace(file_desc: io.BufferedReader):
235     """
236     Read from file until whitespace
237     :param file_desc: file descriptor
238     :return:
239     """
240     res = b''
241     while True:
242         new_sym = file_desc.read(1)
243         if new_sym == b' ' or new_sym == b'':
244             break
245         res += new_sym
246     return res
247
248
249 def collect_until_token(file_desc: io.BufferedReader, token):
250     """
251     Read from file until the token
252     :param file_desc: file descriptor
253     :return:
254     """
255     while True:
256         # usually there is the following structure <CellDim> DIM<ClipGradient> VALUEFM
257         res = collect_until_whitespace(file_desc)
258         if res == token or res[-len(token):] == token:
259             return
260         if isinstance(file_desc, io.BytesIO):
261             size = len(file_desc.getbuffer())
262         elif isinstance(file_desc, io.BufferedReader):
263             size = os.fstat(file_desc.fileno()).st_size
264         if file_desc.tell() == size:
265             raise Error('End of the file. Token {} not found. {}'.format(token, file_desc.tell()))
266
267
268 def create_edge_attrs(prev_layer_id: str, next_layer_id: str) -> dict:
269     """
270     Create common edge's attributes
271     :param prev_layer_id: id of previous layer
272     :param next_layer_id: id of next layer
273     :return: dictionary contains common attributes for edge
274     """
275     return {
276         'out': 0,
277         'in': 0,
278         'name': next_layer_id,
279         'fw_tensor_debug_info': [(prev_layer_id, next_layer_id)],
280         'in_attrs': ['in', 'name'],
281         'out_attrs': ['out', 'name'],
282         'data_attrs': ['fw_tensor_debug_info']
283     }
284
285
286 def read_blob(file_desc: io.BufferedReader, size: int, dtype=np.float32):
287     """
288     Read blob from the file
289     :param file_desc: file descriptor
290     :param size: size of the blob
291     :param dtype: type of values of the blob
292     :return: np array contains blob
293     """
294     dsizes = {
295         np.float32: 4,
296         np.int32: 4
297     }
298     data = file_desc.read(size * dsizes[dtype])
299     return np.fromstring(data, dtype=dtype)