"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
See the License for the specific language governing permissions and
limitations under the License.
"""
+from extensions.back.CreateConstNodes import CreateConstNodesReplacement
+from extensions.front.restore_ports import RestorePorts
+from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
from mo.utils.error import Error, FrameworkError
from mo.utils.utils import refer_to_faq_msg
raise Error('Module mxnet was not found. Please install appropriate version of mxnet via install_prerequisites '
'script.' + refer_to_faq_msg(52))
-import logging as log
-
-import numpy as np
import argparse
-import networkx as nx
-from mo.front.extractor import add_output_ops, extract_node_attrs, create_tensor_nodes, \
- add_input_ops, remove_output_ops, user_data_repack
+from mo.front.extractor import extract_node_attrs, remove_output_ops
from mo.front.mxnet.extractor import mxnet_op_extractor
from mo.front.mxnet.loader import symbol2nx, load_symbol_def
from mo.middle.passes.fusing.decomposition import convert_batch_norm, convert_scale_shift_to_mul_add
from mo.middle.passes.conv import convert_muladd_to_scaleshift_or_power, \
- convert_add_to_scaleshift, convert_mul_to_scaleshift, fuse_pad
-from mo.middle.passes.eliminate import graph_clean_up, remove_op_nodes
+ convert_add_or_mul_to_scaleshift, fuse_pad
+from mo.middle.passes.eliminate import graph_clean_up, remove_const_ops
from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
-from mo.middle.passes.shared_weights_duplication import duplicate_shared_weights
from mo.middle.passes.fusing.resnet_optimization import stride_optimization
-from mo.middle.passes.infer import mark_outputs, override_placeholder_shapes, partial_infer, add_mean_scale_values, \
- scale_input, convert_mul_add_to_power
+from mo.middle.passes.infer import convert_mul_add_to_power
from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
from mo.middle.passes.shape import reverse_input_channels
from mo.pipeline.common import prepare_emit_ir
-from mo.graph.graph import create_edge, Node, print_graph_stat, check_empty_graph
from mo.front.mxnet.nd_to_params import save_params_file
from mo.front.common.register_custom_ops import update_extractors_with_extensions
from mo.front.mxnet.extractor import mxnet_op_extractors
from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
-def add_input_data_to_prior_boxes(graph: nx.MultiDiGraph, input_names: str = ''):
- """
- PriorBox layer has data input unlike mxnet.
- Need to add data input to _contrib_MultiBoxPrior for
- for correct conversion to PriorBox layer.
-
- Parameters
- ----------
- graph : nx.MultiDiGraph
- Graph with loaded model.
- """
- if not input_names:
- input_names = ('data',)
- else:
- input_names = input_names.split(',')
-
- input_nodes = {}
- for node in graph.nodes():
- node = Node(graph, node)
- if node.has_valid('op') and node.name in input_names:
- input_nodes.update({node.id: node})
-
- if len(input_nodes) > 0:
- for node in graph.nodes():
- node = Node(graph, node)
- if node.has_valid('op') and node.op == '_contrib_MultiBoxPrior':
- create_edge(list(input_nodes.values())[0], node, out_port=0, in_port=1)
-
-
-#TODO Remove the func after 'add_output_ops' will be moved to front replacer.
-def check_softmax_node_inputs(graph: nx.MultiDiGraph):
- for i, attrs in list(graph.nodes(data=True)):
- if 'op' in attrs and attrs['op'] == 'SoftMax':
- node = Node(graph, i)
- if len(node.in_nodes()) > 1:
- graph.remove_node(node.in_node(1).id)
-
-
-def driver(argv: argparse.Namespace, input_model: str, output_model_name: str, outputs: list, output_dir: str,
- scale: float,
- placeholder_shapes: [None, list, np.array] = None,
- mean_scale_values: [dict, list] = ()):
+def driver(argv: argparse.Namespace, input_model: str, output_model_name: str, output_dir: str):
meta_info = get_meta_info(argv)
try:
update_extractors_with_extensions(mxnet_op_extractors)
graph = symbol2nx(model_nodes, model_params, argv.input)
- check_empty_graph(graph, 'symbol2nx. It may happen due to problems with loaded model')
+ graph.check_empty_graph('symbol2nx. It may happen due to problems with loaded model')
graph.__setattr__('name', output_model_name)
graph.graph['layout'] = 'NCHW'
graph.graph['cmd_params'] = argv
graph.graph['fw'] = 'mxnet'
graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
- graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 4
- graph = extract_node_attrs(graph, mxnet_op_extractor)
- check_softmax_node_inputs(graph)
-
- user_shapes, packed_outputs, _ = user_data_repack(graph, placeholder_shapes, outputs, None)
- output_op_nodes = add_output_ops(graph, packed_outputs)
- input_op_nodes = add_input_ops(graph, user_shapes, True)
-
- try:
- override_placeholder_shapes(graph, user_shapes, argv.batch)
- except ValueError as err:
- raise Error(
- 'The following error happened while processing input shapes: {}. ' +
- refer_to_faq_msg(54),
- str(err)
- ) from err
- check_empty_graph(graph, 'add_output_ops and add_input_ops')
+ graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
+ extract_node_attrs(graph, mxnet_op_extractor)
+ # --------------------------------- LOAD END ------------------------------------------------------
class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
- add_input_data_to_prior_boxes(graph, argv.input)
-
- graph = create_tensor_nodes(graph)
-
- graph_clean_up(graph)
- remove_output_ops(graph)
- mark_outputs(graph)
- remove_output_ops(graph)
-
- graph_clean_up(graph)
-
- log.debug("After removing specific nodes for output")
-
- print_graph_stat(graph)
-
- graph = partial_infer(graph)
- graph_clean_up(graph)
- check_empty_graph(graph, 'partial_infer')
-
- duplicate_shared_weights(graph)
-
- scale_input(graph, scale)
- add_mean_scale_values(graph, mean_scale_values)
-
- remove_op_nodes(graph, {'identity': True})
-
- graph_clean_up(graph)
-
class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
+
fuse_pad(graph)
# Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
graph_clean_up(graph)
convert_mul_add_to_power(graph)
- convert_add_to_scaleshift(graph) # scale = 1
- convert_mul_to_scaleshift(graph) # biases = 0
+ graph_clean_up(graph)
+ convert_add_or_mul_to_scaleshift(graph) # scale = 1
+ graph_clean_up(graph)
if argv.reverse_input_channels:
reverse_input_channels(graph)
class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
+ for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)
+ CreateConstNodesReplacement().find_and_replace_pattern(graph)
+
+ for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)
+
prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
meta_info=meta_info)
return 0