Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / kaldi / loader / loader.py
index 8bf9085..9f0bdf3 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@ import logging as log
 from mo.front.kaldi.loader.utils import find_next_tag, read_placeholder, find_next_component, get_name_from_path, \
     find_end_of_component, end_of_nnet_tag, read_binary_integer32_token, get_parameters, read_token_value, collect_until_token, \
     create_edge_attrs
-from mo.graph.graph import unique_id, Node
+from mo.graph.graph import Node, Graph
 from mo.utils.error import Error
 from mo.utils.utils import refer_to_faq_msg
 
@@ -39,7 +39,7 @@ def read_counts_file(file_path):
 
     counts_line = file_content[0].strip().replace('[', '').replace(']', '')
     try:
-        counts = np.fromstring(counts_line, dtype=int, sep=' ')
+        counts = np.fromstring(counts_line, dtype=float, sep=' ')
     except TypeError:
         raise Error('Expect counts file to contain list of integers.' +
                     refer_to_faq_msg(90))
@@ -47,12 +47,12 @@ def read_counts_file(file_path):
     cutoff_idxs = np.where(counts < cutoff)
     counts[cutoff_idxs] = cutoff
     scale = 1.0 / np.sum(counts)
-    counts = np.log(counts * scale)
+    counts = np.log(counts * scale)  # pylint: disable=assignment-from-no-return
     counts[cutoff_idxs] += np.finfo(np.float32).max / 2
     return counts
 
 
-def load_parallel_component(file_descr, graph: nx.MultiDiGraph, prev_layer_id):
+def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
     """
     Load ParallelComponent of the Kaldi model.
     ParallelComponent contains parallel nested networks.
@@ -67,7 +67,7 @@ def load_parallel_component(file_descr, graph: nx.MultiDiGraph, prev_layer_id):
     nnet_count = read_token_value(file_descr, b'<NestedNnetCount>')
     log.debug('Model contains parallel component with {} nested networks'.format(nnet_count))
 
-    slice_id = unique_id(graph, prefix='Slice')
+    slice_id = graph.unique_id(prefix='Slice')
     graph.add_node(slice_id, parameters=None, op='slice', kind='op')
 
     slice_node = Node(graph, slice_id)
@@ -84,7 +84,7 @@ def load_parallel_component(file_descr, graph: nx.MultiDiGraph, prev_layer_id):
         if i != nnet_count - 1:
             slices_points.append(shape[1])
         g.remove_node(input_nodes[0][0])
-        mapping = {node: unique_id(graph, node) for node in g.nodes(data=False) if node in graph}
+        mapping = {node: graph.unique_id(node) for node in g.nodes(data=False) if node in graph}
         g = nx.relabel_nodes(g, mapping)
         for val in mapping.values():
             g.node[val]['name'] = val
@@ -99,7 +99,7 @@ def load_parallel_component(file_descr, graph: nx.MultiDiGraph, prev_layer_id):
     for i in slices_points:
         packed_sp += struct.pack("I", i)
     slice_node.parameters = io.BytesIO(packed_sp)
-    concat_id = unique_id(graph, prefix='Concat')
+    concat_id = graph.unique_id(prefix='Concat')
     graph.add_node(concat_id, parameters=None, op='concat', kind='op')
     for i, output in enumerate(outputs):
         edge_attrs = create_edge_attrs(output, concat_id)
@@ -113,7 +113,6 @@ def load_kaldi_model(nnet_path):
     Structure of the file is the following:
     magic-number(16896)<Nnet> <Next Layer Name> weights etc.
     :param nnet_path:
-    :param check_sum:
     :return:
     """
     nnet_name = None
@@ -140,7 +139,7 @@ def load_kaldi_model(nnet_path):
 
 
 def load_kalid_nnet1_model(file_descr, name):
-    graph = nx.MultiDiGraph(name=name)
+    graph = Graph(name=name)
 
     prev_layer_id = 'Input'
     graph.add_node(prev_layer_id, name=prev_layer_id, kind='op', op='Input', parameters=None)
@@ -161,7 +160,7 @@ def load_kalid_nnet1_model(file_descr, name):
         start_index = file_descr.tell()
         end_tag, end_index = find_end_of_component(file_descr, component_type)
         end_index -= len(end_tag)
-        layer_id = unique_id(graph, prefix=component_type)
+        layer_id = graph.unique_id(prefix=component_type)
         graph.add_node(layer_id,
                        parameters=get_parameters(file_descr, start_index, end_index),
                        op=component_type,
@@ -180,8 +179,9 @@ def load_kalid_nnet1_model(file_descr, name):
 
 
 def load_kalid_nnet2_model(file_descr, nnet_name):
-    graph = nx.MultiDiGraph(name=nnet_name)
+    graph = Graph(name=nnet_name)
     input_name = 'Input'
+    input_shape = np.array([])
     graph.add_node(input_name, name=input_name, kind='op', op='Input', parameters=None, shape=None)
 
     prev_layer_id = input_name
@@ -197,7 +197,7 @@ def load_kalid_nnet2_model(file_descr, nnet_name):
             break
         start_index = file_descr.tell()
         end_tag, end_index = find_end_of_component(file_descr, component_type)
-        layer_id = unique_id(graph, prefix=component_type)
+        layer_id = graph.unique_id(prefix=component_type)
         graph.add_node(layer_id,
                        parameters=get_parameters(file_descr, start_index, end_index),
                        op=component_type,