[Static Runtime] Implement aten::nonzero out variant (#64126)
authorHarut Movsisyan <harutm@fb.com>
Tue, 31 Aug 2021 07:49:39 +0000 (00:49 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 07:51:15 +0000 (00:51 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64126

Test Plan:
Confirm out variant is called:

```
> buck run //caffe2/benchmarks/static_runtime:static_runtime_cpptest -- --v=1
```

Reviewed By: mikeiovine

Differential Revision: D30617729

fbshipit-source-id: 752749638c8f467815efa57021cb3de5c728ab1b

benchmarks/static_runtime/test_scripts.h
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/ops.cpp

index e26437f..37bb222 100644 (file)
@@ -752,6 +752,12 @@ const auto append_tensor_script = R"JIT(
       return lst
 )JIT";
 
+const auto nonzero_tensor = R"JIT(
+  def forward(self, input: Tensor):
+      a = torch.nonzero(input).clone()
+      return (a)
+)JIT";
+
 const std::string quantize_script = R"IR(
   graph(%input: Tensor, %weights: Tensor):
       %scale: float = prim::Constant[value=1.]()
index aa5cd35..8e498db 100644 (file)
@@ -1312,7 +1312,6 @@ TEST(StaticRuntime, IndividualOps_Cat) {
   testStaticRuntime(cat_script, args0, args1);
 }
 
-
 TEST(StaticRuntime, IndividualOps_Cumsum) {
   auto a = at::randn({2, 3});
   std::vector<IValue> args0{a, 0};
@@ -1333,3 +1332,11 @@ TEST(StaticRuntime, IndividualOps_CumsumDtype) {
   std::vector<IValue> args1{b, 1, dtype};
   testStaticRuntime(cumsum_script_dtype, args0, args1);
 }
+
+TEST(StaticRuntime, IndividualOps_Nonzero) {
+  auto a = at::randint(0, 2, {2, 3});
+  testStaticRuntime(nonzero_tensor, {a});
+
+  auto b = at::randint(0, 2, {4, 3, 2});
+  testStaticRuntime(nonzero_tensor, {a}, {b});
+}
index a73872b..0cc38b0 100644 (file)
@@ -1755,6 +1755,27 @@ REGISTER_OPERATOR_FUNCTOR(aten::cumsum, aten_cumsum, [](Node* n) -> SROperator {
   };
 });
 
+REGISTER_OPERATOR_FUNCTOR(
+    aten::nonzero,
+    aten_nonzero,
+    [](Node* n) -> SROperator {
+      if (!n->matches(torch::schema("aten::nonzero(Tensor self) -> Tensor"))) {
+        LogAndDumpSchema(n);
+        return nullptr;
+      }
+      return [](ProcessedNode* p_node) {
+        const auto& input = p_node->Input(0).toTensor();
+        if (p_node->Output(0).isNone()) {
+          p_node->Output(0) = at::native::nonzero_cpu(input);
+          return;
+        }
+
+        auto& output = p_node->Output(0).toTensor();
+        fastResizeToZero(output);
+        at::native::nonzero_out_cpu(input, output);
+      };
+    });
+
 namespace {
 
 void check_cat_no_zero_dim(const std::vector<at::Tensor>& tensors) {