[Static Runtime] Add sign/abs/lop1p/mul fusion pass (#64209)
authorMike Iovine <mikeiovine@fb.com>
Thu, 2 Sep 2021 15:12:48 +0000 (08:12 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 15:31:40 +0000 (08:31 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64209

Add a new fusion pass that turns transforms the following pattern:
```
graph(%input):
    %0 : Tensor = aten::sign(%input)
    %1 : Tensor = aten::abs(%input)
    %2 : Tensor = aten::log1p(%1)
    %res : Tensor = aten::mul(%0, %2)
    return (%res)
```
Into a single op:
```
graph(%input):
    %res : Tensor = static_runtim::signed_log1p(%input)
    return (%res)
```

The intent is to reduce the number of passes over the tensor. However, enabling this pass actually causes a performance regression, probably due to a lack of vectorization in the fused implementation. Because of this issue, this diff **does not** enable this pass.

Followup: navahgar will add an NNC kernel which is faster than the the unfused version and enable this pass. We still need this version as a fallback since the NNC kernel will not support all dtypes.

Test Plan:
`buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest -- SignedLog1p`

Test passed with new graph pass disabled and enabled.

Reviewed By: hlu1

Differential Revision: D30559929

fbshipit-source-id: e4e080cb2e6a705cfdde1fc98bee92b723f8132a

benchmarks/static_runtime/test_scripts.h
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/ops.cpp
torch/csrc/jit/runtime/static/passes.cpp
torch/csrc/jit/runtime/static/passes.h

index 99b73db..b17dded 100644 (file)
@@ -827,3 +827,14 @@ const auto cumsum_script_dtype = R"JIT(
    def forward(self, a: Tensor, dim: int, dtype: int):
       return torch.cumsum(a, dim, dtype=dtype).clone()
 )JIT";
+
+const std::string signed_log1p_script = R"IR(
+  graph(%input):
+      %0 : Tensor = aten::sign(%input)
+      %1 : Tensor = aten::abs(%input)
+      %2 : Tensor = aten::log1p(%1)
+      %3 : Tensor = aten::mul(%0, %2)
+      %none : NoneType = prim::Constant()
+      %res : Tensor = aten::clone(%3, %none)
+      return (%res)
+)IR";
index 16941da..5eb3dfe 100644 (file)
@@ -1356,3 +1356,11 @@ TEST(StaticRuntime, IndividualOps_Nonzero) {
   auto b = at::randint(0, 2, {4, 3, 2});
   testStaticRuntime(nonzero_tensor, {a}, {b});
 }
+
+TEST(StaticRuntime, SignedLog1p) {
+  std::vector<IValue> args1 = {at::randn({2, 2})};
+  testStaticRuntime(signed_log1p_script, args1, {}, true);
+
+  std::vector<IValue> args2 = {at::randn({3, 3, 3})};
+  testStaticRuntime(signed_log1p_script, args1, args2, true);
+}
index 7ede15c..62f5bb2 100644 (file)
@@ -1837,5 +1837,68 @@ REGISTER_OPERATOR_FUNCTOR(
         }
       };
     });
+
+namespace {
+
+// This template and its specialization help us avoid compiler warnings
+// about taking the absolute value of an unsigned type in signed_log1p
+template <class T>
+T abs_if_signed(T val) {
+  return std::abs(val);
+}
+
+template <>
+unsigned char abs_if_signed<unsigned char>(unsigned char val) {
+  return val;
+}
+
+// Computes f(x) = sign(x) * ln(|1 + x|) for each x in the input tensor
+void signed_log1p_out(at::Tensor& out, const at::Tensor& input) {
+  at::native::resize_(out, input.sizes(), c10::nullopt);
+
+  const auto input_contig = input.expect_contiguous();
+  auto output_contig = out.expect_contiguous();
+
+  AT_DISPATCH_ALL_TYPES(input.scalar_type(), "signed_log1p_kernel", [&]() {
+    const auto input_data = input_contig->data_ptr<scalar_t>();
+    auto output_data = output_contig->data_ptr<float>();
+    const auto N = input.numel();
+
+    for (const auto i : c10::irange(N)) {
+      const int sign = input_data[i] < 0 ? -1 : 1;
+      output_data[i] = std::log1p(abs_if_signed(input_data[i])) * sign;
+    }
+  });
+}
+
+at::Tensor signed_log1p(const at::Tensor& input) {
+  auto out = create_empty_from(input);
+  signed_log1p_out(out, input);
+  return out;
+}
+
+} // namespace
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_OPERATOR_FUNCTOR(
+    static_runtime::signed_log1p,
+    static_runtime_signed_log1p,
+    [](Node* n) -> SROperator {
+      if (!n->matches(torch::schema(
+              "static_runtime::signed_log1p(Tensor x) -> Tensor"))) {
+        LogAndDumpSchema(n);
+        return nullptr;
+      }
+      return [](ProcessedNode* p_node) {
+        const auto& input = p_node->Input(0).toTensor();
+        if (p_node->Output(0).isNone()) {
+          p_node->Output(0) = signed_log1p(input);
+        } else {
+          auto& out = p_node->Output(0).toTensor();
+          fastResizeToZero(out);
+          signed_log1p_out(out, input);
+        }
+      };
+    });
 } // namespace jit
 } // namespace torch
index 5099dc1..0eaebfd 100644 (file)
@@ -306,6 +306,28 @@ TORCH_LIBRARY_FRAGMENT(static_runtime, m) {
   m.def(torch::schema(
       "static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor, Tensor, Tensor)",
       c10::AliasAnalysisKind::PURE_FUNCTION));
+  m.def("static_runtime::signed_log1p(Tensor input) -> Tensor");
+}
+
+void FuseSignLog1P(std::shared_ptr<torch::jit::Graph>& graph) {
+  std::string pattern = R"IR(
+    graph(%input):
+        %0 : Tensor = aten::sign(%input)
+        %1 : Tensor = aten::abs(%input)
+        %2 : Tensor = aten::log1p(%1)
+        %res : Tensor = aten::mul(%0, %2)
+        return (%res)
+  )IR";
+
+  std::string fused_pattern = R"IR(
+    graph(%input):
+        %res : Tensor = static_runtime::signed_log1p(%input)
+        return (%res)
+    )IR";
+
+  SubgraphRewriter fuse;
+  fuse.RegisterRewritePattern(pattern, fused_pattern);
+  fuse.runOnGraph(graph);
 }
 
 bool HasInplaceOp(std::shared_ptr<Graph>& graph, const AliasDb& alias_db) {
index a42bc97..0904d37 100644 (file)
@@ -20,5 +20,7 @@ TORCH_API bool HasInplaceOp(
     std::shared_ptr<Graph>& graph,
     const AliasDb& alias_db);
 
+TORCH_API void FuseSignLog1P(std::shared_ptr<Graph>& graph);
+
 } // namespace jit
 } // namespace torch