Fix Mish and SoftPlus value propagation functions (#2120)
authorAnton Chetverikov <Anton.Chetverikov@intel.com>
Fri, 11 Sep 2020 09:58:14 +0000 (12:58 +0300)
committerGitHub <noreply@github.com>
Fri, 11 Sep 2020 09:58:14 +0000 (12:58 +0300)
* Fix Mish and SoftPlus value propagation functions

* Add unit tests for SoftPlus & Mish operations value propagation functions

model-optimizer/extensions/ops/activation_ops.py
model-optimizer/extensions/ops/activation_test.py

index 162ebf7..49396f4 100644 (file)
@@ -233,13 +233,13 @@ class Log(Activation):
 class SoftPlus(Activation):
     op = 'SoftPlus'
     version = 'opset4'
-    operation = staticmethod(lambda x: np.ln(np.exp(x) + 1.0))
+    operation = staticmethod(lambda x: np.log(np.exp(x) + 1.0))
 
 
 class Mish(Activation):
     op = 'Mish'
     version = 'opset4'
-    operation = staticmethod(lambda x: x * np.tanh(np.ln(np.exp(x) + 1.0)))
+    operation = staticmethod(lambda x: x * np.tanh(np.log(np.exp(x) + 1.0)))
 
 
 class HSwish(Activation):
index fb53cb5..17270d7 100644 (file)
@@ -18,7 +18,7 @@ import unittest
 
 import numpy as np
 
-from extensions.ops.activation_ops import Elu
+from extensions.ops.activation_ops import Elu, SoftPlus, Mish
 from mo.graph.graph import Node
 from mo.utils.unittest.graph import build_graph
 
@@ -26,12 +26,13 @@ from mo.utils.unittest.graph import build_graph
 class TestActivationOp(unittest.TestCase):
     nodes_attributes = {
         'node_1': {
-            'shape': np.array([227, 227, 227, 227]),
+            'shape': np.array([4]),
             'value': None
         },
         'activation_node': {
             'op': 'Activation',
-            'kind': 'op'
+            'kind': 'op',
+            'operation': None
         },
         'node_3': {
             'shape': None
@@ -59,7 +60,7 @@ class TestActivationOp(unittest.TestCase):
         graph.graph['layout'] = 'NCHW'
         activation_node = Node(graph, 'activation_node')
         Elu.infer(activation_node)
-        exp_shape = np.array([227, 227, 227, 227])
+        exp_shape = np.array([4])
         res_shape = graph.node['node_3']['shape']
         res_value = graph.node['node_3']['value']
         exp_value = np.array([6., -0.98168436, -0.86466472, -0.63212056])
@@ -67,3 +68,63 @@ class TestActivationOp(unittest.TestCase):
             self.assertEqual(res_shape[i], value)
         for i, value in enumerate(exp_value):
             self.assertAlmostEqual(res_value[i], value)
+
+    def test_activation_softplus_infer(self):
+        graph = build_graph(self.nodes_attributes,
+                            [
+                                ('node_1', 'activation_node'),
+                                ('activation_node', 'node_3')
+                            ],
+                            {
+                                'node_1': {
+                                    'value': np.array([-1.0, 0.0, 1.0, 20.0])
+                                },
+                                'activation_node': {
+                                    'op': 'SoftPlus',
+                                    'operation': SoftPlus.operation,
+                                },
+                                'node_3': {
+                                    'value': None
+                                }
+                            })
+        graph.graph['layout'] = 'NCHW'
+        activation_node = Node(graph, 'activation_node')
+        SoftPlus.infer(activation_node)
+        exp_shape = np.array([4])
+        res_shape = graph.node['node_3']['shape']
+        res_value = graph.node['node_3']['value']
+        exp_value = np.array([0.3132617, 0.6931472, 1.3132617, 20.0])
+        for i, value in enumerate(exp_shape):
+            self.assertEqual(res_shape[i], value)
+        for i, value in enumerate(exp_value):
+            self.assertAlmostEqual(res_value[i], value)
+
+    def test_activation_mish_infer(self):
+        graph = build_graph(self.nodes_attributes,
+                            [
+                                ('node_1', 'activation_node'),
+                                ('activation_node', 'node_3')
+                            ],
+                            {
+                                'node_1': {
+                                    'value': np.array([-1.0, 0.0, 1.0, 20.0])
+                                },
+                                'activation_node': {
+                                    'op': 'Mish',
+                                    'operation': Mish.operation,
+                                },
+                                'node_3': {
+                                    'value': None
+                                }
+                            })
+        graph.graph['layout'] = 'NCHW'
+        activation_node = Node(graph, 'activation_node')
+        Mish.infer(activation_node)
+        exp_shape = np.array([4])
+        res_shape = graph.node['node_3']['shape']
+        res_value = graph.node['node_3']['value']
+        exp_value = np.array([-0.30340146, 0.0, 0.8650984, 20.0])
+        for i, value in enumerate(exp_shape):
+            self.assertEqual(res_shape[i], value)
+        for i, value in enumerate(exp_value):
+            self.assertAlmostEqual(res_value[i], value)