Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64209
Add a new fusion pass that turns transforms the following pattern:
```
graph(%input):
%0 : Tensor = aten::sign(%input)
%1 : Tensor = aten::abs(%input)
%2 : Tensor = aten::log1p(%1)
%res : Tensor = aten::mul(%0, %2)
return (%res)
```
Into a single op:
```
graph(%input):
%res : Tensor = static_runtim::signed_log1p(%input)
return (%res)
```
The intent is to reduce the number of passes over the tensor. However, enabling this pass actually causes a performance regression, probably due to a lack of vectorization in the fused implementation. Because of this issue, this diff **does not** enable this pass.
Followup: navahgar will add an NNC kernel which is faster than the the unfused version and enable this pass. We still need this version as a fallback since the NNC kernel will not support all dtypes.
Test Plan:
`buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest -- SignedLog1p`
Test passed with new graph pass disabled and enabled.
Reviewed By: hlu1
Differential Revision:
D30559929
fbshipit-source-id:
e4e080cb2e6a705cfdde1fc98bee92b723f8132a
def forward(self, a: Tensor, dim: int, dtype: int):
return torch.cumsum(a, dim, dtype=dtype).clone()
)JIT";
+
+const std::string signed_log1p_script = R"IR(
+ graph(%input):
+ %0 : Tensor = aten::sign(%input)
+ %1 : Tensor = aten::abs(%input)
+ %2 : Tensor = aten::log1p(%1)
+ %3 : Tensor = aten::mul(%0, %2)
+ %none : NoneType = prim::Constant()
+ %res : Tensor = aten::clone(%3, %none)
+ return (%res)
+)IR";
auto b = at::randint(0, 2, {4, 3, 2});
testStaticRuntime(nonzero_tensor, {a}, {b});
}
+
+TEST(StaticRuntime, SignedLog1p) {
+ std::vector<IValue> args1 = {at::randn({2, 2})};
+ testStaticRuntime(signed_log1p_script, args1, {}, true);
+
+ std::vector<IValue> args2 = {at::randn({3, 3, 3})};
+ testStaticRuntime(signed_log1p_script, args1, args2, true);
+}
}
};
});
+
+namespace {
+
+// This template and its specialization help us avoid compiler warnings
+// about taking the absolute value of an unsigned type in signed_log1p
+template <class T>
+T abs_if_signed(T val) {
+ return std::abs(val);
+}
+
+template <>
+unsigned char abs_if_signed<unsigned char>(unsigned char val) {
+ return val;
+}
+
+// Computes f(x) = sign(x) * ln(|1 + x|) for each x in the input tensor
+void signed_log1p_out(at::Tensor& out, const at::Tensor& input) {
+ at::native::resize_(out, input.sizes(), c10::nullopt);
+
+ const auto input_contig = input.expect_contiguous();
+ auto output_contig = out.expect_contiguous();
+
+ AT_DISPATCH_ALL_TYPES(input.scalar_type(), "signed_log1p_kernel", [&]() {
+ const auto input_data = input_contig->data_ptr<scalar_t>();
+ auto output_data = output_contig->data_ptr<float>();
+ const auto N = input.numel();
+
+ for (const auto i : c10::irange(N)) {
+ const int sign = input_data[i] < 0 ? -1 : 1;
+ output_data[i] = std::log1p(abs_if_signed(input_data[i])) * sign;
+ }
+ });
+}
+
+at::Tensor signed_log1p(const at::Tensor& input) {
+ auto out = create_empty_from(input);
+ signed_log1p_out(out, input);
+ return out;
+}
+
+} // namespace
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_OPERATOR_FUNCTOR(
+ static_runtime::signed_log1p,
+ static_runtime_signed_log1p,
+ [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "static_runtime::signed_log1p(Tensor x) -> Tensor"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ const auto& input = p_node->Input(0).toTensor();
+ if (p_node->Output(0).isNone()) {
+ p_node->Output(0) = signed_log1p(input);
+ } else {
+ auto& out = p_node->Output(0).toTensor();
+ fastResizeToZero(out);
+ signed_log1p_out(out, input);
+ }
+ };
+ });
} // namespace jit
} // namespace torch
m.def(torch::schema(
"static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor, Tensor, Tensor)",
c10::AliasAnalysisKind::PURE_FUNCTION));
+ m.def("static_runtime::signed_log1p(Tensor input) -> Tensor");
+}
+
+void FuseSignLog1P(std::shared_ptr<torch::jit::Graph>& graph) {
+ std::string pattern = R"IR(
+ graph(%input):
+ %0 : Tensor = aten::sign(%input)
+ %1 : Tensor = aten::abs(%input)
+ %2 : Tensor = aten::log1p(%1)
+ %res : Tensor = aten::mul(%0, %2)
+ return (%res)
+ )IR";
+
+ std::string fused_pattern = R"IR(
+ graph(%input):
+ %res : Tensor = static_runtime::signed_log1p(%input)
+ return (%res)
+ )IR";
+
+ SubgraphRewriter fuse;
+ fuse.RegisterRewritePattern(pattern, fused_pattern);
+ fuse.runOnGraph(graph);
}
bool HasInplaceOp(std::shared_ptr<Graph>& graph, const AliasDb& alias_db) {
std::shared_ptr<Graph>& graph,
const AliasDb& alias_db);
+TORCH_API void FuseSignLog1P(std::shared_ptr<Graph>& graph);
+
} // namespace jit
} // namespace torch