[Static Runtime] Implement aten::cumsum out variant (#64159)
authorHarut Movsisyan <harutm@fb.com>
Mon, 30 Aug 2021 23:16:45 +0000 (16:16 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 23:18:22 +0000 (16:18 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64159

Test Plan:
Confirm out variant is called for both versions:

```
> buck run //caffe2/benchmarks/static_runtime:static_runtime_cpptest -- --v=1
```

Reviewed By: mikeiovine

Differential Revision: D30622819

fbshipit-source-id: a2c8c7f969dae5f507718fb3d513e1fb4f026736

benchmarks/static_runtime/test_scripts.h
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/ops.cpp

index 7fdb113..e26437f 100644 (file)
@@ -801,3 +801,13 @@ const std::string cat_script = R"IR(
       %ret: Tensor = aten::cat(%ten_list2, %dim)
       return (%ret)
 )IR";
+
+const auto cumsum_script = R"JIT(
+   def forward(self, a: Tensor, dim: int):
+      return torch.cumsum(a, dim).clone()
+)JIT";
+
+const auto cumsum_script_dtype = R"JIT(
+   def forward(self, a: Tensor, dim: int, dtype: int):
+      return torch.cumsum(a, dim, dtype=dtype).clone()
+)JIT";
index b7201ba..aa5cd35 100644 (file)
@@ -1311,3 +1311,25 @@ TEST(StaticRuntime, IndividualOps_Cat) {
   std::vector<IValue> args1{c, d, 1};
   testStaticRuntime(cat_script, args0, args1);
 }
+
+
+TEST(StaticRuntime, IndividualOps_Cumsum) {
+  auto a = at::randn({2, 3});
+  std::vector<IValue> args0{a, 0};
+  testStaticRuntime(cumsum_script, args0);
+
+  auto b = at::randn({4, 3});
+  std::vector<IValue> args1{b, 1};
+  testStaticRuntime(cumsum_script, args0, args1);
+}
+
+TEST(StaticRuntime, IndividualOps_CumsumDtype) {
+  auto a = at::randn({1, 2});
+  auto dtype = at::ScalarType::Float;
+  std::vector<IValue> args0{a, 0, dtype};
+  testStaticRuntime(cumsum_script_dtype, args0);
+
+  auto b = at::randn({3, 4});
+  std::vector<IValue> args1{b, 1, dtype};
+  testStaticRuntime(cumsum_script_dtype, args0, args1);
+}
index cf91f33..a73872b 100644 (file)
@@ -1733,6 +1733,28 @@ REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator {
   };
 });
 
+REGISTER_OPERATOR_FUNCTOR(aten::cumsum, aten_cumsum, [](Node* n) -> SROperator {
+  if (!n->matches(torch::schema(
+          "aten::cumsum(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"))) {
+    LogAndDumpSchema(n);
+    return nullptr;
+  }
+  return [](ProcessedNode* p_node) {
+    const auto& input = p_node->Input(0).toTensor();
+    const auto dim = p_node->Input(1).toInt();
+    const auto dtype = p_node->Input(2).toOptional<c10::ScalarType>();
+
+    if (p_node->Output(0).isNone()) {
+      p_node->Output(0) = at::cpu::cumsum(input, dim, dtype);
+      return;
+    }
+
+    auto& output = p_node->Output(0).toTensor();
+    fastResizeToZero(output);
+    at::cpu::cumsum_out(output, input, dim, dtype);
+  };
+});
+
 namespace {
 
 void check_cat_no_zero_dim(const std::vector<at::Tensor>& tensors) {