Publishing R3
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / fusing / resnet_optimization.py
1 import logging as log
2
3 import networkx as nx
4 import numpy as np
5
6 from mo.front.extractor import add_attrs_props
7 from mo.graph.graph import Node, unique_id, dump_graph_for_graphviz
8 from mo.middle.pattern_match import apply_pattern
9 from mo.front.extractor import update_ie_fields
10 from mo.middle.passes.fusing.helpers import backward_bfs, get_next_operation
11 from mo.utils.graph import pseudo_topological_sort
12 from mo.middle.passes.infer import partial_infer
13 from mo.front.common.partial_infer.pooling import pool_explicit_padding_infer
14 from mo.ops.pooling import Pooling
15
16
17 def _clean_fw_tensor_attrs(node: Node):
18     attrs = ['fw_tensor_debug_info']
19     for attr in attrs:
20         if node.has_valid(attr):
21             node[attr] = None
22
23
24 def _insert_pooling(graph: nx.MultiDiGraph, first_node: Node, second_node: Node, spatial_dims):
25     """
26     This function inserts point wise pooling layer between two nodes
27     """
28     log.debug("STRIDE PROP: Insert pooling between {} and {}".format(first_node.name, second_node.name))
29     stride_prop = second_node.stride_prop
30     assert len(graph.get_edge_data(first_node.id, second_node.id)) == 1
31     eattrs = graph.get_edge_data(first_node.id, second_node.id)[0]
32     graph.remove_edge(first_node.id, second_node.id)
33
34     pooling = Pooling(graph, dict(name='Pooling_', spatial_dims=spatial_dims, window=np.array([1, 1, 1, 1]),
35                                   output_spatial_shape=None,
36                                   stride=np.array(stride_prop), pad_spatial_shape=np.array([[0, 0], [0, 0]]),
37                                   pad=np.array([[0, 0], [0, 0], [0, 0], [0, 0]]), pool_method='avg',
38                                   is_partial_inferred=False))
39     pooling_data = pooling.create_node_with_data([first_node])
40
41     _clean_fw_tensor_attrs(pooling_data)
42
43     graph.add_edges_from([(pooling_data.id, second_node.id, eattrs)])
44
45
46 def _check_next_ops(next_ops: list):
47     """
48     This function checks list of operation to determine that all ops has same (not 1,1,1,1) stride_prop attr
49     """
50     stride_props = []
51     for op in next_ops:
52         if op.has_valid('stride_prop'):
53             stride_props.append(np.array(op.stride_prop))
54         else:
55             continue
56
57     status = not (len(next_ops) != len(stride_props) or (len(stride_props) > 0 and not all(
58         np.array_equal(x, stride_props[0]) and not np.array_equal(x, [1, 1, 1, 1]) for x in stride_props)))
59     return stride_props, status
60
61
62 def _simple_stride_prop(graph: nx.MultiDiGraph, node: Node, spatial_dims, supported=True):
63     """
64     This function handles stride propagation for op nodes. If node is in supported ops dict so this is supported operation and we
65     can propagate stride directly via this op (stride_prop will be set by using bottom stride_prop), otherwise we can't and
66     stride_prop attr will be set as 1,1,1,1
67     """
68     next_ops = get_next_operation(node)
69     stride_props, all_ops_are_valid = _check_next_ops(next_ops)
70
71     if not supported or not all_ops_are_valid:
72         # We have to insert pooling layers
73         for op in next_ops:
74             if op.has_valid('stride_prop') and not np.array_equal(op.stride_prop[spatial_dims], np.array([1, 1])) and \
75                     (op.has_valid('has_stride') == False or op.soft_get('has_stride') == False):
76                 _insert_pooling(graph, node.out_node(), op, spatial_dims)
77         # If Convolution is valid then set `stride_prop` to Convolution stride
78         node['stride_prop'] = np.array([1, 1, 1, 1])
79         return
80
81     for op in next_ops:
82         if op.soft_get('has_stride') == True:
83             op.stride = np.array([1, 1, 1, 1])
84             log.debug("STRIDE PROP: {} {} strides was moved upper via {}".format(op.type, op.name, node.name))
85
86     node['stride_prop'] = np.array(stride_props[0]) if len(stride_props) > 0 else np.array([1, 1, 1, 1])
87     node['is_partial_inferred'] = False
88     _clean_fw_tensor_attrs(node.out_node())
89
90
91 def _conv_stride_prop(graph: nx.MultiDiGraph, node: Node, spatial_dims, supported=True):
92     """
93     This function handles convolution stride propagation. There is two cases: conv->(op) and conv->conv. In first case
94     we propagate stride from op, and in second case we also change stride for second conv
95     """
96     next_ops = get_next_operation(node)
97     stride_props, all_ops_are_valid = _check_next_ops(next_ops)
98
99     def _check_convolution(node: Node):
100         return node.has_valid('kernel_spatial') and np.array_equal(node.kernel_spatial, np.array([1, 1]))
101
102     # Check that all ops are valid and have same values
103     if not all_ops_are_valid:
104         # We have to insert pooling layers
105         for op in next_ops:
106             if op.has_valid('stride_prop') and not np.array_equal(op.stride_prop[spatial_dims], np.array([1, 1])):
107                 # Insert pooling
108                 _insert_pooling(graph, node.out_node(), op, spatial_dims)
109     elif len(stride_props) > 0:
110         node.stride *= stride_props[0]
111         log.debug('STRIDE PROP: {} got new strides {}'.format(node.name, node.stride))
112         for op in next_ops:
113             if op.soft_get('has_stride') == True:
114                 op.stride = np.array([1, 1, 1, 1])
115         node['is_partial_inferred'] = False
116         _clean_fw_tensor_attrs(node.out_node())
117
118     # If Convolution is valid then set `stride_prop` to Convolution stride
119     node['stride_prop'] = np.array(node.stride) if _check_convolution(node) else np.array([1, 1, 1, 1])
120
121
122 supported_ops = {
123     'ReLU': {'stride_prop': _simple_stride_prop, 'attrs': {}},
124     'Eltwise': {'stride_prop': _simple_stride_prop, 'attrs': {}},
125     'Convolution': {'stride_prop': _conv_stride_prop, 'attrs': {'has_stride': True}},
126 }
127
128
129 def _stride_propagation(graph: nx.MultiDiGraph, spatial_dims):
130     """
131     This function do stride propagation for all op nodes
132     """
133     nodes = [Node(graph, x) for x in pseudo_topological_sort(graph, reverse=True) if Node(graph, x).kind == 'op']
134
135     for node in nodes:
136         if node.soft_get('type') in supported_ops:
137             op = supported_ops[node.type]
138             # Add node attrs
139             for key in op['attrs'].keys():
140                 node[key] = op['attrs'][key]
141             op['stride_prop'](graph, node, spatial_dims, True)
142         else:
143             _simple_stride_prop(graph, node, spatial_dims, False)
144
145
146 def stride_optimization(graph: nx.MultiDiGraph):
147     """
148     This is main function for stride optimization pass
149     """
150     layout = graph.graph['layout']
151     if layout == 'NCHW':
152         spatial_dims = np.array([2, 3])
153     elif layout == 'NHWC':
154         spatial_dims = np.array([1, 2])
155     else:
156         log.warning('STRIDE PROP: layout {} is not supported'.format(layout))
157         return
158     _stride_propagation(graph, spatial_dims)
159
160     nodes = [Node(graph, x) for x in pseudo_topological_sort(graph) if
161              Node(graph, x).soft_get('is_partial_inferred') == False]
162     for node in nodes:
163         node.infer(node)