Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / DepthToSpace.py
index 1e05c8a..6470b23 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  See the License for the specific language governing permissions and
  limitations under the License.
 """
-
-import logging as log
-from copy import deepcopy
-
-import networkx as nx
-
+from mo.front.common.partial_infer.utils import int64_array
+from mo.graph.graph import Graph
 from mo.middle.replacement import MiddleReplacementPattern
 from mo.ops.permute import Permute
 from mo.ops.reshape import Reshape
@@ -31,6 +27,14 @@ class DepthToSpace(MiddleReplacementPattern):
 
     enabled = True
 
+    def run_after(self):
+        from extensions.middle.pass_separator import MiddleStart
+        return [MiddleStart]
+
+    def run_before(self):
+        from extensions.middle.pass_separator import MiddleFinish
+        return [MiddleFinish]
+
     def pattern(self):
         return dict(
             nodes=[
@@ -43,7 +47,7 @@ class DepthToSpace(MiddleReplacementPattern):
                 ('op', 'out_data')
             ])
 
-    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(self, graph: Graph, match: dict):
         node = match['op']
 
         N, H, W, C = match['in_data'].shape
@@ -52,13 +56,17 @@ class DepthToSpace(MiddleReplacementPattern):
         graph.remove_edge(match['in_data'].id, match['op'].id)
         graph.remove_edge(match['op'].id, match['out_data'].id)
 
-        dim_6D = [N, block_size, block_size, int(C / (block_size ** 2)), H, W]
-        order_6D = [0, 3, 4, 1, 5, 2]
-        dim_4D = [N, int(H * block_size), int(W * block_size), int(C / (block_size ** 2))]
-
-        reshape_data_node = Reshape(graph=graph, attrs={'name': match['op'].id + '/Reshape_to_6D', 'dim': dim_6D}).create_node_with_data([match['in_data']])
-        permute_data_node = Permute(graph=graph, attrs={'name': match['op'].id + '/Permute', 'order': order_6D}).create_node_with_data([reshape_data_node])
-        reshape_node = Reshape(graph=graph, attrs={'infer': None, 'name': match['op'].id + '/Reshape_to_4D', 'dim': dim_4D}).create_node_with_data([permute_data_node], data_nodes=[match['out_data']])
+        dim_6D = int64_array([N, block_size, block_size, int(C / (block_size ** 2)), H, W])
+        order_6D = int64_array([0, 3, 4, 1, 5, 2])
+        dim_4D = int64_array([N, int(H * block_size), int(W * block_size), int(C / (block_size ** 2))])
+
+        reshape_data_node = Reshape(graph=graph, attrs={'name': match['op'].id + '/Reshape_to_6D',
+                                                        'dim': dim_6D}).create_node_with_data([match['in_data']])
+        permute_data_node = Permute(graph=graph, attrs={'name': match['op'].id + '/Permute',
+                                                        'order': order_6D}).create_node_with_data([reshape_data_node])
+        reshape_node = Reshape(graph=graph, attrs={'name': match['op'].id + '/Reshape_to_4D',
+                                                   'dim': dim_4D}).create_node_with_data([permute_data_node],
+                                                                                         data_nodes=[match['out_data']])
 
         reshape_data_node.in_node()['nchw_layout'] = True
         reshape_data_node['nchw_layout'] = True