"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import networkx as nx
-import numpy as np
-from mo.graph.graph import Node
+
+from mo.graph.graph import Node, Graph
from mo.ops.op import Op
+
# TODO: check all supported attributes in this file
class TensorIteratorInput(Op):
op = "TensorIteratorInput"
- def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
+ def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'op': __class__.op,
'axis': None,
'end': None,
'stride': None,
'part_size': None,
+ 'in_ports_count': 3,
+ 'out_ports_count': 1,
'infer': TensorIteratorInput.input_infer,
}
super().__init__(graph, mandatory_props, attrs)
class TensorIteratorOutput(Op):
op = "TensorIteratorOutput"
- def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
+ def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'op': __class__.op,
'axis': None,
'end': None,
'stride': None,
'part_size': None,
+ 'in_ports_count': 3,
+ 'out_ports_count': 1,
'infer': TensorIteratorOutput.input_infer,
}
super().__init__(graph, mandatory_props, attrs)
class TensorIteratorCondition(Op):
op = "TensorIteratorCondition"
- def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
+ def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'op': __class__.op,
+ 'in_ports_count': 2,
+ 'out_ports_count': 2,
'infer': TensorIteratorCondition.input_infer,
}
super().__init__(graph, mandatory_props, attrs)
- def supported_attrs(self):
- return ['time', 'iter']
-
@staticmethod
def input_infer(node: Node):
pass
class TensorIteratorBackEdge(Op):
op = 'TensorIteratorBackEdge'
- def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
+ def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'op': __class__.op,
+ 'in_ports_count': 3,
+ 'out_ports_count': 1,
'infer': TensorIteratorBackEdge.input_infer,
}
super().__init__(graph, mandatory_props, attrs)
@staticmethod
- def supported_attrs():
- return ['is_output']
-
- @staticmethod
def input_infer(node: Node):
pass