Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / pipeline / mx.py
index 03ac18f..e382cd6 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -13,6 +13,9 @@
  See the License for the specific language governing permissions and
  limitations under the License.
 """
+from extensions.back.CreateConstNodes import CreateConstNodesReplacement
+from extensions.front.restore_ports import RestorePorts
+from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
 from mo.utils.error import Error, FrameworkError
 from mo.utils.utils import refer_to_faq_msg
 
@@ -22,31 +25,23 @@ except ImportError:
     raise Error('Module mxnet was not found. Please install appropriate version of mxnet via install_prerequisites '
                 'script.' + refer_to_faq_msg(52))
 
-import logging as log
-
-import numpy as np
 import argparse
-import networkx as nx
 
-from mo.front.extractor import add_output_ops, extract_node_attrs, create_tensor_nodes, \
-    add_input_ops, remove_output_ops, user_data_repack
+from mo.front.extractor import extract_node_attrs, remove_output_ops
 from mo.front.mxnet.extractor import mxnet_op_extractor
 from mo.front.mxnet.loader import symbol2nx, load_symbol_def
 from mo.middle.passes.fusing.decomposition import convert_batch_norm, convert_scale_shift_to_mul_add
 from mo.middle.passes.conv import convert_muladd_to_scaleshift_or_power, \
-    convert_add_to_scaleshift, convert_mul_to_scaleshift, fuse_pad
-from mo.middle.passes.eliminate import graph_clean_up, remove_op_nodes
+    convert_add_or_mul_to_scaleshift, fuse_pad
+from mo.middle.passes.eliminate import graph_clean_up, remove_const_ops
 from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
 from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
 from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
-from mo.middle.passes.shared_weights_duplication import duplicate_shared_weights
 from mo.middle.passes.fusing.resnet_optimization import stride_optimization
-from mo.middle.passes.infer import mark_outputs, override_placeholder_shapes, partial_infer, add_mean_scale_values, \
-    scale_input, convert_mul_add_to_power
+from mo.middle.passes.infer import convert_mul_add_to_power
 from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
 from mo.middle.passes.shape import reverse_input_channels
 from mo.pipeline.common import prepare_emit_ir
-from mo.graph.graph import create_edge, Node, print_graph_stat, check_empty_graph
 from mo.front.mxnet.nd_to_params import save_params_file
 from mo.front.common.register_custom_ops import update_extractors_with_extensions
 from mo.front.mxnet.extractor import mxnet_op_extractors
@@ -55,48 +50,7 @@ from mo.utils.cli_parser import get_meta_info
 from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
 
 
-def add_input_data_to_prior_boxes(graph: nx.MultiDiGraph, input_names: str = ''):
-    """
-    PriorBox layer has data input unlike mxnet.
-    Need to add data input to _contrib_MultiBoxPrior for
-    for correct conversion to PriorBox layer.
-
-    Parameters
-    ----------
-    graph : nx.MultiDiGraph
-       Graph with loaded model.
-    """
-    if not input_names:
-        input_names = ('data',)
-    else:
-        input_names = input_names.split(',')
-
-    input_nodes = {}
-    for node in graph.nodes():
-        node = Node(graph, node)
-        if node.has_valid('op') and node.name in input_names:
-            input_nodes.update({node.id: node})
-
-    if len(input_nodes) > 0:
-        for node in graph.nodes():
-            node = Node(graph, node)
-            if node.has_valid('op') and node.op == '_contrib_MultiBoxPrior':
-                create_edge(list(input_nodes.values())[0], node, out_port=0, in_port=1)
-
-
-#TODO Remove the func after 'add_output_ops' will be moved to front replacer.
-def check_softmax_node_inputs(graph: nx.MultiDiGraph):
-    for i, attrs in list(graph.nodes(data=True)):
-        if 'op' in attrs and attrs['op'] == 'SoftMax':
-            node = Node(graph, i)
-            if len(node.in_nodes()) > 1:
-                graph.remove_node(node.in_node(1).id)
-
-
-def driver(argv: argparse.Namespace, input_model: str, output_model_name: str, outputs: list, output_dir: str,
-           scale: float,
-           placeholder_shapes: [None, list, np.array] = None,
-           mean_scale_values: [dict, list] = ()):
+def driver(argv: argparse.Namespace, input_model: str, output_model_name: str, output_dir: str):
     meta_info = get_meta_info(argv)
 
     try:
@@ -118,61 +72,20 @@ def driver(argv: argparse.Namespace, input_model: str, output_model_name: str, o
 
     update_extractors_with_extensions(mxnet_op_extractors)
     graph = symbol2nx(model_nodes, model_params, argv.input)
-    check_empty_graph(graph, 'symbol2nx. It may happen due to problems with loaded model')
+    graph.check_empty_graph('symbol2nx. It may happen due to problems with loaded model')
 
     graph.__setattr__('name', output_model_name)
     graph.graph['layout'] = 'NCHW'
     graph.graph['cmd_params'] = argv
     graph.graph['fw'] = 'mxnet'
     graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
-    graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 4
-    graph = extract_node_attrs(graph, mxnet_op_extractor)
-    check_softmax_node_inputs(graph)
-
-    user_shapes, packed_outputs, _ = user_data_repack(graph, placeholder_shapes, outputs, None)
-    output_op_nodes = add_output_ops(graph, packed_outputs)
-    input_op_nodes = add_input_ops(graph, user_shapes, True)
-
-    try:
-        override_placeholder_shapes(graph, user_shapes, argv.batch)
-    except ValueError as err:
-        raise Error(
-            'The following error happened while processing input shapes: {}. ' +
-            refer_to_faq_msg(54),
-            str(err)
-        ) from err
-    check_empty_graph(graph, 'add_output_ops and add_input_ops')
+    graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
+    extract_node_attrs(graph, mxnet_op_extractor)
 
+    # --------------------------------- LOAD END ------------------------------------------------------
     class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
-    add_input_data_to_prior_boxes(graph, argv.input)
-
-    graph = create_tensor_nodes(graph)
-
-    graph_clean_up(graph)
-    remove_output_ops(graph)
-    mark_outputs(graph)
-    remove_output_ops(graph)
-
-    graph_clean_up(graph)
-
-    log.debug("After removing specific nodes for output")
-
-    print_graph_stat(graph)
-
-    graph = partial_infer(graph)
-    graph_clean_up(graph)
-    check_empty_graph(graph, 'partial_infer')
-
-    duplicate_shared_weights(graph)
-
-    scale_input(graph, scale)
-    add_mean_scale_values(graph, mean_scale_values)
-
-    remove_op_nodes(graph, {'identity': True})
-
-    graph_clean_up(graph)
-
     class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
+
     fuse_pad(graph)
 
     # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
@@ -205,8 +118,9 @@ def driver(argv: argparse.Namespace, input_model: str, output_model_name: str, o
     graph_clean_up(graph)
 
     convert_mul_add_to_power(graph)
-    convert_add_to_scaleshift(graph)  # scale = 1
-    convert_mul_to_scaleshift(graph)  # biases = 0
+    graph_clean_up(graph)
+    convert_add_or_mul_to_scaleshift(graph)  # scale = 1
+    graph_clean_up(graph)
 
     if argv.reverse_input_channels:
         reverse_input_channels(graph)
@@ -220,6 +134,11 @@ def driver(argv: argparse.Namespace, input_model: str, output_model_name: str, o
 
     class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
 
+    for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)
+    CreateConstNodesReplacement().find_and_replace_pattern(graph)
+
+    for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)
+
     prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
                     meta_info=meta_info)
     return 0