Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / kaldi / loader / loader.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import io
17
18 import numpy as np
19 import struct
20 from io import IOBase
21
22 import networkx as nx
23 import logging as log
24
25 from mo.front.kaldi.loader.utils import find_next_tag, read_placeholder, find_next_component, get_name_from_path, \
26     find_end_of_component, end_of_nnet_tag, read_binary_integer32_token, get_parameters, read_token_value, collect_until_token, \
27     create_edge_attrs
28 from mo.graph.graph import Node, Graph
29 from mo.utils.error import Error
30 from mo.utils.utils import refer_to_faq_msg
31
32
33 def read_counts_file(file_path):
34     with open(file_path, 'r') as f:
35         file_content = f.readlines()
36     if len(file_content) > 1:
37         raise Error('Expect counts file to be one-line file. ' +
38                     refer_to_faq_msg(90))
39
40     counts_line = file_content[0].strip().replace('[', '').replace(']', '')
41     try:
42         counts = np.fromstring(counts_line, dtype=float, sep=' ')
43     except TypeError:
44         raise Error('Expect counts file to contain list of integers.' +
45                     refer_to_faq_msg(90))
46     cutoff = 1.00000001e-10
47     cutoff_idxs = np.where(counts < cutoff)
48     counts[cutoff_idxs] = cutoff
49     scale = 1.0 / np.sum(counts)
50     counts = np.log(counts * scale)  # pylint: disable=assignment-from-no-return
51     counts[cutoff_idxs] += np.finfo(np.float32).max / 2
52     return counts
53
54
55 def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
56     """
57     Load ParallelComponent of the Kaldi model.
58     ParallelComponent contains parallel nested networks.
59     Slice is inserted before nested networks.
60     Outputs of nested networks concatenate with layer Concat.
61
62     :param file_descr: descriptor of the model file
63     :param graph: graph with the topology.
64     :param prev_layer_id: id of the input layers for parallel component layer
65     :return: id of the concat layer - last layer of the parallel component layers
66     """
67     nnet_count = read_token_value(file_descr, b'<NestedNnetCount>')
68     log.debug('Model contains parallel component with {} nested networks'.format(nnet_count))
69
70     slice_id = graph.unique_id(prefix='Slice')
71     graph.add_node(slice_id, parameters=None, op='slice', kind='op')
72
73     slice_node = Node(graph, slice_id)
74     graph.add_edge(prev_layer_id, slice_id, **create_edge_attrs(prev_layer_id, slice_id))
75     slices_points = []
76
77     outputs = []
78
79     for i in range(nnet_count):
80         read_token_value(file_descr, b'<NestedNnet>')
81         collect_until_token(file_descr, b'<Nnet>')
82         g, shape = load_kalid_nnet1_model(file_descr, 'Nested_net_{}'.format(i))
83         input_nodes = [n for n in graph.nodes(data=True) if n[1]['op'] == 'Input']
84         if i != nnet_count - 1:
85             slices_points.append(shape[1])
86         g.remove_node(input_nodes[0][0])
87         mapping = {node: graph.unique_id(node) for node in g.nodes(data=False) if node in graph}
88         g = nx.relabel_nodes(g, mapping)
89         for val in mapping.values():
90             g.node[val]['name'] = val
91         graph.add_nodes_from(g.nodes(data=True))
92         graph.add_edges_from(g.edges(data=True))
93         sorted_nodes = tuple(nx.topological_sort(g))
94         edge_attrs = create_edge_attrs(slice_id, sorted_nodes[0])
95         edge_attrs['out'] = i
96         graph.add_edge(slice_id, sorted_nodes[0], **edge_attrs)
97         outputs.append(sorted_nodes[-1])
98     packed_sp = struct.pack("B", 4) + struct.pack("I", len(slices_points))
99     for i in slices_points:
100         packed_sp += struct.pack("I", i)
101     slice_node.parameters = io.BytesIO(packed_sp)
102     concat_id = graph.unique_id(prefix='Concat')
103     graph.add_node(concat_id, parameters=None, op='concat', kind='op')
104     for i, output in enumerate(outputs):
105         edge_attrs = create_edge_attrs(output, concat_id)
106         edge_attrs['in'] = i
107         graph.add_edge(output, concat_id, **edge_attrs)
108     return concat_id
109
110
111 def load_kaldi_model(nnet_path):
112     """
113     Structure of the file is the following:
114     magic-number(16896)<Nnet> <Next Layer Name> weights etc.
115     :param nnet_path:
116     :return:
117     """
118     nnet_name = None
119     if isinstance(nnet_path, str):
120         file_desc = open(nnet_path, "rb")
121         nnet_name = get_name_from_path(nnet_path)
122     elif isinstance(nnet_path, IOBase):
123         file_desc = nnet_path
124     else:
125         raise Error('Unsupported type of Kaldi model')
126
127     name = find_next_tag(file_desc)
128     # start new model / submodel
129     if name == '<Nnet>':
130         load_function = load_kalid_nnet1_model
131     elif name == '<TransitionModel>':
132         load_function = load_kalid_nnet2_model
133     else:
134         raise Error('Kaldi model should start with <Nnet> or <TransitionModel> tag. ',
135                     refer_to_faq_msg(89))
136     read_placeholder(file_desc, 1)
137
138     return load_function(file_desc, nnet_name)
139
140
141 def load_kalid_nnet1_model(file_descr, name):
142     graph = Graph(name=name)
143
144     prev_layer_id = 'Input'
145     graph.add_node(prev_layer_id, name=prev_layer_id, kind='op', op='Input', parameters=None)
146     input_shape = []
147
148     while True:
149         component_type = find_next_component(file_descr)
150         if component_type == end_of_nnet_tag.lower()[1:-1]:
151             break
152
153         layer_o = read_binary_integer32_token(file_descr)
154         layer_i = read_binary_integer32_token(file_descr)
155
156         if component_type == 'parallelcomponent':
157             prev_layer_id = load_parallel_component(file_descr, graph, prev_layer_id)
158             continue
159
160         start_index = file_descr.tell()
161         end_tag, end_index = find_end_of_component(file_descr, component_type)
162         end_index -= len(end_tag)
163         layer_id = graph.unique_id(prefix=component_type)
164         graph.add_node(layer_id,
165                        parameters=get_parameters(file_descr, start_index, end_index),
166                        op=component_type,
167                        kind='op',
168                        layer_i=layer_i,
169                        layer_o=layer_o)
170
171         prev_node = Node(graph, prev_layer_id)
172         if prev_node.op == 'Input':
173             prev_node['shape'] = np.array([1, layer_i], dtype=np.int64)
174             input_shape = np.array([1, layer_i], dtype=np.int64)
175         graph.add_edge(prev_layer_id, layer_id, **create_edge_attrs(prev_layer_id, layer_id))
176         prev_layer_id = layer_id
177         log.debug('{} (type is {}) was loaded'.format(prev_layer_id, component_type))
178     return graph, input_shape
179
180
181 def load_kalid_nnet2_model(file_descr, nnet_name):
182     graph = Graph(name=nnet_name)
183     input_name = 'Input'
184     input_shape = np.array([])
185     graph.add_node(input_name, name=input_name, kind='op', op='Input', parameters=None, shape=None)
186
187     prev_layer_id = input_name
188
189     collect_until_token(file_descr, b'<Nnet>')
190     num_components = read_token_value(file_descr, b'<NumComponents>')
191     log.debug('Network contains {} components'.format(num_components))
192     collect_until_token(file_descr, b'<Components>')
193     for _ in range(num_components):
194         component_type = find_next_component(file_descr)
195
196         if component_type == end_of_nnet_tag.lower()[1:-1]:
197             break
198         start_index = file_descr.tell()
199         end_tag, end_index = find_end_of_component(file_descr, component_type)
200         layer_id = graph.unique_id(prefix=component_type)
201         graph.add_node(layer_id,
202                        parameters=get_parameters(file_descr, start_index, end_index),
203                        op=component_type,
204                        kind='op')
205
206         prev_node = Node(graph, prev_layer_id)
207         if prev_node.op == 'Input':
208             parameters = Node(graph, layer_id).parameters
209             input_dim = read_token_value(parameters, b'<InputDim>')
210             prev_node['shape'] = np.array([1, input_dim], dtype=np.int64)
211             input_shape = np.array([1, input_dim], dtype=np.int64)
212         graph.add_edge(prev_layer_id, layer_id, **create_edge_attrs(prev_layer_id, layer_id))
213         prev_layer_id = layer_id
214         log.debug('{} (type is {}) was loaded'.format(prev_layer_id, component_type))
215     return graph, input_shape