Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / fusing / resnet_optimization.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import numpy as np
20
21 from mo.graph.graph import Node, Graph
22 from mo.middle.passes.fusing.helpers import get_next_operation
23 from mo.ops.pooling import Pooling
24 from mo.utils.graph import pseudo_topological_sort
25
26
27 def _clean_fw_tensor_attrs(node: Node):
28     attrs = ['fw_tensor_debug_info']
29     for attr in attrs:
30         if node.has_valid(attr):
31             node[attr] = None
32
33
34 def _insert_pooling(graph: Graph, first_node: Node, second_node: Node, spatial_dims):
35     """
36     This function inserts point wise pooling layer between two nodes
37     """
38     log.debug("STRIDE PROP: Insert pooling between {} and {}".format(first_node.name, second_node.name))
39     stride_prop = second_node.stride_prop
40     assert len(graph.get_edge_data(first_node.id, second_node.id)) == 1
41     eattrs = graph.get_edge_data(first_node.id, second_node.id)[0]
42     graph.remove_edge(first_node.id, second_node.id)
43
44     pooling = Pooling(graph, dict(name='Pooling_', spatial_dims=spatial_dims, window=np.array([1, 1, 1, 1]),
45                                   output_spatial_shape=None,
46                                   stride=np.array(stride_prop), pad_spatial_shape=np.array([[0, 0], [0, 0]]),
47                                   pad=np.array([[0, 0], [0, 0], [0, 0], [0, 0]]), pool_method='max',
48                                   is_partial_inferred=False))
49     pooling_data = pooling.create_node_with_data([first_node])
50
51     _clean_fw_tensor_attrs(pooling_data)
52
53     graph.add_edges_from([(pooling_data.id, second_node.id, eattrs)])
54
55
56 def _check_next_ops(next_ops: list):
57     """
58     This function checks list of operation to determine that all ops has same (not 1,1,1,1) stride_prop attr
59     """
60     stride_props = []
61     for op in next_ops:
62         if op.has_valid('stride_prop'):
63             stride_props.append(np.array(op.stride_prop))
64         else:
65             continue
66
67     status = not (len(next_ops) != len(stride_props) or (len(stride_props) > 0 and not all(
68         np.array_equal(x, stride_props[0]) and not np.array_equal(x, [1, 1, 1, 1]) for x in stride_props)))
69     return stride_props, status
70
71
72 def _simple_stride_prop(graph: Graph, node: Node, spatial_dims, supported=True):
73     """
74     This function handles stride propagation for op nodes. If node is in supported ops dict so this is supported operation and we
75     can propagate stride directly via this op (stride_prop will be set by using bottom stride_prop), otherwise we can't and
76     stride_prop attr will be set as 1,1,1,1
77     """
78     next_ops = get_next_operation(node)
79     stride_props, all_ops_are_valid = _check_next_ops(next_ops)
80
81     if not supported or not all_ops_are_valid:
82         # We have to insert pooling layers
83         for op in next_ops:
84             if op.has_valid('stride_prop') and not np.array_equal(op.stride_prop[spatial_dims], np.array([1, 1])) and \
85                     (op.has_valid('has_stride') == False or op.soft_get('has_stride') == False):
86                 _insert_pooling(graph, node.out_node(), op, spatial_dims)
87         # If Convolution is valid then set `stride_prop` to Convolution stride
88         node['stride_prop'] = np.array([1, 1, 1, 1])
89         return
90
91     for op in next_ops:
92         if op.soft_get('has_stride') == True:
93             op.stride = np.array([1, 1, 1, 1])
94             log.debug("STRIDE PROP: {} {} strides was moved upper via {}".format(op.type, op.name, node.name))
95
96     node['stride_prop'] = np.array(stride_props[0]) if len(stride_props) > 0 else np.array([1, 1, 1, 1])
97     node['is_partial_inferred'] = False
98     _clean_fw_tensor_attrs(node.out_node())
99
100
101 def _conv_stride_prop(graph: Graph, node: Node, spatial_dims, supported=True):
102     """
103     This function handles convolution stride propagation. There is two cases: conv->(op) and conv->conv. In first case
104     we propagate stride from op, and in second case we also change stride for second conv
105     """
106     next_ops = get_next_operation(node)
107     stride_props, all_ops_are_valid = _check_next_ops(next_ops)
108
109     def _check_convolution(node: Node):
110         return node.has_valid('kernel_spatial') and np.array_equal(node.kernel_spatial, np.array([1, 1]))
111
112     # Check that all ops are valid and have same values
113     if not all_ops_are_valid:
114         # We have to insert pooling layers
115         for op in next_ops:
116             if op.has_valid('stride_prop') and not np.array_equal(op.stride_prop[spatial_dims], np.array([1, 1])):
117                 # Insert pooling
118                 _insert_pooling(graph, node.out_node(), op, spatial_dims)
119     elif len(stride_props) > 0:
120         node.stride *= stride_props[0]
121         log.debug('STRIDE PROP: {} got new strides {}'.format(node.name, node.stride))
122         for op in next_ops:
123             if op.soft_get('has_stride') == True:
124                 op.stride = np.array([1, 1, 1, 1])
125         node['is_partial_inferred'] = False
126         node['output_spatial_shape'] = False
127         _clean_fw_tensor_attrs(node.out_node())
128
129     # If Convolution is valid then set `stride_prop` to Convolution stride
130     node['stride_prop'] = np.array(node.stride) if _check_convolution(node) else np.array([1, 1, 1, 1])
131
132
133 supported_ops = {
134     'ReLU': {'stride_prop': _simple_stride_prop, 'attrs': {}},
135     'Eltwise': {'stride_prop': _simple_stride_prop, 'attrs': {}},
136     'Convolution': {'stride_prop': _conv_stride_prop, 'attrs': {'has_stride': True}},
137 }
138
139
140 def _stride_propagation(graph: Graph, spatial_dims):
141     """
142     This function do stride propagation for all op nodes
143     """
144     nodes = [Node(graph, x) for x in pseudo_topological_sort(graph, reverse=True) if
145              Node(graph, x).kind == 'op' and Node(graph, x).soft_get('type') != 'Const']
146
147     for node in nodes:
148         if node.soft_get('type') in supported_ops:
149             op = supported_ops[node.type]
150             # Add node attrs
151             for key in op['attrs'].keys():
152                 node[key] = op['attrs'][key]
153             op['stride_prop'](graph, node, spatial_dims, True)
154         else:
155             _simple_stride_prop(graph, node, spatial_dims, False)
156
157
158 def stride_optimization(graph: Graph):
159     """
160     This is main function for stride optimization pass
161     """
162     layout = graph.graph['layout']
163     if layout == 'NCHW':
164         spatial_dims = np.array([2, 3])
165     elif layout == 'NHWC':
166         spatial_dims = np.array([1, 2])
167     else:
168         log.warning('STRIDE PROP: layout {} is not supported'.format(layout))
169         return
170     _stride_propagation(graph, spatial_dims)
171
172     nodes = [Node(graph, x) for x in pseudo_topological_sort(graph) if
173              Node(graph, x).soft_get('is_partial_inferred') == False]
174     for node in nodes:
175         node.infer(node)