From e6e7f5158a684aa241a498385b18798b35e68a55 Mon Sep 17 00:00:00 2001 From: Anton Chetverikov Date: Fri, 11 Sep 2020 12:58:14 +0300 Subject: [PATCH] Fix Mish and SoftPlus value propagation functions (#2120) * Fix Mish and SoftPlus value propagation functions * Add unit tests for SoftPlus & Mish operations value propagation functions --- model-optimizer/extensions/ops/activation_ops.py | 4 +- model-optimizer/extensions/ops/activation_test.py | 69 +++++++++++++++++++++-- 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/model-optimizer/extensions/ops/activation_ops.py b/model-optimizer/extensions/ops/activation_ops.py index 162ebf7..49396f4 100644 --- a/model-optimizer/extensions/ops/activation_ops.py +++ b/model-optimizer/extensions/ops/activation_ops.py @@ -233,13 +233,13 @@ class Log(Activation): class SoftPlus(Activation): op = 'SoftPlus' version = 'opset4' - operation = staticmethod(lambda x: np.ln(np.exp(x) + 1.0)) + operation = staticmethod(lambda x: np.log(np.exp(x) + 1.0)) class Mish(Activation): op = 'Mish' version = 'opset4' - operation = staticmethod(lambda x: x * np.tanh(np.ln(np.exp(x) + 1.0))) + operation = staticmethod(lambda x: x * np.tanh(np.log(np.exp(x) + 1.0))) class HSwish(Activation): diff --git a/model-optimizer/extensions/ops/activation_test.py b/model-optimizer/extensions/ops/activation_test.py index fb53cb5..17270d7 100644 --- a/model-optimizer/extensions/ops/activation_test.py +++ b/model-optimizer/extensions/ops/activation_test.py @@ -18,7 +18,7 @@ import unittest import numpy as np -from extensions.ops.activation_ops import Elu +from extensions.ops.activation_ops import Elu, SoftPlus, Mish from mo.graph.graph import Node from mo.utils.unittest.graph import build_graph @@ -26,12 +26,13 @@ from mo.utils.unittest.graph import build_graph class TestActivationOp(unittest.TestCase): nodes_attributes = { 'node_1': { - 'shape': np.array([227, 227, 227, 227]), + 'shape': np.array([4]), 'value': None }, 'activation_node': { 'op': 'Activation', - 'kind': 'op' + 'kind': 'op', + 'operation': None }, 'node_3': { 'shape': None @@ -59,7 +60,7 @@ class TestActivationOp(unittest.TestCase): graph.graph['layout'] = 'NCHW' activation_node = Node(graph, 'activation_node') Elu.infer(activation_node) - exp_shape = np.array([227, 227, 227, 227]) + exp_shape = np.array([4]) res_shape = graph.node['node_3']['shape'] res_value = graph.node['node_3']['value'] exp_value = np.array([6., -0.98168436, -0.86466472, -0.63212056]) @@ -67,3 +68,63 @@ class TestActivationOp(unittest.TestCase): self.assertEqual(res_shape[i], value) for i, value in enumerate(exp_value): self.assertAlmostEqual(res_value[i], value) + + def test_activation_softplus_infer(self): + graph = build_graph(self.nodes_attributes, + [ + ('node_1', 'activation_node'), + ('activation_node', 'node_3') + ], + { + 'node_1': { + 'value': np.array([-1.0, 0.0, 1.0, 20.0]) + }, + 'activation_node': { + 'op': 'SoftPlus', + 'operation': SoftPlus.operation, + }, + 'node_3': { + 'value': None + } + }) + graph.graph['layout'] = 'NCHW' + activation_node = Node(graph, 'activation_node') + SoftPlus.infer(activation_node) + exp_shape = np.array([4]) + res_shape = graph.node['node_3']['shape'] + res_value = graph.node['node_3']['value'] + exp_value = np.array([0.3132617, 0.6931472, 1.3132617, 20.0]) + for i, value in enumerate(exp_shape): + self.assertEqual(res_shape[i], value) + for i, value in enumerate(exp_value): + self.assertAlmostEqual(res_value[i], value) + + def test_activation_mish_infer(self): + graph = build_graph(self.nodes_attributes, + [ + ('node_1', 'activation_node'), + ('activation_node', 'node_3') + ], + { + 'node_1': { + 'value': np.array([-1.0, 0.0, 1.0, 20.0]) + }, + 'activation_node': { + 'op': 'Mish', + 'operation': Mish.operation, + }, + 'node_3': { + 'value': None + } + }) + graph.graph['layout'] = 'NCHW' + activation_node = Node(graph, 'activation_node') + Mish.infer(activation_node) + exp_shape = np.array([4]) + res_shape = graph.node['node_3']['shape'] + res_value = graph.node['node_3']['value'] + exp_value = np.array([-0.30340146, 0.0, 0.8650984, 20.0]) + for i, value in enumerate(exp_shape): + self.assertEqual(res_shape[i], value) + for i, value in enumerate(exp_value): + self.assertAlmostEqual(res_value[i], value) -- 2.7.4