2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.graph.graph import Node, Graph
22 from mo.middle.passes.fusing.helpers import get_next_operation
23 from mo.ops.pooling import Pooling
24 from mo.utils.graph import pseudo_topological_sort
27 def _clean_fw_tensor_attrs(node: Node):
28 attrs = ['fw_tensor_debug_info']
30 if node.has_valid(attr):
34 def _insert_pooling(graph: Graph, first_node: Node, second_node: Node, spatial_dims):
36 This function inserts point wise pooling layer between two nodes
38 log.debug("STRIDE PROP: Insert pooling between {} and {}".format(first_node.name, second_node.name))
39 stride_prop = second_node.stride_prop
40 assert len(graph.get_edge_data(first_node.id, second_node.id)) == 1
41 eattrs = graph.get_edge_data(first_node.id, second_node.id)[0]
42 graph.remove_edge(first_node.id, second_node.id)
44 pooling = Pooling(graph, dict(name='Pooling_', spatial_dims=spatial_dims, window=np.array([1, 1, 1, 1]),
45 output_spatial_shape=None,
46 stride=np.array(stride_prop), pad_spatial_shape=np.array([[0, 0], [0, 0]]),
47 pad=np.array([[0, 0], [0, 0], [0, 0], [0, 0]]), pool_method='max',
48 is_partial_inferred=False))
49 pooling_data = pooling.create_node_with_data([first_node])
51 _clean_fw_tensor_attrs(pooling_data)
53 graph.add_edges_from([(pooling_data.id, second_node.id, eattrs)])
56 def _check_next_ops(next_ops: list):
58 This function checks list of operation to determine that all ops has same (not 1,1,1,1) stride_prop attr
62 if op.has_valid('stride_prop'):
63 stride_props.append(np.array(op.stride_prop))
67 status = not (len(next_ops) != len(stride_props) or (len(stride_props) > 0 and not all(
68 np.array_equal(x, stride_props[0]) and not np.array_equal(x, [1, 1, 1, 1]) for x in stride_props)))
69 return stride_props, status
72 def _simple_stride_prop(graph: Graph, node: Node, spatial_dims, supported=True):
74 This function handles stride propagation for op nodes. If node is in supported ops dict so this is supported operation and we
75 can propagate stride directly via this op (stride_prop will be set by using bottom stride_prop), otherwise we can't and
76 stride_prop attr will be set as 1,1,1,1
78 next_ops = get_next_operation(node)
79 stride_props, all_ops_are_valid = _check_next_ops(next_ops)
81 if not supported or not all_ops_are_valid:
82 # We have to insert pooling layers
84 if op.has_valid('stride_prop') and not np.array_equal(op.stride_prop[spatial_dims], np.array([1, 1])) and \
85 (op.has_valid('has_stride') == False or op.soft_get('has_stride') == False):
86 _insert_pooling(graph, node.out_node(), op, spatial_dims)
87 # If Convolution is valid then set `stride_prop` to Convolution stride
88 node['stride_prop'] = np.array([1, 1, 1, 1])
92 if op.soft_get('has_stride') == True:
93 op.stride = np.array([1, 1, 1, 1])
94 log.debug("STRIDE PROP: {} {} strides was moved upper via {}".format(op.type, op.name, node.name))
96 node['stride_prop'] = np.array(stride_props[0]) if len(stride_props) > 0 else np.array([1, 1, 1, 1])
97 node['is_partial_inferred'] = False
98 _clean_fw_tensor_attrs(node.out_node())
101 def _conv_stride_prop(graph: Graph, node: Node, spatial_dims, supported=True):
103 This function handles convolution stride propagation. There is two cases: conv->(op) and conv->conv. In first case
104 we propagate stride from op, and in second case we also change stride for second conv
106 next_ops = get_next_operation(node)
107 stride_props, all_ops_are_valid = _check_next_ops(next_ops)
109 def _check_convolution(node: Node):
110 return node.has_valid('kernel_spatial') and np.array_equal(node.kernel_spatial, np.array([1, 1]))
112 # Check that all ops are valid and have same values
113 if not all_ops_are_valid:
114 # We have to insert pooling layers
116 if op.has_valid('stride_prop') and not np.array_equal(op.stride_prop[spatial_dims], np.array([1, 1])):
118 _insert_pooling(graph, node.out_node(), op, spatial_dims)
119 elif len(stride_props) > 0:
120 node.stride *= stride_props[0]
121 log.debug('STRIDE PROP: {} got new strides {}'.format(node.name, node.stride))
123 if op.soft_get('has_stride') == True:
124 op.stride = np.array([1, 1, 1, 1])
125 node['is_partial_inferred'] = False
126 node['output_spatial_shape'] = False
127 _clean_fw_tensor_attrs(node.out_node())
129 # If Convolution is valid then set `stride_prop` to Convolution stride
130 node['stride_prop'] = np.array(node.stride) if _check_convolution(node) else np.array([1, 1, 1, 1])
134 'ReLU': {'stride_prop': _simple_stride_prop, 'attrs': {}},
135 'Eltwise': {'stride_prop': _simple_stride_prop, 'attrs': {}},
136 'Convolution': {'stride_prop': _conv_stride_prop, 'attrs': {'has_stride': True}},
140 def _stride_propagation(graph: Graph, spatial_dims):
142 This function do stride propagation for all op nodes
144 nodes = [Node(graph, x) for x in pseudo_topological_sort(graph, reverse=True) if
145 Node(graph, x).kind == 'op' and Node(graph, x).soft_get('type') != 'Const']
148 if node.soft_get('type') in supported_ops:
149 op = supported_ops[node.type]
151 for key in op['attrs'].keys():
152 node[key] = op['attrs'][key]
153 op['stride_prop'](graph, node, spatial_dims, True)
155 _simple_stride_prop(graph, node, spatial_dims, False)
158 def stride_optimization(graph: Graph):
160 This is main function for stride optimization pass
162 layout = graph.graph['layout']
164 spatial_dims = np.array([2, 3])
165 elif layout == 'NHWC':
166 spatial_dims = np.array([1, 2])
168 log.warning('STRIDE PROP: layout {} is not supported'.format(layout))
170 _stride_propagation(graph, spatial_dims)
172 nodes = [Node(graph, x) for x in pseudo_topological_sort(graph) if
173 Node(graph, x).soft_get('is_partial_inferred') == False]