"""
- Copyright (c) 2017-2018 Intel Corporation
+ Copyright (c) 2017-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import numpy as np
from mo.utils.error import Error
-from mo.graph.graph import Node, dict_includes
+from mo.graph.graph import Node, dict_includes, Graph
from mo.ops.op import Op
from mo.utils.utils import refer_to_faq_msg
op = 'TensorIterator'
- def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
+ def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': __class__.op,
'op': __class__.op,
'input_port_map': [], # a list of dicts with such attrs as external_port_id, etc.
'output_port_map': [], # a list of dicts with such attrs as external_port_id, etc.
'back_edges': [], # a list of dicts with such attrs as from_layer, from_port, etc.
- 'body': None, # an nx.MultiDiGraph object with a body sub-graph
+ 'body': None, # an Graph object with a body sub-graph
'sub_graphs': ['body'], # built-in attribute with all sub-graphg
'infer': __class__.infer
}
@staticmethod
- def find_internal_layer_id(graph: nx.MultiDiGraph, virtual_id):
+ def find_internal_layer_id(graph: Graph, virtual_id):
internal_nodes = list(filter(lambda d: dict_includes(d[1], {'internal_layer_id': virtual_id}), graph.nodes(data=True)))
assert len(internal_nodes) == 1, 'Nodes: {}, virtual_id: {}'.format(internal_nodes, virtual_id)
return internal_nodes[0][0]
@staticmethod
- def find_internal_layer_and_port(graph: nx.MultiDiGraph, virtual_layer_id, virtual_port_id):
+ def find_internal_layer_and_port(graph: Graph, virtual_layer_id, virtual_port_id):
internal_layer_id = __class__.find_internal_layer_id(graph, virtual_layer_id)
internal_port_id = __class__.find_port_id(Node(graph, internal_layer_id), virtual_port_id, 'internal_port_id')
return internal_layer_id, internal_port_id
@staticmethod
def generate_port_map(node: Node, src_port_map):
- ''' Extract port_map attributes from node and node.body attributes.
+ """ Extract port_map attributes from node and node.body attributes.
It iterates over src_port_map and substitude external_port_id, internal_port_id and
internal_layer_id by real values queried from node ports and node.body attributes.
- '''
+ """
result_list = []
for map_item in src_port_map:
result = dict(map_item)