2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from mo.graph.graph import Node
21 from mo.middle.replacement import MiddleReplacementPattern
24 class ConcatOptimization(MiddleReplacementPattern):
25 # This optimization reduces number of edges between Concat operations
26 # that significantly reduce memory consumption
33 def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
36 for node in graph.nodes():
37 node = Node(graph, node)
38 if node.kind == 'op' and node.soft_get('type') == 'Concat':
39 in_nodes = tuple([node.in_node(idx).id for idx in range(len(node.in_nodes()))])
40 out_node = (node.id, node.out_node().id)
42 log.warning("Something is weird! {} and {}".format(node.id, mp[in_nodes]))
44 mp.update({in_nodes: out_node})
45 used.update({node.id: {x: False for x in in_nodes}})
49 for i in range(len(key)):
50 for j in range(i + 1, len(key)):
51 arr = tuple(key[i:j + 1])
52 if arr in mp.keys() and arr != key:
53 # print("Output of {} can be used as input for {} ({})".format(mp[arr][0], mp[key][0], len(arr)))
54 replacers.append((len(arr), arr))
56 replacers.sort(reverse=True)
58 concat_id = mp[key][0]
59 for ln, arr in replacers:
60 # Check that we can do it!!!
63 if used[concat_id][x]:
64 # print("Sorry but {} input was already removed from {}".format(x, concat_id))
72 used[concat_id][x] = True
74 edge_attrs = graph.get_edge_data(arr[0], concat_id)[0]
76 graph.remove_edge(in_node, concat_id)
78 new_input = mp[arr][1]
79 out_port = len(Node(graph, new_input).out_nodes()) + 1
80 edge_attrs['out'] = out_port
81 graph.add_edge(new_input, concat_id, **edge_attrs)
84 concat_node = Node(graph, concat_id)
85 ln = len(concat_node.in_nodes())
86 ports = [x for x in concat_node.in_nodes().keys()]
91 in_node = concat_node.in_nodes()[p]
92 graph[in_node.id][concat_id][0]['in'] = p_id