Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64589
Adding softplus operator lowering for NNC. Enabling element wise fusion as well.
Test Plan: Added a test in test_jit_fuser.py
Reviewed By: bertmaher
Differential Revision:
D30736449
fbshipit-source-id:
6c5fc3bceb5cef2322ecd4449f827e4af018ea93
F.hardtanh,
F.hardsigmoid,
F.hardswish,
+ F.softplus,
torch.sqrt,
torch.rsqrt,
F.gelu,
'nn.functional.hardshrink',
'nn.functional.hardsigmoid',
'nn.functional.hardswish',
+ 'nn.functional.softplus',
'nn.functional.hardtanh',
'nn.functional.leaky_relu',
'nn.functional.relu',
"aten::sigmoid(Tensor self) -> Tensor",
"aten::relu(Tensor self) -> Tensor",
"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor",
+ "aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor",
"aten::relu6(Tensor self) -> Tensor",
"aten::gelu(Tensor self) -> Tensor",
"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor",
case aten::hardtanh:
case aten::hardsigmoid:
case aten::hardswish:
+ case aten::softplus:
case aten::sqrt:
case aten::rsqrt:
case aten::abs:
});
} break;
+ case aten::softplus: {
+ return computeThreeOperand(
+ "aten_softplus",
+ inputs,
+ outputShape,
+ outputType,
+ [](const ExprHandle& a,
+ const ExprHandle& beta,
+ const ExprHandle& threshold) {
+ auto beta_promoted = Cast::make(a.dtype(), beta);
+ auto threshold_promoted = Cast::make(a.dtype(), threshold);
+ auto beta_a = beta_promoted * a;
+ return CompareSelect::make(
+ beta_a,
+ threshold_promoted,
+ a,
+ log1p(exp(beta_a)) / beta_promoted,
+ kGT);
+ });
+ } break;
+
case aten::hardsigmoid: {
return computeOneOperand(
"aten_hardsigmoid",