2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from extensions.middle.CheckForCycle import CheckForCycle
22 from extensions.middle.DeleteControlFlowEdges import DeleteControlFlowEdges
23 from extensions.middle.DeleteNotExecutable import DeleteNotExecutable
24 from mo.graph.graph import Graph
25 from mo.middle.replacement import MiddleReplacementPattern
26 from mo.ops.lin_op import Mul
27 from mo.ops.power import Power
30 class BinarizeWeightsM1P1(MiddleReplacementPattern):
31 """ Convert weights to -1/+1 form
33 Applicable for convolutions and other operations that have 'weights' that combined with the input data
34 by mean of multiplication operation. So any linear operator suits. Detect such operations by
35 multiplication_transparent attribute -- if it is presents and set to True, then multiplication term
36 can be passed through the operation. If multiplication_transparent attribute is set to True for an operation,
37 such operation should also has multiplication_transparent_ports that contain a list of pairs with
38 port indices (in_port, out_port) that defines which port pairs can pass multiplication through.
40 For example for some convolutional operation which has 2 ports (input tensor and weights) and 1 output port
41 this list includes [(0,0)(1,0)]. If convolutional operation also has biases at port 2, it is not included into
42 this list because this port is not transparent for multiplication operation.
44 multiplication_transparent_ports can be None if all possible input/output pairs are multiplication
47 #TODO Describe how to apply multiplication at output ports -- this is not specified. In the current definition
48 we can pass through only scalar multiplication, but we already requre passing it channel-wise.
53 return [DeleteControlFlowEdges]
56 # CheckForCycle and DeleteNotExecutable run graph clean up which should not be run before weights binarization
57 return [CheckForCycle, DeleteNotExecutable]
62 ('quantize', dict(kind='op', op='Quantize')),
63 ('quantized', dict()),
64 ('operator', dict(kind='op', multiplication_transparent=True)),
67 ('quantize', 'quantized'),
68 ('quantized', 'operator'),
72 def replace_pattern(self, graph: Graph, match: dict):
73 assert match['operator'].has('multiplication_transparent_ports')
75 port = match['operator'].input_ports_with(match['quantized'])
78 log.debug('BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more'
79 ' than once'.format(match['quantized'].name))
84 applicable = [pair for pair in match['operator'].multiplication_transparent_ports if pair[0] == port]
85 if len(applicable) == 0:
88 # Look at 3-rd and 4-th inputs of Quantize -- they have constants that should be passed through.
89 # Assume that the constant that should be passed through is a scalar.
90 quantize = match['quantize']
91 output_low = quantize.in_node(3)
92 output_high = quantize.in_node(4)
94 if not output_low.has_valid('value') and not output_high.has_valid('value'):
97 output_low = output_low.value
98 output_high = output_high.value
100 # This pass is applicable for binarization only. Other intX variants are not relevant.
101 if quantize.levels != 2:
104 # Recognize two cases: 0/+1 and -1/+1.
105 zp1 = np.all(output_low == 0) or np.all(output_high == 0)
106 m1p1 = np.all(-output_low == output_high)
107 if (not zp1 and not m1p1) or (zp1 and m1p1):
108 log.debug('BinarizeWeightsM1P1 cannot apply transformation for data {} because it does\'t has one of'
109 ' 0/+1 or -1/+1 forms.'.format(match['quantized'].name))
113 if len(np.unique(output_low)) != 1 or len(np.unique(output_high)) != 1:
114 log.debug('BinarizeWeightsM1P1 cannot apply transformation for data {} because output_low or output_high '
115 'cannot be interpreted as scalars.'.format(match['quantized'].name))
118 # TODO: Extract real scalar from 3rd and 4th inputs; reusing original tensors is dangerous because
119 # it may have incompatible shape.
121 mult_term = quantize.in_node(3) if np.all(output_high == 0) else quantize.in_node(4)
123 # Patch inflow path (by diving by mult_term)
124 # Put a new Power/Mul combination here:
125 # ---->---- (here)---> data ---> [3rd/4th ports]quantize ---> quantized ---> operator
127 if len(match['quantized'].out_nodes()) > 1:
128 log.debug('BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1')
130 div_op = Power(graph, {'name': quantize.name + '/DivNormalize', 'power': -1.0})
131 div_output = div_op.create_node_with_data([mult_term])
134 match['quantize'].insert_node_with_data_before(
135 match['quantize'].in_node(i),
137 dict(name=quantize.name + '/MulNormalize'),
138 additional_inputs=[div_output],
141 match['quantized'].value = None # reset value because it will be recomputed
142 match['quantize'].infer(match['quantize'])
144 # Put a complimentary new Mul node here: operator -->---(here)-----> operator.out_node()
146 match['operator'].insert_node_with_data_after(
147 match['operator'].out_node(),
149 dict(name=match['operator'].name + '/MulNormalize'),
153 # Disable 'operator' fusion with linear ops, otherwise it will annihilate changes that we just made
154 match['operator']['can_be_fused'] = False