Support ONNX Clamp-11 (#538)
authorMaxim Vafin <maxim.vafin@intel.com>
Mon, 25 May 2020 16:59:07 +0000 (19:59 +0300)
committerGitHub <noreply@github.com>
Mon, 25 May 2020 16:59:07 +0000 (19:59 +0300)
16 files changed:
model-optimizer/automation/package_BOM.txt
model-optimizer/extensions/back/ClampNormalizer.py [new file with mode: 0644]
model-optimizer/extensions/back/ClampNormalizer_test.py [new file with mode: 0644]
model-optimizer/extensions/front/AttributedClampNormalizer.py [new file with mode: 0644]
model-optimizer/extensions/front/AttributedClampNormalizer_test.py [new file with mode: 0644]
model-optimizer/extensions/front/kaldi/replace_lstm_node_pattern.py
model-optimizer/extensions/front/mxnet/clip_ext.py
model-optimizer/extensions/front/mxnet/cumsum.py
model-optimizer/extensions/front/onnx/clip_ext.py
model-optimizer/extensions/load/onnx/loader.py
model-optimizer/extensions/middle/ReluQuantizeFuse.py
model-optimizer/extensions/ops/activation_ops.py
model-optimizer/mo/front/onnx/extractors/utils.py
model-optimizer/mo/ops/clamp.py
model-optimizer/mo/ops/clamp_test.py
model-optimizer/mo/utils/ir_reader/layer_to_class.py

index 104d86f..dfbd2e6 100644 (file)
@@ -12,6 +12,7 @@ extensions/back/ActivationsNormalizer.py
 extensions/back/AvgPool.py
 extensions/back/blob_normalizer.py
 extensions/back/CellNormalizer.py
+extensions/back/ClampNormalizer.py
 extensions/back/compress_quantized_weights.py
 extensions/back/ConvolutionNormalizer.py
 extensions/back/CorrectName.py
@@ -72,6 +73,7 @@ extensions/back/UselessConcatRemoval.py
 extensions/front/__init__.py
 extensions/front/ArgMaxSqueeze.py
 extensions/front/ATenToEmbeddingBag.py
+extensions/front/AttributedClampNormalizer.py
 extensions/front/AttributedGatherNormalizer.py
 extensions/front/AttributedPadToPad.py
 extensions/front/binary_quantize_normalization.py
diff --git a/model-optimizer/extensions/back/ClampNormalizer.py b/model-optimizer/extensions/back/ClampNormalizer.py
new file mode 100644 (file)
index 0000000..52fe3c7
--- /dev/null
@@ -0,0 +1,71 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from extensions.ops.elementwise import Minimum, Maximum
+from mo.back.replacement import BackReplacementPattern
+from mo.graph.graph import Graph, rename_node
+from mo.ops.clamp import AttributedClamp
+
+
+class ClampNormalizer(BackReplacementPattern):
+    """
+    Replaces Clamp with `min` and `max` as inputs with AttributedClamp with `min` and `max` as attributes.
+    """
+    enabled = True
+    force_clean_up = True
+
+    def pattern(self):
+        return dict(
+            nodes=[('clamp', dict(op='Clamp'))],
+            edges=[]
+        )
+
+    def replace_pattern(self, graph: Graph, match: dict):
+        clamp = match['clamp']
+        name = clamp.soft_get('name', clamp.id)
+
+        min_value = max_value = None
+        port_1_exist = clamp.has_port('in', 1) and not clamp.in_port(1).disconnected()
+        port_2_exist = clamp.has_port('in', 2) and not clamp.in_port(2).disconnected()
+        if port_1_exist and clamp.in_port(1).get_source().node.soft_get('type') == 'Const':
+            min_value = clamp.in_port(1).data.get_value()
+        if port_2_exist and clamp.in_port(2).get_source().node.soft_get('type') == 'Const':
+            max_value = clamp.in_port(2).data.get_value()
+
+        rename_node(clamp, name + '/TBR')
+        if min_value is None or max_value is None:
+            max_node = min_node = None
+            if port_1_exist:
+                max_node = Maximum(graph, {}).create_node()
+                clamp.in_port(0).get_connection().set_destination(max_node.in_port(0))
+                clamp.in_port(1).get_connection().set_destination(max_node.in_port(1))
+                clamp.out_port(0).get_connection().set_source(max_node.out_port(0))
+            if port_2_exist:
+                min_node = Minimum(graph, {}).create_node()
+                if max_node is not None:
+                    max_node.out_port(0).get_connection().set_source(min_node.out_port(0))
+                    max_node.out_port(0).connect(min_node.in_port(0))
+                else:
+                    clamp.in_port(0).get_connection().set_destination(min_node.in_port(0))
+                    clamp.out_port(0).get_connection().set_source(min_node.out_port(0))
+                clamp.in_port(2).get_connection().set_destination(min_node.in_port(1))
+            assert min_node is not None or max_node is not None, 'Clamp node should have either min or max input used'
+            rename_node(max_node if min_node is None else min_node, name)
+        else:
+            a_clamp = AttributedClamp(graph, {'name': name, 'min': min_value, 'max': max_value}).create_node()
+            rename_node(a_clamp, name)
+            clamp.in_port(0).get_connection().set_destination(a_clamp.in_port(0))
+            clamp.out_port(0).get_connection().set_source(a_clamp.out_port(0))
diff --git a/model-optimizer/extensions/back/ClampNormalizer_test.py b/model-optimizer/extensions/back/ClampNormalizer_test.py
new file mode 100644 (file)
index 0000000..eb5b07e
--- /dev/null
@@ -0,0 +1,118 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+import unittest
+
+import numpy as np
+
+from extensions.back.ClampNormalizer import ClampNormalizer
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, result, connect
+
+
+class AttributedClampNormalizerTests(unittest.TestCase):
+
+    def test_2_inputs(self):
+        nodes = {
+            **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
+            **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
+            **regular_op_with_shaped_data('clamp', [1, 3, 20, 20],
+                                          {'type': 'Clamp', 'op': 'AttributedClamp', 'min': -3.5, 'max': 3.5}),
+            **valued_const_with_data('min', np.array(-3.5)),
+            **valued_const_with_data('max', np.array(3.5)),
+            **result('result'),
+        }
+        edges = [*connect('placeholder', '0:a_clamp'),
+                 *connect('min', '1:a_clamp'),
+                 *connect('max', '2:a_clamp'),
+                 *connect('a_clamp', 'result'),
+                 ]
+        graph = build_graph(nodes, edges)
+        ClampNormalizer().find_and_replace_pattern(graph)
+        ref_graph = build_graph(nodes, [*connect('placeholder', '0:clamp'), *connect('clamp', 'result')])
+
+        (flag, resp) = compare_graphs(graph, ref_graph, 'result')
+        self.assertTrue(flag, resp)
+
+    def test_all_dynamic_inputs(self):
+        nodes = {
+            **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
+            **regular_op_with_shaped_data('min', [1, 3, 20, 20], {'type': 'Parameter'}),
+            **regular_op_with_shaped_data('max', [1, 3, 20, 20], {'type': 'Parameter'}),
+            **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
+            **regular_op_with_shaped_data('maximum', [1, 3, 20, 20], {'type': 'Maximum', 'op': 'Maximum'}),
+            **regular_op_with_shaped_data('minimum', [1, 3, 20, 20], {'type': 'Minimum', 'op': 'Minimum'}),
+            **result('result'),
+        }
+        edges = [*connect('placeholder', '0:a_clamp'),
+                 *connect('min', '1:a_clamp'),
+                 *connect('max', '2:a_clamp'),
+                 *connect('a_clamp', 'result'),
+                 ]
+        graph = build_graph(nodes, edges)
+        ClampNormalizer().find_and_replace_pattern(graph)
+        ref_graph = build_graph(nodes, [*connect('placeholder', '0:maximum'),
+                                        *connect('min', '1:maximum'),
+                                        *connect('maximum', '0:minimum'),
+                                        *connect('max', '1:minimum'),
+                                        *connect('minimum', 'result')
+                                        ])
+
+        (flag, resp) = compare_graphs(graph, ref_graph, 'result')
+        self.assertTrue(flag, resp)
+
+    def test_no_2nd_input(self):
+        nodes = {
+            **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
+            **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
+            **regular_op_with_shaped_data('maximum', [1, 3, 20, 20], {'type': 'Maximum', 'op': 'Maximum'}),
+            **valued_const_with_data('min', np.array(-3.5)),
+            **result('result'),
+        }
+        edges = [*connect('placeholder', '0:a_clamp'),
+                 *connect('min', '1:a_clamp'),
+                 *connect('a_clamp', 'result'),
+                 ]
+        graph = build_graph(nodes, edges)
+        ClampNormalizer().find_and_replace_pattern(graph)
+        ref_graph = build_graph(nodes, [*connect('placeholder', '0:maximum'),
+                                        *connect('min', '1:maximum'),
+                                        *connect('maximum', 'result')
+                                        ])
+
+        (flag, resp) = compare_graphs(graph, ref_graph, 'result')
+        self.assertTrue(flag, resp)
+
+    def test_no_1st_input(self):
+        nodes = {
+            **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
+            **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
+            **regular_op_with_shaped_data('minimum', [1, 3, 20, 20], {'type': 'Minimum', 'op': 'Minimum'}),
+            **valued_const_with_data('max', np.array(3.5)),
+            **result('result'),
+        }
+        edges = [*connect('placeholder', '0:a_clamp'),
+                 *connect('max', '2:a_clamp'),
+                 *connect('a_clamp', 'result'),
+                 ]
+        graph = build_graph(nodes, edges)
+        ClampNormalizer().find_and_replace_pattern(graph)
+        ref_graph = build_graph(nodes, [*connect('placeholder', '0:minimum'),
+                                        *connect('max', '1:minimum'),
+                                        *connect('minimum', 'result')
+                                        ])
+
+        (flag, resp) = compare_graphs(graph, ref_graph, 'result')
+        self.assertTrue(flag, resp)
diff --git a/model-optimizer/extensions/front/AttributedClampNormalizer.py b/model-optimizer/extensions/front/AttributedClampNormalizer.py
new file mode 100644 (file)
index 0000000..0b55596
--- /dev/null
@@ -0,0 +1,46 @@
+"""
+ Copyright (C) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+import numpy as np
+
+from mo.front.common.replacement import FrontReplacementPattern
+from mo.front.tf.graph_utils import create_op_with_const_inputs
+from mo.graph.graph import Graph, rename_node
+from mo.ops.clamp import Clamp
+
+
+class AttributedClampNormalizer(FrontReplacementPattern):
+    """
+    This transformation converts AttributedClamp operation (min/max are specified as attribute) to Clamp
+    operation.
+    """
+    enabled = True
+
+    def find_and_replace_pattern(self, graph: Graph):
+        for attr_clamp in graph.get_op_nodes(op='AttributedClamp'):
+            original_name = attr_clamp.soft_get('name', attr_clamp.id)
+
+            rename_node(attr_clamp, original_name + '/TBR')
+            min_value = attr_clamp.soft_get('min', np.finfo(np.float32).min)
+            max_value = attr_clamp.soft_get('max', np.finfo(np.float32).max)
+            new_clamp = create_op_with_const_inputs(graph, Clamp,
+                                                    {1: np.array(min_value, dtype=np.float32),
+                                                     2: np.array(max_value, dtype=np.float32)},
+                                                    {'name': original_name})
+            rename_node(new_clamp, original_name)
+
+            attr_clamp.in_port(0).get_connection().set_destination(new_clamp.in_port(0))
+            attr_clamp.out_port(0).get_connection().set_source(new_clamp.out_port(0))
+            graph.remove_node(attr_clamp.id)
diff --git a/model-optimizer/extensions/front/AttributedClampNormalizer_test.py b/model-optimizer/extensions/front/AttributedClampNormalizer_test.py
new file mode 100644 (file)
index 0000000..0eb1bb7
--- /dev/null
@@ -0,0 +1,62 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import unittest
+
+import numpy as np
+
+from extensions.front.AttributedClampNormalizer import AttributedClampNormalizer
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph, const
+
+nodes_attributes = {
+    'placeholder': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+    'attr_clamp': {'type': 'Clamp', 'kind': 'op', 'op': 'AttributedClamp', 'name': 'attr_clamp',
+                   'min': np.array(-3.5, dtype=np.float32), 'max': np.array(3.5, dtype=np.float32)},
+    'result': {'type': 'Result', 'value': None, 'kind': 'op', 'op': 'Result'},
+
+    # new Clamp layer and inputs
+    'clamp': {'type': None, 'kind': 'op', 'op': 'Clamp'},
+    **const('min', np.array(-3.5, dtype=np.float32)),
+    **const('max', np.array(3.5, dtype=np.float32)),
+}
+
+
+class AttributedClampNormalizerTest(unittest.TestCase):
+    def test_1(self):
+        graph = build_graph(nodes_attributes,
+                            [('placeholder', 'attr_clamp', {'in': 0, 'out': 0}),
+                             ('attr_clamp', 'result', {'in': 0, 'out': 0}),
+                             ],
+                            {}, nodes_with_edges_only=True)
+
+        graph_ref = build_graph(nodes_attributes,
+                                [('placeholder', 'clamp', {'in': 0, 'out': 0}),
+                                 ('min', 'clamp', {'in': 1, 'out': 0}),
+                                 ('max', 'clamp', {'in': 2, 'out': 0}),
+                                 ('clamp', 'result')
+                                 ],
+                                {}, nodes_with_edges_only=True)
+
+        graph.graph['layout'] = 'NCHW'
+        graph.stage = 'front'
+
+        replacer = AttributedClampNormalizer()
+        replacer.find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
+        self.assertTrue(flag, resp)
+        self.assertTrue(graph.node[graph.get_nodes_with_attributes(op='Clamp')[0]]['name'] == 'attr_clamp')
index 05c8500..8f29bae 100644 (file)
@@ -22,13 +22,14 @@ from extensions.ops.split import Split
 from mo.front.caffe.extractors.utils import input_as_const
 from mo.front.common.partial_infer.utils import int64_array
 from mo.front.common.replacement import FrontReplacementOp
+from mo.front.tf.graph_utils import create_op_with_const_inputs
 from mo.graph.graph import Node, Graph, Port
 from mo.ops.assign import Assign
 from mo.ops.broadcast import Broadcast
 from mo.ops.clamp import Clamp
-from mo.ops.crop import Crop
 from mo.ops.concat import Concat
 from mo.ops.const import Const
+from mo.ops.crop import Crop
 from mo.ops.read_value import ReadValue
 from mo.ops.result import Result
 from mo.ops.scale_shift import ScaleShiftOp
@@ -238,10 +239,10 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
         join_forget_remember_sum.in_port(1).connect(join_remember_candidates_mul.out_port(0))
 
         # (7)Eltwise(sum) -> Clamp
-        join_forget_clamp = Clamp(graph, {'name': 'join_forget_clamp',
-                                          'max': node.clip_value,
-                                          'min': -node.clip_value}).create_node()
-        join_forget_clamp.in_port(0).connect(join_forget_remember_sum.out_port(0))
+        join_forget_clamp = create_op_with_const_inputs(graph, Clamp, {1: np.array(-node.clip_value, dtype=np.float32),
+                                                                       2: np.array(node.clip_value, dtype=np.float32)},
+                                                        {'name': 'join_forget_clamp'},
+                                                        join_forget_remember_sum)
         #
         # Clamp -> (2)Memory(state)
         next_lstm_state = Assign(graph, {'name': 'next_lstm_state',
index dbb0af8..ad8b2ba 100644 (file)
@@ -16,7 +16,7 @@
 from mo.front.extractor import FrontExtractorOp
 from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
 from mo.graph.graph import Node
-from mo.ops.clamp import Clamp
+from mo.ops.clamp import AttributedClamp
 
 
 class ClipExt(FrontExtractorOp):
@@ -27,5 +27,5 @@ class ClipExt(FrontExtractorOp):
     def extract(cls, node: Node):
         attrs = get_mxnet_layer_attrs(node.symbol_dict)
 
-        Clamp.update_node_stat(node, {'min': attrs.float('a_min', None), 'max': attrs.float('a_max', None),})
+        AttributedClamp.update_node_stat(node, {'min': attrs.float('a_min', None), 'max': attrs.float('a_max', None)})
         return cls.enabled
index db98c65..84717e6 100644 (file)
@@ -38,7 +38,7 @@ class CumSumFrontReplacer(FrontReplacementOp):
 
         node.in_port(0).get_connection().set_destination(cumsum_node.in_port(0))
         if node.has_valid('mx_out_type') and node['mx_out_type'] is not None:
-            rename_node(node=cumsum_node, name=name + '/Clamp')
+            rename_node(node=cumsum_node, name=name + '/CumSum')
             convert = Cast(graph, {'name': name, 'dst_type': node['mx_out_type']}).create_node()
             rename_node(convert, name)
             cumsum_node.out_port(0).connect(convert.in_port(0))
index 748d1fb..1883b78 100644 (file)
  See the License for the specific language governing permissions and
  limitations under the License.
 """
+import numpy as np
 
 from mo.front.extractor import FrontExtractorOp
-from mo.front.onnx.extractors.utils import onnx_attr
-from mo.ops.clamp import Clamp
+from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version
+from mo.ops.clamp import Clamp, AttributedClamp
 
 
 class ClipFrontExtractor(FrontExtractorOp):
@@ -25,9 +26,12 @@ class ClipFrontExtractor(FrontExtractorOp):
 
     @classmethod
     def extract(cls, node):
-        attrs = {
-            'min': onnx_attr(node, 'min', 'f', -3.4028234663852886e+38),
-            'max': onnx_attr(node, 'max', 'f', 3.4028234663852886e+38),
-        }
-        Clamp.update_node_stat(node, attrs)
+        if get_onnx_opset_version(node) < 11:
+            attrs = {
+                'min': onnx_attr(node, 'min', 'f', np.finfo(np.float32).min),
+                'max': onnx_attr(node, 'max', 'f', np.finfo(np.float32).max),
+            }
+            AttributedClamp.update_node_stat(node, attrs)
+        else:
+            Clamp.update_node_stat(node)
         return cls.enabled
index 6da5eac..c92058b 100644 (file)
@@ -63,6 +63,10 @@ class ONNXLoader(Loader):
         graph.graph['layout'] = 'NCHW'
         graph.graph['fw'] = 'onnx'
         graph.graph['feature_dim'] = 1
+        if hasattr(model_proto, 'opset_import'):
+            graph.graph['fw_opset_version'] = model_proto.opset_import[0].version   # pylint: disable=no-member
+        else:
+            graph.graph['fw_opset_version'] = None
 
         graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model')
         extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors)))
index ca89b4c..3e025f0 100644 (file)
@@ -120,7 +120,11 @@ class ClampQuantizeMark(MiddleReplacementPattern):
     def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
         clamp = match['clamp']
         quantize = match['quantize']
-        clamp_min, clamp_max = clamp['min'], clamp['max']
+        clamp_min = clamp.in_port(1).data.get_value()
+        clamp_max = clamp.in_port(2).data.get_value()
+        if clamp_min is None or clamp_max is None:
+            log.debug('ReluQuantizeFuse: cannot fuse because Clamp op has dynamic input on the 1st or 2nd port')
+            return
 
         if not clamp.has_valid('quantized_to_fuse_count'):
             clamp['quantized_to_fuse_count'] = 0
index cd4282b..877acaa 100644 (file)
@@ -18,7 +18,7 @@ import numpy as np
 
 from mo.front.common.partial_infer.eltwise import eltwise_infer
 from mo.graph.graph import Graph, Node
-from mo.ops.clamp import Clamp
+from mo.ops.clamp import AttributedClamp
 from mo.ops.op import Op
 
 activation_ops = ['Sigmoid', 'Tanh', 'ReLU6', 'Exp', 'Elu', 'LogicalNot', 'Floor', 'Ceiling']
@@ -95,7 +95,7 @@ class Atan(Activation):
     operation = staticmethod(lambda x: np.arctan(x))
 
 
-class ReLU6(Clamp):
+class ReLU6(AttributedClamp):
     op = 'ReLU6'
 
     def __init__(self, graph: Graph, attrs: dict):
index baf9ea3..6bff1a7 100644 (file)
@@ -57,6 +57,10 @@ def get_onnx_autopad(auto_pad):
     return auto_pad
 
 
+def get_onnx_opset_version(node: Node):
+    return node.graph.graph.get('fw_opset_version', 0)
+
+
 def get_onnx_datatype_as_numpy(value):
     datatype_to_numpy = {
         1: np.float32,
index 8205bc9..bd6cd76 100644 (file)
@@ -19,13 +19,13 @@ from mo.graph.graph import Graph
 from mo.ops.op import Op
 
 
-class Clamp(Op):
-    op = 'Clamp'
+class AttributedClamp(Op):
+    op = 'AttributedClamp'
 
     def __init__(self, graph: Graph, attrs: dict):
         super().__init__(graph, {
-            'type': __class__.op,
-            'op': __class__.op,
+            'type': 'Clamp',
+            'op': self.op,
             'version': 'opset1',
             'infer': copy_shape_infer,
             'in_ports_count': 1,
@@ -37,3 +37,16 @@ class Clamp(Op):
             'max',
             'min'
         ]
+
+
+class Clamp(Op):
+    op = 'Clamp'
+
+    def __init__(self, graph: Graph, attrs: dict):
+        super().__init__(graph, {
+            'type': None,
+            'op': self.op,
+            'infer': copy_shape_infer,
+            'in_ports_count': 3,
+            'out_ports_count': 1,
+        }, attrs)
index 5cd7927..71ebffe 100644 (file)
@@ -19,7 +19,7 @@ import unittest
 import numpy as np
 
 from mo.front.common.partial_infer.elemental import copy_shape_infer
-from mo.ops.clamp import Clamp
+from mo.ops.clamp import AttributedClamp
 from mo.utils.unittest.graph import build_graph
 
 
@@ -41,7 +41,7 @@ class TestClampOp(unittest.TestCase):
                                 ('node_1', 'clamp_node'),
                                 ('clamp_node', 'node_3')
                             ])
-        clamp_node = Clamp(graph, self.nodes_attributes['clamp_node']).add_node()
+        clamp_node = AttributedClamp(graph, self.nodes_attributes['clamp_node']).add_node()
         self.assertEqual(clamp_node.type, 'Clamp')
-        self.assertEqual(clamp_node.op, 'Clamp')
+        self.assertEqual(clamp_node.op, 'AttributedClamp')
         self.assertEqual(clamp_node.infer, copy_shape_infer)
index 83b63f9..c9ad763 100644 (file)
@@ -28,6 +28,7 @@ from extensions.ops.scatter import Scatter
 from extensions.ops.split import Split, VariadicSplit
 from mo.front.common.partial_infer.utils import int64_array
 from mo.graph.graph import Graph, Node
+from mo.ops.clamp import AttributedClamp
 from mo.ops.convolution import Convolution
 from mo.ops.deconvolution import Deconvolution
 from mo.ops.op import Op
@@ -53,6 +54,7 @@ custom_ops = {
     'Split': Split,
     'Subtract': Sub,
     'VariadicSplit': VariadicSplit,
+    'Clamp': AttributedClamp,
 }