[PYTORCH]celu, gelu, selu activations (#5263)
authorSamuel <siju.samuel@huawei.com>
Wed, 8 Apr 2020 03:45:41 +0000 (09:15 +0530)
committerGitHub <noreply@github.com>
Wed, 8 Apr 2020 03:45:41 +0000 (12:45 +0900)
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 1f10e60..46068a4 100644 (file)
@@ -216,15 +216,44 @@ def _prelu():
 def _leaky_relu():
     def _impl(inputs, input_types):
         data = inputs[0]
-        alpha = int(inputs[1])
+        alpha = float(inputs[1])
         return _op.nn.leaky_relu(data, alpha)
     return _impl
 
 def _elu():
     def _impl(inputs, input_types):
         data = inputs[0]
-        alpha = _expr.const(int(inputs[1]), dtype='float32')
-        return alpha * _op.nn.relu(alpha - _op.exp(data)) + _op.nn.relu(data)
+        alpha = _expr.const(float(inputs[1]))
+        return alpha * _op.nn.relu(_expr.const(1.0) - _op.exp(data)) + _op.nn.relu(data)
+    return _impl
+
+def _celu():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        alpha = _expr.const(float(inputs[1]))
+        return alpha * _op.nn.relu(_expr.const(1.0) - _op.exp(data / alpha)) + _op.nn.relu(data)
+    return _impl
+
+def _gelu():
+    def _impl(inputs, input_types):
+        import math
+        data = inputs[0]
+
+        def _pow3(x):
+            return x * x * x
+        return _expr.const(0.5) * data * (_expr.const(1.0) +
+                                          _op.tanh(_expr.const(math.sqrt(2.0 / math.pi)) *
+                                                   (data + _expr.const(0.044715) * _pow3(data))))
+    return _impl
+
+def _selu():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        # https://pytorch.org/docs/stable/nn.html#selu
+        alpha = _expr.const(-1.6732632423543772848170429916717)
+        gamma = _expr.const(1.0507009873554804934193349852946)
+        return gamma * (alpha * _op.nn.relu(_expr.const(1.0)
+                                            - _op.exp(data)) + _op.nn.relu(data))
     return _impl
 
 def _log_sigmoid():
@@ -1066,6 +1095,9 @@ _convert_map = {
     "aten::prelu"                           : _prelu(),
     "aten::leaky_relu"                      : _leaky_relu(),
     "aten::elu"                             : _elu(),
+    "aten::celu"                            : _celu(),
+    "aten::gelu"                            : _gelu(),
+    "aten::selu"                            : _selu(),
     "aten::log_sigmoid"                     : _log_sigmoid(),
     "aten::adaptive_avg_pool2d"             : _adaptive_avg_pool_2d(),
     "aten::adaptive_max_pool2d"             : _adaptive_max_pool_2d(),
index fb3f18b..05bf7e4 100644 (file)
@@ -353,16 +353,43 @@ def test_forward_prelu():
 
 def test_forward_leakyrelu():
     torch.set_grad_enabled(False)
-    input_shape = [10, 10]
+    input_shape = [1, 3, 10, 10]
     input_data = torch.rand(input_shape).float()
+    verify_model(torch.nn.LeakyReLU().eval(), input_data=input_data)
     verify_model(torch.nn.LeakyReLU(negative_slope=0.05).eval(), input_data=input_data)
+    verify_model(torch.nn.LeakyReLU(negative_slope=1.0).eval(), input_data=input_data)
+    verify_model(torch.nn.LeakyReLU(negative_slope=1.25).eval(), input_data=input_data)
 
 def test_forward_elu():
     torch.set_grad_enabled(False)
-    input_shape = [10, 10]
+    input_shape = [1, 3, 10, 10]
     input_data = torch.rand(input_shape).float()
+    verify_model(torch.nn.ELU().eval(), input_data=input_data)
+    verify_model(torch.nn.ELU(alpha=0.3).eval(), input_data=input_data)
+    verify_model(torch.nn.ELU(alpha=1.0).eval(), input_data=input_data)
     verify_model(torch.nn.ELU(alpha=1.3).eval(), input_data=input_data)
 
+def test_forward_celu():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+    input_data = torch.rand(input_shape).float()
+    verify_model(torch.nn.CELU().eval(), input_data=input_data)
+    verify_model(torch.nn.CELU(alpha=0.3).eval(), input_data=input_data)
+    verify_model(torch.nn.CELU(alpha=1.0).eval(), input_data=input_data)
+    verify_model(torch.nn.CELU(alpha=1.3).eval(), input_data=input_data)
+
+def test_forward_gelu():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+    input_data = torch.rand(input_shape).float()
+    verify_model(torch.nn.GELU().eval(), input_data=input_data)
+
+def test_forward_selu():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+    input_data = torch.rand(input_shape).float()
+    verify_model(torch.nn.SELU().eval(), input_data=input_data)
+
 def test_forward_log_sigmoid():
     torch.set_grad_enabled(False)
     input_shape = [10, 10]
@@ -1131,6 +1158,9 @@ if __name__ == "__main__":
     test_forward_prelu()
     test_forward_leakyrelu()
     test_forward_elu()
+    test_forward_celu()
+    test_forward_gelu()
+    test_forward_selu()
     test_forward_log_sigmoid()
     test_forward_adaptiveavgpool()
     test_forward_maxpool2d()