Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / concat.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 from mo.front.common.replacement import FrontReplacementSubgraph
18 from mo.graph.graph import Graph
19
20
21 class Concat(FrontReplacementSubgraph):
22     enabled = True
23
24     def pattern(self):
25         return dict(
26             nodes=[('concat', dict(op='Concat', simple_concat=True))],
27             edges=[]
28         )
29
30     def replace_sub_graph(self, graph: Graph, match: dict):
31         """
32         There are Concat and ConcatV2 operations in TensorFlow
33         The main difference is incoming port of tensor representing axis of concatenation
34         In Concat it is the 0 port, in ConcatV2 it is the last port
35         To reuse ConcatV2 logic (infer) that already exists in the Model Optimizer here we renumber ports of Concat
36         """
37         in_edges = list(graph.in_edges(match['concat'].id, data=True))
38         for u, v, attrs in in_edges:
39             in_port = attrs['in']
40             attrs['in'] = len(in_edges) - 1 if in_port == 0 else attrs['in'] - 1
41         if match['concat'].has('axis'):
42             # we delete axis parameter here (it was set by default by Concat Op) to carefully get it from the last
43             # input in Concat infer function
44             del graph.node[match['concat'].id]['axis']