Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / pipeline / kaldi.py
index fcb3faa..e86b794 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 import logging as log
+
 import numpy as np
 
+from extensions.back.CreateConstNodes import CreateConstNodesReplacement
 from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern
 from extensions.back.remove_last_softmax_pattern import RemoveLastSoftMaxPattern
 from extensions.front.kaldi.eliminate_redundant_reshape import EliminateRedundantReshape
 from extensions.front.kaldi.fuse_repeated_reshape import FuseRepeatedReshapes
 from extensions.middle.EltwiseChecker import EltwiseChecker
 from mo.front.common.register_custom_ops import update_extractors_with_extensions
-from mo.front.extractor import create_tensor_nodes, extract_node_attrs, add_output_ops, remove_output_ops
+from mo.front.extractor import extract_node_attrs, remove_output_ops
 from mo.front.kaldi.extractor import kaldi_extractor, kaldi_type_extractors
 from mo.front.kaldi.loader.loader import load_kaldi_model, read_counts_file
+from mo.graph.graph import Node
+from mo.middle.passes.eliminate import graph_clean_up, remove_const_ops
+from mo.middle.passes.infer import partial_infer
+from mo.pipeline.common import prepare_emit_ir
 from mo.utils import class_registration
 from mo.utils.cli_parser import get_meta_info
 from mo.utils.error import Error
 from mo.utils.find_inputs import find_outputs
-from mo.graph.graph import print_graph_stat, Node, check_empty_graph
-from mo.middle.passes.eliminate import graph_clean_up
-from mo.middle.passes.infer import override_placeholder_shapes, partial_infer, mark_outputs, override_batch
-from mo.pipeline.common import prepare_emit_ir
 from mo.utils.utils import refer_to_faq_msg
 
 
@@ -92,14 +94,13 @@ def apply_biases_to_last_layer(graph, counts):
 
     biases_node = target_node.in_nodes()[2]  # first - input, second - weights, third - biases
     if biases_node.value is not None:
-        biases_node.value = np.subtract(biases_node.value, counts)
+        biases_node.value = np.subtract(biases_node.value, counts)  # pylint: disable=assignment-from-no-return
     else:
         biases_node.value = counts * -1
         biases_node.shape = counts.shape
 
 
-def driver(argv, input_model, output_model_name, outputs, output_dir, scale, placeholder_shapes=None,
-           mean_scale_values=()):
+def driver(argv, input_model, output_model_name, output_dir):
     meta_info = get_meta_info(argv)
 
     EltwiseChecker.enabled = False
@@ -109,51 +110,22 @@ def driver(argv, input_model, output_model_name, outputs, output_dir, scale, pla
     except Exception as e:
         raise Error('Model Optimizer is not able to read Kaldi model {}. '.format(input_model) +
                     refer_to_faq_msg(91)) from e
-    check_empty_graph(graph, 'load_kaldi_nnet_model')
+    graph.check_empty_graph('load_kaldi_nnet_model')
     graph.graph['cmd_params'] = argv
     graph.graph['fw'] = 'kaldi'
-    graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 4
-
+    graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
     update_extractors_with_extensions(kaldi_type_extractors)
-
     extract_node_attrs(graph, lambda node: kaldi_extractor(node))
 
+    # --------------------------------- LOAD END ------------------------------------------------------
     class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
 
-    output_op_nodes = add_output_ops(graph, outputs)  # TODO pass real outputs instead of None
-    log.debug("After adding specific nodes for outputs")
-    print_graph_stat(graph)
-
-    check_empty_graph(graph, 'add_output_ops')
-    create_tensor_nodes(graph)
-
-    graph_clean_up(graph)
-    log.debug("After removing specific nodes for output")
-    print_graph_stat(graph)
-
-    override_placeholder_shapes(graph, placeholder_shapes)
-    override_batch(graph, argv.batch)
-
-    graph_clean_up(graph)
-    log.debug("After setting input shapes")
-    print_graph_stat(graph)
-    graph_clean_up(graph)
-    remove_output_ops(graph)
-    log.debug("After removing specific nodes for output")
-    print_graph_stat(graph)
-
-    # You need to pass required network outputs here
-    # but we don't have a way yet, so just passing all discovered sinks
-    mark_outputs(graph)
-    graph_clean_up(graph)
-    log.debug("After graph_cleanup")
-    print_graph_stat(graph)
     graph = partial_infer(graph)
 
     # The order is intentional, firstly eliminate repeated, then remove redundant
     FuseRepeatedReshapes().find_and_replace_pattern(graph)
     EliminateRedundantReshape().find_and_replace_pattern(graph)
-    check_empty_graph(graph, 'partial_infer')
+    graph.check_empty_graph('partial_infer')
     if argv.counts:
         try:
             counts = read_counts_file(argv.counts)
@@ -167,9 +139,15 @@ def driver(argv, input_model, output_model_name, outputs, output_dir, scale, pla
         RemoveLastSoftMaxPattern().find_and_replace_pattern(graph)
         graph_clean_up(graph)
         log.debug("After removing softmax")
-        print_graph_stat(graph)
+        graph.print_graph_stat()
 
     # Intentionally after all transformations
     KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)
+
+    remove_const_ops(graph)
+    CreateConstNodesReplacement().find_and_replace_pattern(graph)
+
+    remove_output_ops(graph)
+
     prepare_emit_ir(graph, argv.data_type, output_dir, output_model_name, meta_info=meta_info)
     return 0