2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from collections import deque
23 from mo.front.extractor import add_attrs_props
24 from mo.graph.graph import Node, Graph
25 from mo.middle.passes.eliminate import graph_clean_up
26 from mo.utils.graph import pseudo_topological_sort
27 from mo.middle.passes.fusing.helpers import get_next_operation, get_tensor_id
31 def concat_convolutions(graph: Graph, start_node: Node, last_node: Node):
33 This function converts group of convolutions into one
36 # Check that concatenation makes in the same order
37 conv_nodes = get_next_operation(start_node)
38 assert len(conv_nodes) == len(last_node.in_nodes())
41 for id in range(len(conv_nodes)):
43 if conv.out_node().id != last_node.in_node(id).id:
45 # Check that all convolutions have same weights shapes
46 if not np.array_equal(conv.in_node(1).shape, gconv.in_node(1).shape):
47 log.debug('Grouped convolutions fusion : convolutions have different weights shape')
50 # Check that split and concat dims are valid
51 channel_dim = gconv.channel_dims[0]
52 if channel_dim != start_node.axis or channel_dim != last_node.axis:
53 log.debug('Grouped convolutions fusion : split or concat has wierd axis!')
56 # Check that all convolutions has the same parameters
57 conv_attrs = ['pad', 'stride']
58 for attr in conv_attrs:
59 for id in range(len(conv_nodes)):
61 if not np.array_equal(gconv[attr], conv[attr]):
62 log.debug('Grouped convolutions fusion : attrs {} doesn\'t match'.format(attr))
65 # Check that all Convolutions has biases (if exists)
67 for id in range(len(conv_nodes)):
69 if len(conv.in_nodes()) == 3:
73 return False # All convolution mast have biases
75 # Check that all biases have same shape
77 for id in range(len(conv_nodes)):
79 if conv.in_node(2).shape != gconv.in_node(2).shape:
80 log.debug('Group convolutions fusion : convolutions have different biases shape {} and {}'.format(
81 conv.in_node(2).shape, gconv.in_node(2).shape))
84 graph.remove_edge(gconv.in_node(0).id, gconv.id)
85 graph.remove_edge(gconv.id, gconv.out_node().id)
87 input = start_node.in_node(start_node.input_port)
88 output = last_node.out_node()
90 # Removing edges from data nodes to Split and Concat
91 graph.remove_edge(input.id, start_node.id)
92 graph.remove_edge(last_node.id, output.id)
94 # Add edges to grouped convolution
95 graph.add_edges_from([
96 (input.id, gconv.id, {'in': 0}),
97 (gconv.id, output.id, {'out': 0})
100 # Concatenation of convolutions
101 weights_node = gconv.in_node(1)
102 bias_node = gconv.in_node(2) if has_biases else None
104 weights_value = np.array(weights_node.value)
105 bias_value = np.array(bias_node.value) if has_biases else None
107 feature_dim = 3 if graph.graph['layout'] == 'NHWC' else 1
109 for conv in conv_nodes[1:]:
110 weights_value = np.concatenate((weights_value, conv.in_node(1).value), axis=feature_dim)
112 bias_value = np.concatenate((bias_value, conv.in_node(2).value), axis=-1) # Not validated
114 weights_node.value = np.array(weights_value)
115 weights_node.shape = np.array(weights_value.shape)
118 bias_node.value = np.array(bias_value)
119 bias_node.shape = np.array(bias_value.shape)
121 log.debug('Start node : {} Last node : {} Nodes inside : {}'.format(start_node.id, last_node.id,
122 len(start_node.out_nodes())))
123 log.debug('Output shape : {}'.format(weights_value.shape))
125 gconv.group = len(conv_nodes)
126 gconv.output = weights_node.shape[feature_dim]
127 gconv.output_shape[feature_dim] = weights_node.shape[feature_dim]
133 def grouped_convolutions_fusing(graph: Graph):
136 graph_clean_up(graph, ['TFCustomSubgraphCall', 'Shape'])
137 nodes = pseudo_topological_sort(graph)
139 node = Node(graph, idx)
140 if node.kind == 'op' and len(node.out_nodes()) > 1:
141 if node.soft_get('can_be_fused') == False:
144 is_valid_convolutions = True
147 next_nodes = get_next_operation(node)
148 # Check that all operation after this one are Convolutions
149 # and all convolutions has same output
150 if len(next_nodes) > 1 and all(_node.soft_get('type') in ['Convolution', 'Deconvolution'] for _node in next_nodes):
151 for conv in next_nodes:
152 conv_outputs = get_next_operation(conv)
153 if conv.soft_get('can_be_fused') == False:
154 is_valid_convolutions = False
155 if len(conv_outputs) != 1:
156 is_valid_convolutions = False
157 if last_layer is None:
158 last_layer = conv_outputs[0].id
159 elif conv_outputs[0].id != last_layer:
160 is_valid_convolutions = False
162 if is_valid_convolutions:
163 is_fused = concat_convolutions(graph, node, Node(graph, last_layer))