6 from mo.front.extractor import add_attrs_props
7 from mo.graph.graph import Node, unique_id, dump_graph_for_graphviz
8 from mo.middle.pattern_match import apply_pattern
9 from mo.front.extractor import update_ie_fields
10 from mo.middle.passes.fusing.helpers import backward_bfs, get_next_operation
11 from mo.utils.graph import pseudo_topological_sort
12 from mo.middle.passes.infer import partial_infer
13 from mo.front.common.partial_infer.pooling import pool_explicit_padding_infer
14 from mo.ops.pooling import Pooling
17 def _clean_fw_tensor_attrs(node: Node):
18 attrs = ['fw_tensor_debug_info']
20 if node.has_valid(attr):
24 def _insert_pooling(graph: nx.MultiDiGraph, first_node: Node, second_node: Node, spatial_dims):
26 This function inserts point wise pooling layer between two nodes
28 log.debug("STRIDE PROP: Insert pooling between {} and {}".format(first_node.name, second_node.name))
29 stride_prop = second_node.stride_prop
30 assert len(graph.get_edge_data(first_node.id, second_node.id)) == 1
31 eattrs = graph.get_edge_data(first_node.id, second_node.id)[0]
32 graph.remove_edge(first_node.id, second_node.id)
34 pooling = Pooling(graph, dict(name='Pooling_', spatial_dims=spatial_dims, window=np.array([1, 1, 1, 1]),
35 output_spatial_shape=None,
36 stride=np.array(stride_prop), pad_spatial_shape=np.array([[0, 0], [0, 0]]),
37 pad=np.array([[0, 0], [0, 0], [0, 0], [0, 0]]), pool_method='avg',
38 is_partial_inferred=False))
39 pooling_data = pooling.create_node_with_data([first_node])
41 _clean_fw_tensor_attrs(pooling_data)
43 graph.add_edges_from([(pooling_data.id, second_node.id, eattrs)])
46 def _check_next_ops(next_ops: list):
48 This function checks list of operation to determine that all ops has same (not 1,1,1,1) stride_prop attr
52 if op.has_valid('stride_prop'):
53 stride_props.append(np.array(op.stride_prop))
57 status = not (len(next_ops) != len(stride_props) or (len(stride_props) > 0 and not all(
58 np.array_equal(x, stride_props[0]) and not np.array_equal(x, [1, 1, 1, 1]) for x in stride_props)))
59 return stride_props, status
62 def _simple_stride_prop(graph: nx.MultiDiGraph, node: Node, spatial_dims, supported=True):
64 This function handles stride propagation for op nodes. If node is in supported ops dict so this is supported operation and we
65 can propagate stride directly via this op (stride_prop will be set by using bottom stride_prop), otherwise we can't and
66 stride_prop attr will be set as 1,1,1,1
68 next_ops = get_next_operation(node)
69 stride_props, all_ops_are_valid = _check_next_ops(next_ops)
71 if not supported or not all_ops_are_valid:
72 # We have to insert pooling layers
74 if op.has_valid('stride_prop') and not np.array_equal(op.stride_prop[spatial_dims], np.array([1, 1])) and \
75 (op.has_valid('has_stride') == False or op.soft_get('has_stride') == False):
76 _insert_pooling(graph, node.out_node(), op, spatial_dims)
77 # If Convolution is valid then set `stride_prop` to Convolution stride
78 node['stride_prop'] = np.array([1, 1, 1, 1])
82 if op.soft_get('has_stride') == True:
83 op.stride = np.array([1, 1, 1, 1])
84 log.debug("STRIDE PROP: {} {} strides was moved upper via {}".format(op.type, op.name, node.name))
86 node['stride_prop'] = np.array(stride_props[0]) if len(stride_props) > 0 else np.array([1, 1, 1, 1])
87 node['is_partial_inferred'] = False
88 _clean_fw_tensor_attrs(node.out_node())
91 def _conv_stride_prop(graph: nx.MultiDiGraph, node: Node, spatial_dims, supported=True):
93 This function handles convolution stride propagation. There is two cases: conv->(op) and conv->conv. In first case
94 we propagate stride from op, and in second case we also change stride for second conv
96 next_ops = get_next_operation(node)
97 stride_props, all_ops_are_valid = _check_next_ops(next_ops)
99 def _check_convolution(node: Node):
100 return node.has_valid('kernel_spatial') and np.array_equal(node.kernel_spatial, np.array([1, 1]))
102 # Check that all ops are valid and have same values
103 if not all_ops_are_valid:
104 # We have to insert pooling layers
106 if op.has_valid('stride_prop') and not np.array_equal(op.stride_prop[spatial_dims], np.array([1, 1])):
108 _insert_pooling(graph, node.out_node(), op, spatial_dims)
109 elif len(stride_props) > 0:
110 node.stride *= stride_props[0]
111 log.debug('STRIDE PROP: {} got new strides {}'.format(node.name, node.stride))
113 if op.soft_get('has_stride') == True:
114 op.stride = np.array([1, 1, 1, 1])
115 node['is_partial_inferred'] = False
116 _clean_fw_tensor_attrs(node.out_node())
118 # If Convolution is valid then set `stride_prop` to Convolution stride
119 node['stride_prop'] = np.array(node.stride) if _check_convolution(node) else np.array([1, 1, 1, 1])
123 'ReLU': {'stride_prop': _simple_stride_prop, 'attrs': {}},
124 'Eltwise': {'stride_prop': _simple_stride_prop, 'attrs': {}},
125 'Convolution': {'stride_prop': _conv_stride_prop, 'attrs': {'has_stride': True}},
129 def _stride_propagation(graph: nx.MultiDiGraph, spatial_dims):
131 This function do stride propagation for all op nodes
133 nodes = [Node(graph, x) for x in pseudo_topological_sort(graph, reverse=True) if Node(graph, x).kind == 'op']
136 if node.soft_get('type') in supported_ops:
137 op = supported_ops[node.type]
139 for key in op['attrs'].keys():
140 node[key] = op['attrs'][key]
141 op['stride_prop'](graph, node, spatial_dims, True)
143 _simple_stride_prop(graph, node, spatial_dims, False)
146 def stride_optimization(graph: nx.MultiDiGraph):
148 This is main function for stride optimization pass
150 layout = graph.graph['layout']
152 spatial_dims = np.array([2, 3])
153 elif layout == 'NHWC':
154 spatial_dims = np.array([1, 2])
156 log.warning('STRIDE PROP: layout {} is not supported'.format(layout))
158 _stride_propagation(graph, spatial_dims)
160 nodes = [Node(graph, x) for x in pseudo_topological_sort(graph) if
161 Node(graph, x).soft_get('is_partial_inferred') == False]