Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / partial_infer / transpose.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import numpy as np
20
21 from mo.front.extractor import update_attrs
22 from mo.ops.op import PermuteAttrs
23
24
25 def tf_transpose_infer(node):
26     if len(node.in_nodes()) != 2:
27         log.error("Transpose should take 2 inputs")
28         return
29
30     node_inp, node_order = (node.in_node(0), node.in_node(1))
31     order = node_order.value
32     in_shape = np.array(node_inp.shape)
33     node.graph.remove_edge(node_order.node, node.node)
34     node.order = np.array(order)
35     node.out_node().shape = in_shape[order]
36     if node_inp.has_valid('value'):
37         node.out_node().value = np.transpose(node_inp.value, axes=order)
38
39     PermuteAttrs.create_permute_attrs(node, attrs=[('order','input:0')])
40
41
42 def transpose_infer(node):
43     if node.order is None and (not node.has_valid('reverse_order') or (node.has_valid('reverse_order') and node.reverse_order == False)):
44         log.error('Cannot infer {} because order is None'.format(node.soft_get('name')))
45         return
46
47     if node.has_valid('reverse_order') and node.reverse_order and node.has_valid('order'):
48         log.error('Cannot infer {} due to both order and reverse_order was set'.format(node.soft_get('name')))
49         return
50
51     input_shape = node.in_node(0).shape
52
53     if node.has_valid('reverse_order') and node.reverse_order:
54         node.order = np.arange(len(input_shape))[::-1] # Reverse order
55
56     output_shape = np.array([input_shape[i] for i in node.order], dtype=np.int64)
57     node.out_node(0).shape = output_shape
58     if node.in_node().has_valid('value'):
59         node.out_node().value = np.transpose(node.in_node().value, axes=node.order)
60     PermuteAttrs.create_permute_attrs(node, attrs=[('order', 'input:0')])