Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / partial_infer / matmul.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import numpy as np
20
21 from mo.front.common.partial_infer.utils import int64_array
22 from mo.ops.op import PermuteAttrs
23 from mo.utils.error import Error
24
25
26 def tf_matmul_infer(node):
27     assert (len(node.in_nodes()) == 2)
28
29     shapes = [node.in_node(i).shape.copy() for i in range(2)]
30     log.debug('matmul shapes: {}'.format(shapes))
31     if any(s is None or len(s) < 2 for s in shapes):
32         log.error("MatMul wasn't able to infer shape")
33         return
34
35     if node.transpose_a:
36         if not node.in_node(0).has_valid('value'):
37             log.error("MatMul wasn't able to infer shape")
38             return
39         else:
40             perm = np.array(range(len(node.in_node(0).shape)), dtype=np.int64)
41             perm[-1], perm[-2] = perm[-2], perm[-1]
42             inv = PermuteAttrs.get_inverse_permutation(perm)
43             permutation = PermuteAttrs.Permutation(perm=perm, inv=int64_array(inv))
44             PermuteAttrs.set_permutation(node.in_node(0), node, permutation)
45             shapes[0] = shapes[0][perm]
46
47     if node.transpose_b:
48         if not node.in_node(1).has_valid('value'):
49             log.error("MatMul wasn't able to infer shape")
50             return
51         else:
52             perm = np.array(range(len(node.in_node(1).shape)), dtype=np.int64)
53             perm[-1], perm[-2] = perm[-2], perm[-1]
54             inv = PermuteAttrs.get_inverse_permutation(perm)
55             permutation = PermuteAttrs.Permutation(perm=perm, inv=int64_array(inv))
56             PermuteAttrs.set_permutation(node.in_node(1), node, permutation)
57             shapes[1] = shapes[1][perm]
58
59     if any(shapes[0][:-2] != shapes[1][:-2]) or shapes[0][-1] != shapes[1][-2]:
60         log.error("MatMul wasn't able to infer shape because input dimensions are not compatible")
61         return
62     if any(shapes[0][1:-1] != 1):
63         log.error("MatMul wasn't able to infer shapes because input[0] shape is invalid: {}".format(shapes[0]))
64         return
65
66     shape_tuple = (np.array([shapes[0][0]], dtype=np.int64), np.array([shapes[1][-1]], dtype=np.int64))
67     if len(shapes[0]) > 2:
68         # TODO Investigate case when MatMul have inputs with not matching output dimensions
69         # It looks to be a practical case and if we add outer dimensions of the first argument
70         # it will lead to incorrect model sometimes. TF documentation is unclear.
71         log.warning('Ignored outer dimensions of input tensor for MatMul node: {}'.format(node.name))
72         # shape_tuple = (shapes[0][:-2], *shape_tuple)
73
74     log.debug('shape_tuple: {}'.format(shape_tuple))
75     node.out_node().shape = np.concatenate(shape_tuple)
76     node['channel_dims'] = node.out_node().shape.size - 1
77     log.debug('matmul shape: {}'.format(node.out_node().shape))
78
79
80 def onnx_gemm_infer(node):
81     assert (len(node.in_nodes()) == 3)
82     shapeA = node.in_node(0).shape
83     shapeB = node.in_node(1).shape
84     shapeC = node.in_node(2).shape
85
86     assert shapeA.size >= 2 and shapeB.size == 2 and shapeC.size in [1, 2]
87
88     if shapeA.size > 2 and node.transpose_a:
89         raise Error(
90             'ONNX Gemm operation do not support {} dimensional input with set transA key'.format(shapeA.size))
91
92     # apply transposes and broadcasts
93     if node.transpose_a:
94         shapeA = shapeA[[1, 0]]
95     if node.transpose_b:
96         shapeB = shapeB[[1, 0]]
97     if node.broadcast_c and shapeC.size == 1:
98         shapeC = np.array([shapeA[0], shapeC[0]])
99
100     node.out_node().shape = shapeC
101     return