Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / GemmToFullyConnected.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import numpy as np
20
21 from typing import Dict
22 from mo.front.common.partial_infer.utils import assign_dims_to_weights
23 from mo.graph.graph import Graph, Node
24 from mo.middle.replacement import MiddleReplacementPattern
25 from mo.ops.lin_op import Add
26
27
28 class GemmToFullyConnected(MiddleReplacementPattern):
29     enabled = True
30     graph_condition = [lambda graph: graph.graph['fw'] == 'onnx']
31
32     def run_after(self):
33         from extensions.middle.pass_separator import MiddleStart
34         return [MiddleStart]
35
36     def run_before(self):
37         from extensions.middle.pass_separator import MiddleFinish
38         return [MiddleFinish]
39
40     def pattern(self):
41         return dict(
42             nodes=[
43                 ('gemm', dict(kind='op', op='Gemm')),
44                 ('output', dict(kind='data'))],
45             edges=[('gemm', 'output')]
46         )
47
48     def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
49         log.debug('GemmToFullyConnected is triggered')
50         gemm = match['gemm']
51         A = gemm.in_node(0)
52         B = gemm.in_node(1)
53         B_consumers = graph.out_edges(B.node)
54         C = gemm.in_node(2)
55
56         if not (B.value is not None and
57                 C.value is not None and
58                 A.shape is not None and
59                 not gemm.transpose_a and
60                 (len(B_consumers) == 1 or not gemm.transpose_b)):
61             log.warning('Cannot convert Gemm to FullyConnected')
62             return
63
64         if gemm.transpose_b:
65             # B.value = B.value.transpose()
66             # B.shape = np.array(B.value.shape, dtype=np.int64)
67             gemm.transpose_b = 0
68         else:
69             B.value = B.value.transpose()
70             B.shape = np.array(B.value.shape, dtype=np.int64)
71
72         gemm['out-size'] = gemm.out_port(0).data.get_shape()[-1]
73         gemm['type'] = 'FullyConnected'
74         gemm['channel_dims'] = len(match['output'].shape) - 1
75         gemm['bias_addable'] = True
76         gemm['input_channel_dim'] = 1  # MatMul weights in IO
77         gemm['output_channel_dim'] = 0
78         gemm['layout'] = 'NCHW'
79
80         gemm.in_port(1).bin = 'weights'
81
82         bias_node = Add(graph, {}).create_node()
83         gemm.out_port(0).get_connection().set_source(bias_node.out_port(0))
84         gemm.in_port(2).get_connection().set_destination(bias_node.in_port(1))
85         gemm.out_port(0).connect(bias_node.in_port(0))
86
87         assign_dims_to_weights(gemm.in_node(1), None, 1, 0, 2)
88         # Do not transpose weights in this pass, it will be done as a separate pass