From a91e056f2ac6fe4fd42d4936ade0694d1e0521a5 Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Wed, 20 Feb 2019 09:11:11 -0800 Subject: [PATCH] add list methods: copy,extend (#17092) Summary: This PR adds the following methods to python's list. * copy * extend and tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/17092 Differential Revision: D14141817 Pulled By: Krovatkin fbshipit-source-id: c89207f0f25f3d1d4ad903ee634745615d61d576 --- test/test_jit.py | 42 +++++++++++++++++++++++++ torch/csrc/jit/register_prim_ops.cpp | 59 +++++++++++++++++++++++++++++++----- 2 files changed, 94 insertions(+), 7 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index e7f0235..f8c4b45 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4112,6 +4112,48 @@ a") return len(a) == 0 self.checkScript(test_clear, ()) + def test_extend_list_mutable(self): + @torch.jit.script + def extend_list(a, b): + # type: (List[Tensor], List[Tensor]) -> List[Tensor] + + a.extend(b) + return a + + for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]: + for r in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]: + self.assertEqual(extend_list(l, r), l + r) + + def test_extend_list_immutable(self): + @torch.jit.script + def extend_list(a, b): + # type: (List[int], List[int]) -> List[int] + + a.extend(b) + return a + + for l in [[], [1], [1, 2, 3]]: + for r in [[], [1], [1, 2, 3]]: + self.assertEqual(extend_list(l, r), l + r) + + def test_copy_list_mutable(self): + @torch.jit.script + def copy_list(a): + # type: (List[Tensor]) -> List[Tensor] + return a.copy() + + for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]: + self.assertEqual(copy_list(l), l) + + def test_copy_list_immutable(self): + @torch.jit.script + def copy_list(a): + # type: (List[int]) -> List[int] + return a.copy() + + for l in [[], [1], [1, 2, 3]]: + self.assertEqual(copy_list(l), l) + def test_func_call(self): script = ''' def add(a, b): diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 222156c..f12af94 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -1043,6 +1043,32 @@ int listClear(Stack& stack) { return 0; } +template +Operation listExtend(const Node* node) { + return [](Stack& stack) { + TList a; + TList b; + pop(stack, a, b); + + auto& vec_a = a->elements(); + const auto& vec_b = b->elements(); + vec_a.insert(vec_a.end(), vec_b.cbegin(), vec_b.cend()); + return 0; + }; +} + +template +Operation listCopy(const Node* node) { + return [](Stack& stack) { + TList list; + pop(stack, list); + + const auto& vec = list->elements(); + auto out = vec; + push(stack, out); + return 0; + }; +} template Operation listSelect(const Node* node) { @@ -1328,6 +1354,15 @@ RegisterOperators reg2({ "(c) el) -> " decl_type "[](a!)", \ listAppend, c_type::ElemType>), \ Operator( \ + "aten::extend(" decl_type "[](a!) self, " decl_type \ + " [] other) -> ()", \ + listExtend>), \ + Operator( \ + "aten::copy(" decl_type \ + "[](a) self)" \ + " -> " decl_type "[]", \ + listCopy>), \ + Operator( \ "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \ " el) -> " decl_type "[](a!)", \ listSetItem, c_type::ElemType>), \ @@ -1335,10 +1370,10 @@ RegisterOperators reg2({ "aten::clear( " decl_type "[](a!) self) -> ()", \ listClear>), \ Operator( \ - "aten::pop(" decl_type "[](a!) self, int idx=-1) \ - -> " decl_type "(*)", \ - listPop>) - + "aten::pop(" decl_type \ + "[](a!) self, int idx=-1) \ + -> " decl_type "(*)", \ + listPop>) CREATE_MUTABLE_LIST_OPS("Tensor", TensorList), @@ -1352,6 +1387,15 @@ RegisterOperators reg2({ " el) -> " decl_type "[](a!)", \ listAppend, c_type::ElemType>), \ Operator( \ + "aten::extend(" decl_type "[](a!) self, " decl_type \ + " [] other) -> ()", \ + listExtend>), \ + Operator( \ + "aten::copy(" decl_type \ + "[](a) self)" \ + " -> " decl_type "[]", \ + listCopy>), \ + Operator( \ "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \ " el) -> " decl_type "[](a!)", \ listSetItem, c_type::ElemType>), \ @@ -1359,9 +1403,10 @@ RegisterOperators reg2({ "aten::clear( " decl_type "[](a!) self) -> ()", \ listClear>), \ Operator( \ - "aten::pop(" decl_type "[](a!) self, int idx=-1) \ - -> " decl_type, listPop>) - + "aten::pop(" decl_type \ + "[](a!) self, int idx=-1) \ + -> " decl_type, \ + listPop>) CREATE_IMMUTABLE_LIST_OPS("int", IntList), CREATE_IMMUTABLE_LIST_OPS("float", DoubleList), -- 2.7.4