From 078b8004a62a51f75e1fbd8d08eea359af6bb1d7 Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Mon, 16 Aug 2021 14:50:27 -0700 Subject: [PATCH] [Static Runtime] Implement prim::TupleUnpack (#63243) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63243 Add `prim::TupleUnpack` native op to static runtime. Test Plan: Unit test: `buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest` Reviewed By: hlu1 Differential Revision: D30306955 fbshipit-source-id: 21923d6cbd5545c144ac051b3d48b37ec6e610cf --- benchmarks/static_runtime/test_scripts.h | 12 ++++++++++++ benchmarks/static_runtime/test_static_runtime.cc | 23 +++++++++++++++++++---- torch/csrc/jit/runtime/static/native_ops.cpp | 16 ++++++++++++++++ 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h index 18a10e9..ebf3047 100644 --- a/benchmarks/static_runtime/test_scripts.h +++ b/benchmarks/static_runtime/test_scripts.h @@ -669,3 +669,15 @@ const auto narrow_with_int_script = R"JIT( def forward(self, a: Tensor, dim: int, start: int, length: int): return a.narrow(dim, start, length).clone() )JIT"; + +const auto two_tuple_unpack_script = R"JIT( + def forward(self, tup: Tuple[Tensor, Tensor]): + a, b = tup + return (a, b) +)JIT"; + +const auto three_tuple_unpack_script = R"JIT( + def forward(self, tup: Tuple[Tensor, Tensor, Tensor]): + a, b, c = tup + return (a, b, c) +)JIT"; diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 6e0fd13..eac9b23 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -593,11 +593,9 @@ TEST(StaticRuntime, IndividualOps_Full) { auto dtype = at::ScalarType::Int; auto cpu = at::Device(DeviceType::CPU); c10::List size0{4, 5}; - std::vector args{ - size0, 4, dtype, at::kStrided, cpu, false}; + std::vector args{size0, 4, dtype, at::kStrided, cpu, false}; c10::List size1{5, 6}; - std::vector args2{ - size1, 5, dtype, at::kStrided, cpu, false}; + std::vector args2{size1, 5, dtype, at::kStrided, cpu, false}; testStaticRuntime(full_script, args); testStaticRuntime(full_script, args, args2); } @@ -1123,3 +1121,20 @@ TEST(StaticRuntime, IndividualOps_Narrow) { testStaticRuntime(narrow_with_int_script, args_a); testStaticRuntime(narrow_with_int_script, args_a, args_b); } + +TEST(StaticRuntime, InvidualOps_TupleUnpack) { + auto two_tup = c10::ivalue::Tuple::create({at::randn({1}), at::randn({1})}); + auto two_tup_large = + c10::ivalue::Tuple::create({at::randn({2, 2}), at::randn({2, 2})}); + + auto three_tup = c10::ivalue::Tuple::create( + {at::randn({1}), at::randn({1}), at::randn({1})}); + auto three_tup_large = c10::ivalue::Tuple::create( + {at::randn({2, 2}), at::randn({2, 2}), at::randn({2, 2})}); + + testStaticRuntime(two_tuple_unpack_script, {two_tup}); + testStaticRuntime(two_tuple_unpack_script, {two_tup}, {two_tup_large}); + + testStaticRuntime(three_tuple_unpack_script, {three_tup}); + testStaticRuntime(three_tuple_unpack_script, {three_tup}, {three_tup_large}); +} diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index 97c4373..ca4d1fe 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -55,6 +55,22 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( }); REGISTER_NATIVE_OPERATOR_FUNCTOR( + prim::TupleUnpack, + prim_TupleUnpack, + [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + const auto& elems = p_node->Input(0).toTuple()->elements(); + const size_t num_outputs = p_node->outputs().size(); + TORCH_CHECK( + num_outputs == elems.size(), + "Number of outputs must match number of tuple elements.") + for (size_t i = 0; i < num_outputs; ++i) { + p_node->Output(i) = elems[i]; + } + }; + }); + +REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::DictConstruct, prim_DictConstruct, [](Node* n) -> SROperator { -- 2.7.4