Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / shape.py
index 647502b..e98a2ac 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 
 import logging as log
 
-import networkx as nx
 import numpy as np
 
+from mo.front.common.partial_infer.utils import int64_array
 from mo.front.extractor import update_attrs
-from mo.graph.graph import Node, create_edge
-from mo.middle.passes.eliminate import remove_op_node_with_data_node, merge_data_nodes, graph_clean_up_tf, get_nodes_with_attributes
+from mo.graph.graph import Node, Graph
+from mo.middle.passes.eliminate import remove_op_node_with_data_node, merge_data_nodes, graph_clean_up_tf
 from mo.middle.passes.fusing.helpers import get_next_operation
 from mo.middle.pattern_match import apply_pattern
 from mo.ops.op import PermuteAttrs, Op
@@ -30,7 +30,7 @@ from mo.utils.error import Error
 from mo.utils.utils import refer_to_faq_msg
 
 
-def reshape_squeeze_transform(graph: nx.MultiDiGraph, match: dict):
+def reshape_squeeze_transform(graph: Graph, match: dict):
     reshape = match['reshape']
     output = match['output']
     if output.shape is None:
@@ -42,11 +42,9 @@ def reshape_squeeze_transform(graph: nx.MultiDiGraph, match: dict):
         # do not override value 'dim' if it is set. It may contain specific values like -1 and 0
         reshape['dim'] = reshape.shape.copy()
     update_attrs(reshape, 'shape_attrs', 'dim')
-    if 'shape' in match:
-        graph.remove_edge(match['shape'].node, match['reshape'].node)
 
 
-def convert_squeeze(graph: nx.MultiDiGraph):
+def convert_squeeze(graph: Graph):
     apply_pattern(
         graph,
         nodes=[
@@ -57,7 +55,7 @@ def convert_squeeze(graph: nx.MultiDiGraph):
     )
 
 
-def convert_reshape(graph: nx.MultiDiGraph):
+def convert_reshape(graph: Graph):
     apply_pattern(
         graph,
         nodes=[
@@ -107,12 +105,12 @@ def can_repack_fully_connected_weights_nhwc_to_nchw(fc_node: Node):
         return False
 
 
-def repack_fully_connected_weights_nhwc_to_nchw(graph: nx.MultiDiGraph):
+def repack_fully_connected_weights_nhwc_to_nchw(graph: Graph):
     """
     Repack weights of FullyConnected layer as a part of nhwc_to_nchw translation if Reshape of
     that involves dimensions that we are repacking appears right before FullyConnected layer.
     """
-    for node_id in get_nodes_with_attributes(graph, type='FullyConnected'):
+    for node_id in graph.get_nodes_with_attributes(type='FullyConnected'):
         fc_node = Node(graph, node_id)
 
         if not can_repack_fully_connected_weights_nhwc_to_nchw(fc_node):
@@ -146,7 +144,7 @@ def repack_fully_connected_weights_nhwc_to_nchw(graph: nx.MultiDiGraph):
         weights.value = np.transpose(weights.value.reshape(tmp_shape), (2, 0, 1, 3)).reshape(weights.shape)
 
 
-def apply_nhwc_to_nchw_permutation(graph: nx.MultiDiGraph):
+def apply_nhwc_to_nchw_permutation(graph: Graph):
     # Add NHWC to NCHW permutation for all data nodes (only for nodes without permutation)
     if graph.graph['layout'] == 'NCHW':
         return
@@ -181,7 +179,7 @@ def apply_nhwc_to_nchw_permutation(graph: nx.MultiDiGraph):
                 PermuteAttrs.set_permutation(node, out_node, permutation)
 
 
-def merge_nodes_permutations(graph: nx.MultiDiGraph):
+def merge_nodes_permutations(graph: Graph):
     # Iterate over all data nodes and check all permutations for similarity
     # In case of equal permutations, this permutation will be set as attribute for data node
     # otherwise exception will be raised
@@ -228,7 +226,7 @@ def merge_nodes_permutations(graph: nx.MultiDiGraph):
             node.permutation = None
 
 
-def permute_data_nodes_attrs(graph: nx.MultiDiGraph):
+def permute_data_nodes_attrs(graph: Graph):
     # Iterate over all data nodes and apply permutation if exists
     for node in graph.nodes():
         node = Node(graph, node)
@@ -245,7 +243,7 @@ def permute_data_nodes_attrs(graph: nx.MultiDiGraph):
             node.value = np.array(node.value.transpose(node.permutation.perm))
 
 
-def permute_op_nodes_attrs(graph: nx.MultiDiGraph):
+def permute_op_nodes_attrs(graph: Graph):
     for node in graph.nodes():
         node = Node(graph, node)
         if node.kind == 'op' and node.has_valid('permute_attrs'):
@@ -255,7 +253,7 @@ def permute_op_nodes_attrs(graph: nx.MultiDiGraph):
                 raise Error('Can\'t permute attrs for node {}. Error message: {}'.format(node.id, e))
 
 
-def reverse_input_channels(graph: nx.MultiDiGraph):
+def reverse_input_channels(graph: Graph):
     """
     Searchers for all type=Input nodes with 4D output tensors,
     tracks tensors down through non-shape-changing ops to the first type=Convolution or other channel-dependent nodes
@@ -311,6 +309,8 @@ def reverse_input_channels(graph: nx.MultiDiGraph):
         if conv.op == 'DepthwiseConv2dNative':
             log.debug('out nodes: {}'.format(conv.out_node()))
             bottoms = conv.out_node().out_nodes()
+            if len(bottoms) == 1 and bottoms[0].op == 'FakeQuantWithMinMaxVars':
+                bottoms = bottoms[0].out_node().out_nodes()
             log.debug('bottoms: {}'.format(bottoms))
             log.debug('assumed conv: name = {}, op = {}'.format(bottoms[0].name, bottoms[0].op))
             if len(bottoms) > 0 and bottoms[0].op == 'Conv2D':
@@ -349,12 +349,13 @@ def reverse_input_channels(graph: nx.MultiDiGraph):
                     'complete the flip')
 
         conv.in_node(1).value = np.flip(conv.in_node(1).value, conv.in_node(1).input_channel_dim)
+        conv.in_node(1).shape = int64_array(conv.in_node(1).value.shape)
         log.debug('Applied reversing input channels for weights of convolution {}'.format(conv.id))
         log.debug('Shape was (shape){}, (value.shape){}'.format(conv.in_node(1).shape, conv.in_node(1).value.shape))
         log.debug('Flipped dim: {}'.format(conv.in_node(1).input_channel_dim))
 
 
-def conv_flatten_concat_action(graph: nx.MultiDiGraph, match: dict):
+def conv_flatten_concat_action(graph: Graph, match: dict):
     assert graph.graph['layout'] == 'NHWC'
     reshape_node = match['reshape']
     reshape_data_node = match['reshape_data']
@@ -370,18 +371,18 @@ def conv_flatten_concat_action(graph: nx.MultiDiGraph, match: dict):
         log.info('There is a FullyConnected layer after the node "{}" which weights will be repacked. So there is no '
                  'need to insert Permute'.format(reshape_node.soft_get('name')))
         return
-    assert len(graph.in_edges(reshape_node.id)) == 1
     graph.remove_edge(conv_data_node.id, reshape_node.id)
 
     permutation_order = PermuteAttrs.get_nchw_to_nhwc_permutation(len(conv_data_node.shape)).perm
     new_permute_op = Permute(graph, {'order': permutation_order})
     permute_data_node = new_permute_op.create_node_with_data([conv_data_node], dict(name=conv_name + '/Permute_'))
-    create_edge(permute_data_node, reshape_node)
+    graph.create_edge(permute_data_node, reshape_node)
     # Disable permutation for Reshape and Concat layers attributes
     PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None)
+    reshape_node['nchw_layout'] = True
 
 
-def conv_flatten_concat(graph: nx.MultiDiGraph):
+def conv_flatten_concat(graph: Graph):
     apply_pattern(
         graph,
         nodes=[
@@ -419,12 +420,12 @@ def conv_flatten_concat(graph: nx.MultiDiGraph):
     )
 
 
-def fuse_sequence_of_reshapes(graph: nx.MultiDiGraph):
+def fuse_sequence_of_reshapes(graph: Graph):
     for node in list(graph.nodes()):
-        node = Node(graph, node)
-        if not graph.has_node(node.id):
+        if not graph.has_node(node):
             # data node can be already removed
             continue
+        node = Node(graph, node)
         if (
                 node.has_valid('type') and node.type == 'Reshape' and
                 len(node.out_nodes()) == 1 and node.out_node().has_valid('kind') and node.out_node().kind == 'data' and
@@ -439,3 +440,22 @@ def fuse_sequence_of_reshapes(graph: nx.MultiDiGraph):
                 # Remove Reshape1
                 log.debug('Second phase for Reshape: {}'.format(node.name))
                 remove_op_node_with_data_node(graph, node)
+
+    reshape_nodes = graph.get_op_nodes(op='Reshape')
+    for reshape_node in reshape_nodes:
+        in_ports = [port for port in reshape_node.in_ports().values() if not port.disconnected()]
+        assert len(in_ports) in [1, 2], "`Reshape` node must have 2 inputs or 1 input with `dim`"
+        if len(in_ports) == 2:
+            previous_dim_op = reshape_node.in_port(1).get_source().node.op
+            if previous_dim_op != 'Const':
+                continue
+            dim = reshape_node.in_port(1).get_connection().data.get_value()
+        else:
+            assert reshape_node.has_valid('dim'), "`Reshape` node with 1 input must have `dim` attribute"
+            dim = reshape_node.dim
+
+        in_shape = reshape_node.in_port(0).get_connection().data.get_shape()
+
+        if np.array_equal(dim, in_shape) and len(reshape_node.out_nodes()):
+            log.debug("Useless reshape with dim {} was deleted: {}".format(str(dim), reshape_node.name))
+            reshape_node.out_port(0).get_connection().set_source(reshape_node.in_port(0).get_source())