From 913c1f83f49f9e1e2a494186cc0069d780cee852 Mon Sep 17 00:00:00 2001 From: Don Jang Date: Fri, 20 Aug 2021 00:43:40 -0700 Subject: [PATCH] [Static Runtime] Add native op for aten::detach (#63625) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63625 This change adds a static runtime's native op implementation for `aten::detach` op. See the standard `aten::detach`'s implementation (https://codebrowser.bddppq.com/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp.html#_ZN2at6native6detachERKNS_6TensorE ) for comparison. Test Plan: - Added `StaticRuntime.IndividualOps_Detach`. - Observed ``` V0819 18:55:33.181188 3092034 impl.cpp:1398] Switch to native impl for node: %a.1 : Tensor = aten::detach(%input.1) ``` Reviewed By: hlu1 Differential Revision: D30443187 fbshipit-source-id: d6e0eadb1b817e0a126c4fc97526abc276ee8a17 --- benchmarks/static_runtime/test_scripts.h | 12 ++++++++++++ benchmarks/static_runtime/test_static_runtime.cc | 11 +++++++++++ torch/csrc/jit/runtime/static/native_ops.cpp | 15 +++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h index 8db8da2..9946c7a 100644 --- a/benchmarks/static_runtime/test_scripts.h +++ b/benchmarks/static_runtime/test_scripts.h @@ -286,6 +286,18 @@ const auto to_script_4 = R"JIT( return (c) )JIT"; +const auto detach_script_0 = R"JIT( + def forward(self, input: Tensor): + a = input.detach() + return input is a +)JIT"; + +const auto detach_script_1 = R"JIT( + def forward(self, input: Tensor): + a = input.detach() + return a.clone() +)JIT"; + const std::string embedding_bag_default = R"JIT( def forward(self, a: Tensor, b: Tensor, c: Tensor): return torch.embedding_bag(a, b, c) diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 14d613f..ec703ef 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -589,6 +589,17 @@ TEST(StaticRuntime, IndividualOps_to) { test_to(at::ScalarType::Half, false, true, c10::MemoryFormat::ChannelsLast); } +TEST(StaticRuntime, IndividualOps_Detach) { + auto a = at::randn({4, 3, 1, 2}); + auto b = at::randn({3, 2, 2}); + std::vector args{a}; + std::vector args2{b}; + testStaticRuntime(detach_script_0, args); + testStaticRuntime(detach_script_0, args, args2); + testStaticRuntime(detach_script_1, args); + testStaticRuntime(detach_script_1, args, args2); +} + TEST(StaticRuntime, IndividualOps_Full) { auto dtype = at::ScalarType::Int; auto cpu = at::Device(DeviceType::CPU); diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index 616ad87..61a6554 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -356,6 +356,21 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator { }); REGISTER_NATIVE_OPERATOR_FUNCTOR( + aten::detach, + aten_detach, + [](Node* n) -> SROperator { + if (!n->matches( + torch::schema("aten::detach(Tensor(a) self) -> Tensor(a)"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& in0_t = p_node->Input(0).toTensor(); + p_node->Output(0) = at::native::alias(in0_t); + }; + }); + +REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::isinstance, prim_isinstance, [](Node* n) -> SROperator { -- 2.7.4