[NNC] Add Softplus operator (#64589)
authorAnimesh Jain <anijain@fb.com>
Wed, 8 Sep 2021 17:48:09 +0000 (10:48 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 17:49:58 +0000 (10:49 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64589

Adding softplus operator lowering for NNC. Enabling element wise fusion as well.

Test Plan: Added a test in test_jit_fuser.py

Reviewed By: bertmaher

Differential Revision: D30736449

fbshipit-source-id: 6c5fc3bceb5cef2322ecd4449f827e4af018ea93

test/test_jit_fuser_te.py
torch/csrc/jit/passes/tensorexpr_fuser.cpp
torch/csrc/jit/tensorexpr/kernel.cpp

index a082ce5..b7830f9 100644 (file)
@@ -1255,6 +1255,7 @@ class TestTEFuser(JitTestCase):
             F.hardtanh,
             F.hardsigmoid,
             F.hardswish,
+            F.softplus,
             torch.sqrt,
             torch.rsqrt,
             F.gelu,
@@ -2015,6 +2016,7 @@ works_list = [
     'nn.functional.hardshrink',
     'nn.functional.hardsigmoid',
     'nn.functional.hardswish',
+    'nn.functional.softplus',
     'nn.functional.hardtanh',
     'nn.functional.leaky_relu',
     'nn.functional.relu',
index 75305d6..b505ce8 100644 (file)
@@ -140,6 +140,7 @@ const OperatorSet& supported_eltwise_set() {
       "aten::sigmoid(Tensor self) -> Tensor",
       "aten::relu(Tensor self) -> Tensor",
       "aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor",
+      "aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor",
       "aten::relu6(Tensor self) -> Tensor",
       "aten::gelu(Tensor self) -> Tensor",
       "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor",
index 8a8aee7..c5f7e99 100644 (file)
@@ -734,6 +734,7 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
     case aten::hardtanh:
     case aten::hardsigmoid:
     case aten::hardswish:
+    case aten::softplus:
     case aten::sqrt:
     case aten::rsqrt:
     case aten::abs:
@@ -2028,6 +2029,27 @@ Tensor tensorexpr::computeOperandValue(
           });
     } break;
 
+    case aten::softplus: {
+      return computeThreeOperand(
+          "aten_softplus",
+          inputs,
+          outputShape,
+          outputType,
+          [](const ExprHandle& a,
+             const ExprHandle& beta,
+             const ExprHandle& threshold) {
+            auto beta_promoted = Cast::make(a.dtype(), beta);
+            auto threshold_promoted = Cast::make(a.dtype(), threshold);
+            auto beta_a = beta_promoted * a;
+            return CompareSelect::make(
+                beta_a,
+                threshold_promoted,
+                a,
+                log1p(exp(beta_a)) / beta_promoted,
+                kGT);
+          });
+    } break;
+
     case aten::hardsigmoid: {
       return computeOneOperand(
           "aten_hardsigmoid",