Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / fusing / resnet_optimization.py
index 8e6481a..6f78a39 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 
 import logging as log
 
-import networkx as nx
 import numpy as np
 
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
 from mo.middle.passes.fusing.helpers import get_next_operation
 from mo.ops.pooling import Pooling
 from mo.utils.graph import pseudo_topological_sort
@@ -32,7 +31,7 @@ def _clean_fw_tensor_attrs(node: Node):
             node[attr] = None
 
 
-def _insert_pooling(graph: nx.MultiDiGraph, first_node: Node, second_node: Node, spatial_dims):
+def _insert_pooling(graph: Graph, first_node: Node, second_node: Node, spatial_dims):
     """
     This function inserts point wise pooling layer between two nodes
     """
@@ -70,7 +69,7 @@ def _check_next_ops(next_ops: list):
     return stride_props, status
 
 
-def _simple_stride_prop(graph: nx.MultiDiGraph, node: Node, spatial_dims, supported=True):
+def _simple_stride_prop(graph: Graph, node: Node, spatial_dims, supported=True):
     """
     This function handles stride propagation for op nodes. If node is in supported ops dict so this is supported operation and we
     can propagate stride directly via this op (stride_prop will be set by using bottom stride_prop), otherwise we can't and
@@ -99,7 +98,7 @@ def _simple_stride_prop(graph: nx.MultiDiGraph, node: Node, spatial_dims, suppor
     _clean_fw_tensor_attrs(node.out_node())
 
 
-def _conv_stride_prop(graph: nx.MultiDiGraph, node: Node, spatial_dims, supported=True):
+def _conv_stride_prop(graph: Graph, node: Node, spatial_dims, supported=True):
     """
     This function handles convolution stride propagation. There is two cases: conv->(op) and conv->conv. In first case
     we propagate stride from op, and in second case we also change stride for second conv
@@ -138,11 +137,12 @@ supported_ops = {
 }
 
 
-def _stride_propagation(graph: nx.MultiDiGraph, spatial_dims):
+def _stride_propagation(graph: Graph, spatial_dims):
     """
     This function do stride propagation for all op nodes
     """
-    nodes = [Node(graph, x) for x in pseudo_topological_sort(graph, reverse=True) if Node(graph, x).kind == 'op']
+    nodes = [Node(graph, x) for x in pseudo_topological_sort(graph, reverse=True) if
+             Node(graph, x).kind == 'op' and Node(graph, x).soft_get('type') != 'Const']
 
     for node in nodes:
         if node.soft_get('type') in supported_ops:
@@ -155,7 +155,7 @@ def _stride_propagation(graph: nx.MultiDiGraph, spatial_dims):
             _simple_stride_prop(graph, node, spatial_dims, False)
 
 
-def stride_optimization(graph: nx.MultiDiGraph):
+def stride_optimization(graph: Graph):
     """
     This is main function for stride optimization pass
     """