From: Mike Iovine Date: Thu, 2 Sep 2021 15:12:48 +0000 (-0700) Subject: [Static Runtime] Add sign/abs/lop1p/mul fusion pass (#64209) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~475 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=616fd9219da18bcfe69da8b0c3a96dd2c6298066;p=platform%2Fupstream%2Fpytorch.git [Static Runtime] Add sign/abs/lop1p/mul fusion pass (#64209) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64209 Add a new fusion pass that turns transforms the following pattern: ``` graph(%input): %0 : Tensor = aten::sign(%input) %1 : Tensor = aten::abs(%input) %2 : Tensor = aten::log1p(%1) %res : Tensor = aten::mul(%0, %2) return (%res) ``` Into a single op: ``` graph(%input): %res : Tensor = static_runtim::signed_log1p(%input) return (%res) ``` The intent is to reduce the number of passes over the tensor. However, enabling this pass actually causes a performance regression, probably due to a lack of vectorization in the fused implementation. Because of this issue, this diff **does not** enable this pass. Followup: navahgar will add an NNC kernel which is faster than the the unfused version and enable this pass. We still need this version as a fallback since the NNC kernel will not support all dtypes. Test Plan: `buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest -- SignedLog1p` Test passed with new graph pass disabled and enabled. Reviewed By: hlu1 Differential Revision: D30559929 fbshipit-source-id: e4e080cb2e6a705cfdde1fc98bee92b723f8132a --- diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h index 99b73db..b17dded 100644 --- a/benchmarks/static_runtime/test_scripts.h +++ b/benchmarks/static_runtime/test_scripts.h @@ -827,3 +827,14 @@ const auto cumsum_script_dtype = R"JIT( def forward(self, a: Tensor, dim: int, dtype: int): return torch.cumsum(a, dim, dtype=dtype).clone() )JIT"; + +const std::string signed_log1p_script = R"IR( + graph(%input): + %0 : Tensor = aten::sign(%input) + %1 : Tensor = aten::abs(%input) + %2 : Tensor = aten::log1p(%1) + %3 : Tensor = aten::mul(%0, %2) + %none : NoneType = prim::Constant() + %res : Tensor = aten::clone(%3, %none) + return (%res) +)IR"; diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 16941da..5eb3dfe 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -1356,3 +1356,11 @@ TEST(StaticRuntime, IndividualOps_Nonzero) { auto b = at::randint(0, 2, {4, 3, 2}); testStaticRuntime(nonzero_tensor, {a}, {b}); } + +TEST(StaticRuntime, SignedLog1p) { + std::vector args1 = {at::randn({2, 2})}; + testStaticRuntime(signed_log1p_script, args1, {}, true); + + std::vector args2 = {at::randn({3, 3, 3})}; + testStaticRuntime(signed_log1p_script, args1, args2, true); +} diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 7ede15c..62f5bb2 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1837,5 +1837,68 @@ REGISTER_OPERATOR_FUNCTOR( } }; }); + +namespace { + +// This template and its specialization help us avoid compiler warnings +// about taking the absolute value of an unsigned type in signed_log1p +template +T abs_if_signed(T val) { + return std::abs(val); +} + +template <> +unsigned char abs_if_signed(unsigned char val) { + return val; +} + +// Computes f(x) = sign(x) * ln(|1 + x|) for each x in the input tensor +void signed_log1p_out(at::Tensor& out, const at::Tensor& input) { + at::native::resize_(out, input.sizes(), c10::nullopt); + + const auto input_contig = input.expect_contiguous(); + auto output_contig = out.expect_contiguous(); + + AT_DISPATCH_ALL_TYPES(input.scalar_type(), "signed_log1p_kernel", [&]() { + const auto input_data = input_contig->data_ptr(); + auto output_data = output_contig->data_ptr(); + const auto N = input.numel(); + + for (const auto i : c10::irange(N)) { + const int sign = input_data[i] < 0 ? -1 : 1; + output_data[i] = std::log1p(abs_if_signed(input_data[i])) * sign; + } + }); +} + +at::Tensor signed_log1p(const at::Tensor& input) { + auto out = create_empty_from(input); + signed_log1p_out(out, input); + return out; +} + +} // namespace + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_OPERATOR_FUNCTOR( + static_runtime::signed_log1p, + static_runtime_signed_log1p, + [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "static_runtime::signed_log1p(Tensor x) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& input = p_node->Input(0).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = signed_log1p(input); + } else { + auto& out = p_node->Output(0).toTensor(); + fastResizeToZero(out); + signed_log1p_out(out, input); + } + }; + }); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index 5099dc1..0eaebfd 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -306,6 +306,28 @@ TORCH_LIBRARY_FRAGMENT(static_runtime, m) { m.def(torch::schema( "static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor, Tensor, Tensor)", c10::AliasAnalysisKind::PURE_FUNCTION)); + m.def("static_runtime::signed_log1p(Tensor input) -> Tensor"); +} + +void FuseSignLog1P(std::shared_ptr& graph) { + std::string pattern = R"IR( + graph(%input): + %0 : Tensor = aten::sign(%input) + %1 : Tensor = aten::abs(%input) + %2 : Tensor = aten::log1p(%1) + %res : Tensor = aten::mul(%0, %2) + return (%res) + )IR"; + + std::string fused_pattern = R"IR( + graph(%input): + %res : Tensor = static_runtime::signed_log1p(%input) + return (%res) + )IR"; + + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, fused_pattern); + fuse.runOnGraph(graph); } bool HasInplaceOp(std::shared_ptr& graph, const AliasDb& alias_db) { diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h index a42bc97..0904d37 100644 --- a/torch/csrc/jit/runtime/static/passes.h +++ b/torch/csrc/jit/runtime/static/passes.h @@ -20,5 +20,7 @@ TORCH_API bool HasInplaceOp( std::shared_ptr& graph, const AliasDb& alias_db); +TORCH_API void FuseSignLog1P(std::shared_ptr& graph); + } // namespace jit } // namespace torch