Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / ReluQuantizeFuse.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 from extensions.middle.BinarizeWeightsM1P1 import BinarizeWeightsM1P1
20 from mo.graph.graph import Graph
21 from mo.middle.passes.eliminate import remove_op_node_with_data_node
22 from mo.middle.replacement import MiddleReplacementPattern
23
24
25 class ReluQuantizeFuse(MiddleReplacementPattern):
26     """ Fuses ReLU --> Quantize sequence if possible
27
28         Relu --> Quantize fusion is possible if:
29             1. Relu is consumed to 0-th port of Quantize
30             2. Quantize ports 1 and 2 defines such input range that 0 is not included
31     """
32     enabled = True
33
34     def run_after(self):
35         return [BinarizeWeightsM1P1]
36
37     def run_before(self):
38         from extensions.middle.SharedWeightsDuplication import SharedWeightsDuplication
39         return [SharedWeightsDuplication]
40
41     def pattern(self):
42         return dict(
43             nodes=[
44                 ('relu', dict(op='Relu')),
45                 ('relued', dict()),
46                 ('quantize', dict(op='Quantize')),
47             ],
48             edges=[
49                 ('relu', 'relued'),
50                 ('relued', 'quantize', {'in': 0}),
51             ]
52         )
53
54     def replace_pattern(self, graph: Graph, match: dict):
55
56         quantize = match['quantize']
57
58         # Check for total number of ReLU consumers -- if something else consume its output it cannot be fused
59         if len(match['relu'].out_node().out_nodes()) > 1:
60             log.debug('ReluQuantizeFuse: cannot fuse because ReLU have multiple consumers')
61             return
62
63         # If the fusion is applicable, direct modifications to quantize 1-st and 2-nd inputs
64         # are performed. So the data nodes at those inputs shouldn't have more than 1 consumer
65         # maximum 2 consumers to the same quantize op (consumed by 1st and 2nd ports).
66         # TODO: relax this limitation and duplicate data nodes accordingly to modify the input range freely
67
68         # Provisional limitation that related to binary quantization
69         # TODO: Relax it beyond binarization case
70         if len(quantize.in_node(1).out_nodes()) != 2 or \
71                         len(quantize.in_node(2).out_nodes()) != 2 or \
72                         quantize.in_node(1).id != quantize.in_node(2).id or \
73                         quantize.levels != 2:
74             log.debug('ReluQuantizeFuse: cannot fuse because Quantize op has '
75                       'unexpected number of consumers for ports 1 and 2')
76             return
77
78         threshold = quantize.in_node(1)
79
80         # As we restricted to binarization case only, so we need to detect from
81         # which side of 0 Quantize threshold resides:
82         #   if the threshold > 0, it remains the same;
83         #   if the threshold == 0, it also remains the same;
84         #   if the threshold < 0, it should be modified to -infinity that means that all inputs map to output_high
85
86         modification_mask = threshold.value < 0
87         threshold.value[modification_mask] = float('-inf')
88
89         # Remove ReLU as it no longer needed
90         remove_op_node_with_data_node(graph, match['relu'])