[Static Runtime] Implement prim::TupleUnpack (#63243)
authorMike Iovine <mikeiovine@fb.com>
Mon, 16 Aug 2021 21:50:27 +0000 (14:50 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 21:56:30 +0000 (14:56 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63243

Add `prim::TupleUnpack` native op to static runtime.

Test Plan: Unit test: `buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest`

Reviewed By: hlu1

Differential Revision: D30306955

fbshipit-source-id: 21923d6cbd5545c144ac051b3d48b37ec6e610cf

benchmarks/static_runtime/test_scripts.h
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/native_ops.cpp

index 18a10e9..ebf3047 100644 (file)
@@ -669,3 +669,15 @@ const auto narrow_with_int_script = R"JIT(
   def forward(self, a: Tensor, dim: int, start: int, length: int):
       return a.narrow(dim, start, length).clone()
 )JIT";
+
+const auto two_tuple_unpack_script = R"JIT(
+  def forward(self, tup: Tuple[Tensor, Tensor]):
+      a, b = tup
+      return (a, b)
+)JIT";
+
+const auto three_tuple_unpack_script = R"JIT(
+  def forward(self, tup: Tuple[Tensor, Tensor, Tensor]):
+      a, b, c = tup
+      return (a, b, c)
+)JIT";
index 6e0fd13..eac9b23 100644 (file)
@@ -593,11 +593,9 @@ TEST(StaticRuntime, IndividualOps_Full) {
   auto dtype = at::ScalarType::Int;
   auto cpu = at::Device(DeviceType::CPU);
   c10::List<int64_t> size0{4, 5};
-  std::vector<IValue> args{
-    size0, 4, dtype, at::kStrided, cpu, false};
+  std::vector<IValue> args{size0, 4, dtype, at::kStrided, cpu, false};
   c10::List<int64_t> size1{5, 6};
-  std::vector<IValue> args2{
-    size1, 5, dtype, at::kStrided, cpu, false};
+  std::vector<IValue> args2{size1, 5, dtype, at::kStrided, cpu, false};
   testStaticRuntime(full_script, args);
   testStaticRuntime(full_script, args, args2);
 }
@@ -1123,3 +1121,20 @@ TEST(StaticRuntime, IndividualOps_Narrow) {
   testStaticRuntime(narrow_with_int_script, args_a);
   testStaticRuntime(narrow_with_int_script, args_a, args_b);
 }
+
+TEST(StaticRuntime, InvidualOps_TupleUnpack) {
+  auto two_tup = c10::ivalue::Tuple::create({at::randn({1}), at::randn({1})});
+  auto two_tup_large =
+      c10::ivalue::Tuple::create({at::randn({2, 2}), at::randn({2, 2})});
+
+  auto three_tup = c10::ivalue::Tuple::create(
+      {at::randn({1}), at::randn({1}), at::randn({1})});
+  auto three_tup_large = c10::ivalue::Tuple::create(
+      {at::randn({2, 2}), at::randn({2, 2}), at::randn({2, 2})});
+
+  testStaticRuntime(two_tuple_unpack_script, {two_tup});
+  testStaticRuntime(two_tuple_unpack_script, {two_tup}, {two_tup_large});
+
+  testStaticRuntime(three_tuple_unpack_script, {three_tup});
+  testStaticRuntime(three_tuple_unpack_script, {three_tup}, {three_tup_large});
+}
index 97c4373..ca4d1fe 100644 (file)
@@ -55,6 +55,22 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
     });
 
 REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    prim::TupleUnpack,
+    prim_TupleUnpack,
+    [](Node* n) -> SROperator {
+      return [](ProcessedNode* p_node) {
+        const auto& elems = p_node->Input(0).toTuple()->elements();
+        const size_t num_outputs = p_node->outputs().size();
+        TORCH_CHECK(
+            num_outputs == elems.size(),
+            "Number of outputs must match number of tuple elements.")
+        for (size_t i = 0; i < num_outputs; ++i) {
+          p_node->Output(i) = elems[i];
+        }
+      };
+    });
+
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
     prim::DictConstruct,
     prim_DictConstruct,
     [](Node* n) -> SROperator {