import numpy as np
-from extensions.ops.activation_ops import Elu
+from extensions.ops.activation_ops import Elu, SoftPlus, Mish
from mo.graph.graph import Node
from mo.utils.unittest.graph import build_graph
class TestActivationOp(unittest.TestCase):
nodes_attributes = {
'node_1': {
- 'shape': np.array([227, 227, 227, 227]),
+ 'shape': np.array([4]),
'value': None
},
'activation_node': {
'op': 'Activation',
- 'kind': 'op'
+ 'kind': 'op',
+ 'operation': None
},
'node_3': {
'shape': None
graph.graph['layout'] = 'NCHW'
activation_node = Node(graph, 'activation_node')
Elu.infer(activation_node)
- exp_shape = np.array([227, 227, 227, 227])
+ exp_shape = np.array([4])
res_shape = graph.node['node_3']['shape']
res_value = graph.node['node_3']['value']
exp_value = np.array([6., -0.98168436, -0.86466472, -0.63212056])
self.assertEqual(res_shape[i], value)
for i, value in enumerate(exp_value):
self.assertAlmostEqual(res_value[i], value)
+
+ def test_activation_softplus_infer(self):
+ graph = build_graph(self.nodes_attributes,
+ [
+ ('node_1', 'activation_node'),
+ ('activation_node', 'node_3')
+ ],
+ {
+ 'node_1': {
+ 'value': np.array([-1.0, 0.0, 1.0, 20.0])
+ },
+ 'activation_node': {
+ 'op': 'SoftPlus',
+ 'operation': SoftPlus.operation,
+ },
+ 'node_3': {
+ 'value': None
+ }
+ })
+ graph.graph['layout'] = 'NCHW'
+ activation_node = Node(graph, 'activation_node')
+ SoftPlus.infer(activation_node)
+ exp_shape = np.array([4])
+ res_shape = graph.node['node_3']['shape']
+ res_value = graph.node['node_3']['value']
+ exp_value = np.array([0.3132617, 0.6931472, 1.3132617, 20.0])
+ for i, value in enumerate(exp_shape):
+ self.assertEqual(res_shape[i], value)
+ for i, value in enumerate(exp_value):
+ self.assertAlmostEqual(res_value[i], value)
+
+ def test_activation_mish_infer(self):
+ graph = build_graph(self.nodes_attributes,
+ [
+ ('node_1', 'activation_node'),
+ ('activation_node', 'node_3')
+ ],
+ {
+ 'node_1': {
+ 'value': np.array([-1.0, 0.0, 1.0, 20.0])
+ },
+ 'activation_node': {
+ 'op': 'Mish',
+ 'operation': Mish.operation,
+ },
+ 'node_3': {
+ 'value': None
+ }
+ })
+ graph.graph['layout'] = 'NCHW'
+ activation_node = Node(graph, 'activation_node')
+ Mish.infer(activation_node)
+ exp_shape = np.array([4])
+ res_shape = graph.node['node_3']['shape']
+ res_value = graph.node['node_3']['value']
+ exp_value = np.array([-0.30340146, 0.0, 0.8650984, 20.0])
+ for i, value in enumerate(exp_shape):
+ self.assertEqual(res_shape[i], value)
+ for i, value in enumerate(exp_value):
+ self.assertAlmostEqual(res_value[i], value)