"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import logging as log
-import networkx as nx
import numpy as np
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
from mo.middle.passes.fusing.helpers import get_next_operation
from mo.ops.pooling import Pooling
from mo.utils.graph import pseudo_topological_sort
node[attr] = None
-def _insert_pooling(graph: nx.MultiDiGraph, first_node: Node, second_node: Node, spatial_dims):
+def _insert_pooling(graph: Graph, first_node: Node, second_node: Node, spatial_dims):
"""
This function inserts point wise pooling layer between two nodes
"""
return stride_props, status
-def _simple_stride_prop(graph: nx.MultiDiGraph, node: Node, spatial_dims, supported=True):
+def _simple_stride_prop(graph: Graph, node: Node, spatial_dims, supported=True):
"""
This function handles stride propagation for op nodes. If node is in supported ops dict so this is supported operation and we
can propagate stride directly via this op (stride_prop will be set by using bottom stride_prop), otherwise we can't and
_clean_fw_tensor_attrs(node.out_node())
-def _conv_stride_prop(graph: nx.MultiDiGraph, node: Node, spatial_dims, supported=True):
+def _conv_stride_prop(graph: Graph, node: Node, spatial_dims, supported=True):
"""
This function handles convolution stride propagation. There is two cases: conv->(op) and conv->conv. In first case
we propagate stride from op, and in second case we also change stride for second conv
}
-def _stride_propagation(graph: nx.MultiDiGraph, spatial_dims):
+def _stride_propagation(graph: Graph, spatial_dims):
"""
This function do stride propagation for all op nodes
"""
- nodes = [Node(graph, x) for x in pseudo_topological_sort(graph, reverse=True) if Node(graph, x).kind == 'op']
+ nodes = [Node(graph, x) for x in pseudo_topological_sort(graph, reverse=True) if
+ Node(graph, x).kind == 'op' and Node(graph, x).soft_get('type') != 'Const']
for node in nodes:
if node.soft_get('type') in supported_ops:
_simple_stride_prop(graph, node, spatial_dims, False)
-def stride_optimization(graph: nx.MultiDiGraph):
+def stride_optimization(graph: Graph):
"""
This is main function for stride optimization pass
"""