Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / BinarizeWeightsM1P1.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import numpy as np
20
21 from extensions.middle.CheckForCycle import CheckForCycle
22 from extensions.middle.DeleteControlFlowEdges import DeleteControlFlowEdges
23 from extensions.middle.DeleteNotExecutable import DeleteNotExecutable
24 from mo.graph.graph import Graph
25 from mo.middle.replacement import MiddleReplacementPattern
26 from mo.ops.lin_op import Mul
27 from mo.ops.power import Power
28
29
30 class BinarizeWeightsM1P1(MiddleReplacementPattern):
31     """ Convert weights to -1/+1 form
32
33         Applicable for convolutions and other operations that have 'weights' that combined with the input data
34         by mean of multiplication operation. So any linear operator suits. Detect such operations by
35         multiplication_transparent attribute -- if it is presents and set to True, then multiplication term
36         can be passed through the operation. If multiplication_transparent attribute is set to True for an operation,
37         such operation should also has multiplication_transparent_ports that contain a list of pairs with
38         port indices (in_port, out_port) that defines which port pairs can pass multiplication through.
39
40         For example for some convolutional operation which has 2 ports (input tensor and weights) and 1 output port
41         this list includes [(0,0)(1,0)]. If convolutional operation also has biases at port 2, it is not included into
42         this list because this port is not transparent for multiplication operation.
43
44         multiplication_transparent_ports can be None if all possible input/output pairs are multiplication
45         transparent.
46
47         #TODO Describe how to apply multiplication at output ports -- this is not specified. In the current definition
48         we can pass through only scalar multiplication, but we already requre passing it channel-wise.
49     """
50     enabled = True
51
52     def run_after(self):
53         return [DeleteControlFlowEdges]
54
55     def run_before(self):
56         # CheckForCycle and DeleteNotExecutable run graph clean up which should not be run before weights binarization
57         return [CheckForCycle, DeleteNotExecutable]
58
59     def pattern(self):
60         return dict(
61             nodes=[
62                 ('quantize', dict(kind='op', op='Quantize')),
63                 ('quantized', dict()),
64                 ('operator', dict(kind='op', multiplication_transparent=True)),
65             ],
66             edges=[
67                 ('quantize', 'quantized'),
68                 ('quantized', 'operator'),
69             ]
70         )
71
72     def replace_pattern(self, graph: Graph, match: dict):
73         assert match['operator'].has('multiplication_transparent_ports')
74
75         port = match['operator'].input_ports_with(match['quantized'])
76         assert len(port) >= 1
77         if len(port) > 1:
78             log.debug('BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more'
79                       ' than once'.format(match['quantized'].name))
80             return
81
82         assert len(port) == 1
83         port = port[0]
84         applicable = [pair for pair in match['operator'].multiplication_transparent_ports if pair[0] == port]
85         if len(applicable) == 0:
86             return
87
88         # Look at 3-rd and 4-th inputs of Quantize -- they have constants that should be passed through.
89         # Assume that the constant that should be passed through is a scalar.
90         quantize = match['quantize']
91         output_low = quantize.in_node(3)
92         output_high = quantize.in_node(4)
93
94         if not output_low.has_valid('value') and not output_high.has_valid('value'):
95             return
96
97         output_low = output_low.value
98         output_high = output_high.value
99
100         # This pass is applicable for binarization only. Other intX variants are not relevant.
101         if quantize.levels != 2:
102             return
103
104         # Recognize two cases: 0/+1 and -1/+1.
105         zp1 = np.all(output_low == 0) or np.all(output_high == 0)
106         m1p1 = np.all(-output_low == output_high)
107         if (not zp1 and not m1p1) or (zp1 and m1p1):
108             log.debug('BinarizeWeightsM1P1 cannot apply transformation for data {} because it does\'t has one of'
109                       ' 0/+1 or -1/+1 forms.'.format(match['quantized'].name))
110             return
111
112         # Recognize scalar
113         if len(np.unique(output_low)) != 1 or len(np.unique(output_high)) != 1:
114             log.debug('BinarizeWeightsM1P1 cannot apply transformation for data {} because output_low or output_high '
115                       'cannot be interpreted as scalars.'.format(match['quantized'].name))
116             return
117
118         # TODO: Extract real scalar from 3rd and 4th inputs; reusing original tensors is dangerous because
119         #       it may have incompatible shape.
120
121         mult_term = quantize.in_node(3) if np.all(output_high == 0) else quantize.in_node(4)
122
123         # Patch inflow path (by diving by mult_term)
124         # Put a new Power/Mul combination here:
125         #       ---->---- (here)---> data ---> [3rd/4th ports]quantize ---> quantized ---> operator
126
127         if len(match['quantized'].out_nodes()) > 1:
128             log.debug('BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1')
129             return
130         div_op = Power(graph, {'name': quantize.name + '/DivNormalize', 'power': -1.0})
131         div_output = div_op.create_node_with_data([mult_term])
132
133         for i in [3, 4]:
134             match['quantize'].insert_node_with_data_before(
135                 match['quantize'].in_node(i),
136                 Mul,
137                 dict(name=quantize.name + '/MulNormalize'),
138                 additional_inputs=[div_output],
139             )
140
141         match['quantized'].value = None  # reset value because it will be recomputed
142         match['quantize'].infer(match['quantize'])
143
144         # Put a complimentary new Mul node here:   operator -->---(here)-----> operator.out_node()
145
146         match['operator'].insert_node_with_data_after(
147             match['operator'].out_node(),
148             Mul,
149             dict(name=match['operator'].name + '/MulNormalize'),
150             [mult_term],
151         )
152
153         # Disable 'operator' fusion with linear ops, otherwise it will annihilate changes that we just made
154         match['operator']['can_be_fused'] = False