2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from extensions.middle.BinarizeWeightsM1P1 import BinarizeWeightsM1P1
20 from mo.graph.graph import Graph
21 from mo.middle.passes.eliminate import remove_op_node_with_data_node
22 from mo.middle.replacement import MiddleReplacementPattern
25 class ReluQuantizeFuse(MiddleReplacementPattern):
26 """ Fuses ReLU --> Quantize sequence if possible
28 Relu --> Quantize fusion is possible if:
29 1. Relu is consumed to 0-th port of Quantize
30 2. Quantize ports 1 and 2 defines such input range that 0 is not included
35 return [BinarizeWeightsM1P1]
38 from extensions.middle.SharedWeightsDuplication import SharedWeightsDuplication
39 return [SharedWeightsDuplication]
44 ('relu', dict(op='Relu')),
46 ('quantize', dict(op='Quantize')),
50 ('relued', 'quantize', {'in': 0}),
54 def replace_pattern(self, graph: Graph, match: dict):
56 quantize = match['quantize']
58 # Check for total number of ReLU consumers -- if something else consume its output it cannot be fused
59 if len(match['relu'].out_node().out_nodes()) > 1:
60 log.debug('ReluQuantizeFuse: cannot fuse because ReLU have multiple consumers')
63 # If the fusion is applicable, direct modifications to quantize 1-st and 2-nd inputs
64 # are performed. So the data nodes at those inputs shouldn't have more than 1 consumer
65 # maximum 2 consumers to the same quantize op (consumed by 1st and 2nd ports).
66 # TODO: relax this limitation and duplicate data nodes accordingly to modify the input range freely
68 # Provisional limitation that related to binary quantization
69 # TODO: Relax it beyond binarization case
70 if len(quantize.in_node(1).out_nodes()) != 2 or \
71 len(quantize.in_node(2).out_nodes()) != 2 or \
72 quantize.in_node(1).id != quantize.in_node(2).id or \
74 log.debug('ReluQuantizeFuse: cannot fuse because Quantize op has '
75 'unexpected number of consumers for ports 1 and 2')
78 threshold = quantize.in_node(1)
80 # As we restricted to binarization case only, so we need to detect from
81 # which side of 0 Quantize threshold resides:
82 # if the threshold > 0, it remains the same;
83 # if the threshold == 0, it also remains the same;
84 # if the threshold < 0, it should be modified to -infinity that means that all inputs map to output_high
86 modification_mask = threshold.value < 0
87 threshold.value[modification_mask] = float('-inf')
89 # Remove ReLU as it no longer needed
90 remove_op_node_with_data_node(graph, match['relu'])