Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / PixelLinkReshape.py
index 9564b5d..9c6cceb 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 """
 
 import logging as log
-import networkx as nx
 import numpy as np
 
 from copy import deepcopy
 
-from extensions.middle.AddReshapeAfterStridedSlice import AddReshapeAfterStridedSlice
+from extensions.middle.ConvertGroupedStridedSlice import ConvertGroupedStridedSlice
 from extensions.middle.FusePermutesSequence import FusePermutesSequence
 from extensions.middle.ShufflenetReshape import ReshapeSoftmaxReshape
+from mo.graph.graph import Graph
 from mo.middle.replacement import MiddleReplacementPattern
 from mo.ops.op import Op
 from mo.ops.permute import Permute
@@ -30,16 +30,17 @@ from mo.ops.permute import Permute
 
 class PixelLinkReshape(MiddleReplacementPattern):
     """
-      Transform adds Permutes around Reshapes that pack 4 dimensions in 2, than 
-      do Softmax and then unpack it back to 5 dims. 
+      Transform adds Permutes around Reshapes that pack 4 dimensions in 2, than
+      do Softmax and then unpack it back to 5 dims.
     """
     enabled = True
 
     def run_before(self):
-        return [FusePermutesSequence, ReshapeSoftmaxReshape, AddReshapeAfterStridedSlice]
+        return [FusePermutesSequence, ReshapeSoftmaxReshape, ConvertGroupedStridedSlice]
 
     def run_after(self):
-        return []
+        from extensions.middle.pass_separator import MiddleStart
+        return [MiddleStart]
 
     def pattern(self):
         return dict(nodes=[('reshape_split', dict(kind='op', type='Reshape')),
@@ -51,7 +52,7 @@ class PixelLinkReshape(MiddleReplacementPattern):
                            ('reshape_unpack', dict(kind='op', type='Reshape')),
                            ('reshape_unpack_data', dict(kind='data')),
                            ('strided_slice', dict(kind='op', op='StridedSlice')),
-                         ],
+                           ],
                     edges=[('reshape_split', 'reshape_split_data'),
                            ('reshape_split_data', 'reshape_pack'),
                            ('reshape_pack', 'reshape_data'),
@@ -84,7 +85,7 @@ class PixelLinkReshape(MiddleReplacementPattern):
         else:
             return False
 
-    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(self, graph: Graph, match: dict):
         if graph.graph['layout'] != 'NHWC':
             return
 
@@ -120,55 +121,72 @@ class PixelLinkReshape(MiddleReplacementPattern):
             attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
             graph.remove_edge(node.id, out_node.id)
 
-            permute_after_node = permute_after.create_node_with_data([data_node], permute_after.attrs,
-                                                                     data_nodes=[out_node])
+            permute_after.create_node_with_data([data_node], permute_after.attrs,
+                                                data_nodes=[out_node])
             graph.add_edge(node.id, data_node.id, **attrs)
 
             # update softmax shape
             node_softmax = match['softmax']
             node_softmax.out_node(0).shape = out_node.shape
 
-            # revert strided slice and reshape
-            node_ss = match['strided_slice']
-            node_unpack = match['reshape_unpack']
-
-            unpack_out = node_unpack.out_node(0).id
-            ss_out = node_ss.out_node(0).id
-
-            #gather edge attributes
-            soft_reshape_attrs = deepcopy(graph.get_edge_data(node_softmax.out_node(0).id, node_unpack.id)[0])
-            reshape_data_attrs = deepcopy(graph.get_edge_data(node_unpack.id, unpack_out)[0])
-            reshape_ss_attrs = deepcopy(graph.get_edge_data(unpack_out, node_ss.id)[0])
-            ss_out_attrs = deepcopy(graph.get_edge_data(node_ss.id, ss_out)[0])
-
-            #remove all edges in Softmax->Reshape->StridedSlice chain
-            graph.remove_edge(node_softmax.out_node(0).id, node_unpack.id)
-            graph.remove_edge(node_unpack.id, unpack_out)
-            graph.remove_edge(unpack_out, node_ss.id)
-            graph.remove_edge(node_ss.id, ss_out)
-
-            #add new edges to get chain Softmax->StridedSlice->Reshape
-            graph.add_edge(node_softmax.out_node(0).id, node_ss.id, **soft_reshape_attrs)
-            graph.add_edge(node_ss.id, unpack_out, **reshape_data_attrs)
-            graph.add_edge(unpack_out, node_unpack.id, **reshape_ss_attrs)
-            graph.add_edge(node_unpack.id, ss_out, **ss_out_attrs)
-
-            #update output shape and parameters for StridedSlice
-            node_ss.out_node(0).shape = np.zeros(3)
-            node_ss.out_node(0).shape[0] = out_node.shape[0]
-            node_ss.out_node(0).shape[1] = 1
-            node_ss.out_node(0).shape[2] = out_node.shape[2]
-
-            old_slices = node_ss.slices.copy()
-            node_ss.slices = []
-            node_ss.slices.append(old_slices[0])
-            node_ss.slices.append(old_slices[-1])
-            node_ss.slices.append(slice(0, out_node.shape[2], 1))
-            node_ss.shrink_axis_mask = [False, False, False]
-            node_ss.new_axis_mask = [False, False, False]
-
-            #update Reshape attribute
-            node_unpack.dim = np.delete(node_unpack.dim, 4)
-            #prevent permute for reshape because it gives wrong result
-            node_unpack['nchw_layout'] = True
-            node_unpack.out_node(0)['nchw_layout'] = True
+            if ConvertGroupedStridedSlice.enabled is True:
+                # revert strided slice and reshape
+                node_ss = match['strided_slice']
+                node_unpack = match['reshape_unpack']
+
+                unpack_out = node_unpack.out_node(0).id
+                ss_out = node_ss.out_node(0).id
+
+                # gather edge attributes
+                soft_reshape_attrs = deepcopy(graph.get_edge_data(node_softmax.out_node(0).id, node_unpack.id)[0])
+                reshape_data_attrs = deepcopy(graph.get_edge_data(node_unpack.id, unpack_out)[0])
+                reshape_ss_attrs = deepcopy(graph.get_edge_data(unpack_out, node_ss.id)[0])
+                ss_out_attrs = deepcopy(graph.get_edge_data(node_ss.id, ss_out)[0])
+
+                # remove all edges in Softmax->Reshape->StridedSlice chain
+                graph.remove_edge(node_softmax.out_node(0).id, node_unpack.id)
+                graph.remove_edge(node_unpack.id, unpack_out)
+                graph.remove_edge(unpack_out, node_ss.id)
+                graph.remove_edge(node_ss.id, ss_out)
+
+                # add new edges to get chain Softmax->StridedSlice->Reshape
+                graph.add_edge(node_softmax.out_node(0).id, node_ss.id, **soft_reshape_attrs)
+                graph.add_edge(node_ss.id, unpack_out, **reshape_data_attrs)
+                graph.add_edge(unpack_out, node_unpack.id, **reshape_ss_attrs)
+                graph.add_edge(node_unpack.id, ss_out, **ss_out_attrs)
+
+                # update output shape and parameters for StridedSlice
+                node_ss.out_node(0).shape = np.zeros(3)
+                node_ss.out_node(0).shape[0] = out_node.shape[0]
+                node_ss.out_node(0).shape[1] = 1
+                node_ss.out_node(0).shape[2] = out_node.shape[2]
+
+                old_slices = node_ss.slices.copy()
+                node_ss.slices = []
+                node_ss.slices.append(old_slices[0])
+                node_ss.slices.append(old_slices[-1])
+                node_ss.slices.append(slice(0, out_node.shape[2], 1))
+                node_ss.shrink_axis_mask = np.array([0, 0, 0], dtype=np.int64)
+                node_ss.new_axis_mask = np.array([0, 0, 0], dtype=np.int64)
+                node_ss.ellipsis_mask = np.array([0, 0, 0], dtype=np.int64)
+                node_ss.begin_mask = np.array([0, 1, 0], dtype=np.int64)
+                node_ss.end_mask = np.array([0, 1, 0], dtype=np.int64)
+
+                # update Reshape attribute
+                node_unpack.dim = np.delete(node_unpack.dim, 4)
+                # prevent permute for reshape because it gives wrong result
+                node_unpack['nchw_layout'] = True
+                node_unpack.out_node(0)['nchw_layout'] = True
+            else:
+                # reshape unpack: permute correctly
+                node_unpack = match['reshape_unpack']
+                data_node = Op._create_data_node(graph, node.name + "/Permute_after_unpack_data", {'shape': node_unpack.out_node().shape})
+                permute_after_unpack = Permute(graph, dict(name=node.name + "/Permute_after_unpack",
+                                                           order=np.array([0, 3, 1, 2, 4])))
+                out_node = node_unpack.out_node(0)
+                out_node.shape = out_node.shape[np.array([0, 3, 1, 2, 4], dtype=np.int)]
+                attrs = deepcopy(graph.get_edge_data(node_unpack.id, out_node.id)[0])
+                graph.remove_edge(node_unpack.id, out_node.id)
+                permute_after.create_node_with_data([data_node], permute_after_unpack.attrs,
+                                                    data_nodes=[out_node])
+                graph.add_edge(node_unpack.id, data_node.id, **attrs)