Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / partial_infer / concat.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 # Concat infer : N - number of inputs to concat
20 #                axis - dimension number for tensors concatenation
21 import numpy as np
22
23 from mo.front.caffe.extractors.utils import get_canonical_axis_index
24 from mo.ops.op import PermuteAttrs
25
26
27 def concat_infer(node):
28     if not node.has('axis'):
29         N = node.N
30         axis_input = node.in_node(N)
31         if axis_input.has_valid('value') and axis_input.value.size == 1:
32             node['axis'] = axis_input.value.item()
33             node.graph.remove_edge(axis_input.node, node.node)  # TODO add skip attribute instead of deleting
34         else:
35             return
36     else:
37         N = len(node.in_nodes())
38
39     shapes = [node.in_node(i).shape for i in range(N)]
40     if any(s is None for s in shapes):
41         return
42
43     shape = np.array(shapes[0])
44
45     axis = get_canonical_axis_index(shape, node.axis)
46     node.axis = axis
47
48     mask = np.zeros_like(shape, dtype=np.bool)
49     mask[axis] = True  # pylint: disable=unsupported-assignment-operation
50     not_mask = np.logical_not(mask)  # pylint: disable=assignment-from-no-return
51     for s in shapes[1:]:
52         if np.all(shape[not_mask] == s[not_mask]):  # TODO handle -1 in a special way
53             shape[mask] += s[mask]
54         else:
55             log.error('Concat input shapes do not match')
56             return
57
58     node.out_node(0).shape = shape
59     if len(shape) != 4:
60         # exclude it from NHWC to NCHW convertion
61         if 'axis' in node.dim_attrs:
62             node.dim_attrs.remove('axis')
63
64     PermuteAttrs.create_permute_attrs(node, attrs=[('axis','input:0')])
65
66     values = [node.in_node(i).value for i in range(N)]
67     if any(v is None for v in values):
68         return
69
70     node.out_node(0).value = np.concatenate(values, axis=node.axis)
71     node.out_node(0).shape = np.array(node.out_node(0).value.shape, dtype=np.int64)
72
73
74
75
76 def tf_pack_infer(node):
77     # Constant path is supported only
78     values = [node.in_node(i).value for i in range(node.N)]
79     if any(v is None for v in values):
80         return
81     node.out_node().value = np.stack(values, node.axis)
82     node.out_node().shape = np.array(node.out_node().value.shape, dtype=np.int64)
83