From 9a9b231c983aea34486d7a1407b4d680976634d9 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Mon, 12 Oct 2020 09:55:45 +0300 Subject: [PATCH] [MO] Fix ONNX Clamp-11 shape infer with no min/max inputs (#2603) --- model-optimizer/extensions/back/ClampNormalizer.py | 2 +- model-optimizer/extensions/back/ClampNormalizer_test.py | 4 ++-- model-optimizer/mo/ops/clamp.py | 11 +++++------ 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/model-optimizer/extensions/back/ClampNormalizer.py b/model-optimizer/extensions/back/ClampNormalizer.py index 52fe3c7..a702b95 100644 --- a/model-optimizer/extensions/back/ClampNormalizer.py +++ b/model-optimizer/extensions/back/ClampNormalizer.py @@ -63,7 +63,7 @@ class ClampNormalizer(BackReplacementPattern): clamp.out_port(0).get_connection().set_source(min_node.out_port(0)) clamp.in_port(2).get_connection().set_destination(min_node.in_port(1)) assert min_node is not None or max_node is not None, 'Clamp node should have either min or max input used' - rename_node(max_node if min_node is None else min_node, name) + rename_node(min_node if min_node is not None else max_node, name) else: a_clamp = AttributedClamp(graph, {'name': name, 'min': min_value, 'max': max_value}).create_node() rename_node(a_clamp, name) diff --git a/model-optimizer/extensions/back/ClampNormalizer_test.py b/model-optimizer/extensions/back/ClampNormalizer_test.py index eb5b07e..e1b8c2e 100644 --- a/model-optimizer/extensions/back/ClampNormalizer_test.py +++ b/model-optimizer/extensions/back/ClampNormalizer_test.py @@ -73,7 +73,7 @@ class AttributedClampNormalizerTests(unittest.TestCase): (flag, resp) = compare_graphs(graph, ref_graph, 'result') self.assertTrue(flag, resp) - def test_no_2nd_input(self): + def test_no_max_input(self): nodes = { **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}), **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}), @@ -95,7 +95,7 @@ class AttributedClampNormalizerTests(unittest.TestCase): (flag, resp) = compare_graphs(graph, ref_graph, 'result') self.assertTrue(flag, resp) - def test_no_1st_input(self): + def test_no_min_input(self): nodes = { **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}), **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}), diff --git a/model-optimizer/mo/ops/clamp.py b/model-optimizer/mo/ops/clamp.py index c22f59f..b2942bb 100644 --- a/model-optimizer/mo/ops/clamp.py +++ b/model-optimizer/mo/ops/clamp.py @@ -71,14 +71,13 @@ class Clamp(Op): @staticmethod def infer(node): name = node.soft_get('name', node.id) - connected_in_ports = [port.idx for port in node.in_ports().values() if not port.disconnected()] - - assert len(connected_in_ports) == 3 and sorted(connected_in_ports) == [0, 1, 2], \ - 'Clamp should have exactly three inputs, but it has {}'.format(len(connected_in_ports)) + min_input_connected = node.has_port('in', 1) and not node.in_port(1).disconnected() + max_input_connected = node.has_port('in', 2) and not node.in_port(2).disconnected() input_value = node.in_port(0).data.get_value() - min_value = node.in_port(1).data.get_value() - max_value = node.in_port(2).data.get_value() + min_value = node.in_port(1).data.get_value() if min_input_connected else np.finfo(np.float32).min + max_value = node.in_port(2).data.get_value() if max_input_connected else np.finfo(np.float32).max + if input_value is not None and min_value is not None and max_value is not None: assert np.all(max_value >= min_value), \ 'Clamp max_value=={} is less than min_value=={} for node `{}`'.format(max_value, min_value, name) -- 2.7.4