From 3c15822f5f4ab616eb6a519a0ff9b82fc7a3dc63 Mon Sep 17 00:00:00 2001 From: Harut Movsisyan Date: Tue, 31 Aug 2021 00:49:39 -0700 Subject: [PATCH] [Static Runtime] Implement aten::nonzero out variant (#64126) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64126 Test Plan: Confirm out variant is called: ``` > buck run //caffe2/benchmarks/static_runtime:static_runtime_cpptest -- --v=1 ``` Reviewed By: mikeiovine Differential Revision: D30617729 fbshipit-source-id: 752749638c8f467815efa57021cb3de5c728ab1b --- benchmarks/static_runtime/test_scripts.h | 6 ++++++ benchmarks/static_runtime/test_static_runtime.cc | 9 ++++++++- torch/csrc/jit/runtime/static/ops.cpp | 21 +++++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h index e26437f..37bb222 100644 --- a/benchmarks/static_runtime/test_scripts.h +++ b/benchmarks/static_runtime/test_scripts.h @@ -752,6 +752,12 @@ const auto append_tensor_script = R"JIT( return lst )JIT"; +const auto nonzero_tensor = R"JIT( + def forward(self, input: Tensor): + a = torch.nonzero(input).clone() + return (a) +)JIT"; + const std::string quantize_script = R"IR( graph(%input: Tensor, %weights: Tensor): %scale: float = prim::Constant[value=1.]() diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index aa5cd35..8e498db 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -1312,7 +1312,6 @@ TEST(StaticRuntime, IndividualOps_Cat) { testStaticRuntime(cat_script, args0, args1); } - TEST(StaticRuntime, IndividualOps_Cumsum) { auto a = at::randn({2, 3}); std::vector args0{a, 0}; @@ -1333,3 +1332,11 @@ TEST(StaticRuntime, IndividualOps_CumsumDtype) { std::vector args1{b, 1, dtype}; testStaticRuntime(cumsum_script_dtype, args0, args1); } + +TEST(StaticRuntime, IndividualOps_Nonzero) { + auto a = at::randint(0, 2, {2, 3}); + testStaticRuntime(nonzero_tensor, {a}); + + auto b = at::randint(0, 2, {4, 3, 2}); + testStaticRuntime(nonzero_tensor, {a}, {b}); +} diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index a73872b..0cc38b0 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1755,6 +1755,27 @@ REGISTER_OPERATOR_FUNCTOR(aten::cumsum, aten_cumsum, [](Node* n) -> SROperator { }; }); +REGISTER_OPERATOR_FUNCTOR( + aten::nonzero, + aten_nonzero, + [](Node* n) -> SROperator { + if (!n->matches(torch::schema("aten::nonzero(Tensor self) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& input = p_node->Input(0).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::native::nonzero_cpu(input); + return; + } + + auto& output = p_node->Output(0).toTensor(); + fastResizeToZero(output); + at::native::nonzero_out_cpu(input, output); + }; + }); + namespace { void check_cat_no_zero_dim(const std::vector& tensors) { -- 2.7.4