Publishing R5 content (#72)
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TF_lstm_cell_to_generic.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import networkx as nx
18 import numpy as np
19
20 from extensions.middle.FusePermutesSequence import FusePermutesSequence
21 from mo.middle.replacement import MiddleReplacementPattern
22
23
24 class TensorFlowLSTMtoGeneric(MiddleReplacementPattern):
25     """
26     Resolves all differences in TensorFlow LSTMCell and Inference Engine LSTMCell:
27     - weights transposing
28     - shift_const value addition to biases
29     - extra inputs deletion
30     """
31     enabled = True
32
33     def run_after(self):
34         return []
35
36     def run_before(self):
37         return [
38             FusePermutesSequence,
39         ]
40
41     def pattern(self):
42         return dict(
43             nodes=[('lstm', dict(op='LSTMCell', tf=True))],
44             edges=[]
45         )
46
47     def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
48         weights_node = match['lstm'].in_node(3)
49         biases_node = match['lstm'].in_node(4)
50         node = match['lstm']
51         shift_const = node.shift_const
52
53         # make sure that the node is the only consumer or weights and biases
54         # to let us modify them without hassle
55         assert len(weights_node.out_nodes()) == 1
56         assert len(biases_node.out_nodes()) == 1
57
58         # Assign temporary shape for them for easier manipulation
59         # TF stores weights in IO order
60         input_size = node.in_node(0).shape[1]
61         hidden_size = node.in_node(1).shape[1]
62         weights = weights_node.value
63         biases = biases_node.value
64         assert weights.shape[0] == input_size + hidden_size, "weights.shape={} input_size={} hidden_size={}".format(
65             weights.shape, input_size, hidden_size)
66         assert weights.shape[1] == biases.shape[0] == 4 * hidden_size,\
67             "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)
68
69         weights = weights.reshape([
70             weights.shape[0],
71             4,  # gates
72             hidden_size
73         ])
74
75         biases = biases.reshape([
76             4,  # gates
77             hidden_size
78         ])
79
80         # Reorder gates icfo --> fico for both weights and biases
81         gate_reorder = [2, 0, 1, 3]
82         weights = np.take(weights, gate_reorder, axis=1)
83         biases = np.take(biases, gate_reorder, axis=0)
84
85         # shift_const.value should be added to the first 1/4th part of the biases (f-gate: 0)
86         # Note: in case of moving this code up before gate reordering, the addition
87         # should be applied at different place
88         biases[0] += shift_const
89
90         # Return to the original shapes
91         weights = weights.reshape([weights.shape[0], -1])
92         biases = biases.flatten()
93
94         # TF stores weights in IO, but IE requires it in OI: transpose
95         weights = weights.transpose()
96
97         weights_node.value = weights
98         weights_node.shape = np.array(weights.shape, dtype=np.int64)
99         biases_node.value = biases
100         biases_node.shape = np.array(biases.shape, dtype=np.int64)
101
102         # Cut all extra inputs off
103         for i in range(len(node.inputs), len(node.inputs) + len(node.extra_inputs)):
104             node.graph.remove_edge(node.in_node(i).id, node.id)