Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / tf / partial_infer / tf.py
index a7247b9..ef35889 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ from google.protobuf import text_format
 
 from mo.front.extractor import node_defs_to_str
 from mo.front.tf.extractors.utils import tf_dtype_extractor, tf_tensor_shape, get_tf_node_port
-from mo.graph.graph import Node, get_sorted_inputs, get_inputs, create_sub_graph_copy
+from mo.graph.graph import Node
 from mo.utils.graph import node_incoming_neighbourhood, node_outcoming_neighbourhood
 
 
@@ -41,7 +41,7 @@ def tf_native_tf_node_infer(node: Node):
     # Also the sub-graph contains names of the output nodes of the node to perform native infer.
     nodes_to_extract = node_incoming_neighbourhood(node.graph, node.id, 10) + node_outcoming_neighbourhood(node.graph,
                                                                                                            node.id, 1)
-    tmp_graph = create_sub_graph_copy(node.graph, nodes_to_extract)
+    tmp_graph = node.graph.create_sub_graph_copy(nodes_to_extract)
 
     tmp_node_attrs = tmp_graph.node[node.id]
     tmp_node = Node(tmp_graph, node.id)
@@ -82,7 +82,7 @@ def generate_feed_dict(graph: tf.Graph, node: Node):
     """
     all_constants = True
     feed_dict = dict()
-    for in_data_node_name, edge_attrs in get_inputs(node.graph, node.id):
+    for in_data_node_name, edge_attrs in node.get_inputs():
         if 'control_flow_edge' in edge_attrs and edge_attrs['control_flow_edge']:
             continue
         value = node.in_node(edge_attrs['in']).value
@@ -198,7 +198,7 @@ def add_placeholders_to_subgraph(node: Node):
     :return: None
     """
     inputs_replacements = list()
-    for index, (in_data_node, edge_attrs) in enumerate(get_sorted_inputs(node)):
+    for index, (in_data_node, edge_attrs) in enumerate(node.get_sorted_inputs()):
         if 'control_flow_edge' in edge_attrs and edge_attrs['control_flow_edge']:
             continue