Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / BlockLSTM.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import networkx as nx
20
21 from mo.front.common.replacement import FrontReplacementOp
22 from mo.graph.graph import Node, Graph
23 from mo.utils.error import Error
24
25
26 class BlockLSTM(FrontReplacementOp):
27     """
28     We prepare TensorFlow BlockLSTM op to be replaced with LSTMSequence op that will be repacked to TensorIterator later
29
30     TensorFlow BlockLSTM op description:
31
32         Op parameters:
33          cell_clip:    Value to clip the 'cs' value to.
34          use_peephole: Whether to use peephole weights.
35          forget_bias:  The forget gate bias.
36
37         Inputs:
38          0: seq_len_max:  Maximum time length actually used by this input. Outputs are padded with 0s beyond this length
39          1: x:            The sequence input to the LSTM, shape (timelen, batch_size, num_inputs)
40          2: cs_prev:      Value of the initial cell state
41          3: h_prev:       Initial output of cell (to be used for peephole)
42          4: w:            The weight matrix
43          5: wci:          The weight matrix for input gate peephole connection
44          6: wcf:          The weight matrix for forget gate peephole connection
45          7: wco:          The weight matrix for output gate peephole connection
46          8: b:            The bias vector
47
48         Outputs:
49          0: i:            The input gate                    over the whole time sequence
50          1: cs:           The cell state before the tanh    over the whole time sequence
51          2: f:            The forget gate                   over the whole time sequence
52          3: o:            The output gate                   over the whole time sequence
53          4: ci:           The cell input                    over the whole time sequence
54          5: co:           The cell after the tanh           over the whole time sequence
55          6: h:            The output h vector               over the whole time sequence
56
57     Limitations:
58     - peephole connection, so we check `use_peephole`!=True and cut `wci`, `wco`, `wcf` off
59     - cell_clip parameter, so we check `cell_clip==-1`, which means we do not clip
60     """
61     op = "BlockLSTM"
62     enabled = True
63
64     def nodes_to_remove(self, graph: Graph, match: dict):
65         # do not remove matched node
66         return []
67
68     @staticmethod
69     def find_key_by_input_port(u: Node, v: Node, p: int):
70         key = None
71         for k, edge_info in u.graph.get_edge_data(u.id, v.id).items():
72             if p == edge_info['in']:
73                 return k
74         return key
75
76     def replace_op(self, graph: Graph, node: Node):
77         if node.use_peephole:
78             raise Error("BlockLSTM operation is not supported with `use_peephole`==True. Node: {}"
79                         "".format(node.soft_get('name')))
80
81         if node.cell_clip != -1:
82             raise Error("Clipping is not supported for BlockLSTM operation. `cell_clip`={!s} for node: {}"
83                         "".format(node.cell_clip, node.soft_get('name')))
84
85         log.debug("Start BlockLSTM->LSTMSequence translation for node: {} with parameters:\n"
86                   "`cell_clip`={!s}, `use_peephole`=={!s}, `forget_bias`={!s}\n"
87                   "inputs: {},\noutputs:{}".format(node.soft_get('name'), node.cell_clip, node.use_peephole,
88                                                    node.forget_bias, {p: i.id for p, i in node.in_nodes().items()},
89                                                    {p: o.id for p, o in node.out_nodes().items()}))
90
91         log.debug("Cutting all inputs for peephole connection (5, 6, 7 input ports) off, as `use_peephole`=False")
92
93         for p, input_data in node.in_nodes().items():
94             if p in [5, 6, 7]:
95                 key = self.find_key_by_input_port(node.in_node(p), node, p)
96                 assert key is not None
97                 graph.remove_edge(node.in_node(p).id, node.id, key=key)
98
99         log.debug("Cutting seq_len_max input off")
100         graph.remove_edge(node.in_node(0).id, node.id)
101
102         """
103         Reconnecting input edges of LSTMSequence:
104         TF input edges:             Description:                 MO input edges:
105               1                          input                        0
106               4                         weights                       1
107               8                         biases                        2
108               3               h_prev: initial output of cell          3
109               2               cs_prev: initial cell state             4
110         """
111         inputs = node.in_edges()
112         assert 1 in inputs, "Sequence input to the BlockLSTM is required (1 port). Node {}".format(node.id)
113         assert 2 in inputs, "Value of the initial cell state is required (2 port). Node {}".format(node.id)
114         assert 3 in inputs, "Initial output of cell is required input to BlockLSTM (3 port). Node {}".format(node.id)
115         assert 4 in inputs, "The weight matrix is required input to BlockLSTM (4 port) . Node {}".format(node.id)
116         assert 8 in inputs, "The bias vector is required input to BlockLSTM (8 port). Node {}".format(node.id)
117
118         inputs[3]['in'] = 3
119         inputs[1]['in'] = 0
120         inputs[4]['in'] = 1
121         inputs[2]['in'] = 4
122         inputs[8]['in'] = 2
123
124         log.debug("Checking for unsupported outputs usage (output ports: 0, 2, 3, 4, 5)")
125         for port, input_data in node.out_nodes().items():
126             if port in [0, 2, 3, 4, 5]:
127                 raise Error("Output port {} of BlockLSTM node {} is not supported".format(node.id, port))
128
129         """
130         Reconnecting output edges of LSTMSequence:
131         TF output edges:             Description:                 MO output edges:
132               6                     output h vector                     0
133               1                   cell state before the tanh            1
134         """
135
136         outputs = node.out_edges()
137         if 6 in outputs:
138             outputs[6]['out'] = 0
139
140         # do not replace any output edge
141         return []