MO fusing activations (#1942)
authorEvgeny Lazarev <evgeny.lazarev@intel.com>
Thu, 27 Aug 2020 08:56:52 +0000 (11:56 +0300)
committerGitHub <noreply@github.com>
Thu, 27 Aug 2020 08:56:52 +0000 (11:56 +0300)
* Added HSwish operation

* Added HSwish fusing transformation

* Fixed BOM

* Added unit test for HSwish fusing transformation

* Fixed unit tests for transformations using 'build_graph_with_edge_attrs' function to build the graph

* Added fusion transformation for Swish operation

* Added fusing transformation for Softplus operation

* Added fusion transformation for Mish operation

* Added check for the node name in the unit tests

* Fixed Mish fusion pattern

* Updated Mish fusion transformation. Added unit test

* Updated HSwish fusing transformation

* Updated Swish fusion transformation and tests

* Fixed unit tests

13 files changed:
model-optimizer/automation/package_BOM.txt
model-optimizer/extensions/front/HSwish_fusing_test.py [new file with mode: 0644]
model-optimizer/extensions/front/HSwish_fusion.py [new file with mode: 0644]
model-optimizer/extensions/front/Mish_fusion.py [new file with mode: 0644]
model-optimizer/extensions/front/Mish_fusion_test.py [new file with mode: 0644]
model-optimizer/extensions/front/Softplus_fusion.py [new file with mode: 0644]
model-optimizer/extensions/front/Softplus_fusion_test.py [new file with mode: 0644]
model-optimizer/extensions/front/Swish_fusion.py [new file with mode: 0644]
model-optimizer/extensions/front/Swish_fusion_test.py [new file with mode: 0644]
model-optimizer/extensions/front/caffe/axpy_test.py
model-optimizer/extensions/front/tf/fifo_replacer_test.py
model-optimizer/extensions/ops/activation_ops.py
model-optimizer/mo/utils/unittest/graph.py

index bd74604..4167d68 100644 (file)
@@ -127,6 +127,7 @@ extensions/front/freeze_placeholder_value.py
 extensions/front/GeLUMerger_Erf.py
 extensions/front/GeLUMerger_Tanh.py
 extensions/front/global_pooling_to_reduce.py
+extensions/front/HSwish_fusion.py
 extensions/front/image_scaler.py
 extensions/front/input_cut.py
 extensions/front/instance_normalization.py
@@ -150,6 +151,7 @@ extensions/front/LayerNorm.py
 extensions/front/Log1p.py
 extensions/front/LogSoftmax.py
 extensions/front/MatMul_normalizer.py
+extensions/front/Mish_fusion.py
 extensions/front/MoveEmbeddedInputsToInputs.py
 extensions/front/mxnet/__init__.py
 extensions/front/mxnet/activation.py
@@ -326,11 +328,13 @@ extensions/front/reshape_dim_normalizer.py
 extensions/front/restore_ports.py
 extensions/front/scatter_normalizer.py
 extensions/front/softmax.py
+extensions/front/Softplus_fusion.py
 extensions/front/softsign_replacer.py
 extensions/front/split_normalizer.py
 extensions/front/SqueezeNormalize.py
 extensions/front/standalone_const_eraser.py
 extensions/front/sub.py
+extensions/front/Swish_fusion.py
 extensions/front/tf/__init__.py
 extensions/front/tf/activation_ext.py
 extensions/front/tf/argmax_ext.py
diff --git a/model-optimizer/extensions/front/HSwish_fusing_test.py b/model-optimizer/extensions/front/HSwish_fusing_test.py
new file mode 100644 (file)
index 0000000..b7cedba
--- /dev/null
@@ -0,0 +1,196 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import unittest
+
+from extensions.front.HSwish_fusion import HSwishWithClamp, HSwishWithMinMax
+from mo.front.common.partial_infer.utils import float_array
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph, const, regular_op, result, build_graph_with_edge_attrs
+
+ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
+             **regular_op('hswish', {'type': 'HSwish', 'name': 'final_mul'}),
+             **result('result')
+             }
+ref_edges = [('input', 'hswish'), ('hswish', 'result')]
+
+
+class HSwishWithClampTest(unittest.TestCase):
+    nodes = {
+        **regular_op('input', {'type': 'Parameter'}),
+        **regular_op('add', {'op': 'Add'}),
+        **regular_op('relu6', {'op': 'Clamp'}),
+        **regular_op('mul', {'op': 'Mul'}),
+        **regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
+        **const('const_0', float_array([0.0])),
+        **const('const_3', float_array([3.0])),
+        **const('const_6', float_array([6.0])),
+        **const('const_1_6', float_array([1.0 / 6.0])),
+        **result('result'),
+    }
+
+    edges = [('input', 'mul', {'in': 0, 'out': 0}),
+             ('input', 'add', {'in': 0, 'out': 0}),
+             ('const_3', 'add', {'in': 1, 'out': 0}),
+             ('add', 'relu6', {'in': 0, 'out': 0}),
+             ('const_0', 'relu6', {'in': 1, 'out': 0}),
+             ('const_6', 'relu6', {'in': 2, 'out': 0}),
+             ('relu6', 'mul', {'in': 1, 'out': 0}),
+             ('mul', 'mul_2', {'in': 0, 'out': 0}),
+             ('const_1_6', 'mul_2', {'in': 1, 'out': 0}),
+             ('mul_2', 'result', {'in': 0, 'out': 0})]
+
+    def test_hswish_with_clamp(self):
+        graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
+
+        graph_ref = build_graph(ref_nodes, ref_edges)
+        graph.stage = 'front'
+
+        HSwishWithClamp().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
+        self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
+                        graph.get_op_nodes(name='final_mul')[0].op == 'HSwish')
+
+    def test_hswish_with_clamp_wrong_constant(self):
+        graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_0': {'value': float_array([0.00001])}})
+
+        graph_ref = graph.copy()
+        graph.stage = 'front'
+
+        HSwishWithClamp().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
+
+    def test_hswish_with_clamp_different_tensors(self):
+        graph = build_graph_with_edge_attrs({
+            **regular_op('input', {'type': 'Parameter'}),
+            **regular_op('input_2', {'type': 'Parameter'}),
+            **regular_op('add', {'op': 'Add'}),
+            **regular_op('relu6', {'op': 'Clamp'}),
+            **regular_op('mul', {'op': 'Mul'}),
+            **regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
+            **const('const_0', float_array([0.0])),
+            **const('const_3', float_array([3.0])),
+            **const('const_6', float_array([6.0])),
+            **const('const_1_6', float_array([1.0 / 6.0])),
+            **result('result'),
+        }, [('input', 'mul', {'in': 0, 'out': 0}),
+            ('input_2', 'add', {'in': 0, 'out': 0}),
+            ('const_3', 'add', {'in': 1, 'out': 0}),
+            ('add', 'relu6', {'in': 0, 'out': 0}),
+            ('const_0', 'relu6', {'in': 1, 'out': 0}),
+            ('const_6', 'relu6', {'in': 2, 'out': 0}),
+            ('relu6', 'mul', {'in': 1, 'out': 0}),
+            ('mul', 'mul_2', {'in': 0, 'out': 0}),
+            ('const_1_6', 'mul_2', {'in': 1, 'out': 0}),
+            ('mul_2', 'result', {'in': 0, 'out': 0})])
+
+        graph_ref = graph.copy()
+        graph.stage = 'front'
+
+        HSwishWithClamp().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
+
+
+class HSwishWithMinMaxTest(unittest.TestCase):
+    nodes = {
+        **regular_op('input', {'type': 'Parameter'}),
+        **regular_op('add', {'op': 'Add'}),
+        **regular_op('max', {'op': 'Maximum'}),
+        **regular_op('min', {'op': 'Minimum'}),
+        **regular_op('mul', {'op': 'Mul'}),
+        **regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
+        **const('const_0', float_array([0.0])),
+        **const('const_3', float_array([3.0])),
+        **const('const_6', float_array([6.0])),
+        **const('const_1_6', float_array([1.0 / 6.0])),
+        **result('result'),
+    }
+
+    edges = [('input', 'mul', {'in': 1, 'out': 0}),
+             ('input', 'add', {'in': 0, 'out': 0}),
+             ('const_3', 'add', {'in': 1, 'out': 0}),
+             ('add', 'max', {'in': 0, 'out': 0}),
+             ('const_0', 'max', {'in': 1, 'out': 0}),
+             ('max', 'min', {'in': 0, 'out': 0}),
+             ('const_6', 'min', {'in': 1, 'out': 0}),
+             ('min', 'mul', {'in': 0, 'out': 0}),
+             ('mul', 'mul_2', {'in': 0, 'out': 0}),
+             ('const_1_6', 'mul_2', {'in': 1, 'out': 0}),
+             ('mul_2', 'result', {'in': 0, 'out': 0})]
+
+    def test_hswish_with_min_max(self):
+        graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
+
+        graph_ref = build_graph(ref_nodes, ref_edges)
+        graph.stage = 'front'
+
+        HSwishWithMinMax().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
+        self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
+                        graph.get_op_nodes(name='final_mul')[0].op == 'HSwish')
+
+    def test_hswish_with_min_max_wrong_constant(self):
+        graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_0': {'value': float_array([0.00001])}})
+
+        graph_ref = graph.copy()
+        graph.stage = 'front'
+
+        HSwishWithMinMax().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
+
+    def test_hswish_with_min_max_different_tensors(self):
+        graph = build_graph_with_edge_attrs({
+            **regular_op('input', {'type': 'Parameter'}),
+            **regular_op('input_2', {'type': 'Parameter'}),
+            **regular_op('add', {'op': 'Add'}),
+            **regular_op('max', {'op': 'Maximum'}),
+            **regular_op('min', {'op': 'Minimum'}),
+            **regular_op('mul', {'op': 'Mul'}),
+            **regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
+            **const('const_0', float_array([0.0])),
+            **const('const_3', float_array([3.0])),
+            **const('const_6', float_array([6.0])),
+            **const('const_1_6', float_array([1.0 / 6.0])),
+            **result('result'),
+        }, [('input_2', 'mul', {'in': 1, 'out': 0}),
+            ('input', 'add', {'in': 0, 'out': 0}),
+            ('const_3', 'add', {'in': 1, 'out': 0}),
+            ('add', 'max', {'in': 0, 'out': 0}),
+            ('const_0', 'max', {'in': 1, 'out': 0}),
+            ('max', 'min', {'in': 0, 'out': 0}),
+            ('const_6', 'min', {'in': 1, 'out': 0}),
+            ('min', 'mul', {'in': 0, 'out': 0}),
+            ('mul', 'mul_2', {'in': 0, 'out': 0}),
+            ('const_1_6', 'mul_2', {'in': 1, 'out': 0}),
+            ('mul_2', 'result', {'in': 0, 'out': 0})])
+
+        graph_ref = graph.copy()
+        graph.stage = 'front'
+
+        HSwishWithMinMax().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
diff --git a/model-optimizer/extensions/front/HSwish_fusion.py b/model-optimizer/extensions/front/HSwish_fusion.py
new file mode 100644 (file)
index 0000000..8148692
--- /dev/null
@@ -0,0 +1,123 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+import numpy as np
+
+from extensions.front.AttributedClampNormalizer import AttributedClampNormalizer
+from extensions.ops.activation_ops import HSwish
+from mo.front.common.replacement import FrontReplacementSubgraph
+from mo.front.subgraph_matcher import SubgraphMatch
+from mo.graph.graph import Graph, rename_nodes
+
+
+def replace_with_hswish(graph: Graph, match: [dict, SubgraphMatch]):
+    add = match['add']
+    mul = match['mul']
+    mul_2 = match['mul_2']
+
+    # determine the input port of Add and Mul which gets the 'input' node output
+    add_input_port_idx = int(add.in_port(0).get_connection().get_source().node.soft_get('op') == 'Const')
+    mul_input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') in ['Clamp', 'Minimum'])
+
+    # check that the same tensor provided as input to Add and Mul
+    if add.in_port(add_input_port_idx).get_source() != mul.in_port(mul_input_port_idx).get_source():
+        return
+    mul_2_name = mul_2.soft_get('name', mul_2.id)
+
+    hswish = HSwish(graph, {}).create_node()
+    hswish.in_port(0).connect(add.in_port(add_input_port_idx).get_source())
+    mul_2.out_port(0).get_connection().set_source(hswish.out_port(0))
+
+    rename_nodes([(mul_2, mul_2_name + '/TBR'), (hswish, mul_2_name)])
+
+
+class HSwishWithClamp(FrontReplacementSubgraph):
+    """
+    The transformation looks for the pattern with ReLU6 (Clamp) defining the HSwish function:
+    HSwish(x) = x * Relu6(x + 3) / 6.0.
+    """
+    enabled = True
+
+    def run_after(self):
+        return [AttributedClampNormalizer]
+
+    def pattern(self):
+        return dict(
+            nodes=[
+                ('input', dict()),
+                ('add', dict(op='Add')),
+                ('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 0.0, atol=1e-6))),
+                ('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 3.0, atol=1e-6))),
+                ('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 6.0, atol=1e-6))),
+                ('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 1 / 6.0, atol=1e-6))),
+                ('clamp', dict(op='Clamp')),
+                ('mul', dict(op='Mul')),
+                ('mul_2', dict(op='Mul')),
+            ],
+            edges=[
+                ('input', 'add', {}),
+                ('input', 'mul', {}),
+                ('const_3', 'add', {}),
+                ('add', 'clamp', {'in': 0}),
+                ('const_0', 'clamp', {'in': 1}),
+                ('const_6', 'clamp', {'in': 2}),
+                ('clamp', 'mul', {}),
+                ('mul', 'mul_2', {}),
+                ('const_1_6', 'mul_2', {}),
+            ])
+
+    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
+        replace_with_hswish(graph, match)
+
+
+class HSwishWithMinMax(FrontReplacementSubgraph):
+    """
+    The transformation looks for the pattern with Min/Max defining the HSwish function:
+    HSwish(x) = x * Min(Max(x + 3, 0), 6) / 6.0.
+    """
+    enabled = True
+
+    def run_after(self):
+        return [AttributedClampNormalizer]
+
+    def pattern(self):
+        return dict(
+            nodes=[
+                ('input', dict()),
+                ('add', dict(op='Add')),
+                ('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 0.0, atol=1e-6))),
+                ('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 3.0, atol=1e-6))),
+                ('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 6.0, atol=1e-6))),
+                ('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 1 / 6.0, atol=1e-6))),
+                ('max', dict(op='Maximum')),
+                ('min', dict(op='Minimum')),
+                ('mul', dict(op='Mul')),
+                ('mul_2', dict(op='Mul')),
+            ],
+            edges=[
+                ('input', 'add', {'out': 0}),
+                ('input', 'mul', {'out': 0}),
+                ('const_3', 'add', {}),
+                ('add', 'max', {}),
+                ('const_0', 'max', {}),
+                ('max', 'min', {}),
+                ('const_6', 'min', {}),
+                ('min', 'mul', {}),
+                ('mul', 'mul_2', {}),
+                ('const_1_6', 'mul_2', {}),
+            ])
+
+    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
+        replace_with_hswish(graph, match)
diff --git a/model-optimizer/extensions/front/Mish_fusion.py b/model-optimizer/extensions/front/Mish_fusion.py
new file mode 100644 (file)
index 0000000..5d0bfa2
--- /dev/null
@@ -0,0 +1,61 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from extensions.front.Softplus_fusion import SoftplusFusion
+from extensions.ops.activation_ops import Mish
+from mo.front.common.replacement import FrontReplacementSubgraph
+from mo.front.subgraph_matcher import SubgraphMatch
+from mo.graph.graph import Graph, rename_nodes
+
+
+class MishFusion(FrontReplacementSubgraph):
+    """
+    The transformation looks for the pattern with Softplus defining the Mish function: Mish(x) = x * tanh(SoftPlus(x)).
+    """
+    enabled = True
+
+    def run_after(self):
+        return [SoftplusFusion]
+
+    def pattern(self):
+        return dict(
+            nodes=[
+                ('mul', dict(op='Mul')),
+                ('tanh', dict(op='Tanh')),
+                ('softplus', dict(op='SoftPlus')),
+            ],
+            edges=[
+                ('softplus', 'tanh'),
+                ('tanh', 'mul'),
+            ])
+
+    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
+        mul = match['mul']
+        mul_name = mul.soft_get('name', mul.id)
+        softplus = match['softplus']
+
+        # determine the input port of Mul which gets the 'input' node output
+        input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Tanh')
+
+        # check that the same tensor provided as input to Mul and SoftPlus
+        if mul.in_port(input_port_idx).get_source() != softplus.in_port(0).get_source():
+            return
+
+        mish = Mish(graph, {}).create_node()
+        mish.in_port(0).connect(mul.in_port(input_port_idx).get_source())
+        mul.out_port(0).get_connection().set_source(mish.out_port(0))
+
+        rename_nodes([(mul, mul_name + '/TBR'), (mish, mul_name)])
diff --git a/model-optimizer/extensions/front/Mish_fusion_test.py b/model-optimizer/extensions/front/Mish_fusion_test.py
new file mode 100644 (file)
index 0000000..c1d97c4
--- /dev/null
@@ -0,0 +1,79 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import unittest
+
+from extensions.front.Mish_fusion import MishFusion
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph, regular_op, result, build_graph_with_edge_attrs
+
+ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
+             **regular_op('mish', {'type': 'Mish', 'name': 'final_mul'}),
+             **result('result')
+             }
+ref_edges = [('input', 'mish'), ('mish', 'result')]
+
+
+class MishFusionTest(unittest.TestCase):
+    nodes = {
+        **regular_op('input', {'type': 'Parameter'}),
+        **regular_op('softplus', {'op': 'SoftPlus'}),
+        **regular_op('tanh', {'op': 'Tanh'}),
+        **regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}),
+        **result('result'),
+    }
+
+    edges = [('input', 'softplus', {'in': 0, 'out': 0}),
+             ('input', 'mul', {'in': 0, 'out': 0}),
+             ('softplus', 'tanh', {'in': 0, 'out': 0}),
+             ('tanh', 'mul', {'in': 1, 'out': 0}),
+             ('mul', 'result', {'in': 0, 'out': 0})]
+
+    def test_mish_fusion(self):
+        graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
+
+        graph_ref = build_graph(ref_nodes, ref_edges)
+        graph.stage = 'front'
+
+        MishFusion().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
+        self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
+                        graph.get_op_nodes(name='final_mul')[0].op == 'Mish')
+
+    def test_mish_fusion_different_source(self):
+        # check case when different tensors goes to Mul and SoftPlus
+        graph = build_graph_with_edge_attrs({
+            **regular_op('input', {'type': 'Parameter'}),
+            **regular_op('input_2', {'type': 'Parameter'}),
+            **regular_op('softplus', {'op': 'SoftPlus'}),
+            **regular_op('tanh', {'op': 'Tanh'}),
+            **regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}),
+            **result('result'),
+        }, [('input', 'softplus', {'in': 0, 'out': 0}),
+            ('input_2', 'mul', {'in': 0, 'out': 0}),
+            ('softplus', 'tanh', {'in': 0, 'out': 0}),
+            ('tanh', 'mul', {'in': 1, 'out': 0}),
+            ('mul', 'result', {'in': 0, 'out': 0})], {})
+
+        graph_ref = graph.copy()
+        graph.stage = 'front'
+
+        MishFusion().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
diff --git a/model-optimizer/extensions/front/Softplus_fusion.py b/model-optimizer/extensions/front/Softplus_fusion.py
new file mode 100644 (file)
index 0000000..1a70e5d
--- /dev/null
@@ -0,0 +1,54 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+import numpy as np
+
+from extensions.ops.activation_ops import SoftPlus
+from mo.front.common.replacement import FrontReplacementSubgraph
+from mo.front.subgraph_matcher import SubgraphMatch
+from mo.graph.graph import Graph, rename_nodes
+
+
+class SoftplusFusion(FrontReplacementSubgraph):
+    """
+    The transformation looks for the pattern for the Softplus function: Softplus(x) = ln(1 + e^x)
+    """
+    enabled = True
+
+    def pattern(self):
+        return dict(
+            nodes=[
+                ('exp', dict(op='Exp')),
+                ('add', dict(op='Add')),
+                ('const_1', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 1.0, atol=1e-6))),
+                ('ln', dict(op='Log')),
+            ],
+            edges=[
+                ('exp', 'add', {}),
+                ('const_1', 'add', {}),
+                ('add', 'ln', {}),
+            ])
+
+    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
+        ln = match['ln']
+        exp = match['exp']
+
+        ln_name = ln.soft_get('name', ln.id)
+
+        softplus = SoftPlus(graph, {}).create_node()
+        softplus.in_port(0).connect(exp.in_port(0).get_source())
+        ln.out_port(0).get_connection().set_source(softplus.out_port(0))
+
+        rename_nodes([(ln, ln_name + '/TBR'), (softplus, ln_name)])
diff --git a/model-optimizer/extensions/front/Softplus_fusion_test.py b/model-optimizer/extensions/front/Softplus_fusion_test.py
new file mode 100644 (file)
index 0000000..eba085e
--- /dev/null
@@ -0,0 +1,70 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import unittest
+
+from extensions.front.Softplus_fusion import SoftplusFusion
+from mo.front.common.partial_infer.utils import float_array
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph, const, regular_op, result, build_graph_with_edge_attrs
+
+ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
+             **regular_op('softplus', {'type': 'SoftPlus', 'name': 'final_log'}),
+             **result('result')
+             }
+ref_edges = [('input', 'softplus'), ('softplus', 'result')]
+
+
+class SoftplusFusionTest(unittest.TestCase):
+    nodes = {
+        **regular_op('input', {'type': 'Parameter'}),
+        **regular_op('exp', {'op': 'Exp'}),
+        **const('const_1', float_array([1.0])),
+        **regular_op('add', {'op': 'Add'}),
+        **regular_op('ln', {'op': 'Log', 'name': 'final_log'}),
+        **result('result'),
+    }
+
+    edges = [('input', 'exp', {'in': 0, 'out': 0}),
+             ('const_1', 'add', {'in': 0, 'out': 0}),
+             ('exp', 'add', {'in': 1, 'out': 0}),
+             ('add', 'ln', {'in': 0, 'out': 0}),
+             ('ln', 'result', {'in': 0, 'out': 0})]
+
+    def test_softplus_fusion_test(self):
+        graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
+
+        graph_ref = build_graph(ref_nodes, ref_edges)
+        graph.stage = 'front'
+
+        SoftplusFusion().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
+        self.assertTrue(len(graph.get_op_nodes(name='final_log')) == 1 and
+                        graph.get_op_nodes(name='final_log')[0].op == 'SoftPlus')
+
+    def test_softplus_fusion_test_wrong_const(self):
+        graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_1': {'value': float_array([0.9999])}})
+
+        graph_ref = graph.copy()
+        graph.stage = 'front'
+
+        SoftplusFusion().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
+
diff --git a/model-optimizer/extensions/front/Swish_fusion.py b/model-optimizer/extensions/front/Swish_fusion.py
new file mode 100644 (file)
index 0000000..bd47af7
--- /dev/null
@@ -0,0 +1,100 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from extensions.ops.activation_ops import Swish
+from mo.front.common.replacement import FrontReplacementSubgraph
+from mo.front.subgraph_matcher import SubgraphMatch
+from mo.graph.graph import Graph, rename_nodes
+
+
+class SwishWithSigmoidWithoutBeta(FrontReplacementSubgraph):
+    """
+    The transformation looks for the pattern with Sigmoid defining the Swish function: Swish(x) = x * Sigmoid(x)
+    """
+    enabled = True
+
+    def pattern(self):
+        return dict(
+            nodes=[
+                ('sigmoid', dict(op='Sigmoid')),
+                ('mul', dict(op='Mul')),
+            ],
+            edges=[
+                ('sigmoid', 'mul', {}),
+            ])
+
+    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
+        sigmoid = match['sigmoid']
+        mul = match['mul']
+        mul_name = mul.soft_get('name', mul.id)
+
+        # determine the input port of Mul which gets the 'input' node output
+        mul_input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Sigmoid')
+
+        # check that the same tensor provided as input to Mul and Sigmoid
+        if mul.in_port(mul_input_port_idx).get_source() != sigmoid.in_port(0).get_source():
+            return
+
+        swish = Swish(graph, {}).create_node()
+        swish.in_port(0).connect(sigmoid.in_port(0).get_source())
+        mul.out_port(0).get_connection().set_source(swish.out_port(0))
+
+        rename_nodes([(mul, mul_name + '/TBR'), (swish, mul_name)])
+
+
+class SwishWithSigmoidWithBeta(FrontReplacementSubgraph):
+    """
+    The transformation looks for the pattern with Sigmoid defining the Swish function: Swish(x) = x * Sigmoid(x * beta)
+    """
+    enabled = True
+
+    def pattern(self):
+        return dict(
+            nodes=[
+                ('sigmoid', dict(op='Sigmoid')),
+                ('beta', dict()),
+                ('mul_beta', dict(op='Mul')),
+                ('mul', dict(op='Mul')),
+            ],
+            edges=[
+                ('beta', 'mul_beta', {}),
+                ('mul_beta', 'sigmoid', {}),
+                ('sigmoid', 'mul', {}),
+            ])
+
+    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
+        beta = match['beta']
+        mul = match['mul']
+        mul_beta = match['mul_beta']
+        mul_name = mul.soft_get('name', mul.id)
+
+        # determine the input port of Muls which get the 'input' node output
+        mul_beta_input_port_idx = int(mul_beta.in_port(0).get_connection().get_source().node.id == beta.id)
+        mul_input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Sigmoid')
+
+        # check that the same tensor provided as input to Mul and MulBeta
+        if mul.in_port(mul_input_port_idx).get_source() != mul_beta.in_port(mul_beta_input_port_idx).get_source():
+            return
+
+        swish = Swish(graph, {}).create_node()
+        swish.in_port(0).connect(mul_beta.in_port(mul_beta_input_port_idx).get_source())
+
+        # connect Beta value
+        swish.in_port(1).connect(mul_beta.in_port(1 - mul_beta_input_port_idx).get_source())
+
+        mul.out_port(0).get_connection().set_source(swish.out_port(0))
+
+        rename_nodes([(mul, mul_name + '/TBR'), (swish, mul_name)])
diff --git a/model-optimizer/extensions/front/Swish_fusion_test.py b/model-optimizer/extensions/front/Swish_fusion_test.py
new file mode 100644 (file)
index 0000000..08144c8
--- /dev/null
@@ -0,0 +1,132 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import unittest
+
+from extensions.front.Swish_fusion import SwishWithSigmoidWithoutBeta, SwishWithSigmoidWithBeta
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph, regular_op, result, build_graph_with_edge_attrs
+
+ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
+             **regular_op('swish', {'type': 'Swish', 'name': 'final_mul'}),
+             **result('result')
+             }
+ref_edges = [('input', 'swish'), ('swish', 'result')]
+
+
+class SwishWithSigmoidWithoutBetaTest(unittest.TestCase):
+    nodes = {
+        **regular_op('input', {'type': 'Parameter'}),
+        **regular_op('sigmoid', {'op': 'Sigmoid'}),
+        **regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}),
+        **result('result'),
+    }
+
+    edges = [('input', 'mul', {'in': 0, 'out': 0}),
+             ('input', 'sigmoid', {'in': 0, 'out': 0}),
+             ('sigmoid', 'mul', {'in': 1, 'out': 0}),
+             ('mul', 'result', {'in': 0, 'out': 0})]
+
+    def test_swish_with_sigmoid_without_beta_test(self):
+        graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
+
+        graph_ref = build_graph(ref_nodes, ref_edges)
+        graph.stage = 'front'
+
+        SwishWithSigmoidWithoutBeta().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
+        self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
+                        graph.get_op_nodes(name='final_mul')[0].op == 'Swish')
+
+    def test_swish_with_sigmoid_without_beta_different_tensors(self):
+        graph = build_graph_with_edge_attrs({
+            **regular_op('input', {'type': 'Parameter'}),
+            **regular_op('input_2', {'type': 'Parameter'}),
+            **regular_op('sigmoid', {'op': 'Sigmoid'}),
+            **regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}),
+            **result('result'),
+        }, [('input_2', 'mul', {'in': 0, 'out': 0}),
+            ('input', 'sigmoid', {'in': 0, 'out': 0}),
+            ('sigmoid', 'mul', {'in': 1, 'out': 0}),
+            ('mul', 'result', {'in': 0, 'out': 0})], {})
+
+        graph_ref = graph.copy()
+        graph.stage = 'front'
+
+        SwishWithSigmoidWithoutBeta().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
+
+
+class SwishWithSigmoidWithBetaTest(unittest.TestCase):
+    nodes = {
+        **regular_op('input', {'type': 'Parameter'}),
+        **regular_op('beta', {'type': 'Parameter'}),
+        **regular_op('mul_beta', {'op': 'Mul'}),
+        **regular_op('sigmoid', {'op': 'Sigmoid'}),
+        **regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
+        **result('result'),
+    }
+
+    edges = [('input', 'mul_beta', {'in': 0, 'out': 0}),
+             ('input', 'mul_2', {'in': 0, 'out': 0}),
+             ('beta', 'mul_beta', {'in': 1, 'out': 0}),
+             ('mul_beta', 'sigmoid', {'in': 0, 'out': 0}),
+             ('sigmoid', 'mul_2', {'in': 1, 'out': 0}),
+             ('mul_2', 'result', {'in': 0, 'out': 0})]
+
+    def test_swish_with_sigmoid_with_beta_test(self):
+        graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})
+
+        new_ref_nodes = ref_nodes.copy()
+        new_ref_nodes.update(**regular_op('beta', {'type': 'Parameter'}))
+
+        graph_ref = build_graph(new_ref_nodes, ref_edges + [('beta', 'swish')])
+        graph.stage = 'front'
+
+        SwishWithSigmoidWithBeta().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
+        self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
+                        graph.get_op_nodes(name='final_mul')[0].op == 'Swish')
+
+    def test_swish_with_sigmoid_with_beta_different_tensors(self):
+        graph = build_graph_with_edge_attrs({
+            **regular_op('input', {'type': 'Parameter'}),
+            **regular_op('input_2', {'type': 'Parameter'}),
+            **regular_op('beta', {'type': 'Parameter'}),
+            **regular_op('mul_beta', {'op': 'Mul'}),
+            **regular_op('sigmoid', {'op': 'Sigmoid'}),
+            **regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
+            **result('result'),
+        }, [('input', 'mul_beta', {'in': 0, 'out': 0}),
+            ('input_2', 'mul_2', {'in': 0, 'out': 0}),
+            ('beta', 'mul_beta', {'in': 1, 'out': 0}),
+            ('mul_beta', 'sigmoid', {'in': 0, 'out': 0}),
+            ('sigmoid', 'mul_2', {'in': 1, 'out': 0}),
+            ('mul_2', 'result', {'in': 0, 'out': 0})], {})
+
+        graph_ref = graph.copy()
+        graph.stage = 'front'
+
+        SwishWithSigmoidWithBeta().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
+        self.assertTrue(flag, resp)
index 9f26c94..fc462b6 100644 (file)
@@ -29,10 +29,10 @@ class TestAxpyReplacer(unittest.TestCase):
             'axpy': {'type': 'Axpy', 'kind': 'op', 'op': 'Axpy'},
             'node_4': {'kind': 'op', 'type': 'Identity', 'op': 'Parameter'}}
         edges = [
-            ('node_1', 'axpy', {'in': 0}),
-            ('node_2', 'axpy', {'in': 1}),
-            ('node_3', 'axpy', {'in': 2}),
-            ('axpy', 'node_4', {'in': 0})]
+            ('node_1', 'axpy', {'in': 0, 'out': 0}),
+            ('node_2', 'axpy', {'in': 1, 'out': 0}),
+            ('node_3', 'axpy', {'in': 2, 'out': 0}),
+            ('axpy', 'node_4', {'in': 0, 'out': 0})]
         graph = build_graph_with_edge_attrs(nodes, edges)
         node = Node(graph, 'axpy')
         replacer = AxpyToSSandAdd()
index fbe6c33..e4a6099 100644 (file)
@@ -61,9 +61,9 @@ class TestFIFOQueueReplacement(unittest.TestCase):
             'image_batch': {'op': 'Identity', 'data_type': np.float32, 'kind': 'op'},
         }
         edges_no_label = [
-            ('placeholder', 'batch_join', {'out': 0}),
-            ('batch_join/fifo_queue', 'batch_join', {'out': 0}),
-            ('batch_join', 'image_batch', {'out': 0})
+            ('placeholder', 'batch_join', {'out': 0, 'in': 0}),
+            ('batch_join/fifo_queue', 'batch_join', {'out': 0, 'in': 1}),
+            ('batch_join', 'image_batch', {'out': 0, 'in': 0})
         ]
 
         graph = build_graph_with_edge_attrs(nodes_no_label, edges_no_label)
index c6c3f3e..162ebf7 100644 (file)
@@ -115,8 +115,6 @@ class Atanh(Activation):
 
 
 class ReLU6(AttributedClamp):
-    op = 'ReLU6'
-
     def __init__(self, graph: Graph, attrs: dict):
         relu6_attrs = {'min': 0, 'max': 6}
         relu6_attrs.update(attrs)
@@ -244,6 +242,12 @@ class Mish(Activation):
     operation = staticmethod(lambda x: x * np.tanh(np.ln(np.exp(x) + 1.0)))
 
 
+class HSwish(Activation):
+    op = 'HSwish'
+    version = 'opset4'
+    operation = staticmethod(lambda x: x * np.minimum(np.maximum(x + 3.0, 0.0), 6.0) / 6.0)
+
+
 class Swish(Op):
     op = 'Swish'
 
index ae4edb5..b7af97f 100644 (file)
@@ -236,6 +236,18 @@ def build_graph_with_edge_attrs(nodes_attrs: dict, edges: list, update_attribute
             assert (node_name in graph.nodes())
             for attr, value in new_attrs.items():
                 graph.node[node_name][attr] = value
+
+    for node in graph.get_op_nodes():
+        # Add in_ports attribute
+        in_edges = node.in_edges()
+        for attr in in_edges.values():
+            node.add_input_port(idx=attr['in'])
+
+        # Add out_ports attribute
+        out_edges = node.out_edges()
+        for attr in out_edges.values():
+            node.add_output_port(idx=attr['out'])
+
     graph.graph['cmd_params'] = cli
     return graph