Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / ConcatOptimization.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import networkx as nx
18 import logging as log
19
20 from mo.graph.graph import Node
21 from mo.middle.replacement import MiddleReplacementPattern
22
23
24 class ConcatOptimization(MiddleReplacementPattern):
25     # This optimization reduces number of edges between Concat operations
26     # that significantly reduce memory consumption
27
28     enabled = False
29
30     def run_after(self):
31         return []
32
33     def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
34         mp = {}
35         used = {}
36         for node in graph.nodes():
37             node = Node(graph, node)
38             if node.kind == 'op' and node.soft_get('type') == 'Concat':
39                 in_nodes = tuple([node.in_node(idx).id for idx in range(len(node.in_nodes()))])
40                 out_node = (node.id, node.out_node().id)
41                 if in_nodes in mp:
42                     log.warning("Something is weird! {} and {}".format(node.id, mp[in_nodes]))
43                 else:
44                     mp.update({in_nodes: out_node})
45                     used.update({node.id: {x: False for x in in_nodes}})
46
47         for key in mp.keys():
48             replacers = []
49             for i in range(len(key)):
50                 for j in range(i + 1, len(key)):
51                     arr = tuple(key[i:j + 1])
52                     if arr in mp.keys() and arr != key:
53                         # print("Output of {} can be used as input for {} ({})".format(mp[arr][0], mp[key][0], len(arr)))
54                         replacers.append((len(arr), arr))
55
56             replacers.sort(reverse=True)
57
58             concat_id = mp[key][0]
59             for ln, arr in replacers:
60                 # Check that we can do it!!!
61                 we_can = True
62                 for x in arr:
63                     if used[concat_id][x]:
64                         # print("Sorry but {} input was already removed from {}".format(x, concat_id))
65                         we_can = False
66                         break
67
68                 if not we_can:
69                     continue
70
71                 for x in arr:
72                     used[concat_id][x] = True
73
74                 edge_attrs = graph.get_edge_data(arr[0], concat_id)[0]
75                 for in_node in arr:
76                     graph.remove_edge(in_node, concat_id)
77
78                 new_input = mp[arr][1]
79                 out_port = len(Node(graph, new_input).out_nodes()) + 1
80                 edge_attrs['out'] = out_port
81                 graph.add_edge(new_input, concat_id, **edge_attrs)
82
83                 # Renumber 'in' attrs
84                 concat_node = Node(graph, concat_id)
85                 ln = len(concat_node.in_nodes())
86                 ports = [x for x in concat_node.in_nodes().keys()]
87                 ports.sort()
88
89                 p_id = 0
90                 for p in ports:
91                     in_node = concat_node.in_nodes()[p]
92                     graph[in_node.id][concat_id][0]['in'] = p_id
93                     p_id += 1