2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
25 from mo.front.kaldi.loader.utils import find_next_tag, read_placeholder, find_next_component, get_name_from_path, \
26 find_end_of_component, end_of_nnet_tag, read_binary_integer32_token, get_parameters, read_token_value, collect_until_token, \
28 from mo.graph.graph import Node, Graph
29 from mo.utils.error import Error
30 from mo.utils.utils import refer_to_faq_msg
33 def read_counts_file(file_path):
34 with open(file_path, 'r') as f:
35 file_content = f.readlines()
36 if len(file_content) > 1:
37 raise Error('Expect counts file to be one-line file. ' +
40 counts_line = file_content[0].strip().replace('[', '').replace(']', '')
42 counts = np.fromstring(counts_line, dtype=float, sep=' ')
44 raise Error('Expect counts file to contain list of integers.' +
46 cutoff = 1.00000001e-10
47 cutoff_idxs = np.where(counts < cutoff)
48 counts[cutoff_idxs] = cutoff
49 scale = 1.0 / np.sum(counts)
50 counts = np.log(counts * scale) # pylint: disable=assignment-from-no-return
51 counts[cutoff_idxs] += np.finfo(np.float32).max / 2
55 def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
57 Load ParallelComponent of the Kaldi model.
58 ParallelComponent contains parallel nested networks.
59 Slice is inserted before nested networks.
60 Outputs of nested networks concatenate with layer Concat.
62 :param file_descr: descriptor of the model file
63 :param graph: graph with the topology.
64 :param prev_layer_id: id of the input layers for parallel component layer
65 :return: id of the concat layer - last layer of the parallel component layers
67 nnet_count = read_token_value(file_descr, b'<NestedNnetCount>')
68 log.debug('Model contains parallel component with {} nested networks'.format(nnet_count))
70 slice_id = graph.unique_id(prefix='Slice')
71 graph.add_node(slice_id, parameters=None, op='slice', kind='op')
73 slice_node = Node(graph, slice_id)
74 graph.add_edge(prev_layer_id, slice_id, **create_edge_attrs(prev_layer_id, slice_id))
79 for i in range(nnet_count):
80 read_token_value(file_descr, b'<NestedNnet>')
81 collect_until_token(file_descr, b'<Nnet>')
82 g, shape = load_kalid_nnet1_model(file_descr, 'Nested_net_{}'.format(i))
83 input_nodes = [n for n in graph.nodes(data=True) if n[1]['op'] == 'Input']
84 if i != nnet_count - 1:
85 slices_points.append(shape[1])
86 g.remove_node(input_nodes[0][0])
87 mapping = {node: graph.unique_id(node) for node in g.nodes(data=False) if node in graph}
88 g = nx.relabel_nodes(g, mapping)
89 for val in mapping.values():
90 g.node[val]['name'] = val
91 graph.add_nodes_from(g.nodes(data=True))
92 graph.add_edges_from(g.edges(data=True))
93 sorted_nodes = tuple(nx.topological_sort(g))
94 edge_attrs = create_edge_attrs(slice_id, sorted_nodes[0])
96 graph.add_edge(slice_id, sorted_nodes[0], **edge_attrs)
97 outputs.append(sorted_nodes[-1])
98 packed_sp = struct.pack("B", 4) + struct.pack("I", len(slices_points))
99 for i in slices_points:
100 packed_sp += struct.pack("I", i)
101 slice_node.parameters = io.BytesIO(packed_sp)
102 concat_id = graph.unique_id(prefix='Concat')
103 graph.add_node(concat_id, parameters=None, op='concat', kind='op')
104 for i, output in enumerate(outputs):
105 edge_attrs = create_edge_attrs(output, concat_id)
107 graph.add_edge(output, concat_id, **edge_attrs)
111 def load_kaldi_model(nnet_path):
113 Structure of the file is the following:
114 magic-number(16896)<Nnet> <Next Layer Name> weights etc.
119 if isinstance(nnet_path, str):
120 file_desc = open(nnet_path, "rb")
121 nnet_name = get_name_from_path(nnet_path)
122 elif isinstance(nnet_path, IOBase):
123 file_desc = nnet_path
125 raise Error('Unsupported type of Kaldi model')
127 name = find_next_tag(file_desc)
128 # start new model / submodel
130 load_function = load_kalid_nnet1_model
131 elif name == '<TransitionModel>':
132 load_function = load_kalid_nnet2_model
134 raise Error('Kaldi model should start with <Nnet> or <TransitionModel> tag. ',
135 refer_to_faq_msg(89))
136 read_placeholder(file_desc, 1)
138 return load_function(file_desc, nnet_name)
141 def load_kalid_nnet1_model(file_descr, name):
142 graph = Graph(name=name)
144 prev_layer_id = 'Input'
145 graph.add_node(prev_layer_id, name=prev_layer_id, kind='op', op='Input', parameters=None)
149 component_type = find_next_component(file_descr)
150 if component_type == end_of_nnet_tag.lower()[1:-1]:
153 layer_o = read_binary_integer32_token(file_descr)
154 layer_i = read_binary_integer32_token(file_descr)
156 if component_type == 'parallelcomponent':
157 prev_layer_id = load_parallel_component(file_descr, graph, prev_layer_id)
160 start_index = file_descr.tell()
161 end_tag, end_index = find_end_of_component(file_descr, component_type)
162 end_index -= len(end_tag)
163 layer_id = graph.unique_id(prefix=component_type)
164 graph.add_node(layer_id,
165 parameters=get_parameters(file_descr, start_index, end_index),
171 prev_node = Node(graph, prev_layer_id)
172 if prev_node.op == 'Input':
173 prev_node['shape'] = np.array([1, layer_i], dtype=np.int64)
174 input_shape = np.array([1, layer_i], dtype=np.int64)
175 graph.add_edge(prev_layer_id, layer_id, **create_edge_attrs(prev_layer_id, layer_id))
176 prev_layer_id = layer_id
177 log.debug('{} (type is {}) was loaded'.format(prev_layer_id, component_type))
178 return graph, input_shape
181 def load_kalid_nnet2_model(file_descr, nnet_name):
182 graph = Graph(name=nnet_name)
184 input_shape = np.array([])
185 graph.add_node(input_name, name=input_name, kind='op', op='Input', parameters=None, shape=None)
187 prev_layer_id = input_name
189 collect_until_token(file_descr, b'<Nnet>')
190 num_components = read_token_value(file_descr, b'<NumComponents>')
191 log.debug('Network contains {} components'.format(num_components))
192 collect_until_token(file_descr, b'<Components>')
193 for _ in range(num_components):
194 component_type = find_next_component(file_descr)
196 if component_type == end_of_nnet_tag.lower()[1:-1]:
198 start_index = file_descr.tell()
199 end_tag, end_index = find_end_of_component(file_descr, component_type)
200 layer_id = graph.unique_id(prefix=component_type)
201 graph.add_node(layer_id,
202 parameters=get_parameters(file_descr, start_index, end_index),
206 prev_node = Node(graph, prev_layer_id)
207 if prev_node.op == 'Input':
208 parameters = Node(graph, layer_id).parameters
209 input_dim = read_token_value(parameters, b'<InputDim>')
210 prev_node['shape'] = np.array([1, input_dim], dtype=np.int64)
211 input_shape = np.array([1, input_dim], dtype=np.int64)
212 graph.add_edge(prev_layer_id, layer_id, **create_edge_attrs(prev_layer_id, layer_id))
213 prev_layer_id = layer_id
214 log.debug('{} (type is {}) was loaded'.format(prev_layer_id, component_type))
215 return graph, input_shape