Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / fake_const.py
index 2a487ef..0ba7579 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 """
 import logging as log
 
-import networkx as nx
+import numpy as np
 
+from mo.front.common.partial_infer.utils import int64_array
 from mo.front.common.replacement import FrontReplacementOp
 from mo.front.tf.extractors.utils import tf_dtype_extractor
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
 from mo.ops.const import Const
 
 
@@ -27,7 +28,7 @@ class FakeConstToConst(FrontReplacementOp):
     op = "FakeConst"
     enabled = True
 
-    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_sub_graph(self, graph: Graph, match: dict):
         node = match['op']
         if not node.has_valid('value'):
             log.debug("No value in FakeConst node {}".format(node.id))
@@ -35,7 +36,7 @@ class FakeConstToConst(FrontReplacementOp):
         node_value = node.value
         extracted_attrs = {
             'data_type': tf_dtype_extractor(node.pb.attr['dtype'].type),
-            'shape': node_value.shape,
+            'shape': int64_array(node_value.shape),
             'value': node_value
         }
         Const.update_node_stat(node, extracted_attrs)