"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
from mo.front.kaldi.loader.utils import find_next_tag, read_placeholder, find_next_component, get_name_from_path, \
find_end_of_component, end_of_nnet_tag, read_binary_integer32_token, get_parameters, read_token_value, collect_until_token, \
create_edge_attrs
-from mo.graph.graph import unique_id, Node
+from mo.graph.graph import Node, Graph
from mo.utils.error import Error
from mo.utils.utils import refer_to_faq_msg
counts_line = file_content[0].strip().replace('[', '').replace(']', '')
try:
- counts = np.fromstring(counts_line, dtype=int, sep=' ')
+ counts = np.fromstring(counts_line, dtype=float, sep=' ')
except TypeError:
raise Error('Expect counts file to contain list of integers.' +
refer_to_faq_msg(90))
cutoff_idxs = np.where(counts < cutoff)
counts[cutoff_idxs] = cutoff
scale = 1.0 / np.sum(counts)
- counts = np.log(counts * scale)
+ counts = np.log(counts * scale) # pylint: disable=assignment-from-no-return
counts[cutoff_idxs] += np.finfo(np.float32).max / 2
return counts
-def load_parallel_component(file_descr, graph: nx.MultiDiGraph, prev_layer_id):
+def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
"""
Load ParallelComponent of the Kaldi model.
ParallelComponent contains parallel nested networks.
nnet_count = read_token_value(file_descr, b'<NestedNnetCount>')
log.debug('Model contains parallel component with {} nested networks'.format(nnet_count))
- slice_id = unique_id(graph, prefix='Slice')
+ slice_id = graph.unique_id(prefix='Slice')
graph.add_node(slice_id, parameters=None, op='slice', kind='op')
slice_node = Node(graph, slice_id)
if i != nnet_count - 1:
slices_points.append(shape[1])
g.remove_node(input_nodes[0][0])
- mapping = {node: unique_id(graph, node) for node in g.nodes(data=False) if node in graph}
+ mapping = {node: graph.unique_id(node) for node in g.nodes(data=False) if node in graph}
g = nx.relabel_nodes(g, mapping)
for val in mapping.values():
g.node[val]['name'] = val
for i in slices_points:
packed_sp += struct.pack("I", i)
slice_node.parameters = io.BytesIO(packed_sp)
- concat_id = unique_id(graph, prefix='Concat')
+ concat_id = graph.unique_id(prefix='Concat')
graph.add_node(concat_id, parameters=None, op='concat', kind='op')
for i, output in enumerate(outputs):
edge_attrs = create_edge_attrs(output, concat_id)
Structure of the file is the following:
magic-number(16896)<Nnet> <Next Layer Name> weights etc.
:param nnet_path:
- :param check_sum:
:return:
"""
nnet_name = None
def load_kalid_nnet1_model(file_descr, name):
- graph = nx.MultiDiGraph(name=name)
+ graph = Graph(name=name)
prev_layer_id = 'Input'
graph.add_node(prev_layer_id, name=prev_layer_id, kind='op', op='Input', parameters=None)
start_index = file_descr.tell()
end_tag, end_index = find_end_of_component(file_descr, component_type)
end_index -= len(end_tag)
- layer_id = unique_id(graph, prefix=component_type)
+ layer_id = graph.unique_id(prefix=component_type)
graph.add_node(layer_id,
parameters=get_parameters(file_descr, start_index, end_index),
op=component_type,
def load_kalid_nnet2_model(file_descr, nnet_name):
- graph = nx.MultiDiGraph(name=nnet_name)
+ graph = Graph(name=nnet_name)
input_name = 'Input'
+ input_shape = np.array([])
graph.add_node(input_name, name=input_name, kind='op', op='Input', parameters=None, shape=None)
prev_layer_id = input_name
break
start_index = file_descr.tell()
end_tag, end_index = find_end_of_component(file_descr, component_type)
- layer_id = unique_id(graph, prefix=component_type)
+ layer_id = graph.unique_id(prefix=component_type)
graph.add_node(layer_id,
parameters=get_parameters(file_descr, start_index, end_index),
op=component_type,