2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.front.common.replacement import FrontReplacementOp
22 from mo.graph.graph import Node, Graph
23 from mo.utils.error import Error
26 class BlockLSTM(FrontReplacementOp):
28 We prepare TensorFlow BlockLSTM op to be replaced with LSTMSequence op that will be repacked to TensorIterator later
30 TensorFlow BlockLSTM op description:
33 cell_clip: Value to clip the 'cs' value to.
34 use_peephole: Whether to use peephole weights.
35 forget_bias: The forget gate bias.
38 0: seq_len_max: Maximum time length actually used by this input. Outputs are padded with 0s beyond this length
39 1: x: The sequence input to the LSTM, shape (timelen, batch_size, num_inputs)
40 2: cs_prev: Value of the initial cell state
41 3: h_prev: Initial output of cell (to be used for peephole)
42 4: w: The weight matrix
43 5: wci: The weight matrix for input gate peephole connection
44 6: wcf: The weight matrix for forget gate peephole connection
45 7: wco: The weight matrix for output gate peephole connection
49 0: i: The input gate over the whole time sequence
50 1: cs: The cell state before the tanh over the whole time sequence
51 2: f: The forget gate over the whole time sequence
52 3: o: The output gate over the whole time sequence
53 4: ci: The cell input over the whole time sequence
54 5: co: The cell after the tanh over the whole time sequence
55 6: h: The output h vector over the whole time sequence
58 - peephole connection, so we check `use_peephole`!=True and cut `wci`, `wco`, `wcf` off
59 - cell_clip parameter, so we check `cell_clip==-1`, which means we do not clip
64 def nodes_to_remove(self, graph: Graph, match: dict):
65 # do not remove matched node
69 def find_key_by_input_port(u: Node, v: Node, p: int):
71 for k, edge_info in u.graph.get_edge_data(u.id, v.id).items():
72 if p == edge_info['in']:
76 def replace_op(self, graph: Graph, node: Node):
78 raise Error("BlockLSTM operation is not supported with `use_peephole`==True. Node: {}"
79 "".format(node.soft_get('name')))
81 if node.cell_clip != -1:
82 raise Error("Clipping is not supported for BlockLSTM operation. `cell_clip`={!s} for node: {}"
83 "".format(node.cell_clip, node.soft_get('name')))
85 log.debug("Start BlockLSTM->LSTMSequence translation for node: {} with parameters:\n"
86 "`cell_clip`={!s}, `use_peephole`=={!s}, `forget_bias`={!s}\n"
87 "inputs: {},\noutputs:{}".format(node.soft_get('name'), node.cell_clip, node.use_peephole,
88 node.forget_bias, {p: i.id for p, i in node.in_nodes().items()},
89 {p: o.id for p, o in node.out_nodes().items()}))
91 log.debug("Cutting all inputs for peephole connection (5, 6, 7 input ports) off, as `use_peephole`=False")
93 for p, input_data in node.in_nodes().items():
95 key = self.find_key_by_input_port(node.in_node(p), node, p)
96 assert key is not None
97 graph.remove_edge(node.in_node(p).id, node.id, key=key)
99 log.debug("Cutting seq_len_max input off")
100 graph.remove_edge(node.in_node(0).id, node.id)
103 Reconnecting input edges of LSTMSequence:
104 TF input edges: Description: MO input edges:
108 3 h_prev: initial output of cell 3
109 2 cs_prev: initial cell state 4
111 inputs = node.in_edges()
112 assert 1 in inputs, "Sequence input to the BlockLSTM is required (1 port). Node {}".format(node.id)
113 assert 2 in inputs, "Value of the initial cell state is required (2 port). Node {}".format(node.id)
114 assert 3 in inputs, "Initial output of cell is required input to BlockLSTM (3 port). Node {}".format(node.id)
115 assert 4 in inputs, "The weight matrix is required input to BlockLSTM (4 port) . Node {}".format(node.id)
116 assert 8 in inputs, "The bias vector is required input to BlockLSTM (8 port). Node {}".format(node.id)
124 log.debug("Checking for unsupported outputs usage (output ports: 0, 2, 3, 4, 5)")
125 for port, input_data in node.out_nodes().items():
126 if port in [0, 2, 3, 4, 5]:
127 raise Error("Output port {} of BlockLSTM node {} is not supported".format(node.id, port))
130 Reconnecting output edges of LSTMSequence:
131 TF output edges: Description: MO output edges:
133 1 cell state before the tanh 1
136 outputs = node.out_edges()
138 outputs[6]['out'] = 0
140 # do not replace any output edge