"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import networkx as nx
import numpy as np
+from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
-from mo.graph.graph import replace_node, Node
+from mo.graph.graph import Node, Graph
from mo.utils.error import Error
]
)
- def nodes_to_remove(self, graph: nx.MultiDiGraph, match: dict):
+ def nodes_to_remove(self, graph: Graph, match: dict):
return [match['cast'].id, match['sparse_to_dense']]
- def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_sub_graph(self, graph: Graph, match: dict):
decoder_node = match['decoder']
graph.remove_edge(decoder_node.id, match['sparse_to_dense'].id)
graph.remove_edge(decoder_node.id, match['cast'].id)
- replace_node(match['sparse_to_dense'], decoder_node)
+ match['sparse_to_dense'].replace_node(decoder_node)
# update the TensorFlow infer function for the CTCGreedyDecoder to make necessary changes with the second input
decoder_node['old_infer'] = decoder_node.infer
new_value[:, 0] = 0
new_value = np.transpose(new_value)
sequence_length_node.value = new_value
- sequence_length_node.shape = sequence_length_node.value.shape
+ sequence_length_node.shape = int64_array(sequence_length_node.value.shape)
node.old_infer(node)