Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / ops / tensor_iterator.py
index faaf9a7..c5bc888 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2017-2018 Intel Corporation
+ Copyright (c) 2017-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@ import networkx as nx
 import numpy as np
 
 from mo.utils.error import Error
-from mo.graph.graph import Node, dict_includes
+from mo.graph.graph import Node, dict_includes, Graph
 from mo.ops.op import Op
 from mo.utils.utils import refer_to_faq_msg
 
@@ -32,14 +32,14 @@ class TensorIterator(Op):
     op = 'TensorIterator'
 
 
-    def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
+    def __init__(self, graph: Graph, attrs: dict):
         mandatory_props = {
             'type': __class__.op,
             'op': __class__.op,
             'input_port_map': [],  # a list of dicts with such attrs as external_port_id, etc.
             'output_port_map': [],  # a list of dicts with such attrs as external_port_id, etc.
             'back_edges': [], # a list of dicts with such attrs as from_layer, from_port, etc.
-            'body': None,   # an nx.MultiDiGraph object with a body sub-graph
+            'body': None,   # an Graph object with a body sub-graph
             'sub_graphs': ['body'],  # built-in attribute with all sub-graphg
             'infer': __class__.infer
         }
@@ -96,14 +96,14 @@ class TensorIterator(Op):
 
 
     @staticmethod
-    def find_internal_layer_id(graph: nx.MultiDiGraph, virtual_id):
+    def find_internal_layer_id(graph: Graph, virtual_id):
         internal_nodes = list(filter(lambda d: dict_includes(d[1], {'internal_layer_id': virtual_id}), graph.nodes(data=True)))
         assert len(internal_nodes) == 1, 'Nodes: {}, virtual_id: {}'.format(internal_nodes, virtual_id)
         return  internal_nodes[0][0]
 
 
     @staticmethod
-    def find_internal_layer_and_port(graph: nx.MultiDiGraph, virtual_layer_id, virtual_port_id):
+    def find_internal_layer_and_port(graph: Graph, virtual_layer_id, virtual_port_id):
         internal_layer_id = __class__.find_internal_layer_id(graph, virtual_layer_id)
         internal_port_id = __class__.find_port_id(Node(graph, internal_layer_id), virtual_port_id, 'internal_port_id')
         return internal_layer_id, internal_port_id
@@ -111,11 +111,11 @@ class TensorIterator(Op):
 
     @staticmethod
     def generate_port_map(node: Node, src_port_map):
-        ''' Extract port_map attributes from node and node.body attributes.
+        """ Extract port_map attributes from node and node.body attributes.
         
             It iterates over src_port_map and substitude external_port_id, internal_port_id and
             internal_layer_id by real values queried from node ports and node.body attributes.
-        '''
+        """
         result_list = []
         for map_item in src_port_map:
             result = dict(map_item)