"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import argparse
import logging as log
-import numpy as np
-
-from extensions.front.freeze_placeholder_value import FreezePlaceholderValue
-from extensions.middle.FusePermutesSequence import FusePermutesSequence
+from extensions.back.CreateConstNodes import CreateConstNodesReplacement
from mo.front.caffe import custom_layers_mapping, loader
-from mo.front.caffe.extractor import caffe_extractor, common_caffe_fields, caffe_type_extractors
-from mo.front.common.register_custom_ops import check_for_duplicates
-from mo.front.common.register_custom_ops import update_extractors_with_extensions
-from mo.front.common.replacement import FrontReplacementSubgraph
-from mo.front.extractor import extract_node_attrs, add_output_ops, create_tensor_nodes, remove_output_ops, \
- add_input_ops, user_data_repack
-from mo.graph.graph import print_graph_stat, check_empty_graph
+from mo.front.caffe.extractor import caffe_type_extractors, caffe_extractor
+from mo.front.common.register_custom_ops import update_extractors_with_extensions, check_for_duplicates
+from mo.front.extractor import extract_node_attrs, remove_output_ops
+from mo.middle.passes.conv import convert_add_or_mul_to_scaleshift
from mo.middle.passes.conv import convert_muladd_to_scaleshift_or_power, \
- convert_matmul_to_fully_connected, batch_norm_fuse, convert_add_to_scaleshift, \
- convert_mul_to_scaleshift, \
- convert_multi_input_conv
-from mo.middle.passes.eliminate import graph_clean_up, remove_op_nodes
+ convert_matmul_to_fully_connected, batch_norm_fuse
+from mo.middle.passes.eliminate import graph_clean_up
+from mo.middle.passes.eliminate import remove_const_ops
from mo.middle.passes.fusing.decomposition import convert_bn_to_mul_add, convert_scale_shift_to_mul_add
from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
from mo.middle.passes.fusing.resnet_optimization import stride_optimization
-from mo.middle.passes.infer import add_mean_scale_values, scale_input, override_placeholder_shapes, mark_outputs, \
- partial_infer, convert_mul_add_to_power, override_batch
+from mo.middle.passes.infer import convert_mul_add_to_power
from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
-from mo.middle.passes.pool import mean_to_avgpool
from mo.middle.passes.shape import reverse_input_channels, fuse_sequence_of_reshapes
-from mo.middle.passes.shared_weights_duplication import duplicate_shared_weights
from mo.pipeline.common import prepare_emit_ir
from mo.utils import class_registration
+from mo.utils.cli_parser import get_meta_info
from mo.utils.error import Error
from mo.utils.find_inputs import find_inputs
from mo.utils.utils import refer_to_faq_msg
-from mo.utils.cli_parser import get_meta_info
-def driver(argv: argparse.Namespace, proto_file_name: str, model_file_name: str, output_model_name: str, outputs: list,
- output_dir: str,
- scale: float,
- user_shapes: [None, list, np.array] = None, mean_scale_values: [dict, list] = (), mean_file: str = "",
- mean_file_offsets: tuple = None,
- custom_layers_mapping_path: str = None):
+def driver(argv: argparse.Namespace, proto_file_name: str, model_file_name: str, output_model_name: str,
+ output_dir: str, mean_file: str = "",
+ mean_file_offsets: tuple = None, custom_layers_mapping_path: str = None):
meta_info = get_meta_info(argv)
- FusePermutesSequence.enabled = False
-
proto, model = loader.load_caffe_proto_model(proto_file_name, model_file_name)
update_extractors_with_extensions(
refer_to_faq_msg(11), str(e)) from e
log.debug("After caffe_pb_to_nx")
- print_graph_stat(graph)
- check_empty_graph(graph, 'load_caffe_proto_model')
+ graph.print_graph_stat()
+ graph.check_empty_graph('load_caffe_proto_model')
graph.__setattr__('proto_path', proto_file_name)
graph.__setattr__('caffemodel_path', model_file_name)
graph.graph['layout'] = 'NCHW'
graph.graph['cmd_params'] = argv
graph.graph['fw'] = 'caffe'
- graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 4
-
- extract_node_attrs(graph, lambda node: (True, common_caffe_fields(node)))
-
- log.debug("After adding specific nodes for outputs")
- print_graph_stat(graph)
+ graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
custom_layers_map = custom_layers_mapping.load_layers_xml(custom_layers_mapping_path)
custom_layers_mapping.update_extractors(
argv.disable_omitting_optional if hasattr(argv, 'disable_omitting_optional') else False,
argv.enable_flattening_nested_params if hasattr(argv, 'enable_flattening_nested_params') else False
)
-
extract_node_attrs(graph, lambda node: caffe_extractor(node, check_for_duplicates(caffe_type_extractors)))
- log.debug("After extract_node_attr")
- print_graph_stat(graph)
-
- packed_user_shapes, packed_outputs, freeze_placeholder = user_data_repack(graph, user_shapes, outputs, argv.freeze_placeholder_with_value)
- if argv.freeze_placeholder_with_value is not None:
- FreezePlaceholderValue.enabled = True
- FreezePlaceholderValue.replacement_dict = freeze_placeholder
- class_registration.update_registration([FrontReplacementSubgraph])
- output_op_nodes = add_output_ops(graph, packed_outputs)
- input_op_nodes = add_input_ops(graph, packed_user_shapes, True)
- override_placeholder_shapes(graph, packed_user_shapes)
- override_batch(graph, argv.batch)
- graph_clean_up(graph)
- check_empty_graph(graph, 'add_output_ops and add_input_ops')
+ # --------------------------------- LOAD END ------------------------------------------------------
class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
-
- graph = create_tensor_nodes(graph)
-
- log.debug("After create_tensor_nodes")
- print_graph_stat(graph)
-
- remove_op_nodes(graph, {'op': 'Identity'})
- remove_output_ops(graph)
- graph_clean_up(graph)
-
- log.debug("After removing specific nodes for output")
- print_graph_stat(graph)
-
- # you need to pass required network outputs here
- # but we don't have a way yet, so just passing all discovered sinks
- mark_outputs(graph)
- graph_clean_up(graph)
- log.debug("After graph_cleanup")
- print_graph_stat(graph)
-
- graph = partial_infer(graph)
- log.debug("After partial_infer")
- print_graph_stat(graph)
- check_empty_graph(graph, 'partial_infer')
- duplicate_shared_weights(graph)
-
- input_op_nodes = add_input_ops(graph, packed_user_shapes, False)
- graph_clean_up(graph)
- check_empty_graph(graph, 'add_input_ops')
- scale_input(graph, scale)
-
- add_mean_scale_values(graph, mean_scale_values)
-
- log.debug("Split multi input convolutions")
- convert_multi_input_conv(graph)
-
- graph_clean_up(graph)
- log.debug("After graph_cleanup")
- print_graph_stat(graph)
-
- remove_op_nodes(graph, {'op': 'Dropout'})
- remove_op_nodes(graph, {'phase': 0})
- graph_clean_up(graph)
-
class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
- mean_to_avgpool(graph)
-
# Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
mark_unfused_nodes(graph, argv.finegrain_fusing)
- #need this pass even without fusing to convert scale with 2 inputs
+ # need this pass even without fusing to convert scale with 2 inputs
convert_scale_shift_to_mul_add(graph)
graph_clean_up(graph)
convert_matmul_to_fully_connected(graph)
batch_norm_fuse(graph)
convert_mul_add_to_power(graph)
- convert_add_to_scaleshift(graph) # scale = 1
- convert_mul_to_scaleshift(graph) # biases = 0
-
graph_clean_up(graph)
+ convert_add_or_mul_to_scaleshift(graph) # scale = 1
+ graph_clean_up(graph)
+
log.debug("After graph_cleanup")
- print_graph_stat(graph)
+ graph.print_graph_stat()
if argv.reverse_input_channels:
reverse_input_channels(graph)
class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
+ remove_const_ops(graph)
+ CreateConstNodesReplacement().find_and_replace_pattern(graph)
+
+ remove_output_ops(graph)
+
prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
mean_data=mf,
input_names=input_names,