From 47a9e8ff320b638fcff0e25147e7f042740bf734 Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Thu, 19 Aug 2021 06:37:44 -0700 Subject: [PATCH] [Static Runtime] Support __getitem__ for lists (#63398) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63398 This change provides a native `__getitem__` implementation for lists to avoid overhead associated with falling back to the JIT interpreter. Test Plan: Unit tests: `buck test //caffe2/benchmarks/static_runtime:static_runtime_cpptest` Reviewed By: hlu1 Differential Revision: D30368464 fbshipit-source-id: e0e0971508cd5d9bcf6025606993dc24ecbf6764 --- benchmarks/static_runtime/test_scripts.h | 18 ++++++++++++--- benchmarks/static_runtime/test_static_runtime.cc | 21 +++++++++++++---- torch/csrc/jit/runtime/static/native_ops.cpp | 29 ++++++++++++++++-------- 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h index 6045a1c..8db8da2 100644 --- a/benchmarks/static_runtime/test_scripts.h +++ b/benchmarks/static_runtime/test_scripts.h @@ -632,24 +632,36 @@ const auto argmin_with_keep_dim_script = R"JIT( return torch.argmin(a, dim, True).clone() )JIT"; -const auto getitem_tensor_script = R"JIT( +const auto getitem_dict_tensor_script = R"JIT( def forward(self, key: Tensor): d = {key: 1} return d[key] )JIT"; -const auto getitem_int_script = R"JIT( +const auto getitem_dict_int_script = R"JIT( def forward(self, key: int): d = {key: 1} return d[key] )JIT"; -const auto getitem_str_script = R"JIT( +const auto getitem_dict_str_script = R"JIT( def forward(self, key: str): d = {key: 1} return d[key] )JIT"; +const auto getitem_list_int_script = R"JIT( + def forward(self, idx: int): + lst = [1, 2, 3] + return lst[idx] +)JIT"; + +const auto getitem_list_tensor_script = R"JIT( + def forward(self, tensor: Tensor, idx: int): + lst = [tensor, tensor] + return lst[idx] +)JIT"; + const auto transpose_script = R"JIT( def forward(self, a: Tensor, dim1: int, dim2: int): return torch.transpose(a, dim1, dim2).clone() diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 7af49d6..14d613f 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -1043,19 +1043,30 @@ TEST(StaticRuntime, IndividualOps_Argmin) { testStaticRuntime(argmin_with_keep_dim_script, args_a, args_b); } -TEST(StaticRuntime, IndividualOps_GetItem) { +TEST(StaticRuntime, IndividualOps_GetItem_Dict) { int int_key = 0; std::string str_key = "str"; // No need to test these multiple times, args are not tensors - testStaticRuntime(getitem_int_script, {int_key}); - testStaticRuntime(getitem_str_script, {str_key}); + testStaticRuntime(getitem_dict_int_script, {int_key}); + testStaticRuntime(getitem_dict_str_script, {str_key}); auto a = torch::tensor({1}); auto b = torch::tensor({1, 1}); - testStaticRuntime(getitem_tensor_script, {a}); - testStaticRuntime(getitem_tensor_script, {a}, {b}); + testStaticRuntime(getitem_dict_tensor_script, {a}); + testStaticRuntime(getitem_dict_tensor_script, {a}, {b}); +} + +TEST(StaticRuntime, IndividualOps_GetItem_List) { + testStaticRuntime(getitem_list_int_script, {1}); + testStaticRuntime(getitem_list_int_script, {-1}); + + auto a = torch::tensor({1}); + auto b = torch::tensor({1, 1}); + + testStaticRuntime(getitem_list_tensor_script, {a, 1}); + testStaticRuntime(getitem_list_tensor_script, {a, 1}, {b, -1}); } TEST(StaticRuntime, IndividualOps_Transpose) { diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index d84b1cd..616ad87 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include namespace torch { @@ -100,17 +101,25 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( if (n->inputs().size() != 2) { return nullptr; } - // TODO: make __getitem__ work for other container types - if (n->input(0)->type()->castRaw() == nullptr) { - return nullptr; + + if (n->input(0)->type()->castRaw()) { + return [](ProcessedNode* p_node) { + auto dict = p_node->Input(0).toGenericDict(); + auto key = p_node->Input(1); + auto value = dict.find(key); + TORCH_CHECK(value != dict.end(), "Key not in dict: ", key); + p_node->Output(0) = value->value(); + }; + } else if (n->input(0)->type()->castRaw()) { + return [](ProcessedNode* p_node) { + auto list = p_node->Input(0).toList(); + auto idx = p_node->Input(1).toInt(); + p_node->Output(0) = getItem(list, idx); + }; } - return [](ProcessedNode* p_node) { - auto dict = p_node->Input(0).toGenericDict(); - auto key = p_node->Input(1); - auto value = dict.find(key); - TORCH_CHECK(value != dict.end(), "Key not in dict: ", key); - p_node->Output(0) = value->value(); - }; + + // TODO(T98581096): make __getitem__ work for other container types + return nullptr; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( -- 2.7.4