add clear functionality to list (#17050)
authorMichael Kösel <michaelkoesel@gmx.de>
Thu, 14 Feb 2019 21:42:27 +0000 (13:42 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 14 Feb 2019 22:38:11 +0000 (14:38 -0800)
Summary:
Add clear functionality to list. See #16662

```python
import torch

torch.jit.script
def foo():
    a = [1, 2, 3, 4]
a.clear()

    return a
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17050

Differential Revision: D14071799

Pulled By: driazati

fbshipit-source-id: 305551c16f7db127c43de0ad5885d9f10678e101

test/test_jit.py
torch/csrc/jit/register_prim_ops.cpp

index 53c09aa..c2b2238 100644 (file)
@@ -3970,6 +3970,24 @@ a")
 
         self.assertEqual(foo(), [1, 2, 3, 4])
 
+    @unittest.skipIf(sys.version_info < (3, 3), "clear not supported in version < 3.3")
+    def test_mutable_list_clear_empty(self):
+        def test_clear_empty():
+            a = torch.jit.annotate(List[int], [])
+            a.clear()
+
+            return len(a) == 0
+        self.checkScript(test_clear_empty, ())
+
+    @unittest.skipIf(sys.version_info < (3, 3), "clear not supported in version < 3.3")
+    def test_mutable_list_clear(self):
+        def test_clear():
+            a = [1, 2, 3, 4]
+            a.clear()
+
+            return len(a) == 0
+        self.checkScript(test_clear, ())
+
     def test_func_call(self):
         script = '''
         def add(a, b):
index 98863f5..fe96ba6 100644 (file)
@@ -1003,6 +1003,17 @@ Operation listAppend(const Node* node) {
   };
 }
 
+template <typename TList>
+int listClear(Stack& stack) {
+  TList a;
+  pop(stack, a);
+
+  a->elements().clear();
+  push(stack, a);
+
+  return 0;
+}
+
 template <typename T>
 Operation listSelect(const Node* node) {
   return [=](Stack& stack) {
@@ -1289,7 +1300,10 @@ RegisterOperators reg2({
       Operator(                                                             \
           "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type      \
           " el) -> " decl_type "[](a!)",                                    \
-          listSetItem<Shared<c_type>, c_type::ElemType>)
+          listSetItem<Shared<c_type>, c_type::ElemType>),                   \
+      Operator(                                                             \
+          "aten::clear( " decl_type "[](a!) self) -> ()",                   \
+          listClear<Shared<c_type>>)
 
     CREATE_MUTABLE_LIST_OPS("Tensor", TensorList),
 
@@ -1305,7 +1319,10 @@ RegisterOperators reg2({
       Operator(                                                        \
           "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \
           " el) -> " decl_type "[](a!)",                               \
-          listSetItem<Shared<c_type>, c_type::ElemType>)
+          listSetItem<Shared<c_type>, c_type::ElemType>),              \
+      Operator(                                                        \
+          "aten::clear( " decl_type "[](a!) self) -> ()",              \
+          listClear<Shared<c_type>>)
 
     CREATE_IMMUTABLE_LIST_OPS("int", IntList),
     CREATE_IMMUTABLE_LIST_OPS("float", DoubleList),