2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from mo.front.common.replacement import FrontReplacementSubgraph
18 from mo.graph.graph import Graph
21 class Concat(FrontReplacementSubgraph):
26 nodes=[('concat', dict(op='Concat', simple_concat=True))],
30 def replace_sub_graph(self, graph: Graph, match: dict):
32 There are Concat and ConcatV2 operations in TensorFlow
33 The main difference is incoming port of tensor representing axis of concatenation
34 In Concat it is the 0 port, in ConcatV2 it is the last port
35 To reuse ConcatV2 logic (infer) that already exists in the Model Optimizer here we renumber ports of Concat
37 in_edges = list(graph.in_edges(match['concat'].id, data=True))
38 for u, v, attrs in in_edges:
40 attrs['in'] = len(in_edges) - 1 if in_port == 0 else attrs['in'] - 1
41 if match['concat'].has('axis'):
42 # we delete axis parameter here (it was set by default by Concat Op) to carefully get it from the last
43 # input in Concat infer function
44 del graph.node[match['concat'].id]['axis']