[MO] Fix ONNX Clamp-11 shape infer with no min/max inputs (#2603)
authorPavel Esir <pavel.esir@intel.com>
Mon, 12 Oct 2020 06:55:45 +0000 (09:55 +0300)
committerGitHub <noreply@github.com>
Mon, 12 Oct 2020 06:55:45 +0000 (09:55 +0300)
model-optimizer/extensions/back/ClampNormalizer.py
model-optimizer/extensions/back/ClampNormalizer_test.py
model-optimizer/mo/ops/clamp.py

index 52fe3c7..a702b95 100644 (file)
@@ -63,7 +63,7 @@ class ClampNormalizer(BackReplacementPattern):
                     clamp.out_port(0).get_connection().set_source(min_node.out_port(0))
                 clamp.in_port(2).get_connection().set_destination(min_node.in_port(1))
             assert min_node is not None or max_node is not None, 'Clamp node should have either min or max input used'
-            rename_node(max_node if min_node is None else min_node, name)
+            rename_node(min_node if min_node is not None else max_node, name)
         else:
             a_clamp = AttributedClamp(graph, {'name': name, 'min': min_value, 'max': max_value}).create_node()
             rename_node(a_clamp, name)
index eb5b07e..e1b8c2e 100644 (file)
@@ -73,7 +73,7 @@ class AttributedClampNormalizerTests(unittest.TestCase):
         (flag, resp) = compare_graphs(graph, ref_graph, 'result')
         self.assertTrue(flag, resp)
 
-    def test_no_2nd_input(self):
+    def test_no_max_input(self):
         nodes = {
             **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
             **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
@@ -95,7 +95,7 @@ class AttributedClampNormalizerTests(unittest.TestCase):
         (flag, resp) = compare_graphs(graph, ref_graph, 'result')
         self.assertTrue(flag, resp)
 
-    def test_no_1st_input(self):
+    def test_no_min_input(self):
         nodes = {
             **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
             **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
index c22f59f..b2942bb 100644 (file)
@@ -71,14 +71,13 @@ class Clamp(Op):
     @staticmethod
     def infer(node):
         name = node.soft_get('name', node.id)
-        connected_in_ports = [port.idx for port in node.in_ports().values() if not port.disconnected()]
-
-        assert len(connected_in_ports) == 3 and sorted(connected_in_ports) == [0, 1, 2], \
-            'Clamp should have exactly three inputs, but it has {}'.format(len(connected_in_ports))
+        min_input_connected = node.has_port('in', 1) and not node.in_port(1).disconnected()
+        max_input_connected = node.has_port('in', 2) and not node.in_port(2).disconnected()
 
         input_value = node.in_port(0).data.get_value()
-        min_value = node.in_port(1).data.get_value()
-        max_value = node.in_port(2).data.get_value()
+        min_value = node.in_port(1).data.get_value() if min_input_connected else np.finfo(np.float32).min
+        max_value = node.in_port(2).data.get_value() if max_input_connected else np.finfo(np.float32).max
+
         if input_value is not None and min_value is not None and max_value is not None:
             assert np.all(max_value >= min_value), \
                 'Clamp max_value=={} is less than min_value=={} for node `{}`'.format(max_value, min_value, name)