import numpy as np
+from mo.front.common.partial_infer.utils import int64_array
from mo.ops.op import Op
mandatory_props = dict(
type=__class__.op,
op=__class__.op,
- infer=__class__.infer
+ infer=__class__.infer,
+ in_ports_count=4,
+ out_ports_count=4,
)
super().__init__(graph, mandatory_props, attrs)
rois_num = node.max_detections_per_image
# boxes
node.out_node(0).shape = np.array([rois_num, 4], dtype=np.int64)
- try:
- # classes
- node.out_node(1).shape = np.array([rois_num], dtype=np.int64)
- # scores
- node.out_node(2).shape = np.array([rois_num], dtype=np.int64)
- # batch_ids
- node.out_node(3).shape = np.array([rois_num], dtype=np.int64)
- except Exception as ex:
- print(ex)
+ # classes, scores, batch indices
+ for port_ind in range(1, 4):
+ if not node.out_port(port_ind).disconnected():
+ node.out_port(port_ind).data.set_shape(int64_array([rois_num]))