Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / fusing / fuse_linear_seq.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 from collections import deque
19
20 import networkx as nx
21 import numpy as np
22
23 from mo.front.extractor import add_attrs_props
24 from mo.middle.passes.eliminate import graph_clean_up
25 from mo.utils.graph import pseudo_topological_sort
26 from mo.ops.lin_op import Mul, Add
27 from mo.middle.passes.eliminate import merge_data_nodes
28 from mo.ops.op import Op
29 from mo.graph.graph import Node, Graph
30 from mo.middle.passes.fusing.helpers import backward_bfs, forward_bfs, get_tensor_id, get_value_id
31
32
33 def _fuse_linear_sequence(graph: Graph, start_node: Node):
34     """
35     This function finds the sequence of Mul/Add operations and replaces this sequence with two ops (Mul->Add).
36     :param graph:
37     :param start_node: The first operation of the sequence
38     """
39     fnodes = [start_node]
40     while True:
41         node = fnodes[-1]
42         data_node = node.out_node()
43         if (len(data_node.out_nodes()) != 1):
44             break
45         if (data_node.out_node().op in ['Mul', 'Add']) and get_value_id(data_node.out_node()) is not None and data_node.out_node().soft_get('can_be_fused') == True:
46             fnodes.append(data_node.out_node())
47         else:
48             break
49
50     if len(fnodes) == 1 or (len(fnodes) == 2 and fnodes[0].op == 'Mul' and fnodes[1].op == 'Add'):
51         return False
52
53     input_shape = start_node.in_node(get_tensor_id(start_node)).shape
54
55     init_dims_cnt = len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 1
56
57     mul = np.ones([1 for x in range(init_dims_cnt)])
58     add = np.zeros([1 for x in range(init_dims_cnt)])
59
60     first_mul_name = None
61     first_add_name = None
62
63     for idx in range(len(fnodes)):
64         node = fnodes[idx]
65         const_node = get_value_id(node)
66         if node.op == 'Mul':
67             if first_mul_name is None:
68                 first_mul_name = node.name
69             mul = mul * node.in_node(const_node).value
70             add = add * node.in_node(const_node).value
71         elif node.op == 'Add':
72             if first_add_name is None:
73                 first_add_name = node.name
74             add = add + node.in_node(const_node).value
75
76     # If mul is scalar we broadcast it to biases shape
77     if mul.shape != add.shape and len(mul.shape) == 1 and mul.shape[0] == 1:
78         mul = np.array([mul[0] for x in range(add.shape[0])])
79
80     assert (np.array_equal(fnodes[0].in_node(get_tensor_id(fnodes[0])).shape, fnodes[-1].out_node().shape))
81
82     mul_node = Mul(graph, dict(name=first_mul_name + '/Fused_Mul_' if first_mul_name is not None else ''))
83     add_node = Add(graph, dict(name=first_add_name + '/Fused_Add_' if first_add_name is not None else ''))
84
85     in_node = fnodes[0].in_node(get_tensor_id(fnodes[0]))
86     out_node = fnodes[-1].out_node()
87
88     graph.remove_edge(in_node.id, fnodes[0].id)
89     graph.remove_edge(fnodes[-1].id, out_node.id)
90
91     # Remove deleted subgraph
92     for node in fnodes:
93         for tmp_node in node.in_nodes().values():
94             # Remove node only if it has one consumer (for case with shared weights)
95             if len(tmp_node.out_nodes()) == 1:
96                 graph.remove_node(tmp_node.id)
97         for tmp_node in node.out_nodes().values():
98             graph.remove_node(tmp_node.id)
99         graph.remove_node(node.id)
100
101     """
102     Four cases considered below:
103         1. Mul and Add have valid values (mul value != 1 and add value != 0)
104         2. Only Mul has valid values, so we add only Mul node
105         3. Only Add has valid values, so we add only Add node
106         4. When Mul and Add has not valid values we just merge two data nodes
107     """
108     if any([x != 0 for x in np.nditer(add)]) and any([x != 1 for x in np.nditer(mul)]):
109         data_mul = Op.create_input_data_node(graph, "data_mul_", np.array(mul))
110         data_add = Op.create_input_data_node(graph, "data_add_", np.array(add))
111         add_node.create_node_with_data(inputs=[mul_node.create_node_with_data([in_node, data_mul]), data_add],
112                                        data_nodes=out_node)
113     elif any([x != 1 for x in np.nditer(mul)]):
114         data_mul = Op.create_input_data_node(graph, "data_mul_", np.array(mul))
115         mul_node.create_node_with_data(inputs=[in_node, data_mul], data_nodes=out_node)
116     elif any([x != 0 for x in np.nditer(add)]):
117         data_add = Op.create_input_data_node(graph, "data_add_", np.array(add))
118         add_node.create_node_with_data(inputs=[in_node, data_add], data_nodes=out_node)
119     else:
120         merge_data_nodes(graph,out_node, in_node)
121         graph.remove_node(in_node.id)
122
123     log.debug('Fused {} operations'.format(len(fnodes)))
124     return True
125
126
127 def fuse_mul_add_sequence(graph: Graph):
128     """
129     This function finds first valid Mul/Add node and pass it to fuse_linear_sequence where full sequence will be found
130     """
131     while True:
132         is_fused = False
133         for idx in list(pseudo_topological_sort(graph)):
134             if idx in graph:
135                 node = Node(graph, idx)
136                 if node.soft_get('op') in ['Mul','Add'] and get_value_id(node) is not None and node.soft_get('can_be_fused') == True:
137                     is_fused |= _fuse_linear_sequence(graph, node)
138         if not is_fused:
139             break