From ff2053dfa1d9d5e39215a297dfaf11b91f9fbb6e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Michael=20K=C3=B6sel?= Date: Thu, 14 Feb 2019 13:42:27 -0800 Subject: [PATCH] add clear functionality to list (#17050) Summary: Add clear functionality to list. See #16662 ```python import torch torch.jit.script def foo(): a = [1, 2, 3, 4] a.clear() return a ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/17050 Differential Revision: D14071799 Pulled By: driazati fbshipit-source-id: 305551c16f7db127c43de0ad5885d9f10678e101 --- test/test_jit.py | 18 ++++++++++++++++++ torch/csrc/jit/register_prim_ops.cpp | 21 +++++++++++++++++++-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 53c09aa..c2b2238 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3970,6 +3970,24 @@ a") self.assertEqual(foo(), [1, 2, 3, 4]) + @unittest.skipIf(sys.version_info < (3, 3), "clear not supported in version < 3.3") + def test_mutable_list_clear_empty(self): + def test_clear_empty(): + a = torch.jit.annotate(List[int], []) + a.clear() + + return len(a) == 0 + self.checkScript(test_clear_empty, ()) + + @unittest.skipIf(sys.version_info < (3, 3), "clear not supported in version < 3.3") + def test_mutable_list_clear(self): + def test_clear(): + a = [1, 2, 3, 4] + a.clear() + + return len(a) == 0 + self.checkScript(test_clear, ()) + def test_func_call(self): script = ''' def add(a, b): diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 98863f5..fe96ba6 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -1003,6 +1003,17 @@ Operation listAppend(const Node* node) { }; } +template +int listClear(Stack& stack) { + TList a; + pop(stack, a); + + a->elements().clear(); + push(stack, a); + + return 0; +} + template Operation listSelect(const Node* node) { return [=](Stack& stack) { @@ -1289,7 +1300,10 @@ RegisterOperators reg2({ Operator( \ "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \ " el) -> " decl_type "[](a!)", \ - listSetItem, c_type::ElemType>) + listSetItem, c_type::ElemType>), \ + Operator( \ + "aten::clear( " decl_type "[](a!) self) -> ()", \ + listClear>) CREATE_MUTABLE_LIST_OPS("Tensor", TensorList), @@ -1305,7 +1319,10 @@ RegisterOperators reg2({ Operator( \ "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \ " el) -> " decl_type "[](a!)", \ - listSetItem, c_type::ElemType>) + listSetItem, c_type::ElemType>), \ + Operator( \ + "aten::clear( " decl_type "[](a!) self) -> ()", \ + listClear>) CREATE_IMMUTABLE_LIST_OPS("int", IntList), CREATE_IMMUTABLE_LIST_OPS("float", DoubleList), -- 2.7.4