"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import networkx as nx
-
from mo.back.replacement import BackReplacementPattern
+from mo.graph.graph import Graph
class KaldiRemoveMemoryOutputBackReplacementPattern(BackReplacementPattern):
def pattern():
return dict(
nodes=[
- ('memory_node', dict(kind='op', op='Memory')),
- ('data_node', dict(kind='data'))
+ ('memory_node', dict(op='Memory')),
+ ('data_node', dict(kind='data')),
+ ('op_output', dict(op='OpOutput'))
],
edges=[
- ('memory_node', 'data_node', {'out': 0})
+ ('memory_node', 'data_node'),
+ ('data_node', 'op_output')
]
)
@staticmethod
- def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(graph: Graph, match: dict):
"""
Need to find the pattern: Memory -> Data -> OpOutput
Parameters
----------
- graph : nx.MultiDiGraph
+ graph : Graph
Graph with loaded model.
match : dict
Patterns which were found in graph structure.
memory = match['memory_node']
data = match['data_node']
- # Those Memory nodes that are not output ones, should not be replaced
- if not data.has_and_set('is_output'):
- return
graph.remove_edge(memory.id, data.id)
graph.remove_node(data.id)