@staticmethod
def infer(node):
name = node.soft_get('name', node.id)
- connected_in_ports = [port.idx for port in node.in_ports().values() if not port.disconnected()]
-
- assert len(connected_in_ports) == 3 and sorted(connected_in_ports) == [0, 1, 2], \
- 'Clamp should have exactly three inputs, but it has {}'.format(len(connected_in_ports))
+ min_input_connected = node.has_port('in', 1) and not node.in_port(1).disconnected()
+ max_input_connected = node.has_port('in', 2) and not node.in_port(2).disconnected()
input_value = node.in_port(0).data.get_value()
- min_value = node.in_port(1).data.get_value()
- max_value = node.in_port(2).data.get_value()
+ min_value = node.in_port(1).data.get_value() if min_input_connected else np.finfo(np.float32).min
+ max_value = node.in_port(2).data.get_value() if max_input_connected else np.finfo(np.float32).max
+
if input_value is not None and min_value is not None and max_value is not None:
assert np.all(max_value >= min_value), \
'Clamp max_value=={} is less than min_value=={} for node `{}`'.format(max_value, min_value, name)