Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / BlockLSTM.py
index 3e1bed4..cd0247f 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -19,7 +19,7 @@ import logging as log
 import networkx as nx
 
 from mo.front.common.replacement import FrontReplacementOp
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
 from mo.utils.error import Error
 
 
@@ -61,11 +61,19 @@ class BlockLSTM(FrontReplacementOp):
     op = "BlockLSTM"
     enabled = True
 
-    def nodes_to_remove(self, graph: nx.MultiDiGraph, match: dict):
+    def nodes_to_remove(self, graph: Graph, match: dict):
         # do not remove matched node
         return []
 
-    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
+    @staticmethod
+    def find_key_by_input_port(u: Node, v: Node, p: int):
+        key = None
+        for k, edge_info in u.graph.get_edge_data(u.id, v.id).items():
+            if p == edge_info['in']:
+                return k
+        return key
+
+    def replace_op(self, graph: Graph, node: Node):
         if node.use_peephole:
             raise Error("BlockLSTM operation is not supported with `use_peephole`==True. Node: {}"
                         "".format(node.soft_get('name')))
@@ -81,7 +89,12 @@ class BlockLSTM(FrontReplacementOp):
                                                    {p: o.id for p, o in node.out_nodes().items()}))
 
         log.debug("Cutting all inputs for peephole connection (5, 6, 7 input ports) off, as `use_peephole`=False")
-        [graph.remove_edge(node.in_node(p).id, node.id) for p, input_data in node.in_nodes().items() if p in [5, 6, 7]]
+
+        for p, input_data in node.in_nodes().items():
+            if p in [5, 6, 7]:
+                key = self.find_key_by_input_port(node.in_node(p), node, p)
+                assert key is not None
+                graph.remove_edge(node.in_node(p).id, node.id, key=key)
 
         log.debug("Cutting seq_len_max input off")
         graph.remove_edge(node.in_node(0).id, node.id)