"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import logging as log
-import networkx as nx
import numpy as np
+from mo.front.common.partial_infer.utils import int64_array
from mo.front.extractor import update_attrs
-from mo.graph.graph import Node, create_edge
-from mo.middle.passes.eliminate import remove_op_node_with_data_node, merge_data_nodes, graph_clean_up_tf, get_nodes_with_attributes
+from mo.graph.graph import Node, Graph
+from mo.middle.passes.eliminate import remove_op_node_with_data_node, merge_data_nodes, graph_clean_up_tf
from mo.middle.passes.fusing.helpers import get_next_operation
from mo.middle.pattern_match import apply_pattern
from mo.ops.op import PermuteAttrs, Op
from mo.utils.utils import refer_to_faq_msg
-def reshape_squeeze_transform(graph: nx.MultiDiGraph, match: dict):
+def reshape_squeeze_transform(graph: Graph, match: dict):
reshape = match['reshape']
output = match['output']
if output.shape is None:
# do not override value 'dim' if it is set. It may contain specific values like -1 and 0
reshape['dim'] = reshape.shape.copy()
update_attrs(reshape, 'shape_attrs', 'dim')
- if 'shape' in match:
- graph.remove_edge(match['shape'].node, match['reshape'].node)
-def convert_squeeze(graph: nx.MultiDiGraph):
+def convert_squeeze(graph: Graph):
apply_pattern(
graph,
nodes=[
)
-def convert_reshape(graph: nx.MultiDiGraph):
+def convert_reshape(graph: Graph):
apply_pattern(
graph,
nodes=[
return False
-def repack_fully_connected_weights_nhwc_to_nchw(graph: nx.MultiDiGraph):
+def repack_fully_connected_weights_nhwc_to_nchw(graph: Graph):
"""
Repack weights of FullyConnected layer as a part of nhwc_to_nchw translation if Reshape of
that involves dimensions that we are repacking appears right before FullyConnected layer.
"""
- for node_id in get_nodes_with_attributes(graph, type='FullyConnected'):
+ for node_id in graph.get_nodes_with_attributes(type='FullyConnected'):
fc_node = Node(graph, node_id)
if not can_repack_fully_connected_weights_nhwc_to_nchw(fc_node):
weights.value = np.transpose(weights.value.reshape(tmp_shape), (2, 0, 1, 3)).reshape(weights.shape)
-def apply_nhwc_to_nchw_permutation(graph: nx.MultiDiGraph):
+def apply_nhwc_to_nchw_permutation(graph: Graph):
# Add NHWC to NCHW permutation for all data nodes (only for nodes without permutation)
if graph.graph['layout'] == 'NCHW':
return
PermuteAttrs.set_permutation(node, out_node, permutation)
-def merge_nodes_permutations(graph: nx.MultiDiGraph):
+def merge_nodes_permutations(graph: Graph):
# Iterate over all data nodes and check all permutations for similarity
# In case of equal permutations, this permutation will be set as attribute for data node
# otherwise exception will be raised
node.permutation = None
-def permute_data_nodes_attrs(graph: nx.MultiDiGraph):
+def permute_data_nodes_attrs(graph: Graph):
# Iterate over all data nodes and apply permutation if exists
for node in graph.nodes():
node = Node(graph, node)
node.value = np.array(node.value.transpose(node.permutation.perm))
-def permute_op_nodes_attrs(graph: nx.MultiDiGraph):
+def permute_op_nodes_attrs(graph: Graph):
for node in graph.nodes():
node = Node(graph, node)
if node.kind == 'op' and node.has_valid('permute_attrs'):
raise Error('Can\'t permute attrs for node {}. Error message: {}'.format(node.id, e))
-def reverse_input_channels(graph: nx.MultiDiGraph):
+def reverse_input_channels(graph: Graph):
"""
Searchers for all type=Input nodes with 4D output tensors,
tracks tensors down through non-shape-changing ops to the first type=Convolution or other channel-dependent nodes
if conv.op == 'DepthwiseConv2dNative':
log.debug('out nodes: {}'.format(conv.out_node()))
bottoms = conv.out_node().out_nodes()
+ if len(bottoms) == 1 and bottoms[0].op == 'FakeQuantWithMinMaxVars':
+ bottoms = bottoms[0].out_node().out_nodes()
log.debug('bottoms: {}'.format(bottoms))
log.debug('assumed conv: name = {}, op = {}'.format(bottoms[0].name, bottoms[0].op))
if len(bottoms) > 0 and bottoms[0].op == 'Conv2D':
'complete the flip')
conv.in_node(1).value = np.flip(conv.in_node(1).value, conv.in_node(1).input_channel_dim)
+ conv.in_node(1).shape = int64_array(conv.in_node(1).value.shape)
log.debug('Applied reversing input channels for weights of convolution {}'.format(conv.id))
log.debug('Shape was (shape){}, (value.shape){}'.format(conv.in_node(1).shape, conv.in_node(1).value.shape))
log.debug('Flipped dim: {}'.format(conv.in_node(1).input_channel_dim))
-def conv_flatten_concat_action(graph: nx.MultiDiGraph, match: dict):
+def conv_flatten_concat_action(graph: Graph, match: dict):
assert graph.graph['layout'] == 'NHWC'
reshape_node = match['reshape']
reshape_data_node = match['reshape_data']
log.info('There is a FullyConnected layer after the node "{}" which weights will be repacked. So there is no '
'need to insert Permute'.format(reshape_node.soft_get('name')))
return
- assert len(graph.in_edges(reshape_node.id)) == 1
graph.remove_edge(conv_data_node.id, reshape_node.id)
permutation_order = PermuteAttrs.get_nchw_to_nhwc_permutation(len(conv_data_node.shape)).perm
new_permute_op = Permute(graph, {'order': permutation_order})
permute_data_node = new_permute_op.create_node_with_data([conv_data_node], dict(name=conv_name + '/Permute_'))
- create_edge(permute_data_node, reshape_node)
+ graph.create_edge(permute_data_node, reshape_node)
# Disable permutation for Reshape and Concat layers attributes
PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None)
+ reshape_node['nchw_layout'] = True
-def conv_flatten_concat(graph: nx.MultiDiGraph):
+def conv_flatten_concat(graph: Graph):
apply_pattern(
graph,
nodes=[
)
-def fuse_sequence_of_reshapes(graph: nx.MultiDiGraph):
+def fuse_sequence_of_reshapes(graph: Graph):
for node in list(graph.nodes()):
- node = Node(graph, node)
- if not graph.has_node(node.id):
+ if not graph.has_node(node):
# data node can be already removed
continue
+ node = Node(graph, node)
if (
node.has_valid('type') and node.type == 'Reshape' and
len(node.out_nodes()) == 1 and node.out_node().has_valid('kind') and node.out_node().kind == 'data' and
# Remove Reshape1
log.debug('Second phase for Reshape: {}'.format(node.name))
remove_op_node_with_data_node(graph, node)
+
+ reshape_nodes = graph.get_op_nodes(op='Reshape')
+ for reshape_node in reshape_nodes:
+ in_ports = [port for port in reshape_node.in_ports().values() if not port.disconnected()]
+ assert len(in_ports) in [1, 2], "`Reshape` node must have 2 inputs or 1 input with `dim`"
+ if len(in_ports) == 2:
+ previous_dim_op = reshape_node.in_port(1).get_source().node.op
+ if previous_dim_op != 'Const':
+ continue
+ dim = reshape_node.in_port(1).get_connection().data.get_value()
+ else:
+ assert reshape_node.has_valid('dim'), "`Reshape` node with 1 input must have `dim` attribute"
+ dim = reshape_node.dim
+
+ in_shape = reshape_node.in_port(0).get_connection().data.get_shape()
+
+ if np.array_equal(dim, in_shape) and len(reshape_node.out_nodes()):
+ log.debug("Useless reshape with dim {} was deleted: {}".format(str(dim), reshape_node.name))
+ reshape_node.out_port(0).get_connection().set_source(reshape_node.in_port(0).get_source())