"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
import logging as log
+
import numpy as np
+from extensions.back.CreateConstNodes import CreateConstNodesReplacement
from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern
from extensions.back.remove_last_softmax_pattern import RemoveLastSoftMaxPattern
from extensions.front.kaldi.eliminate_redundant_reshape import EliminateRedundantReshape
from extensions.front.kaldi.fuse_repeated_reshape import FuseRepeatedReshapes
from extensions.middle.EltwiseChecker import EltwiseChecker
from mo.front.common.register_custom_ops import update_extractors_with_extensions
-from mo.front.extractor import create_tensor_nodes, extract_node_attrs, add_output_ops, remove_output_ops
+from mo.front.extractor import extract_node_attrs, remove_output_ops
from mo.front.kaldi.extractor import kaldi_extractor, kaldi_type_extractors
from mo.front.kaldi.loader.loader import load_kaldi_model, read_counts_file
+from mo.graph.graph import Node
+from mo.middle.passes.eliminate import graph_clean_up, remove_const_ops
+from mo.middle.passes.infer import partial_infer
+from mo.pipeline.common import prepare_emit_ir
from mo.utils import class_registration
from mo.utils.cli_parser import get_meta_info
from mo.utils.error import Error
from mo.utils.find_inputs import find_outputs
-from mo.graph.graph import print_graph_stat, Node, check_empty_graph
-from mo.middle.passes.eliminate import graph_clean_up
-from mo.middle.passes.infer import override_placeholder_shapes, partial_infer, mark_outputs, override_batch
-from mo.pipeline.common import prepare_emit_ir
from mo.utils.utils import refer_to_faq_msg
biases_node = target_node.in_nodes()[2] # first - input, second - weights, third - biases
if biases_node.value is not None:
- biases_node.value = np.subtract(biases_node.value, counts)
+ biases_node.value = np.subtract(biases_node.value, counts) # pylint: disable=assignment-from-no-return
else:
biases_node.value = counts * -1
biases_node.shape = counts.shape
-def driver(argv, input_model, output_model_name, outputs, output_dir, scale, placeholder_shapes=None,
- mean_scale_values=()):
+def driver(argv, input_model, output_model_name, output_dir):
meta_info = get_meta_info(argv)
EltwiseChecker.enabled = False
except Exception as e:
raise Error('Model Optimizer is not able to read Kaldi model {}. '.format(input_model) +
refer_to_faq_msg(91)) from e
- check_empty_graph(graph, 'load_kaldi_nnet_model')
+ graph.check_empty_graph('load_kaldi_nnet_model')
graph.graph['cmd_params'] = argv
graph.graph['fw'] = 'kaldi'
- graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 4
-
+ graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
update_extractors_with_extensions(kaldi_type_extractors)
-
extract_node_attrs(graph, lambda node: kaldi_extractor(node))
+ # --------------------------------- LOAD END ------------------------------------------------------
class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
- output_op_nodes = add_output_ops(graph, outputs) # TODO pass real outputs instead of None
- log.debug("After adding specific nodes for outputs")
- print_graph_stat(graph)
-
- check_empty_graph(graph, 'add_output_ops')
- create_tensor_nodes(graph)
-
- graph_clean_up(graph)
- log.debug("After removing specific nodes for output")
- print_graph_stat(graph)
-
- override_placeholder_shapes(graph, placeholder_shapes)
- override_batch(graph, argv.batch)
-
- graph_clean_up(graph)
- log.debug("After setting input shapes")
- print_graph_stat(graph)
- graph_clean_up(graph)
- remove_output_ops(graph)
- log.debug("After removing specific nodes for output")
- print_graph_stat(graph)
-
- # You need to pass required network outputs here
- # but we don't have a way yet, so just passing all discovered sinks
- mark_outputs(graph)
- graph_clean_up(graph)
- log.debug("After graph_cleanup")
- print_graph_stat(graph)
graph = partial_infer(graph)
# The order is intentional, firstly eliminate repeated, then remove redundant
FuseRepeatedReshapes().find_and_replace_pattern(graph)
EliminateRedundantReshape().find_and_replace_pattern(graph)
- check_empty_graph(graph, 'partial_infer')
+ graph.check_empty_graph('partial_infer')
if argv.counts:
try:
counts = read_counts_file(argv.counts)
RemoveLastSoftMaxPattern().find_and_replace_pattern(graph)
graph_clean_up(graph)
log.debug("After removing softmax")
- print_graph_stat(graph)
+ graph.print_graph_stat()
# Intentionally after all transformations
KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)
+
+ remove_const_ops(graph)
+ CreateConstNodesReplacement().find_and_replace_pattern(graph)
+
+ remove_output_ops(graph)
+
prepare_emit_ir(graph, argv.data_type, output_dir, output_model_name, meta_info=meta_info)
return 0