"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import networkx as nx
from mo.front.common.replacement import FrontReplacementOp
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
from mo.utils.error import Error
op = "BlockLSTM"
enabled = True
- def nodes_to_remove(self, graph: nx.MultiDiGraph, match: dict):
+ def nodes_to_remove(self, graph: Graph, match: dict):
# do not remove matched node
return []
- def replace_op(self, graph: nx.MultiDiGraph, node: Node):
+ @staticmethod
+ def find_key_by_input_port(u: Node, v: Node, p: int):
+ key = None
+ for k, edge_info in u.graph.get_edge_data(u.id, v.id).items():
+ if p == edge_info['in']:
+ return k
+ return key
+
+ def replace_op(self, graph: Graph, node: Node):
if node.use_peephole:
raise Error("BlockLSTM operation is not supported with `use_peephole`==True. Node: {}"
"".format(node.soft_get('name')))
{p: o.id for p, o in node.out_nodes().items()}))
log.debug("Cutting all inputs for peephole connection (5, 6, 7 input ports) off, as `use_peephole`=False")
- [graph.remove_edge(node.in_node(p).id, node.id) for p, input_data in node.in_nodes().items() if p in [5, 6, 7]]
+
+ for p, input_data in node.in_nodes().items():
+ if p in [5, 6, 7]:
+ key = self.find_key_by_input_port(node.in_node(p), node, p)
+ assert key is not None
+ graph.remove_edge(node.in_node(p).id, node.id, key=key)
log.debug("Cutting seq_len_max input off")
graph.remove_edge(node.in_node(0).id, node.id)