Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / fusing / fuse_grouped_conv.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 from collections import deque
19
20 import networkx as nx
21 import numpy as np
22
23 from mo.front.extractor import add_attrs_props
24 from mo.graph.graph import Node, Graph
25 from mo.middle.passes.eliminate import graph_clean_up
26 from mo.utils.graph import pseudo_topological_sort
27 from mo.middle.passes.fusing.helpers import get_next_operation, get_tensor_id
28
29
30 # TODO: unit tests
31 def concat_convolutions(graph: Graph, start_node: Node, last_node: Node):
32     """
33     This function converts group of convolutions into one
34     """
35
36     # Check that concatenation makes in the same order
37     conv_nodes = get_next_operation(start_node)
38     assert len(conv_nodes) == len(last_node.in_nodes())
39     gconv = conv_nodes[0]
40
41     for id in range(len(conv_nodes)):
42         conv = conv_nodes[id]
43         if conv.out_node().id != last_node.in_node(id).id:
44             return False
45         # Check that all convolutions have same weights shapes
46         if not np.array_equal(conv.in_node(1).shape, gconv.in_node(1).shape):
47             log.debug('Grouped convolutions fusion : convolutions have different weights shape')
48             return False
49
50     # Check that split and concat dims are valid
51     channel_dim = gconv.channel_dims[0]
52     if channel_dim != start_node.axis or channel_dim != last_node.axis:
53         log.debug('Grouped convolutions fusion : split or concat has wierd axis!')
54         return False
55
56     # Check that all convolutions has the same parameters
57     conv_attrs = ['pad', 'stride']
58     for attr in conv_attrs:
59         for id in range(len(conv_nodes)):
60             conv = conv_nodes[id]
61             if not np.array_equal(gconv[attr], conv[attr]):
62                 log.debug('Grouped convolutions fusion : attrs {} doesn\'t match'.format(attr))
63                 return False
64
65     # Check that all Convolutions has biases (if exists)
66     has_biases = False
67     for id in range(len(conv_nodes)):
68         conv = conv_nodes[id]
69         if len(conv.in_nodes()) == 3:
70             if not has_biases:
71                 has_biases = True
72         elif has_biases:
73             return False  # All convolution mast have biases
74
75     # Check that all biases have same shape
76     if has_biases:
77         for id in range(len(conv_nodes)):
78             conv = conv_nodes[id]
79             if conv.in_node(2).shape != gconv.in_node(2).shape:
80                 log.debug('Group convolutions fusion : convolutions have different biases shape {} and {}'.format(
81                     conv.in_node(2).shape, gconv.in_node(2).shape))
82                 return False
83
84     graph.remove_edge(gconv.in_node(0).id, gconv.id)
85     graph.remove_edge(gconv.id, gconv.out_node().id)
86
87     input = start_node.in_node(start_node.input_port)
88     output = last_node.out_node()
89
90     # Removing edges from data nodes to Split and Concat
91     graph.remove_edge(input.id, start_node.id)
92     graph.remove_edge(last_node.id, output.id)
93
94     # Add edges to grouped convolution
95     graph.add_edges_from([
96         (input.id, gconv.id, {'in': 0}),
97         (gconv.id, output.id, {'out': 0})
98     ])
99
100     # Concatenation of convolutions
101     weights_node = gconv.in_node(1)
102     bias_node = gconv.in_node(2) if has_biases else None
103
104     weights_value = np.array(weights_node.value)
105     bias_value = np.array(bias_node.value) if has_biases else None
106
107     feature_dim = 3 if graph.graph['layout'] == 'NHWC' else 1
108
109     for conv in conv_nodes[1:]:
110         weights_value = np.concatenate((weights_value, conv.in_node(1).value), axis=feature_dim)
111         if has_biases:
112             bias_value = np.concatenate((bias_value, conv.in_node(2).value), axis=-1)  # Not validated
113
114     weights_node.value = np.array(weights_value)
115     weights_node.shape = np.array(weights_value.shape)
116
117     if has_biases:
118         bias_node.value = np.array(bias_value)
119         bias_node.shape = np.array(bias_value.shape)
120
121     log.debug('Start node : {} Last node : {}  Nodes inside : {}'.format(start_node.id, last_node.id,
122                                                                          len(start_node.out_nodes())))
123     log.debug('Output shape : {}'.format(weights_value.shape))
124
125     gconv.group = len(conv_nodes)
126     gconv.output = weights_node.shape[feature_dim]
127     gconv.output_shape[feature_dim] = weights_node.shape[feature_dim]
128
129     return True
130
131
132 # TODO: unit tests
133 def grouped_convolutions_fusing(graph: Graph):
134     while True:
135         is_fused = False
136         graph_clean_up(graph, ['TFCustomSubgraphCall', 'Shape'])
137         nodes = pseudo_topological_sort(graph)
138         for idx in nodes:
139             node = Node(graph, idx)
140             if node.kind == 'op' and len(node.out_nodes()) > 1:
141                 if node.soft_get('can_be_fused') == False:
142                     continue
143
144                 is_valid_convolutions = True
145                 last_layer = None
146
147                 next_nodes = get_next_operation(node)
148                 # Check that all operation after this one are Convolutions
149                 # and all convolutions has same output
150                 if len(next_nodes) > 1 and all(_node.soft_get('type') in ['Convolution', 'Deconvolution'] for _node in next_nodes):
151                     for conv in next_nodes:
152                         conv_outputs = get_next_operation(conv)
153                         if conv.soft_get('can_be_fused') == False:
154                             is_valid_convolutions = False
155                         if len(conv_outputs) != 1:
156                             is_valid_convolutions = False
157                         if last_layer is None:
158                             last_layer = conv_outputs[0].id
159                         elif conv_outputs[0].id != last_layer:
160                             is_valid_convolutions = False
161
162                     if is_valid_convolutions:
163                         is_fused = concat_convolutions(graph, node, Node(graph, last_layer))
164                         if is_fused:
165                             break
166
167         if not is_fused:
168             break