2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.front.common.partial_infer.utils import int64_array
22 from mo.ops.op import PermuteAttrs
23 from mo.utils.error import Error
26 def tf_matmul_infer(node):
27 assert (len(node.in_nodes()) == 2)
29 shapes = [node.in_node(i).shape.copy() for i in range(2)]
30 log.debug('matmul shapes: {}'.format(shapes))
31 if any(s is None or len(s) < 2 for s in shapes):
32 log.error("MatMul wasn't able to infer shape")
36 if not node.in_node(0).has_valid('value'):
37 log.error("MatMul wasn't able to infer shape")
40 perm = np.array(range(len(node.in_node(0).shape)), dtype=np.int64)
41 perm[-1], perm[-2] = perm[-2], perm[-1]
42 inv = PermuteAttrs.get_inverse_permutation(perm)
43 permutation = PermuteAttrs.Permutation(perm=perm, inv=int64_array(inv))
44 PermuteAttrs.set_permutation(node.in_node(0), node, permutation)
45 shapes[0] = shapes[0][perm]
48 if not node.in_node(1).has_valid('value'):
49 log.error("MatMul wasn't able to infer shape")
52 perm = np.array(range(len(node.in_node(1).shape)), dtype=np.int64)
53 perm[-1], perm[-2] = perm[-2], perm[-1]
54 inv = PermuteAttrs.get_inverse_permutation(perm)
55 permutation = PermuteAttrs.Permutation(perm=perm, inv=int64_array(inv))
56 PermuteAttrs.set_permutation(node.in_node(1), node, permutation)
57 shapes[1] = shapes[1][perm]
59 if any(shapes[0][:-2] != shapes[1][:-2]) or shapes[0][-1] != shapes[1][-2]:
60 log.error("MatMul wasn't able to infer shape because input dimensions are not compatible")
62 if any(shapes[0][1:-1] != 1):
63 log.error("MatMul wasn't able to infer shapes because input[0] shape is invalid: {}".format(shapes[0]))
66 shape_tuple = (np.array([shapes[0][0]], dtype=np.int64), np.array([shapes[1][-1]], dtype=np.int64))
67 if len(shapes[0]) > 2:
68 # TODO Investigate case when MatMul have inputs with not matching output dimensions
69 # It looks to be a practical case and if we add outer dimensions of the first argument
70 # it will lead to incorrect model sometimes. TF documentation is unclear.
71 log.warning('Ignored outer dimensions of input tensor for MatMul node: {}'.format(node.name))
72 # shape_tuple = (shapes[0][:-2], *shape_tuple)
74 log.debug('shape_tuple: {}'.format(shape_tuple))
75 node.out_node().shape = np.concatenate(shape_tuple)
76 node['channel_dims'] = node.out_node().shape.size - 1
77 log.debug('matmul shape: {}'.format(node.out_node().shape))
80 def onnx_gemm_infer(node):
81 assert (len(node.in_nodes()) == 3)
82 shapeA = node.in_node(0).shape
83 shapeB = node.in_node(1).shape
84 shapeC = node.in_node(2).shape
86 assert shapeA.size >= 2 and shapeB.size == 2 and shapeC.size in [1, 2]
88 if shapeA.size > 2 and node.transpose_a:
90 'ONNX Gemm operation do not support {} dimensional input with set transA key'.format(shapeA.size))
92 # apply transposes and broadcasts
94 shapeA = shapeA[[1, 0]]
96 shapeB = shapeB[[1, 0]]
97 if node.broadcast_c and shapeC.size == 1:
98 shapeC = np.array([shapeA[0], shapeC[0]])
100 node.out_node().shape = shapeC