[Static Runtime] Add native op for aten::detach (#63625)
authorDon Jang <djang@fb.com>
Fri, 20 Aug 2021 07:43:40 +0000 (00:43 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 20 Aug 2021 07:46:27 +0000 (00:46 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63625

This change adds a static runtime's native op implementation for `aten::detach` op.

See the standard  `aten::detach`'s implementation (https://codebrowser.bddppq.com/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp.html#_ZN2at6native6detachERKNS_6TensorE ) for comparison.

Test Plan:
- Added `StaticRuntime.IndividualOps_Detach`.

- Observed

```
V0819 18:55:33.181188 3092034 impl.cpp:1398] Switch to native impl for node: %a.1 : Tensor = aten::detach(%input.1)
```

Reviewed By: hlu1

Differential Revision: D30443187

fbshipit-source-id: d6e0eadb1b817e0a126c4fc97526abc276ee8a17

benchmarks/static_runtime/test_scripts.h
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/native_ops.cpp

index 8db8da2..9946c7a 100644 (file)
@@ -286,6 +286,18 @@ const auto to_script_4 = R"JIT(
       return (c)
 )JIT";
 
+const auto detach_script_0 = R"JIT(
+  def forward(self, input: Tensor):
+      a = input.detach()
+      return input is a
+)JIT";
+
+const auto detach_script_1 = R"JIT(
+  def forward(self, input: Tensor):
+      a = input.detach()
+      return a.clone()
+)JIT";
+
 const std::string embedding_bag_default = R"JIT(
   def forward(self, a: Tensor, b: Tensor, c: Tensor):
       return torch.embedding_bag(a, b, c)
index 14d613f..ec703ef 100644 (file)
@@ -589,6 +589,17 @@ TEST(StaticRuntime, IndividualOps_to) {
   test_to(at::ScalarType::Half, false, true, c10::MemoryFormat::ChannelsLast);
 }
 
+TEST(StaticRuntime, IndividualOps_Detach) {
+  auto a = at::randn({4, 3, 1, 2});
+  auto b = at::randn({3, 2, 2});
+  std::vector<IValue> args{a};
+  std::vector<IValue> args2{b};
+  testStaticRuntime(detach_script_0, args);
+  testStaticRuntime(detach_script_0, args, args2);
+  testStaticRuntime(detach_script_1, args);
+  testStaticRuntime(detach_script_1, args, args2);
+}
+
 TEST(StaticRuntime, IndividualOps_Full) {
   auto dtype = at::ScalarType::Int;
   auto cpu = at::Device(DeviceType::CPU);
index 616ad87..61a6554 100644 (file)
@@ -356,6 +356,21 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator {
 });
 
 REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    aten::detach,
+    aten_detach,
+    [](Node* n) -> SROperator {
+      if (!n->matches(
+              torch::schema("aten::detach(Tensor(a) self) -> Tensor(a)"))) {
+        LogAndDumpSchema(n);
+        return nullptr;
+      }
+      return [](ProcessedNode* p_node) {
+        const auto& in0_t = p_node->Input(0).toTensor();
+        p_node->Output(0) = at::native::alias(in0_t);
+      };
+    });
+
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
     prim::isinstance,
     prim_isinstance,
     [](Node* n) -> SROperator {