"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
"""
import logging as log
-import networkx as nx
+import numpy as np
+from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.extractors.utils import tf_dtype_extractor
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
from mo.ops.const import Const
op = "FakeConst"
enabled = True
- def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_sub_graph(self, graph: Graph, match: dict):
node = match['op']
if not node.has_valid('value'):
log.debug("No value in FakeConst node {}".format(node.id))
node_value = node.value
extracted_attrs = {
'data_type': tf_dtype_extractor(node.pb.attr['dtype'].type),
- 'shape': node_value.shape,
+ 'shape': int64_array(node_value.shape),
'value': node_value
}
Const.update_node_stat(node, extracted_attrs)