clamp.out_port(0).get_connection().set_source(min_node.out_port(0))
clamp.in_port(2).get_connection().set_destination(min_node.in_port(1))
assert min_node is not None or max_node is not None, 'Clamp node should have either min or max input used'
- rename_node(max_node if min_node is None else min_node, name)
+ rename_node(min_node if min_node is not None else max_node, name)
else:
a_clamp = AttributedClamp(graph, {'name': name, 'min': min_value, 'max': max_value}).create_node()
rename_node(a_clamp, name)
(flag, resp) = compare_graphs(graph, ref_graph, 'result')
self.assertTrue(flag, resp)
- def test_no_2nd_input(self):
+ def test_no_max_input(self):
nodes = {
**regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
**regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
(flag, resp) = compare_graphs(graph, ref_graph, 'result')
self.assertTrue(flag, resp)
- def test_no_1st_input(self):
+ def test_no_min_input(self):
nodes = {
**regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
**regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
@staticmethod
def infer(node):
name = node.soft_get('name', node.id)
- connected_in_ports = [port.idx for port in node.in_ports().values() if not port.disconnected()]
-
- assert len(connected_in_ports) == 3 and sorted(connected_in_ports) == [0, 1, 2], \
- 'Clamp should have exactly three inputs, but it has {}'.format(len(connected_in_ports))
+ min_input_connected = node.has_port('in', 1) and not node.in_port(1).disconnected()
+ max_input_connected = node.has_port('in', 2) and not node.in_port(2).disconnected()
input_value = node.in_port(0).data.get_value()
- min_value = node.in_port(1).data.get_value()
- max_value = node.in_port(2).data.get_value()
+ min_value = node.in_port(1).data.get_value() if min_input_connected else np.finfo(np.float32).min
+ max_value = node.in_port(2).data.get_value() if max_input_connected else np.finfo(np.float32).max
+
if input_value is not None and min_value is not None and max_value is not None:
assert np.all(max_value >= min_value), \
'Clamp max_value=={} is less than min_value=={} for node `{}`'.format(max_value, min_value, name)