"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
},
'output_node': {
'kind': 'data'
+ },
+ 'op_output': {
+ 'kind': 'data',
+ 'op': 'OpOutput',
}
}
def test_remove_out_data_for_memory(self):
- graph = build_graph(self.nodes, [('input_node', 'memory_node')])
- # Need for matching in pattern. The edge memory_node->out_node must contain only the attribute 'out' = 0
- # build_graph creates edge memory_node->out_node with attributes 'in' and 'out'
- graph.add_node('output_node', is_output=True, **self.nodes['output_node'])
- graph.add_edge('memory_node', 'output_node', out=0)
+ graph = build_graph(self.nodes,
+ [
+ ('input_node', 'memory_node'),
+ ('memory_node', 'output_node'),
+ ('output_node', 'op_output')
+ ])
KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)
self.assertNotIn('output_node', graph.node)
def test_do_not_remove_out_data_for_memory(self):
- graph = build_graph(self.nodes, [('input_node', 'memory_node')])
- graph.add_node('output_node', **self.nodes['output_node'])
- graph.add_edge('memory_node', 'output_node', out=0)
+ graph = build_graph(self.nodes,
+ [
+ ('input_node', 'memory_node'),
+ ('memory_node', 'output_node'),
+ ])
KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)
self.assertIn('output_node', graph.node)